/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp Source File
mx_flatmm_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 
13 
14 namespace ck_tile {
15 
16 template <typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
17 struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_>
18 {
20 
31  static constexpr index_t KernelBlockSize = MXFlatmmPipeline::BlockSize;
32  static constexpr bool UsePersistentKernel = MXFlatmmPipeline::UsePersistentKernel;
33 
36  // Below type is actually accumulation data type - the output of block GEMM.
38 
39  static constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(number<0>{});
40  static constexpr int NThreadPerXdl = BlockGemmShape::WarpTile::at(number<1>{});
41  static constexpr int KThreadPerXdl = 64 / MThreadPerXdl;
42 
45 
46  static constexpr int MXdlPack = MXFlatmmPipeline::MXdlPack;
47  static constexpr int NXdlPack = MXFlatmmPipeline::NXdlPack;
48  static constexpr int KXdlPack = MXFlatmmPipeline::KXdlPack;
49 
50  static constexpr index_t NumDTensor = DsDataType::size();
51 
52  static constexpr auto I0 = number<0>();
53  static constexpr auto I1 = number<1>();
54  static constexpr auto I2 = number<2>();
55  static constexpr auto I3 = number<3>();
56  static constexpr auto I4 = number<4>();
57  static constexpr auto I5 = number<5>();
58 
59  static_assert(DsLayout::size() == DsDataType::size(),
60  "The size of DsLayout and DsDataType should be the same");
61  // using KernelArgs = FlatmmKernelArgs<DsLayout::size()>;
62 
63  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
64  {
65  // clang-format off
66  return concat('_', "mx_flatmm_gemm", gemm_prec_str<ADataType, BDataType>, MXFlatmmPipeline::GetName());
67  // clang-format on
68  }
69 
70  template <class ScaleM, class ScaleN>
71  CK_TILE_HOST static constexpr auto
72  GridSize(const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs)
73  {
74  if constexpr(UsePersistentKernel)
75  {
76  hipDeviceProp_t prop;
77  int deviceId = 0; // default device
78 
79  constexpr int block_size = MXFlatmmKernel::BlockSize().x;
80  int dync_smem_size = 0;
81  int maxActiveBlocksPerCU = 0;
82 
83  if(hipGetDeviceProperties(&prop, deviceId) != hipSuccess)
84  throw std::runtime_error(std::string("hipGetDeviceProperties failed: ") +
85  hipGetErrorName(hipGetLastError()));
86 
87  if(hipOccupancyMaxActiveBlocksPerMultiprocessor(
88  &maxActiveBlocksPerCU,
89  reinterpret_cast<void*>(
90  kentry<1, MXFlatmmKernel, remove_cvref_t<decltype(kargs)>>),
91  block_size,
92  dync_smem_size) != hipSuccess)
93  throw std::runtime_error(
94  std::string("hipOccupancyMaxActiveBlocksPerMultiprocessor failed: ") +
95  hipGetErrorName(hipGetLastError()));
96 
97  const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
98  const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
99 
100  // std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
101  // << ", persistent_block_size: " << persistent_block_size
102  // << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
103 
104  if(kargs.k_batch != 1)
105  throw std::runtime_error("Wrong! k_batch != 1 not supported in persistent kernel");
106  return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
107  }
108  else
109  {
110  return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
111  }
112  }
113 
114  using SplitKBatchOffset = typename Underlying::SplitKBatchOffset;
115 
116  template <memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>
117  CK_TILE_DEVICE static auto
119  const BDataType* b_flat_ptr,
120  const std::array<const void*, NumDTensor>& ds_ptr,
121  EDataType* e_ptr,
122  const KernelArgs& kargs,
123  const SplitKBatchOffset& splitk_batch_offset)
124  {
125  const auto& a_tensor_view = [&]() {
126  static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>,
127  "A tensor for mx must be RowMajor");
128  return make_naive_tensor_view<address_space_enum::global>(
129  a_ptr,
130  make_tuple(kargs.M, splitk_batch_offset.splitted_k),
131  make_tuple(kargs.stride_A, 1),
132  number<MXFlatmmPipeline::GetVectorSizeA()>{},
133  number<1>{});
134  }();
135 
136  constexpr index_t kKPerBlock = MXFlatmmPipeline::kKPerBlock;
137  constexpr index_t kNWarpTile = BlockGemmShape::WarpTile::at(I1);
138  constexpr index_t flatKPerBlock = kKPerBlock * kNWarpTile;
139  const index_t kFlatKBlocks = kargs.K / kKPerBlock;
140  const index_t kFlatN = kargs.N / kNWarpTile;
141  const auto& b_flat_tensor_view = [&]() {
142  static_assert(flatKPerBlock % MXFlatmmPipeline::GetVectorSizeB() == 0,
143  "wrong! vector size for B tensor");
144  auto&& naive_desc = make_naive_tensor_descriptor_packed(
145  make_tuple(kFlatN, kFlatKBlocks, number<flatKPerBlock>{}));
146  auto&& desc = transform_tensor_descriptor(
147  naive_desc,
150  make_tuple(kFlatKBlocks, number<flatKPerBlock>{}))),
153  return make_tensor_view<address_space_enum::global>(b_flat_ptr, desc);
154  }();
155 
156  const auto& ds_tensor_view = generate_tuple(
157  [&](auto i) {
158  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
159  using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
160  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
161  {
162  return make_naive_tensor_view<address_space_enum::global>(
163  static_cast<const DDataType_*>(ds_ptr[i]),
164  make_tuple(kargs.M, kargs.N),
165  make_tuple(kargs.stride_Ds[i], 1),
166  number<EpiloguePipeline::GetVectorSizeD(i)>{},
167  number<1>{});
168  }
169  else
170  {
171  return make_naive_tensor_view<address_space_enum::global>(
172  static_cast<const DDataType_*>(ds_ptr[i]),
173  make_tuple(kargs.N, kargs.M),
174  make_tuple(kargs.stride_Ds[i], 1),
175  number<EpiloguePipeline::GetVectorSizeD(i)>{},
176  number<1>{});
177  }
178  },
180 
181  // TODO: enable vector write for C in ColMajor
182  const auto& e_tensor_view = [&]() {
183  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
184  {
185  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
186  e_ptr,
187  make_tuple(kargs.M, kargs.N),
188  make_tuple(kargs.stride_E, 1),
189  number<EpiloguePipeline::GetVectorSizeC()>{},
190  number<1>{});
191  }
192  else
193  {
194  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
195  e_ptr,
196  make_tuple(kargs.N, kargs.M),
197  make_tuple(kargs.stride_E, 1),
198  number<1>{},
199  number<1>{});
200  }
201  }();
202 
203  auto scale_a = kargs.scale_m_ptr;
204  auto scale_b = kargs.scale_n_ptr;
205 
206  static constexpr int BlockScaleSize = 32; // decltype(scale_n)::GranularityK;
207  const auto&& scale_packs_m = integer_divide_ceil(kargs.M, (MXdlPack * MThreadPerXdl));
208  const auto&& scale_packs_n = integer_divide_ceil(kargs.N, (NXdlPack * NThreadPerXdl));
209  const auto&& scale_packs_k = kargs.K / BlockScaleSize / (KXdlPack * KThreadPerXdl);
210 
211  // A scale tensor view
212  const auto& scale_a_tensor_view = [&]() {
213  // Pack 2x2 e8m0 over M/K dimension into 1 int32_t to trigger dword width load
214  const auto scale_a_naive_desc = make_naive_tensor_descriptor_packed(
215  make_tuple(scale_packs_m, scale_packs_k, KThreadPerXdl, MThreadPerXdl));
216  const auto scale_a_desc = transform_tensor_descriptor(
217  scale_a_naive_desc,
222 
223  return make_tensor_view<address_space_enum::global>(
224  reinterpret_cast<const int32_t*>(scale_a.ptr), scale_a_desc);
225  }();
226 
227  // B scale tensor view
228  const auto& scale_b_tensor_view = [&]() {
229  const auto scale_b_navie_desc = make_naive_tensor_descriptor_packed(
230  make_tuple(scale_packs_n, scale_packs_k, KThreadPerXdl, NThreadPerXdl));
231  const auto scale_b_desc = transform_tensor_descriptor(
232  scale_b_navie_desc,
237 
238  return make_tensor_view<address_space_enum::global>(
239  reinterpret_cast<const int32_t*>(scale_b.ptr), scale_b_desc);
240  }();
241 
242  return make_tuple(a_tensor_view,
243  b_flat_tensor_view,
244  ds_tensor_view,
245  e_tensor_view,
246  scale_a_tensor_view,
247  scale_b_tensor_view);
248  }
249 
250  template <typename TensorView>
251  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
252  {
253  const auto& a_pad_view = [&]() {
254  const auto& a_tensor_view = views.at(I0);
255  static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>,
256  "A tensor for mx must be RowMajor");
257  return pad_tensor_view(a_tensor_view,
261  }();
262 
263  const auto& b_flat_tensor_view = views.at(I1);
264 
265  const auto& ds_pad_view = generate_tuple(
266  [&](auto i) {
267  const auto& d_tensor_view = views.at(I2);
268  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
269  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
270  {
271  return pad_tensor_view(d_tensor_view[i],
275  }
276  else
277  {
278  return pad_tensor_view(d_tensor_view[i],
282  }
283  },
285 
286  // TODO vector write in for C in ColMajor
287  const auto& e_pad_view = [&]() {
288  const auto& e_tensor_view = views.at(I3);
289  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
290  {
291  return pad_tensor_view(e_tensor_view,
295  }
296  else
297  {
298  return pad_tensor_view(e_tensor_view,
302  }
303  }();
304 
305  return make_tuple(
306  a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view, views.at(I4), views.at(I5));
307  }
308 
309  template <typename PadView>
310  CK_TILE_DEVICE static auto
311  MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
312  {
313  const auto& a_pad_view = views.at(I0);
314  const auto& b_flat_pad_view = views.at(I1);
315  const auto& ds_pad_view = views.at(I2);
316  const auto& e_pad_view = views.at(I3);
317 
318  const auto& a_block_window = [&]() {
319  static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>,
320  "A tensor for mx must be RowMajor");
321  return make_tile_window(a_pad_view,
324  {i_m, 0});
325  }();
326 
327  const auto& b_flat_block_window =
328  make_tile_window(b_flat_pad_view,
331  {static_cast<int>(i_n / BlockGemmShape::WarpTile::at(I1)), 0});
332 
333  const auto ds_block_window = generate_tuple(
334  [&](auto i) {
335  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
336  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
337  {
338  return make_tile_window(ds_pad_view[i],
341  {i_m, i_n});
342  }
343  else
344  {
345  return make_tile_window(ds_pad_view[i],
348  {i_n, i_m});
349  }
350  },
352 
353  auto e_block_window = make_tile_window(
354  e_pad_view,
356  {i_m, i_n});
357 
358  static constexpr int BlockScaleSize = 32;
359 
360  auto scale_a_block_window = make_tile_window(
361  views.at(I4),
363  number<TilePartitioner::KPerBlock / (BlockScaleSize * KXdlPack)>{}),
364  {i_m / MXdlPack, 0});
365 
366  auto scale_b_block_window = make_tile_window(
367  views.at(I5),
369  number<TilePartitioner::KPerBlock / (BlockScaleSize * KXdlPack)>{}),
370  {i_n / NXdlPack, 0});
371 
372  return make_tuple(a_block_window,
373  b_flat_block_window,
374  ds_block_window,
375  e_block_window,
376  scale_a_block_window,
377  scale_b_block_window);
378  }
379 
380  template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
381  CK_TILE_DEVICE static void
382  RunFlatmm(const ADataType* a_ptr,
383  const BDataType* b_flat_ptr,
384  const std::array<const void*, NumDTensor>& ds_ptr,
385  EDataType* e_ptr,
386  void* smem_ptr_ping,
387  void* smem_ptr_pong,
388  const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs,
389  const SplitKBatchOffset& splitk_batch_offset,
390  const index_t block_idx_m,
391  const index_t block_idx_n)
392  {
393  // Create Gemm tensor views, pad views and tile windows
394  const auto& gemm_tensor_views_tuple =
395  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
396  a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
397  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
398  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
399 
400  const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
401 
402  // Run GEMM cooperatively by whole workgroup.
403  const auto& a_block_window = gemm_tile_windows.at(I0);
404  const auto& b_flat_block_window = gemm_tile_windows.at(I1);
405  const auto& d_block_window = gemm_tile_windows.at(I2);
406  const auto& scale_a_block_window = gemm_tile_windows.at(I4);
407  const auto& scale_b_block_window = gemm_tile_windows.at(I5);
408 
409  static_assert(ScaleM::GranularityK == ScaleN::GranularityK // have the same granK
410  || ScaleM::GranularityMN == -1 // or ScaleA is disable
411  || ScaleN::GranularityMN == -1, // or ScaleB is disable
412  "ScaleM and ScaleN should have the same GranularityK");
413  constexpr bool DoEpiScale =
414  (ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) || // per token
415  (ScaleN::GranularityMN != -1 && ScaleN::GranularityK == 0); // per channel
416 
417  auto a_block_window_with_distr =
418  ck_tile::make_tile_window(a_block_window.get_bottom_tensor_view(),
419  a_block_window.get_window_lengths(),
420  a_block_window.get_window_origin(),
421  MXFlatmmPipeline::GetADramTileDistribution());
422  const auto& c_block_tile = MXFlatmmPipeline{}(a_block_window_with_distr,
423  b_flat_block_window,
424  scale_a_block_window,
425  scale_b_block_window,
426  num_loop,
427  smem_ptr_ping,
428  smem_ptr_pong);
429 
430  // Run Epilogue Pipeline
431  if constexpr(DoEpiScale)
432  {
433  auto& c_block_window = gemm_tile_windows.at(I3);
434  EpiloguePipeline{}(c_block_window,
435  c_block_tile,
436  d_block_window,
437  smem_ptr_ping,
438  kargs.scale_m_ptr + block_idx_m,
439  kargs.scale_n_ptr + block_idx_n);
440  }
441  else if(UseDefaultScheduler || (get_warp_id() == 0))
442  {
443  // Run Epilogue Pipeline
444  auto& c_block_window = gemm_tile_windows.at(I3);
445  EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
446  }
447  }
448 
449  template <class ScaleM, class ScaleN>
450  CK_TILE_DEVICE void operator()(FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()> kargs,
451  int partition_idx = blockIdx.x) const
452  {
453  int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
454 
455  do
456  {
457  const auto [iM, iN] =
458  TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
459  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
460  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
461 
462  const SplitKBatchOffset splitk_batch_offset(kargs);
463  // options
464  const auto a_ptr = static_cast<const ADataType*>(kargs.a_ptr) +
465  splitk_batch_offset.a_k_split_offset / APackedSize;
466  const auto b_flat_ptr = static_cast<const BDataType*>(kargs.b_ptr) +
467  splitk_batch_offset.b_k_split_offset / BPackedSize;
468  EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
469 
470  // allocate LDS
471  __shared__ char smem_ptr_ping[Underlying::GetSmemPingSize()];
472  __shared__ char smem_ptr_pong[Underlying::GetSmemPongSize()];
473 
474  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
475  EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
477  {
478  constexpr auto scheduler_type = (MXFlatmmPipeline::NumWaveGroups == 1);
479  RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
480  b_flat_ptr,
481  kargs.ds_ptr,
482  e_ptr,
483  smem_ptr_ping,
484  smem_ptr_pong,
485  kargs,
486  splitk_batch_offset,
487  i_m,
488  i_n);
489  }
490  else
491  {
492  static_assert(false,
493  "Unimplemented: atomic_add with odd vector size for fp16/bf16");
494  }
495  partition_idx += gridDim.x;
496  } while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
497  }
498 };
499 
500 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
__global__ void kentry(Args... args)
Definition: kernel_launch.hpp:22
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1690
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1633
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1684
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:203
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor_packed(const tuple< Lengths... > &lengths, number< GuaranteedLastDimensionVectorLength >=number<-1 >{})
Definition: tensor_descriptor.hpp:371
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__device__ X atomic_add(X *p_dst, const X &x)
Definition: flatmm_kernel.hpp:229
Definition: flatmm_kernel.hpp:249
static constexpr CK_TILE_HOST auto BlockSize()
Definition: flatmm_kernel.hpp:330
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPingSize()
Definition: flatmm_kernel.hpp:352
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPongSize()
Definition: flatmm_kernel.hpp:356
Definition: mx_flatmm_kernel.hpp:18
static CK_TILE_HOST const std::string GetName()
Definition: mx_flatmm_kernel.hpp:63
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: mx_flatmm_kernel.hpp:25
static constexpr index_t NumDTensor
Definition: mx_flatmm_kernel.hpp:50
remove_cvref_t< typename MXFlatmmPipeline::BLayout > BLayout
Definition: mx_flatmm_kernel.hpp:27
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition: mx_flatmm_kernel.hpp:29
static constexpr int NThreadPerXdl
Definition: mx_flatmm_kernel.hpp:40
static constexpr int NXdlPack
Definition: mx_flatmm_kernel.hpp:47
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: mx_flatmm_kernel.hpp:118
static constexpr CK_TILE_HOST auto GridSize(const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs)
Definition: mx_flatmm_kernel.hpp:72
static constexpr auto I2
Definition: mx_flatmm_kernel.hpp:54
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: mx_flatmm_kernel.hpp:37
static constexpr bool UsePersistentKernel
Definition: mx_flatmm_kernel.hpp:32
remove_cvref_t< typename MXFlatmmPipeline::CLayout > ELayout
Definition: mx_flatmm_kernel.hpp:28
remove_cvref_t< typename MXFlatmmPipeline::ALayout > ALayout
Definition: mx_flatmm_kernel.hpp:26
remove_cvref_t< typename MXFlatmmPipeline::BDataType > BDataType
Definition: mx_flatmm_kernel.hpp:35
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: mx_flatmm_kernel.hpp:311
static constexpr auto I4
Definition: mx_flatmm_kernel.hpp:56
remove_cvref_t< MXFlatmmPipeline_ > MXFlatmmPipeline
Definition: mx_flatmm_kernel.hpp:22
static constexpr auto I1
Definition: mx_flatmm_kernel.hpp:53
static constexpr int MXdlPack
Definition: mx_flatmm_kernel.hpp:46
static constexpr auto I0
Definition: mx_flatmm_kernel.hpp:52
remove_cvref_t< typename MXFlatmmPipeline_::BlockGemmShape > BlockGemmShape
Definition: mx_flatmm_kernel.hpp:24
static constexpr int MThreadPerXdl
Definition: mx_flatmm_kernel.hpp:39
remove_cvref_t< typename MXFlatmmPipeline::ADataType > ADataType
Definition: mx_flatmm_kernel.hpp:34
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: mx_flatmm_kernel.hpp:21
static constexpr auto I3
Definition: mx_flatmm_kernel.hpp:55
typename Underlying::SplitKBatchOffset SplitKBatchOffset
Definition: mx_flatmm_kernel.hpp:114
CK_TILE_DEVICE void operator()(FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> kargs, int partition_idx=blockIdx.x) const
Definition: mx_flatmm_kernel.hpp:450
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: mx_flatmm_kernel.hpp:30
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: mx_flatmm_kernel.hpp:251
static constexpr index_t KernelBlockSize
Definition: mx_flatmm_kernel.hpp:31
static CK_TILE_DEVICE void RunFlatmm(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_ping, void *smem_ptr_pong, const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Definition: mx_flatmm_kernel.hpp:382
static constexpr int KXdlPack
Definition: mx_flatmm_kernel.hpp:48
static constexpr auto I5
Definition: mx_flatmm_kernel.hpp:57
static constexpr int KThreadPerXdl
Definition: mx_flatmm_kernel.hpp:41
static constexpr int BPackedSize
Definition: mx_flatmm_kernel.hpp:44
static constexpr int APackedSize
Definition: mx_flatmm_kernel.hpp:43
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:115
Definition: numeric.hpp:81
Definition: sequence.hpp:49