include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp Source File

include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp Source File#

Composable Kernel: include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp Source File
fused_moegemm_pipeline_flatmm_uk.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
12 /*
13 This pipeline deal with a gemm(actually 2 gemm) with one very small(token), one very big(weight)
14 we need to design the pipeline such that all waves along gemm-N dim (gemm-m only 1 wave)
15 
16  <----- gemm-N ------>
17  +----+----+----+----+
18  | w0 | w1 | w2 | w3 | gemm-m
19  +----+----+----+----+
20 */
21 template <typename Problem_, typename Policy_ = FusedMoeGemmPipelineFlatmmPolicy>
23 {
26 
27  using BlockShape = typename Problem::BlockShape; // this is FusedMoeGemmShape
28 
29  using ADataType = typename Problem::ADataType;
30  using GDataType = typename Problem::GDataType;
31  using DDataType = typename Problem::DDataType;
32  using AccDataType = typename Problem::AccDataType;
33  using ODataType = typename Problem::ODataType;
34  using AScaleDataType = typename Problem::AScaleDataType;
35  using GScaleDataType = typename Problem::GScaleDataType;
36  using DScaleDataType = typename Problem::DScaleDataType;
37  using YSmoothScaleDataType = typename Problem::YSmoothScaleDataType;
38  using TopkWeightDataType = typename Problem::TopkWeightDataType;
39  using IndexDataType = typename Problem::IndexDataType;
40  using YDataType = typename Problem::YDataType;
41 
42  using Traits = typename Problem::Traits;
43 
44  static constexpr bool IsGateOnly = Traits::IsGateOnly;
45  static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
46  static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
47  static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
48 
49  static constexpr index_t kAlignmentA = Policy::template GetAlignment_A<Problem>();
50  static constexpr index_t kAlignmentG = Policy::template GetAlignment_G<Problem>();
51  static constexpr index_t kAlignmentD = Policy::template GetAlignment_D<Problem>();
52  static constexpr index_t kAlignmentO = Policy::template GetAlignment_O<Problem>();
53 
58 
59  static constexpr index_t kBlockPerCu = []() {
60  if constexpr(Problem::kBlockPerCu != -1)
61  return Problem::kBlockPerCu;
62  else
63  {
64  // minimize occupancy
65  return 2;
66  }
67  }();
68 
69  static constexpr const char* name = "flatmm_uk";
70 
72  {
73 #if 1
74  constexpr index_t smem_0 = Policy::template GetUK_0<Problem>().GetSmemSize();
75  constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize();
76  constexpr index_t smem_bridge =
77  BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType);
78  return max(smem_0 + smem_1, smem_bridge);
79 #else
80  // keep it here purposely in case we have regression
81  return 65536;
82 #endif
83  }
84 
85  // this is the thread-offset along row/col
87  {
88  constexpr auto a_dist = Policy::template MakeGlobalTileDistribution_A<Problem>();
89  const auto a_coord = a_dist.calculate_index();
90  return a_coord;
91  }
92 
93  // this is the thread-offset along row/col
95  {
96  constexpr auto o_dist = Policy::template MakeOGlobalTileDistribution<Problem>();
97  const auto o_coord = o_dist.calculate_index();
98  return o_coord;
99  }
100 
102  {
103  constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA;
104  constexpr index_t MLans = BlockShape::BlockSize / KLans;
105  constexpr index_t MRepeat = BlockShape::Block_M0 / MLans;
106 
107  return MRepeat;
108  }
109 
110  // TODO: properlly support scatter/gather
112  {
113  constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA;
114  constexpr index_t MLans = BlockShape::BlockSize / KLans;
115  constexpr index_t MRepeat = BlockShape::Block_M0 / MLans;
116 
117  auto base_coord = threadIdx.x / KLans + base_offset;
118 
120  static_for<0, MRepeat, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLans; });
121 
122  return coords;
123  }
124 
125  template <typename ROW_COORDS>
126  CK_TILE_DEVICE auto GetRowID(const ROW_COORDS coords, const IndexDataType* sorted_token_ids_ptr)
127  {
128  constexpr index_t n_size = coords.size();
129 
130  array<index_t, n_size> row_ids;
131  static_for<0, n_size, 1>{}([&](auto i) {
132  row_ids.at(i) = sorted_token_ids_ptr[coords[i]]; // base_coord + i * MLans;
133 #if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
134  row_ids.at(i) &= 0xffffff;
135 #endif
136  });
137 
138  return row_ids;
139  }
140 
141  template <typename ROW_COORDS>
142  CK_TILE_DEVICE auto GetWeightScale(const ROW_COORDS coords,
143  const TopkWeightDataType* sorted_weight_ptr)
144  {
145  constexpr index_t n_size = coords.size();
146 
148  static_for<0, n_size, 1>{}([&](auto i) {
149  w.at(i) = sorted_weight_ptr[coords[i]]; // base_coord + i * MLans;
150  });
151 
152  return w;
153  }
154 
155  // TODO: this row id is before shuffle atomic, need use acc distribution
157  {
158  constexpr index_t MLanes = BlockShape::Warp_M1;
159  constexpr index_t Repeat_M = BlockShape::Repeat_M1;
160 
161  auto base_coord = threadIdx.x % MLanes + base_offset;
162 
164  static_for<0, Repeat_M, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLanes; });
165 
166  return coords;
167  }
168 
169  template <typename Karg>
170  CK_TILE_DEVICE auto operator()(const Karg& kargs,
171  CK_TILE_LDS_ADDR void* smem,
172  index_t sorted_tile_id,
173  index_t intermediate_tile_id)
174  {
175  constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
176  ck_tile::index_t shared_intermediate_size_0 =
177  kargs.intermediate_size * hidden_radio_0; // total gate+up
178  ck_tile::index_t shared_intermediate_size_1 = kargs.intermediate_size;
179 
180  // after weight shuffling, gate-only: [nr0, kr0, w0], gate+up: [nr0_gate + nr0_up, kr0, w0]
181 
182  index_t nr_0 = shared_intermediate_size_0 / BlockShape::Warp_N0; // divide N in W
183  index_t kr_0 = kargs.hidden_size / BlockShape::Warp_K0; // divide K in W
184  index_t nr_1 = kargs.hidden_size / BlockShape::Warp_N1;
185  index_t kr_1 = shared_intermediate_size_1 / BlockShape::Warp_K1;
186 
187  const IndexDataType expert_id = __builtin_amdgcn_readfirstlane(
188  reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
189  index_t expert_stride_0 = shared_intermediate_size_0 * kargs.hidden_size;
190  index_t expert_stride_1 = shared_intermediate_size_1 * kargs.hidden_size;
191 
192  // nr*kr*w
193  index_t interm_idx_nr0 = __builtin_amdgcn_readfirstlane(
194  intermediate_tile_id *
195  BlockShape::Block_Nr0); // intermediate_tile_id * Block_N / (N in W)
196 
197  index_t interm_idx_kr1 = __builtin_amdgcn_readfirstlane(
198  intermediate_tile_id *
199  BlockShape::Block_Kr1); // intermediate_tile_id * Block_N / (N in W)
200 
201  auto row_coords_a = GetRowCoords_A(sorted_tile_id * BlockShape::Block_M0);
202  auto row_ids_a = GetRowID(
203  row_coords_a, reinterpret_cast<const IndexDataType*>(kargs.sorted_token_ids_ptr));
204  auto a_coords = generate_tuple(
205  [&](auto i) {
206  return row_ids_a[i] * kargs.stride_token +
207  threadIdx.x % (BlockShape::Block_K0 / kAlignmentA) * kAlignmentA;
208  },
209  number<row_ids_a.size()>{});
210  auto a_res =
211  make_wave_buffer_resource(reinterpret_cast<const ADataType*>(kargs.a_ptr),
212  kargs.num_tokens * kargs.stride_token * sizeof(ADataType));
213 
214  auto make_gu_win = [&](const auto* ptr_) {
215  auto view_ = make_naive_tensor_view<address_space_enum::global>(
216  ptr_,
218  make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
220  number<1>{});
221 
222  auto win_ = make_tile_window_linear_raw(
223  view_,
227  {0, 0, 0},
228  Policy::template MakeGlobalTileDistribution_G<Problem>(),
230  return win_;
231  };
232 
233  const GDataType* gu_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
234  static_cast<long_index_t>(expert_id) * expert_stride_0 +
235  interm_idx_nr0 * kr_0 * BlockShape::Block_W0;
236 
237  auto g_win = make_gu_win(gu_ptr);
238  // Note: gu swizzled, [nr_u+nr_g, kr, w], hence base offset to up is just interm*hidden
239  auto u_win = make_gu_win(gu_ptr + kargs.intermediate_size * kargs.hidden_size);
240 
241  auto g_res = g_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
242  auto u_res = u_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
243  auto g_coords = generate_tuple([&](auto i) { return g_win.cached_coords_[i].get_offset(); },
244  number<decltype(g_win)::NumAccess_NonLinear>{});
245 
246  const auto d_win = [&]() {
247  const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
248  static_cast<long_index_t>(expert_id) * expert_stride_1 +
249  interm_idx_kr1 * BlockShape::Block_W1;
250  // note interm_idx_nr0 is along the gemm-k dim of 2nd gemm
251 
252  const auto d_view_ = make_naive_tensor_view<address_space_enum::global>(
253  d_ptr,
254  make_tuple(nr_1, kr_1, BlockShape::Block_W1),
255  make_tuple(kr_1 * BlockShape::Block_W1, BlockShape::Block_W1, 1),
257  number<1>{});
258 
259  const auto d_window_ = make_tile_window_linear_raw(
260  d_view_,
264  {0, 0, 0},
265  Policy::template MakeGlobalTileDistribution_D<Problem>(),
267  return d_window_;
268  }();
269  auto d_res = d_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
270 
271  // TODO: load D order is N0.K0...127, N64.K0...127, N0.K128...255, N64.K128...255
272  // block-k=512, block-n=128
273  // wg |<----- W_ ----->|
274  // Nr(2)*Nw(4)* Kr *Kr0(4)*Kr1(4) * [Kl(4)*Nl(16)*Kv(8)]->one issue
275  // y p y y p p y
276  // 1 2 0(imm)
277  auto d_coords = [&]() {
278  constexpr index_t Nr_ = 2;
279  constexpr index_t Nw_ = 4;
280  constexpr index_t Kr0_ = 4;
281  constexpr index_t Kr1_ = 4;
282  constexpr index_t Kl_ = 4;
283  constexpr index_t Nl_ = 16;
284  constexpr index_t Kv_ = 8;
285  constexpr index_t W_ = Kl_ * Nl_ * Kv_;
286  constexpr index_t num_offsets_ = Nr_ * Kr0_;
287  index_t base_os_ = (threadIdx.x % 64) * Kv_ + (threadIdx.x / 64) *
288  shared_intermediate_size_1 *
289  Nl_; // Kr0_ * Kr1_ * W_;
290  return generate_tuple(
291  [&](auto i) {
292  constexpr auto i_nr_ = number<i % Nr_>{};
293  constexpr auto i_kr0_ = number<i / Nr_>{};
294 
295  return i_nr_ * shared_intermediate_size_1 * Nw_ * Nl_ + i_kr0_ * Kr1_ * W_ +
296  base_os_;
297  },
299  }();
300 
301  auto o_coords = generate_tuple(
302  [&](auto i) {
303  return row_ids_a[i] * kargs.stride_token +
304  threadIdx.x % (BlockShape::Block_N1 / kAlignmentO) * kAlignmentO;
305  },
306  number<row_ids_a.size()>{});
307 
308  auto o_flags =
309  generate_tuple([&](auto i) { return cmp_lt_to_exec(row_ids_a[i], kargs.num_tokens); },
310  number<row_ids_a.size()>{});
311 
312  auto bridge_sst_win = [&]() {
313  constexpr auto desc_ = Policy::template MakeBridgeLdsStoreForUKDesc<Problem>();
314  constexpr auto dist_ = Policy::template GetUK_0<Problem>().MakeCBlockDist();
315  return make_tile_window_linear(make_tensor_view<address_space_enum::lds>(
316  reinterpret_cast<YDataType*>(smem), desc_),
317  desc_.get_lengths(),
318  {0, 0},
319  dist_);
320  }();
321  auto o_res =
322  make_wave_buffer_resource(reinterpret_cast<const ODataType*>(kargs.o_ptr),
323  kargs.num_tokens * kargs.stride_token * sizeof(ODataType));
324 
325  auto row_coords_o = GetRowCoords_O(sorted_tile_id * BlockShape::Block_M0);
326  auto w_scale = GetWeightScale(
327  row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr));
328 
329  auto uk_0 = Policy::template GetUK_0<Problem>();
330 
331  auto y_pre = [&]() {
332  if constexpr(IsGateOnly)
333  {
334  auto acc_0 = uk_0(a_res,
335  a_coords,
336  g_res,
337  g_coords,
338  smem,
339  kargs.hidden_size,
340  BlockShape::Block_K0, // tile offset for B matrix each unroll
341  BlockShape::Block_Kr0 *
342  BlockShape::Block_W0); // tile offset for B matrix each unroll
343 
344  sweep_tile(
345  acc_0,
346  [&](auto idx0, auto idx1) {
347  fp32x2_t v_{acc_0(idx0), acc_0(idx1)};
348  typename Problem::GateActivation{}(v_, v_);
349  acc_0(idx0) = v_.x;
350  acc_0(idx1) = v_.y;
351  },
352  sequence<1, 2>{});
353 
354  return cast_tile<YDataType>(acc_0);
355  }
356  else
357  {
358  uint32x8_t gu_res;
359  gu_res[0] = g_res[0];
360  gu_res[1] = g_res[1];
361  gu_res[2] = g_res[2];
362  gu_res[3] = g_res[3];
363  gu_res[4] = u_res[0];
364  gu_res[5] = u_res[1];
365  gu_res[6] = u_res[2];
366  gu_res[7] = u_res[3];
367 
368  auto acc_0 = uk_0(a_res,
369  a_coords,
370  gu_res,
371  g_coords,
372  smem,
373  kargs.hidden_size,
374  BlockShape::Block_K0, // tile offset for B matrix each unroll
375  BlockShape::Block_Kr0 * BlockShape::Block_W0,
376  bool_constant<true>{}); // tile offset for B matrix each unroll
377 
378  sweep_tile(
379  acc_0.at(number<0>{}),
380  [&](auto idx0, auto idx1) {
381  fp32x2_t v_{acc_0.at(number<0>{})(idx0), acc_0.at(number<0>{})(idx1)};
382  typename Problem::GateActivation{}(v_, v_);
383  acc_0.at(number<0>{})(idx0) = v_.x;
384  acc_0.at(number<0>{})(idx1) = v_.y;
385  },
386  sequence<1, 2>{});
387 
388  auto reduced_acc_0 =
389  tile_elementwise_in([&](const auto& a_, const auto& b_) { return a_ * b_; },
390  acc_0.at(number<0>{}),
391  acc_0.at(number<1>{}));
392 
393  return cast_tile<YDataType>(reduced_acc_0);
394  }
395  }();
396 
397  block_sync_lds();
398 
399  store_tile(bridge_sst_win, y_pre);
400  block_sync_lds();
401 
402  auto uk_1 = Policy::template GetUK_1<Problem>();
403  uk_1(d_res,
404  d_coords,
405  o_res,
406  o_coords,
407  o_flags,
408  smem,
409  kargs.hidden_size, // total n number
410  w_scale,
411  BlockShape::Block_Nr1 * kr_1 * BlockShape::Block_W1, // along N
412  BlockShape::Block_N1); // along N
413  }
414 };
415 
416 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_LDS_ADDR
Definition: config.hpp:56
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE auto cmp_lt_to_exec(const X &x, const Y &y)
Definition: utility.hpp:118
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition: tile_elementwise.hpp:40
CK_TILE_DEVICE void block_sync_lds()
Definition: arch.hpp:80
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
uint32_t uint32x8_t
Definition: vector_type.hpp:124
CK_TILE_DEVICE auto make_tile_window_linear_raw(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition: tile_window_linear.hpp:1160
int64_t long_index_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition: tile_window_linear.hpp:1124
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:400
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
float fp32x2_t
Definition: vector_type.hpp:86
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
constexpr CK_TILE_HOST_DEVICE void sweep_tile(const F &f, UnpacksPerXDim={})
Definition: sweep_tile.hpp:231
CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void *ptr, uint32_t size=0xffffffff)
Definition: amd_buffer_addressing.hpp:26
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:23
typename Problem::IndexDataType IndexDataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:39
typename Problem::ADataType ADataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:29
static constexpr index_t kAlignmentO
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:52
constexpr CK_TILE_DEVICE auto GetNumRowCoords_A()
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:101
typename Problem::DDataType DDataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:31
static constexpr bool PadIntermediateSize
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:47
CK_TILE_DEVICE auto operator()(const Karg &kargs, CK_TILE_LDS_ADDR void *smem, index_t sorted_tile_id, index_t intermediate_tile_id)
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:170
static constexpr const char * name
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:69
CK_TILE_DEVICE auto GetRowCoords_A(index_t base_offset)
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:111
static CK_TILE_HOST_DEVICE auto GetOCoord()
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:94
typename Problem::DScaleDataType DScaleDataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:36
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:71
static constexpr index_t kAlignmentD
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:51
static constexpr index_t kAlignmentA
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:49
static constexpr index_t GLD_B
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:56
static constexpr index_t kBlockPerCu
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:59
typename Problem::GScaleDataType GScaleDataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:35
typename Problem::TopkWeightDataType TopkWeightDataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:38
static constexpr bool IsGateOnly
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:44
typename Problem::ODataType ODataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:33
typename Problem::BlockShape BlockShape
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:27
static constexpr index_t kAlignmentG
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:50
static constexpr bool PadHiddenSize
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:46
static constexpr index_t GST_O
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:57
static constexpr index_t SLD_A
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:54
CK_TILE_DEVICE auto GetWeightScale(const ROW_COORDS coords, const TopkWeightDataType *sorted_weight_ptr)
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:142
static constexpr index_t GLD_A
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:55
remove_cvref_t< Problem_ > Problem
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:24
typename Problem::YSmoothScaleDataType YSmoothScaleDataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:37
CK_TILE_DEVICE auto GetRowID(const ROW_COORDS coords, const IndexDataType *sorted_token_ids_ptr)
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:126
typename Problem::GDataType GDataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:30
typename Problem::YDataType YDataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:40
CK_TILE_DEVICE auto GetRowCoords_O(index_t base_offset)
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:156
remove_cvref_t< Policy_ > Policy
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:25
static constexpr bool UseSmoothQuant
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:45
typename Problem::Traits Traits
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:42
typename Problem::AccDataType AccDataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:32
typename Problem::AScaleDataType AScaleDataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:34
static CK_TILE_HOST_DEVICE auto GetACoord()
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:86
Definition: array.hpp:24
constexpr CK_TILE_HOST_DEVICE auto & at(index_t i)
Definition: array.hpp:91
Definition: integral_constant.hpp:13
Definition: sequence.hpp:52
Definition: functional.hpp:43