/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp Source File
mx_flatmm_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 
13 
14 namespace ck_tile {
15 
16 template <typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
17 struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_>
18 {
20 
31  static constexpr index_t KernelBlockSize = MXFlatmmPipeline::BlockSize;
32  static constexpr bool UsePersistentKernel = MXFlatmmPipeline::UsePersistentKernel;
33 
36  // Below type is actually accumulation data type - the output of block GEMM.
38 
39  static constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(number<0>{});
40  static constexpr int NThreadPerXdl = BlockGemmShape::WarpTile::at(number<1>{});
41  static constexpr int KThreadPerXdl = 64 / MThreadPerXdl;
42 
45 
46  static constexpr int MXdlPack = MXFlatmmPipeline::MXdlPack;
47  static constexpr int NXdlPack = MXFlatmmPipeline::NXdlPack;
48  static constexpr int KXdlPack = MXFlatmmPipeline::KXdlPack;
49 
50  static constexpr index_t NumDTensor = DsDataType::size();
51 
52  static constexpr auto I0 = number<0>();
53  static constexpr auto I1 = number<1>();
54  static constexpr auto I2 = number<2>();
55  static constexpr auto I3 = number<3>();
56  static constexpr auto I4 = number<4>();
57  static constexpr auto I5 = number<5>();
58 
59  static_assert(DsLayout::size() == DsDataType::size(),
60  "The size of DsLayout and DsDataType should be the same");
61  // using KernelArgs = FlatmmKernelArgs<DsLayout::size()>;
62 
63  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
64  {
65  // clang-format off
66  return concat('_', "mx_flatmm_gemm", gemm_prec_str<ADataType, BDataType>, MXFlatmmPipeline::GetName());
67  // clang-format on
68  }
69 
70  template <class ScaleM, class ScaleN>
71  CK_TILE_HOST static constexpr auto
72  GridSize(const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs)
73  {
74  if constexpr(UsePersistentKernel)
75  {
76  hipDeviceProp_t prop;
77  int deviceId = 0; // default device
78 
79  constexpr int block_size = MXFlatmmKernel::BlockSize().x;
80  int dync_smem_size = 0;
81  int maxActiveBlocksPerCU = 0;
82 
83  if(hipGetDeviceProperties(&prop, deviceId) != hipSuccess)
84  throw std::runtime_error(std::string("hipGetDeviceProperties failed: ") +
85  hipGetErrorName(hipGetLastError()));
86 
87  if(hipOccupancyMaxActiveBlocksPerMultiprocessor(
88  &maxActiveBlocksPerCU,
89  reinterpret_cast<void*>(
90  kentry<1, MXFlatmmKernel, remove_cvref_t<decltype(kargs)>>),
91  block_size,
92  dync_smem_size) != hipSuccess)
93  throw std::runtime_error(
94  std::string("hipOccupancyMaxActiveBlocksPerMultiprocessor failed: ") +
95  hipGetErrorName(hipGetLastError()));
96 
97  const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
98  const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
99 
100  // std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
101  // << ", persistent_block_size: " << persistent_block_size
102  // << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
103 
104  if(kargs.k_batch != 1)
105  throw std::runtime_error("Wrong! k_batch != 1 not supported in persistent kernel");
106  return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
107  }
108  else
109  {
110  return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
111  }
112  }
113 
114  using SplitKBatchOffset = typename Underlying::SplitKBatchOffset;
115 
116  template <typename KernelArgs>
117  CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr,
118  const KernelArgs& kargs,
119  const index_t k_size,
120  const index_t block_idx_m)
121  {
122  // Step 1: Create tensor view
123  const auto& a_tensor_view = [&]() {
124  static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>,
125  "A tensor for mx must be RowMajor");
126  return make_naive_tensor_view<address_space_enum::global>(
127  a_ptr,
128  make_tuple(kargs.M, k_size),
129  make_tuple(kargs.stride_A, 1),
130  number<MXFlatmmPipeline::GetVectorSizeA()>{},
131  number<1>{});
132  }();
133 
134  // Step 2: Create padded view
135  const auto& a_pad_view = pad_tensor_view(
136  a_tensor_view,
139 
140  // Step 3: Create tile window
141  return make_tile_window(
142  a_pad_view,
144  {block_idx_m, 0});
145  }
146 
147  template <typename KernelArgs>
148  CK_TILE_DEVICE static auto MakeBFlatBlockWindow(const BDataType* b_flat_ptr,
149  const KernelArgs& kargs,
150  const index_t block_idx_n)
151  {
152  // Step 1: Create tensor view with special flat layout
153  constexpr index_t kKPerBlock = MXFlatmmPipeline::kKPerBlock;
154  constexpr index_t kNWarpTile = BlockGemmShape::WarpTile::at(I1);
155  constexpr index_t flatKPerBlock = kKPerBlock * kNWarpTile;
156  const index_t kFlatKBlocks = kargs.K / kKPerBlock;
157  const index_t kFlatN = kargs.N / kNWarpTile;
158 
159  const auto& b_flat_tensor_view = [&]() {
160  static_assert(flatKPerBlock % MXFlatmmPipeline::GetVectorSizeB() == 0,
161  "wrong! vector size for B tensor");
162  auto&& naive_desc = make_naive_tensor_descriptor_packed(
163  make_tuple(kFlatN, kFlatKBlocks, number<flatKPerBlock>{}));
164  auto&& desc = transform_tensor_descriptor(
165  naive_desc,
168  make_tuple(kFlatKBlocks, number<flatKPerBlock>{}))),
171  return make_tensor_view<address_space_enum::global>(b_flat_ptr, desc);
172  }();
173 
174  // Step 2: No padding for flat B
175  // Step 3: Create tile window
176  return make_tile_window(
177  b_flat_tensor_view,
180  {static_cast<int>(block_idx_n / BlockGemmShape::WarpTile::at(I1)), 0});
181  }
182 
183  template <typename KernelArgs>
184  CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
185  const KernelArgs& kargs,
186  const index_t block_idx_m,
187  const index_t block_idx_n)
188  {
189  // Step 1: Create tensor views
190  const auto& ds_tensor_view = generate_tuple(
191  [&](auto i) {
192  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
193  using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
194  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
195  {
196  return make_naive_tensor_view<address_space_enum::global>(
197  static_cast<const DDataType_*>(ds_ptr[i]),
198  make_tuple(kargs.M, kargs.N),
199  make_tuple(kargs.stride_Ds[i], 1),
200  number<EpiloguePipeline::GetVectorSizeD(i)>{},
201  number<1>{});
202  }
203  else
204  {
205  return make_naive_tensor_view<address_space_enum::global>(
206  static_cast<const DDataType_*>(ds_ptr[i]),
207  make_tuple(kargs.N, kargs.M),
208  make_tuple(kargs.stride_Ds[i], 1),
209  number<EpiloguePipeline::GetVectorSizeD(i)>{},
210  number<1>{});
211  }
212  },
214 
215  // Step 2: Create padded views
216  const auto& ds_pad_view = generate_tuple(
217  [&](auto i) {
218  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
219  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
220  {
221  return pad_tensor_view(ds_tensor_view[i],
225  }
226  else
227  {
228  return pad_tensor_view(ds_tensor_view[i],
232  }
233  },
235 
236  // Step 3: Create tile windows
237  return generate_tuple(
238  [&](auto i) {
239  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
240  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
241  {
242  return make_tile_window(ds_pad_view[i],
245  {block_idx_m, block_idx_n});
246  }
247  else
248  {
249  return make_tile_window(ds_pad_view[i],
252  {block_idx_n, block_idx_m});
253  }
254  },
256  }
257 
258  template <memory_operation_enum DstInMemOp = memory_operation_enum::set, typename KernelArgs>
260  const KernelArgs& kargs,
261  const index_t block_idx_m,
262  const index_t block_idx_n)
263  {
264  // Step 1: Create tensor view
265  const auto& e_tensor_view = [&]() {
266  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
267  {
268  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
269  e_ptr,
270  make_tuple(kargs.M, kargs.N),
271  make_tuple(kargs.stride_E, 1),
272  number<EpiloguePipeline::GetVectorSizeC()>{},
273  number<1>{});
274  }
275  else
276  {
277  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
278  e_ptr,
279  make_tuple(kargs.N, kargs.M),
280  make_tuple(kargs.stride_E, 1),
281  number<1>{},
282  number<1>{});
283  }
284  }();
285 
286  // Step 2: Create padded view
287  const auto& e_pad_view = [&]() {
288  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
289  {
290  return pad_tensor_view(e_tensor_view,
294  }
295  else
296  {
297  return pad_tensor_view(e_tensor_view,
301  }
302  }();
303 
304  // Step 3: Create tile window
305  return make_tile_window(
306  e_pad_view,
308  {block_idx_m, block_idx_n});
309  }
310 
311  template <typename KernelArgs>
312  CK_TILE_DEVICE static auto MakeScaleABlockWindow(const KernelArgs& kargs,
313  const index_t block_idx_m)
314  {
315  static constexpr int BlockScaleSize = 32;
316 
317  const auto&& scale_packs_m = integer_divide_ceil(kargs.M, (MXdlPack * MThreadPerXdl));
318  const auto&& scale_packs_k = kargs.K / BlockScaleSize / (KXdlPack * KThreadPerXdl);
319 
320  // Step 1: Create tensor view
321  const auto scale_a_naive_desc = make_naive_tensor_descriptor_packed(
322  make_tuple(scale_packs_m, scale_packs_k, KThreadPerXdl, MThreadPerXdl));
323  const auto scale_a_desc = transform_tensor_descriptor(
324  scale_a_naive_desc,
329 
330  const auto& scale_a_tensor_view = make_tensor_view<address_space_enum::global>(
331  reinterpret_cast<const int32_t*>(kargs.scale_m_ptr.ptr), scale_a_desc);
332 
333  // Step 2: Create tile window
334  return make_tile_window(
335  scale_a_tensor_view,
337  number<TilePartitioner::KPerBlock / (BlockScaleSize * KXdlPack)>{}),
338  {block_idx_m / MXdlPack, 0});
339  }
340 
341  template <typename KernelArgs>
342  CK_TILE_DEVICE static auto MakeScaleBBlockWindow(const KernelArgs& kargs,
343  const index_t block_idx_n)
344  {
345  static constexpr int BlockScaleSize = 32;
346 
347  const auto&& scale_packs_n = integer_divide_ceil(kargs.N, (NXdlPack * NThreadPerXdl));
348  const auto&& scale_packs_k = kargs.K / BlockScaleSize / (KXdlPack * KThreadPerXdl);
349 
350  // Step 1: Create tensor view
351  const auto scale_b_naive_desc = make_naive_tensor_descriptor_packed(
352  make_tuple(scale_packs_n, scale_packs_k, KThreadPerXdl, NThreadPerXdl));
353  const auto scale_b_desc = transform_tensor_descriptor(
354  scale_b_naive_desc,
359 
360  const auto& scale_b_tensor_view = make_tensor_view<address_space_enum::global>(
361  reinterpret_cast<const int32_t*>(kargs.scale_n_ptr.ptr), scale_b_desc);
362 
363  // Step 2: Create tile window
364  return make_tile_window(
365  scale_b_tensor_view,
367  number<TilePartitioner::KPerBlock / (BlockScaleSize * KXdlPack)>{}),
368  {block_idx_n / NXdlPack, 0});
369  }
370 
371  template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
372  CK_TILE_DEVICE static void
373  RunFlatmm(const ADataType* a_ptr,
374  const BDataType* b_flat_ptr,
375  const std::array<const void*, NumDTensor>& ds_ptr,
376  EDataType* e_ptr,
377  void* smem_ptr_ping,
378  void* smem_ptr_pong,
379  const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs,
380  const SplitKBatchOffset& splitk_batch_offset,
381  const index_t block_idx_m,
382  const index_t block_idx_n)
383  {
384  // Create block windows using specialized methods
385  const auto& a_block_window =
386  MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
387  const auto& b_flat_block_window = MakeBFlatBlockWindow(b_flat_ptr, kargs, block_idx_n);
388  const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
389  const auto& scale_a_block_window = MakeScaleABlockWindow(kargs, block_idx_m);
390  const auto& scale_b_block_window = MakeScaleBBlockWindow(kargs, block_idx_n);
391 
392  const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
393 
394  static_assert(ScaleM::GranularityK == ScaleN::GranularityK // have the same granK
395  || ScaleM::GranularityMN == -1 // or ScaleA is disable
396  || ScaleN::GranularityMN == -1, // or ScaleB is disable
397  "ScaleM and ScaleN should have the same GranularityK");
398  constexpr bool DoEpiScale =
399  (ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) || // per token
400  (ScaleN::GranularityMN != -1 && ScaleN::GranularityK == 0); // per channel
401 
402  const auto& c_block_tile = MXFlatmmPipeline{}(a_block_window,
403  b_flat_block_window,
404  scale_a_block_window,
405  scale_b_block_window,
406  num_loop,
407  smem_ptr_ping,
408  smem_ptr_pong);
409 
410  // Run Epilogue Pipeline with split_k dispatch
411  if constexpr(DoEpiScale)
412  {
413  if(kargs.k_batch == 1)
414  {
415  auto e_block_window = MakeEBlockWindow<memory_operation_enum::set>(
416  e_ptr, kargs, block_idx_m, block_idx_n);
417  EpiloguePipeline{}(e_block_window,
418  c_block_tile,
419  ds_block_window,
420  smem_ptr_ping,
421  kargs.scale_m_ptr + block_idx_m,
422  kargs.scale_n_ptr + block_idx_n);
423  }
424  else
425  {
426  auto e_block_window = MakeEBlockWindow<memory_operation_enum::atomic_add>(
427  e_ptr, kargs, block_idx_m, block_idx_n);
428  EpiloguePipeline{}(e_block_window,
429  c_block_tile,
430  ds_block_window,
431  smem_ptr_ping,
432  kargs.scale_m_ptr + block_idx_m,
433  kargs.scale_n_ptr + block_idx_n);
434  }
435  }
436  else if(UseDefaultScheduler || (get_warp_id() == 0))
437  {
438  if(kargs.k_batch == 1)
439  {
440  auto e_block_window = MakeEBlockWindow<memory_operation_enum::set>(
441  e_ptr, kargs, block_idx_m, block_idx_n);
442  EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_window, smem_ptr_ping);
443  }
444  else
445  {
446  auto e_block_window = MakeEBlockWindow<memory_operation_enum::atomic_add>(
447  e_ptr, kargs, block_idx_m, block_idx_n);
448  EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_window, smem_ptr_ping);
449  }
450  }
451  }
452 
453  template <class ScaleM, class ScaleN>
454  CK_TILE_DEVICE void operator()(FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()> kargs,
455  int partition_idx = blockIdx.x) const
456  {
457  int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
458 
459  do
460  {
461  const auto [iM, iN] =
462  TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
463  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
464  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
465 
466  const SplitKBatchOffset splitk_batch_offset(kargs);
467  // options
468  const auto a_ptr = static_cast<const ADataType*>(kargs.a_ptr) +
469  splitk_batch_offset.a_k_split_offset / APackedSize;
470  const auto b_flat_ptr = static_cast<const BDataType*>(kargs.b_ptr) +
471  splitk_batch_offset.b_k_split_offset / BPackedSize;
472  EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
473 
474  // allocate LDS
475  __shared__ char smem_ptr_ping[Underlying::GetSmemPingSize()];
476  __shared__ char smem_ptr_pong[Underlying::GetSmemPongSize()];
477 
478  constexpr auto scheduler_type = (MXFlatmmPipeline::NumWaveGroups == 1);
479  RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
480  b_flat_ptr,
481  kargs.ds_ptr,
482  e_ptr,
483  smem_ptr_ping,
484  smem_ptr_pong,
485  kargs,
486  splitk_batch_offset,
487  i_m,
488  i_n);
489  partition_idx += gridDim.x;
490  } while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
491  }
492 };
493 
494 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:145
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:36
__global__ void kentry(Args... args)
Definition: kernel_launch.hpp:22
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1690
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1633
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1684
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:203
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor_packed(const tuple< Lengths... > &lengths, number< GuaranteedLastDimensionVectorLength >=number<-1 >{})
Definition: tensor_descriptor.hpp:371
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:206
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
Definition: flatmm_kernel.hpp:232
Definition: flatmm_kernel.hpp:252
static constexpr CK_TILE_HOST auto BlockSize()
Definition: flatmm_kernel.hpp:333
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPingSize()
Definition: flatmm_kernel.hpp:355
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPongSize()
Definition: flatmm_kernel.hpp:359
Definition: mx_flatmm_kernel.hpp:18
static CK_TILE_DEVICE auto MakeABlockWindow(const ADataType *a_ptr, const KernelArgs &kargs, const index_t k_size, const index_t block_idx_m)
Definition: mx_flatmm_kernel.hpp:117
static CK_TILE_DEVICE auto MakeScaleBBlockWindow(const KernelArgs &kargs, const index_t block_idx_n)
Definition: mx_flatmm_kernel.hpp:342
static CK_TILE_HOST const std::string GetName()
Definition: mx_flatmm_kernel.hpp:63
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: mx_flatmm_kernel.hpp:25
static CK_TILE_DEVICE auto MakeScaleABlockWindow(const KernelArgs &kargs, const index_t block_idx_m)
Definition: mx_flatmm_kernel.hpp:312
static constexpr index_t NumDTensor
Definition: mx_flatmm_kernel.hpp:50
remove_cvref_t< typename MXFlatmmPipeline::BLayout > BLayout
Definition: mx_flatmm_kernel.hpp:27
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition: mx_flatmm_kernel.hpp:29
static constexpr int NThreadPerXdl
Definition: mx_flatmm_kernel.hpp:40
static constexpr int NXdlPack
Definition: mx_flatmm_kernel.hpp:47
static constexpr CK_TILE_HOST auto GridSize(const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs)
Definition: mx_flatmm_kernel.hpp:72
static constexpr auto I2
Definition: mx_flatmm_kernel.hpp:54
static CK_TILE_DEVICE auto MakeEBlockWindow(EDataType *e_ptr, const KernelArgs &kargs, const index_t block_idx_m, const index_t block_idx_n)
Definition: mx_flatmm_kernel.hpp:259
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: mx_flatmm_kernel.hpp:37
static constexpr bool UsePersistentKernel
Definition: mx_flatmm_kernel.hpp:32
remove_cvref_t< typename MXFlatmmPipeline::CLayout > ELayout
Definition: mx_flatmm_kernel.hpp:28
remove_cvref_t< typename MXFlatmmPipeline::ALayout > ALayout
Definition: mx_flatmm_kernel.hpp:26
remove_cvref_t< typename MXFlatmmPipeline::BDataType > BDataType
Definition: mx_flatmm_kernel.hpp:35
static constexpr auto I4
Definition: mx_flatmm_kernel.hpp:56
remove_cvref_t< MXFlatmmPipeline_ > MXFlatmmPipeline
Definition: mx_flatmm_kernel.hpp:22
static constexpr auto I1
Definition: mx_flatmm_kernel.hpp:53
static constexpr int MXdlPack
Definition: mx_flatmm_kernel.hpp:46
static constexpr auto I0
Definition: mx_flatmm_kernel.hpp:52
remove_cvref_t< typename MXFlatmmPipeline_::BlockGemmShape > BlockGemmShape
Definition: mx_flatmm_kernel.hpp:24
static constexpr int MThreadPerXdl
Definition: mx_flatmm_kernel.hpp:39
remove_cvref_t< typename MXFlatmmPipeline::ADataType > ADataType
Definition: mx_flatmm_kernel.hpp:34
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: mx_flatmm_kernel.hpp:21
static CK_TILE_DEVICE auto MakeDBlockWindows(const std::array< const void *, NumDTensor > &ds_ptr, const KernelArgs &kargs, const index_t block_idx_m, const index_t block_idx_n)
Definition: mx_flatmm_kernel.hpp:184
static CK_TILE_DEVICE auto MakeBFlatBlockWindow(const BDataType *b_flat_ptr, const KernelArgs &kargs, const index_t block_idx_n)
Definition: mx_flatmm_kernel.hpp:148
static constexpr auto I3
Definition: mx_flatmm_kernel.hpp:55
typename Underlying::SplitKBatchOffset SplitKBatchOffset
Definition: mx_flatmm_kernel.hpp:114
CK_TILE_DEVICE void operator()(FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> kargs, int partition_idx=blockIdx.x) const
Definition: mx_flatmm_kernel.hpp:454
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: mx_flatmm_kernel.hpp:30
static constexpr index_t KernelBlockSize
Definition: mx_flatmm_kernel.hpp:31
static CK_TILE_DEVICE void RunFlatmm(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_ping, void *smem_ptr_pong, const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Definition: mx_flatmm_kernel.hpp:373
static constexpr int KXdlPack
Definition: mx_flatmm_kernel.hpp:48
static constexpr auto I5
Definition: mx_flatmm_kernel.hpp:57
static constexpr int KThreadPerXdl
Definition: mx_flatmm_kernel.hpp:41
static constexpr int BPackedSize
Definition: mx_flatmm_kernel.hpp:44
static constexpr int APackedSize
Definition: mx_flatmm_kernel.hpp:43
Definition: integral_constant.hpp:13
Definition: numeric.hpp:81
Definition: sequence.hpp:49