/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp Source File
fused_moegemm_pipeline_flatmm_uk.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
12 /*
13 This pipeline deal with a gemm(actually 2 gemm) with one very small(token), one very big(weight)
14 we need to design the pipeline such that all waves along gemm-N dim (gemm-m only 1 wave)
15 
16  <----- gemm-N ------>
17  +----+----+----+----+
18  | w0 | w1 | w2 | w3 | gemm-m
19  +----+----+----+----+
20 */
21 template <typename Problem_, typename Policy_ = FusedMoeGemmPipelineFlatmmPolicy>
23 {
26 
27  using BlockShape = typename Problem::BlockShape; // this is FusedMoeGemmShape
28 
29  using ADataType = typename Problem::ADataType;
30  using GDataType = typename Problem::GDataType;
31  using DDataType = typename Problem::DDataType;
32  using AccDataType = typename Problem::AccDataType;
33  using ODataType = typename Problem::ODataType;
34  using AScaleDataType = typename Problem::AScaleDataType;
35  using GScaleDataType = typename Problem::GScaleDataType;
36  using DScaleDataType = typename Problem::DScaleDataType;
37  using YSmoothScaleDataType = typename Problem::YSmoothScaleDataType;
38  using TopkWeightDataType = typename Problem::TopkWeightDataType;
39  using IndexDataType = typename Problem::IndexDataType;
40  using YDataType = typename Problem::YDataType;
41 
42  using Traits = typename Problem::Traits;
43 
44  static constexpr bool IsGateOnly = Traits::IsGateOnly;
45  static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
46  static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
47  static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
48 
49  static constexpr index_t kAlignmentA = Policy::template GetAlignment_A<Problem>();
50  static constexpr index_t kAlignmentG = Policy::template GetAlignment_G<Problem>();
51  static constexpr index_t kAlignmentD = Policy::template GetAlignment_D<Problem>();
52  static constexpr index_t kAlignmentO = Policy::template GetAlignment_O<Problem>();
53 
58 
59  static constexpr index_t kBlockPerCu = []() {
60  if constexpr(Problem::kBlockPerCu != -1)
61  return Problem::kBlockPerCu;
62  else
63  {
64  // minimize occupancy
65  return 2;
66  }
67  }();
68 
69  static constexpr const char* name = "flatmm_uk";
70 
72  {
73 #if 1
74  constexpr index_t smem_0 = Policy::template GetUK_0<Problem>().GetSmemSize();
75  constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize();
76  constexpr index_t smem_bridge =
77  BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType);
78  return max(smem_0 + smem_1, smem_bridge);
79 #else
80  // keep it here purposely in case we have regression
81  return 65536;
82 #endif
83  }
84 
85  // this is the thread-offset along row/col
87  {
88  constexpr auto a_dist = Policy::template MakeGlobalTileDistribution_A<Problem>();
89  const auto a_coord = a_dist.calculate_index();
90  return a_coord;
91  }
92 
93  // this is the thread-offset along row/col
95  {
96  constexpr auto o_dist = Policy::template MakeOGlobalTileDistribution<Problem>();
97  const auto o_coord = o_dist.calculate_index();
98  return o_coord;
99  }
100 
102  {
103  constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA;
104  constexpr index_t MLans = BlockShape::BlockSize / KLans;
105  constexpr index_t MRepeat = BlockShape::Block_M0 / MLans;
106 
107  return MRepeat;
108  }
109 
110  // TODO: properlly support scatter/gather
112  {
113  constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA;
114  constexpr index_t MLans = BlockShape::BlockSize / KLans;
115  constexpr index_t MRepeat = BlockShape::Block_M0 / MLans;
116 
117  auto base_coord = threadIdx.x / KLans + base_offset;
118 
120  static_for<0, MRepeat, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLans; });
121 
122  return coords;
123  }
124 
125  template <typename ROW_COORDS>
126  CK_TILE_DEVICE auto GetRowID(const ROW_COORDS coords, const IndexDataType* sorted_token_ids_ptr)
127  {
128  constexpr index_t n_size = coords.size();
129 
130  array<index_t, n_size> row_ids;
131  static_for<0, n_size, 1>{}([&](auto i) {
132  row_ids.at(i) = sorted_token_ids_ptr[coords[i]]; // base_coord + i * MLans;
133 #if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
134  row_ids.at(i) &= 0xffffff;
135 #endif
136  });
137 
138  return row_ids;
139  }
140 
141  template <typename ROW_COORDS>
142  CK_TILE_DEVICE auto GetWeightScale(const ROW_COORDS coords,
143  const TopkWeightDataType* sorted_weight_ptr)
144  {
145  constexpr index_t n_size = coords.size();
146 
148  static_for<0, n_size, 1>{}([&](auto i) {
149  w.at(i) = sorted_weight_ptr[coords[i]]; // base_coord + i * MLans;
150  });
151 
152  return w;
153  }
154 
155  // TODO: this row id is before shuffle atomic, need use acc distribution
157  {
158  constexpr index_t MLanes = BlockShape::Warp_M1;
159  constexpr index_t Repeat_M = BlockShape::Repeat_M1;
160 
161  auto base_coord = threadIdx.x % MLanes + base_offset;
162 
164  static_for<0, Repeat_M, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLanes; });
165 
166  return coords;
167  }
168 
169  template <typename Karg>
170  CK_TILE_DEVICE auto operator()(const Karg& kargs,
171  CK_TILE_LDS_ADDR void* smem,
172  index_t sorted_tile_id,
173  index_t intermediate_tile_id)
174  {
175  constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
176  ck_tile::index_t shared_intermediate_size_0 =
177  kargs.intermediate_size * hidden_radio_0; // total gate+up
178  ck_tile::index_t shared_intermediate_size_1 = kargs.intermediate_size;
179 
180  // after weight shuffling, gate-only: [nr0, kr0, w0], gate+up: [nr0_gate + nr0_up, kr0, w0]
181 
182  index_t nr_0 = shared_intermediate_size_0 / BlockShape::Warp_N0; // divide N in W
183  index_t kr_0 = kargs.hidden_size / BlockShape::Warp_K0; // divide K in W
184  index_t nr_1 = kargs.hidden_size / BlockShape::Warp_N1;
185  index_t kr_1 = shared_intermediate_size_1 / BlockShape::Warp_K1;
186 
187  const IndexDataType expert_id = __builtin_amdgcn_readfirstlane(
188  reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
189  index_t expert_stride_0 = shared_intermediate_size_0 * kargs.hidden_size;
190  index_t expert_stride_1 = shared_intermediate_size_1 * kargs.hidden_size;
191 
192  // nr*kr*w
193  index_t interm_idx_nr0 = __builtin_amdgcn_readfirstlane(
194  intermediate_tile_id *
195  BlockShape::Block_Nr0); // intermediate_tile_id * Block_N / (N in W)
196 
197  index_t interm_idx_kr1 = __builtin_amdgcn_readfirstlane(
198  intermediate_tile_id *
199  BlockShape::Block_Kr1); // intermediate_tile_id * Block_N / (N in W)
200 
201  auto row_coords_a = GetRowCoords_A(sorted_tile_id * BlockShape::Block_M0);
202  auto row_ids_a = GetRowID(
203  row_coords_a, reinterpret_cast<const IndexDataType*>(kargs.sorted_token_ids_ptr));
204  auto a_coords = generate_tuple(
205  [&](auto i) {
206  return row_ids_a[i] * kargs.stride_token +
207  threadIdx.x % (BlockShape::Block_K0 / kAlignmentA) * kAlignmentA;
208  },
209  number<row_ids_a.size()>{});
210 
211  auto a_res =
212  make_wave_buffer_resource(reinterpret_cast<const ADataType*>(kargs.a_ptr),
213  kargs.num_tokens * kargs.stride_token * sizeof(ADataType));
214 
215  auto make_gu_win = [&](const auto* ptr_) {
216  auto view_ = make_naive_tensor_view<address_space_enum::global>(
217  ptr_,
219  make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
221  number<1>{});
222 
223  auto win_ = make_tile_window_linear_raw(
224  view_,
228  {0, 0, 0},
229  Policy::template MakeGlobalTileDistribution_G<Problem>(),
231  return win_;
232  };
233 
234  const GDataType* gu_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
235  static_cast<long_index_t>(expert_id) * expert_stride_0 +
236  interm_idx_nr0 * kr_0 * BlockShape::Block_W0;
237 
238  auto g_win = make_gu_win(gu_ptr);
239  // Note: gu swizzled, [nr_u+nr_g, kr, w], hence base offset to up is just interm*hidden
240  auto u_win = make_gu_win(gu_ptr + kargs.intermediate_size * kargs.hidden_size);
241 
242  auto g_res = g_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
243  auto u_res = u_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
244  auto g_coords = generate_tuple([&](auto i) { return g_win.cached_coords_[i].get_offset(); },
245  number<decltype(g_win)::NumAccess_NonLinear>{});
246 
247  const auto d_win = [&]() {
248  const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
249  static_cast<long_index_t>(expert_id) * expert_stride_1 +
250  interm_idx_kr1 * BlockShape::Block_W1;
251  // note interm_idx_nr0 is along the gemm-k dim of 2nd gemm
252 
253  const auto d_view_ = make_naive_tensor_view<address_space_enum::global>(
254  d_ptr,
255  make_tuple(nr_1, kr_1, BlockShape::Block_W1),
256  make_tuple(kr_1 * BlockShape::Block_W1, BlockShape::Block_W1, 1),
258  number<1>{});
259 
260  const auto d_window_ = make_tile_window_linear_raw(
261  d_view_,
265  {0, 0, 0},
266  Policy::template MakeGlobalTileDistribution_D<Problem>(),
268  return d_window_;
269  }();
270  auto d_res = d_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
271 
272  // TODO: load D order is N0.K0...127, N64.K0...127, N0.K128...255, N64.K128...255
273  // block-k=512, block-n=128
274  // wg |<----- W_ ----->|
275  // Nr(2)*Nw(4)* Kr *Kr0(4)*Kr1(4) * [Kl(4)*Nl(16)*Kv(8)]->one issue
276  // y p y y p p y
277  // 1 2 0(imm)
278  auto d_coords = [&]() {
279  constexpr index_t Nr_ = 2;
280  constexpr index_t Nw_ = 4;
281  constexpr index_t Kr0_ = 4;
282  constexpr index_t Kr1_ = 4;
283  constexpr index_t Kl_ = 4;
284  constexpr index_t Nl_ = 16;
285  constexpr index_t Kv_ = 8;
286  constexpr index_t W_ = Kl_ * Nl_ * Kv_;
287  constexpr index_t num_offsets_ = Nr_ * Kr0_;
288  index_t base_os_ = (threadIdx.x % 64) * Kv_ + (threadIdx.x / 64) *
289  shared_intermediate_size_1 *
290  Nl_; // Kr0_ * Kr1_ * W_;
291  return generate_tuple(
292  [&](auto i) {
293  constexpr auto i_nr_ = number<i % Nr_>{};
294  constexpr auto i_kr0_ = number<i / Nr_>{};
295 
296  return i_nr_ * shared_intermediate_size_1 * Nw_ * Nl_ + i_kr0_ * Kr1_ * W_ +
297  base_os_;
298  },
300  }();
301 
302  auto o_coords = generate_tuple(
303  [&](auto i) {
304  return row_ids_a[i] * kargs.stride_token +
305  threadIdx.x % (BlockShape::Block_N1 / kAlignmentO) * kAlignmentO;
306  },
307  number<row_ids_a.size()>{});
308 
309  auto o_flags =
310  generate_tuple([&](auto i) { return cmp_lt_to_exec(row_ids_a[i], kargs.num_tokens); },
311  number<row_ids_a.size()>{});
312 
313  auto bridge_sst_win = [&]() {
314  constexpr auto desc_ = Policy::template MakeBridgeLdsStoreForUKDesc<Problem>();
315  constexpr auto dist_ = Policy::template GetUK_0<Problem>().MakeCBlockDist();
316  return make_tile_window_linear(make_tensor_view<address_space_enum::lds>(
317  reinterpret_cast<YDataType*>(smem), desc_),
318  desc_.get_lengths(),
319  {0, 0},
320  dist_);
321  }();
322 
323  auto o_res =
324  make_wave_buffer_resource(reinterpret_cast<const ODataType*>(kargs.o_ptr),
325  kargs.num_tokens * kargs.stride_token * sizeof(ODataType));
326  auto row_coords_o = GetRowCoords_O(sorted_tile_id * BlockShape::Block_M0);
327  auto w_scale = GetWeightScale(
328  row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr));
329 
330  auto uk_0 = Policy::template GetUK_0<Problem>();
331 
332  auto y_pre = [&]() {
333  if constexpr(IsGateOnly)
334  {
335  auto acc_0 = uk_0(a_res,
336  a_coords,
337  g_res,
338  g_coords,
339  smem,
340  kargs.hidden_size,
341  BlockShape::Block_K0, // tile offset for B matrix each unroll
342  BlockShape::Block_Kr0 *
343  BlockShape::Block_W0); // tile offset for B matrix each unroll
344 
345  sweep_tile(
346  acc_0,
347  [&](auto idx0, auto idx1) {
348  fp32x2_t v_{acc_0(idx0), acc_0(idx1)};
349  typename Problem::GateActivation{}(v_, v_);
350  acc_0(idx0) = v_.x;
351  acc_0(idx1) = v_.y;
352  },
353  sequence<1, 2>{});
354 
355  return cast_tile<YDataType>(acc_0);
356  }
357  else
358  {
359  uint32x8_t gu_res;
360  gu_res[0] = g_res[0];
361  gu_res[1] = g_res[1];
362  gu_res[2] = g_res[2];
363  gu_res[3] = g_res[3];
364  gu_res[4] = u_res[0];
365  gu_res[5] = u_res[1];
366  gu_res[6] = u_res[2];
367  gu_res[7] = u_res[3];
368 
369  auto acc_0 = uk_0(a_res,
370  a_coords,
371  gu_res,
372  g_coords,
373  smem,
374  kargs.hidden_size,
375  BlockShape::Block_K0, // tile offset for B matrix each unroll
376  BlockShape::Block_Kr0 * BlockShape::Block_W0,
377  bool_constant<true>{}); // tile offset for B matrix each unroll
378 
379  sweep_tile(
380  acc_0.at(number<0>{}),
381  [&](auto idx0, auto idx1) {
382  fp32x2_t v_{acc_0.at(number<0>{})(idx0), acc_0.at(number<0>{})(idx1)};
383  typename Problem::GateActivation{}(v_, v_);
384  acc_0.at(number<0>{})(idx0) = v_.x;
385  acc_0.at(number<0>{})(idx1) = v_.y;
386  },
387  sequence<1, 2>{});
388 
389  auto reduced_acc_0 =
390  tile_elementwise_in([&](const auto& a_, const auto& b_) { return a_ * b_; },
391  acc_0.at(number<0>{}),
392  acc_0.at(number<1>{}));
393 
394  return cast_tile<YDataType>(reduced_acc_0);
395  }
396  }();
397 
398  block_sync_lds();
399 
400  store_tile(bridge_sst_win, y_pre);
401  block_sync_lds();
402 
403  auto uk_1 = Policy::template GetUK_1<Problem>();
404  uk_1(d_res,
405  d_coords,
406  o_res,
407  o_coords,
408  o_flags,
409  smem,
410  kargs.hidden_size, // total n number
411  w_scale,
412  BlockShape::Block_Nr1 * kr_1 * BlockShape::Block_W1, // along N
413  BlockShape::Block_N1); // along N
414  }
415 };
416 
417 } // namespace ck_tile
CK_TILE_DEVICE void block_sync_lds()
Definition: arch.hpp:190
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_LDS_ADDR
Definition: config.hpp:58
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE auto cmp_lt_to_exec(const X &x, const Y &y)
Definition: utility.hpp:133
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition: tile_elementwise.hpp:40
float fp32x2_t
Definition: pk_fp4.hpp:22
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
uint32_t uint32x8_t
Definition: vector_type.hpp:154
CK_TILE_DEVICE auto make_tile_window_linear_raw(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition: tile_window_linear.hpp:1029
int64_t long_index_t
Definition: integer.hpp:11
constexpr CK_TILE_DEVICE auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition: tile_window_linear.hpp:993
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
constexpr CK_TILE_HOST_DEVICE void sweep_tile(const F &f, UnpacksPerXDim={})
Definition: sweep_tile.hpp:231
CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void *ptr, uint32_t size=0xffffffff)
Definition: amd_buffer_addressing.hpp:40
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:23
typename Problem::IndexDataType IndexDataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:39
typename Problem::ADataType ADataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:29
static constexpr index_t kAlignmentO
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:52
constexpr CK_TILE_DEVICE auto GetNumRowCoords_A()
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:101
typename Problem::DDataType DDataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:31
static constexpr bool PadIntermediateSize
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:47
CK_TILE_DEVICE auto operator()(const Karg &kargs, CK_TILE_LDS_ADDR void *smem, index_t sorted_tile_id, index_t intermediate_tile_id)
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:170
static constexpr const char * name
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:69
CK_TILE_DEVICE auto GetRowCoords_A(index_t base_offset)
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:111
static CK_TILE_HOST_DEVICE auto GetOCoord()
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:94
typename Problem::DScaleDataType DScaleDataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:36
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:71
static constexpr index_t kAlignmentD
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:51
static constexpr index_t kAlignmentA
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:49
static constexpr index_t GLD_B
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:56
static constexpr index_t kBlockPerCu
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:59
typename Problem::GScaleDataType GScaleDataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:35
typename Problem::TopkWeightDataType TopkWeightDataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:38
static constexpr bool IsGateOnly
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:44
typename Problem::ODataType ODataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:33
typename Problem::BlockShape BlockShape
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:27
static constexpr index_t kAlignmentG
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:50
static constexpr bool PadHiddenSize
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:46
static constexpr index_t GST_O
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:57
static constexpr index_t SLD_A
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:54
CK_TILE_DEVICE auto GetWeightScale(const ROW_COORDS coords, const TopkWeightDataType *sorted_weight_ptr)
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:142
static constexpr index_t GLD_A
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:55
remove_cvref_t< Problem_ > Problem
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:24
typename Problem::YSmoothScaleDataType YSmoothScaleDataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:37
CK_TILE_DEVICE auto GetRowID(const ROW_COORDS coords, const IndexDataType *sorted_token_ids_ptr)
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:126
typename Problem::GDataType GDataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:30
typename Problem::YDataType YDataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:40
CK_TILE_DEVICE auto GetRowCoords_O(index_t base_offset)
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:156
remove_cvref_t< Policy_ > Policy
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:25
static constexpr bool UseSmoothQuant
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:45
typename Problem::Traits Traits
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:42
typename Problem::AccDataType AccDataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:32
typename Problem::AScaleDataType AScaleDataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:34
static CK_TILE_HOST_DEVICE auto GetACoord()
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:86
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
constexpr CK_TILE_HOST_DEVICE auto & at(index_t i)
Definition: array.hpp:110
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49
Definition: functional.hpp:43