/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp Source File
mx_flatmm_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 
13 
14 namespace ck_tile {
15 
16 template <typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
17 struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_>
18 {
20 
31  static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize;
32  static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel;
33 
36  // Below type is actually accumulation data type - the output of block GEMM.
38 
39  static constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(number<0>{});
40  static constexpr int NThreadPerXdl = BlockGemmShape::WarpTile::at(number<1>{});
41  static constexpr int KThreadPerXdl = 64 / MThreadPerXdl;
42 
45 
46  static constexpr int MXdlPack = FlatmmPipeline::MXdlPack;
47  static constexpr int NXdlPack = FlatmmPipeline::NXdlPack;
48  static constexpr int KXdlPack = FlatmmPipeline::KXdlPack;
49 
50  static constexpr index_t NumDTensor = DsDataType::size();
51 
52  static constexpr auto I0 = number<0>();
53  static constexpr auto I1 = number<1>();
54  static constexpr auto I2 = number<2>();
55  static constexpr auto I3 = number<3>();
56  static constexpr auto I4 = number<4>();
57  static constexpr auto I5 = number<5>();
58 
59  static_assert(DsLayout::size() == DsDataType::size(),
60  "The size of DsLayout and DsDataType should be the same");
61  // using KernelArgs = FlatmmKernelArgs<DsLayout::size()>;
62 
63  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
64  {
65  // clang-format off
66  return concat('_', "mx_flatmm_gemm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
67  // clang-format on
68  }
69 
70  template <class ScaleM, class ScaleN>
71  CK_TILE_HOST static constexpr auto
72  GridSize(const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs)
73  {
74  if constexpr(UsePersistentKernel)
75  {
76  hipDeviceProp_t prop;
77  int deviceId = 0; // default device
78 
79  constexpr int block_size = MXFlatmmKernel::BlockSize().x;
80  int dync_smem_size = 0;
81  int maxActiveBlocksPerCU = 0;
82 
83  if(hipGetDeviceProperties(&prop, deviceId) != hipSuccess)
84  throw std::runtime_error(std::string("hipGetDeviceProperties failed: ") +
85  hipGetErrorName(hipGetLastError()));
86 
87  if(hipOccupancyMaxActiveBlocksPerMultiprocessor(
88  &maxActiveBlocksPerCU,
89  reinterpret_cast<void*>(
90  kentry<1, MXFlatmmKernel, remove_cvref_t<decltype(kargs)>>),
91  block_size,
92  dync_smem_size) != hipSuccess)
93  throw std::runtime_error(
94  std::string("hipOccupancyMaxActiveBlocksPerMultiprocessor failed: ") +
95  hipGetErrorName(hipGetLastError()));
96 
97  const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
98  const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
99 
100  // std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
101  // << ", persistent_block_size: " << persistent_block_size
102  // << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
103 
104  if(kargs.k_batch != 1)
105  throw std::runtime_error("Wrong! k_batch != 1 not supported in persistent kernel");
106  return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
107  }
108  else
109  {
110  return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
111  }
112  }
113 
114  using SplitKBatchOffset = typename Underlying::SplitKBatchOffset;
115 
116  template <memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>
117  CK_TILE_DEVICE static auto
119  const BDataType* b_flat_ptr,
120  const std::array<const void*, NumDTensor>& ds_ptr,
121  EDataType* e_ptr,
122  const KernelArgs& kargs,
123  const SplitKBatchOffset& splitk_batch_offset)
124  {
125  const auto& a_tensor_view = [&]() {
126  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
127  {
128  return make_naive_tensor_view<address_space_enum::global>(
129  a_ptr,
130  make_tuple(kargs.M, splitk_batch_offset.splitted_k),
131  make_tuple(kargs.stride_A, 1),
132  number<FlatmmPipeline::GetVectorSizeA()>{},
133  number<1>{});
134  }
135  else
136  {
137  return make_naive_tensor_view<address_space_enum::global>(
138  a_ptr,
139  make_tuple(splitk_batch_offset.splitted_k, kargs.M),
140  make_tuple(kargs.stride_A, 1),
141  number<FlatmmPipeline::GetVectorSizeA()>{},
142  number<1>{});
143  }
144  }();
145 
146  index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1);
147  index_t kFlatN = kargs.N * kargs.K / kFlatK;
148 
149  const auto& b_flat_tensor_view = [&]() {
150  return make_naive_tensor_view<address_space_enum::global>(
151  b_flat_ptr,
152  make_tuple(kFlatN, kFlatK),
153  make_tuple(kFlatK, 1),
154  number<FlatmmPipeline::GetVectorSizeB()>{},
155  number<1>{});
156  }();
157 
158  const auto& ds_tensor_view = generate_tuple(
159  [&](auto i) {
160  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
161  using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
162  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
163  {
164  return make_naive_tensor_view<address_space_enum::global>(
165  static_cast<const DDataType_*>(ds_ptr[i]),
166  make_tuple(kargs.M, kargs.N),
167  make_tuple(kargs.stride_Ds[i], 1),
168  number<EpiloguePipeline::GetVectorSizeD(i)>{},
169  number<1>{});
170  }
171  else
172  {
173  return make_naive_tensor_view<address_space_enum::global>(
174  static_cast<const DDataType_*>(ds_ptr[i]),
175  make_tuple(kargs.N, kargs.M),
176  make_tuple(kargs.stride_Ds[i], 1),
177  number<EpiloguePipeline::GetVectorSizeD(i)>{},
178  number<1>{});
179  }
180  },
182 
183  // TODO: enable vector write for C in ColMajor
184  const auto& e_tensor_view = [&]() {
185  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
186  {
187  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
188  e_ptr,
189  make_tuple(kargs.M, kargs.N),
190  make_tuple(kargs.stride_E, 1),
191  number<EpiloguePipeline::GetVectorSizeC()>{},
192  number<1>{});
193  }
194  else
195  {
196  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
197  e_ptr,
198  make_tuple(kargs.N, kargs.M),
199  make_tuple(kargs.stride_E, 1),
200  number<1>{},
201  number<1>{});
202  }
203  }();
204 
205  auto scale_a = kargs.scale_m_ptr;
206  auto scale_b = kargs.scale_n_ptr;
207 
208  static constexpr int BlockScaleSize = 32; // decltype(scale_n)::GranularityK;
209  const auto&& scale_packs_m = integer_divide_ceil(kargs.M, (MXdlPack * MThreadPerXdl));
210  const auto&& scale_packs_n = integer_divide_ceil(kargs.N, (NXdlPack * NThreadPerXdl));
211  const auto&& scale_packs_k = kargs.K / BlockScaleSize / (KXdlPack * KThreadPerXdl);
212 
213  // A scale tensor view
214  const auto& scale_a_tensor_view = [&]() {
215  // Pack 2x2 e8m0 over M/K dimension into 1 int32_t to trigger dword width load
216  const auto scale_a_naive_desc = make_naive_tensor_descriptor_packed(
217  make_tuple(scale_packs_m, scale_packs_k, KThreadPerXdl, MThreadPerXdl));
218  const auto scale_a_desc = transform_tensor_descriptor(
219  scale_a_naive_desc,
224 
225  return make_tensor_view<address_space_enum::global>(
226  reinterpret_cast<const int32_t*>(scale_a.ptr), scale_a_desc);
227  }();
228 
229  // B scale tensor view
230  const auto& scale_b_tensor_view = [&]() {
231  const auto scale_b_navie_desc = make_naive_tensor_descriptor_packed(
232  make_tuple(scale_packs_n, scale_packs_k, KThreadPerXdl, NThreadPerXdl));
233  const auto scale_b_desc = transform_tensor_descriptor(
234  scale_b_navie_desc,
239 
240  return make_tensor_view<address_space_enum::global>(
241  reinterpret_cast<const int32_t*>(scale_b.ptr), scale_b_desc);
242  }();
243 
244  return make_tuple(a_tensor_view,
245  b_flat_tensor_view,
246  ds_tensor_view,
247  e_tensor_view,
248  scale_a_tensor_view,
249  scale_b_tensor_view);
250  }
251 
252  template <typename TensorView>
253  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
254  {
255  const auto& a_pad_view = [&]() {
256  const auto& a_tensor_view = views.at(I0);
257  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
258  {
259  return pad_tensor_view(a_tensor_view,
263  }
264  else
265  {
266  return pad_tensor_view(a_tensor_view,
270  }
271  }();
272 
273  const auto& b_flat_tensor_view = views.at(I1);
274 
275  const auto& ds_pad_view = generate_tuple(
276  [&](auto i) {
277  const auto& d_tensor_view = views.at(I2);
278  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
279  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
280  {
281  return pad_tensor_view(d_tensor_view[i],
285  }
286  else
287  {
288  return pad_tensor_view(d_tensor_view[i],
292  }
293  },
295 
296  // TODO vector write in for C in ColMajor
297  const auto& e_pad_view = [&]() {
298  const auto& e_tensor_view = views.at(I3);
299  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
300  {
301  return pad_tensor_view(e_tensor_view,
305  }
306  else
307  {
308  return pad_tensor_view(e_tensor_view,
312  }
313  }();
314 
315  return make_tuple(
316  a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view, views.at(I4), views.at(I5));
317  }
318 
319  template <typename PadView>
320  CK_TILE_DEVICE static auto
321  MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
322  {
323  const auto& a_pad_view = views.at(I0);
324  const auto& b_flat_pad_view = views.at(I1);
325  const auto& ds_pad_view = views.at(I2);
326  const auto& e_pad_view = views.at(I3);
327 
328  const auto& a_block_window = [&]() {
329  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
330  {
331  return make_tile_window(a_pad_view,
334  {i_m, 0});
335  }
336  else
337  {
338  return make_tile_window(a_pad_view,
341  {0, i_m});
342  }
343  }();
344 
345  const auto& b_flat_block_window =
346  make_tile_window(b_flat_pad_view,
349  {static_cast<int>(i_n / BlockGemmShape::WarpTile::at(I1)), 0});
350 
351  const auto ds_block_window = generate_tuple(
352  [&](auto i) {
353  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
354  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
355  {
356  return make_tile_window(ds_pad_view[i],
359  {i_m, i_n});
360  }
361  else
362  {
363  return make_tile_window(ds_pad_view[i],
366  {i_n, i_m});
367  }
368  },
370 
371  auto e_block_window = make_tile_window(
372  e_pad_view,
374  {i_m, i_n});
375 
376  static constexpr int BlockScaleSize = 32;
377 
378  auto scale_a_block_window = make_tile_window(
379  views.at(I4),
381  number<TilePartitioner::KPerBlock / (BlockScaleSize * KXdlPack)>{}),
382  {i_m / MXdlPack, 0});
383 
384  auto scale_b_block_window = make_tile_window(
385  views.at(I5),
387  number<TilePartitioner::KPerBlock / (BlockScaleSize * KXdlPack)>{}),
388  {i_n / NXdlPack, 0});
389 
390  return make_tuple(a_block_window,
391  b_flat_block_window,
392  ds_block_window,
393  e_block_window,
394  scale_a_block_window,
395  scale_b_block_window);
396  }
397 
398  template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
399  CK_TILE_DEVICE static void
400  RunFlatmm(const ADataType* a_ptr,
401  const BDataType* b_flat_ptr,
402  const std::array<const void*, NumDTensor>& ds_ptr,
403  EDataType* e_ptr,
404  void* smem_ptr_ping,
405  void* smem_ptr_pong,
406  const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs,
407  const SplitKBatchOffset& splitk_batch_offset,
408  const index_t block_idx_m,
409  const index_t block_idx_n)
410  {
411  // Create Gemm tensor views, pad views and tile windows
412  const auto& gemm_tensor_views_tuple =
413  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
414  a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
415  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
416  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
417 
418  const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
419 
420  // Run GEMM cooperatively by whole workgroup.
421  const auto& a_block_window = gemm_tile_windows.at(I0);
422  const auto& b_flat_block_window = gemm_tile_windows.at(I1);
423  const auto& d_block_window = gemm_tile_windows.at(I2);
424  const auto& scale_a_block_window = gemm_tile_windows.at(I4);
425  const auto& scale_b_block_window = gemm_tile_windows.at(I5);
426 
427  static_assert(ScaleM::GranularityK == ScaleN::GranularityK // have the same granK
428  || ScaleM::GranularityMN == -1 // or ScaleA is disable
429  || ScaleN::GranularityMN == -1, // or ScaleB is disable
430  "ScaleM and ScaleN should have the same GranularityK");
431  constexpr bool DoEpiScale =
432  (ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) || // per token
433  (ScaleN::GranularityMN != -1 && ScaleN::GranularityK == 0); // per channel
434 
435  auto a_block_window_with_distr =
436  ck_tile::make_tile_window(a_block_window.get_bottom_tensor_view(),
437  a_block_window.get_window_lengths(),
438  a_block_window.get_window_origin(),
439  FlatmmPipeline::GetADramTileDistribution());
440  const auto& c_block_tile = FlatmmPipeline{}(a_block_window_with_distr,
441  b_flat_block_window,
442  scale_a_block_window,
443  scale_b_block_window,
444  num_loop,
445  smem_ptr_ping,
446  smem_ptr_pong);
447 
448  // Run Epilogue Pipeline
449  if constexpr(DoEpiScale)
450  {
451  auto& c_block_window = gemm_tile_windows.at(I3);
452  EpiloguePipeline{}(c_block_window,
453  c_block_tile,
454  d_block_window,
455  smem_ptr_ping,
456  kargs.scale_m_ptr + block_idx_m,
457  kargs.scale_n_ptr + block_idx_n);
458  }
459  else if(UseDefaultScheduler || (get_warp_id() == 0))
460  {
461  // Run Epilogue Pipeline
462  auto& c_block_window = gemm_tile_windows.at(I3);
463  EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
464  }
465  }
466 
467  template <class ScaleM, class ScaleN>
468  CK_TILE_DEVICE void operator()(FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()> kargs,
469  int partition_idx = blockIdx.x) const
470  {
471  int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
472 
473  do
474  {
475  const auto [iM, iN] =
476  TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
477  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
478  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
479 
480  const SplitKBatchOffset splitk_batch_offset(kargs);
481  // options
482  const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr) +
483  splitk_batch_offset.a_k_split_offset / APackedSize;
484  const BDataType* b_flat_ptr = static_cast<const BDataType*>(kargs.b_ptr) +
485  splitk_batch_offset.b_k_split_offset / BPackedSize;
486  EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
487 
488  // allocate LDS
489  __shared__ char smem_ptr_ping[Underlying::GetSmemPingSize()];
490  __shared__ char smem_ptr_pong[Underlying::GetSmemPongSize()];
491 
492  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
493  EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
495  {
496  constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
497  RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
498  b_flat_ptr,
499  kargs.ds_ptr,
500  e_ptr,
501  smem_ptr_ping,
502  smem_ptr_pong,
503  kargs,
504  splitk_batch_offset,
505  i_m,
506  i_n);
507  }
508  else
509  {
510  static_assert(false,
511  "Unimplemented: atomic_add with odd vector size for fp16/bf16");
512  }
513  partition_idx += gridDim.x;
514  } while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
515  }
516 };
517 
518 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
__global__ void kentry(Args... args)
Definition: kernel_launch.hpp:22
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1615
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:203
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor_packed(const tuple< Lengths... > &lengths, number< GuaranteedLastDimensionVectorLength >=number<-1 >{})
Definition: tensor_descriptor.hpp:371
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__device__ X atomic_add(X *p_dst, const X &x)
Definition: flatmm_kernel.hpp:229
Definition: flatmm_kernel.hpp:249
static constexpr CK_TILE_HOST auto BlockSize()
Definition: flatmm_kernel.hpp:330
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPingSize()
Definition: flatmm_kernel.hpp:352
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPongSize()
Definition: flatmm_kernel.hpp:356
Definition: mx_flatmm_kernel.hpp:18
remove_cvref_t< typename FlatmmPipeline::CLayout > ELayout
Definition: mx_flatmm_kernel.hpp:28
static CK_TILE_HOST const std::string GetName()
Definition: mx_flatmm_kernel.hpp:63
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: mx_flatmm_kernel.hpp:25
static constexpr index_t NumDTensor
Definition: mx_flatmm_kernel.hpp:50
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition: mx_flatmm_kernel.hpp:29
remove_cvref_t< typename FlatmmPipeline::BLayout > BLayout
Definition: mx_flatmm_kernel.hpp:27
static constexpr int NThreadPerXdl
Definition: mx_flatmm_kernel.hpp:40
static constexpr int NXdlPack
Definition: mx_flatmm_kernel.hpp:47
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: mx_flatmm_kernel.hpp:118
static constexpr CK_TILE_HOST auto GridSize(const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs)
Definition: mx_flatmm_kernel.hpp:72
static constexpr auto I2
Definition: mx_flatmm_kernel.hpp:54
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: mx_flatmm_kernel.hpp:37
static constexpr bool UsePersistentKernel
Definition: mx_flatmm_kernel.hpp:32
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: mx_flatmm_kernel.hpp:321
static constexpr auto I4
Definition: mx_flatmm_kernel.hpp:56
static constexpr auto I1
Definition: mx_flatmm_kernel.hpp:53
static constexpr int MXdlPack
Definition: mx_flatmm_kernel.hpp:46
static constexpr auto I0
Definition: mx_flatmm_kernel.hpp:52
remove_cvref_t< typename FlatmmPipeline::ALayout > ALayout
Definition: mx_flatmm_kernel.hpp:26
remove_cvref_t< typename MXFlatmmPipeline_::BlockGemmShape > BlockGemmShape
Definition: mx_flatmm_kernel.hpp:24
static constexpr int MThreadPerXdl
Definition: mx_flatmm_kernel.hpp:39
remove_cvref_t< typename FlatmmPipeline::ADataType > ADataType
Definition: mx_flatmm_kernel.hpp:34
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: mx_flatmm_kernel.hpp:21
static constexpr auto I3
Definition: mx_flatmm_kernel.hpp:55
typename Underlying::SplitKBatchOffset SplitKBatchOffset
Definition: mx_flatmm_kernel.hpp:114
CK_TILE_DEVICE void operator()(FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> kargs, int partition_idx=blockIdx.x) const
Definition: mx_flatmm_kernel.hpp:468
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: mx_flatmm_kernel.hpp:30
remove_cvref_t< typename FlatmmPipeline::BDataType > BDataType
Definition: mx_flatmm_kernel.hpp:35
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: mx_flatmm_kernel.hpp:253
static constexpr index_t KernelBlockSize
Definition: mx_flatmm_kernel.hpp:31
remove_cvref_t< MXFlatmmPipeline_ > FlatmmPipeline
Definition: mx_flatmm_kernel.hpp:22
static CK_TILE_DEVICE void RunFlatmm(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_ping, void *smem_ptr_pong, const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Definition: mx_flatmm_kernel.hpp:400
static constexpr int KXdlPack
Definition: mx_flatmm_kernel.hpp:48
static constexpr auto I5
Definition: mx_flatmm_kernel.hpp:57
static constexpr int KThreadPerXdl
Definition: mx_flatmm_kernel.hpp:41
static constexpr int BPackedSize
Definition: mx_flatmm_kernel.hpp:44
static constexpr int APackedSize
Definition: mx_flatmm_kernel.hpp:43
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:115
Definition: numeric.hpp:81
Definition: sequence.hpp:49