/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp Source File
flatmm_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
12 
13 namespace ck_tile {
15 {
18  index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
19  : M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_C(stride_C_)
20  {
21  }
22 
29 };
30 
31 template <int SharedGranularityMN, int SharedGranularityK = 0, typename ScaleType_ = float>
33 {
34  using ScaleType = ScaleType_;
35  static constexpr int GranularityMN = SharedGranularityMN;
36  static constexpr int GranularityK = SharedGranularityK;
37 
38  const ScaleType* ptr;
39 
42  CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_, [[maybe_unused]] index_t length_)
43  : ptr(ptr_)
44  {
45  }
46 
48  {
50  if constexpr(GranularityMN == 0)
51  {
52  ret.ptr = ptr + offset / GranularityK;
53  }
54  else
55  {
57  }
58  return ret;
59  }
60 
62 };
63 
64 template <int SharedGranularityMN, typename ScaleType_>
65 struct FlatmmScalePointer<SharedGranularityMN, 0, ScaleType_>
66 {
67  using ScaleType = ScaleType_;
68  static constexpr int GranularityMN = SharedGranularityMN;
69  static constexpr int GranularityK = 0;
70 
71  static_assert(GranularityMN != 0);
72 
73  const ScaleType* ptr;
75 
77  CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_) : ptr(ptr_), length(1) {}
79  : ptr(ptr_), length(length_)
80  {
81  }
82 
84  {
86  if constexpr(GranularityMN == 1)
87  {
88  ret.ptr = ptr + offset;
89  ret.length = length - offset;
90  }
91  else
92  {
93  ret.ptr = ptr + offset / GranularityMN;
94  ret.length = length - offset / GranularityMN;
95  }
96  return ret;
97  }
98 
100  {
101  // with additional oob check
102  if constexpr(GranularityMN == 1)
103  return i < length ? ptr[i] : 0;
104  else
105  return i / GranularityMN < length ? ptr[i / GranularityMN] : 0;
106  }
107 };
108 
109 // shared granularityMN = -1 means no scale
110 template <typename ScaleType_>
111 struct FlatmmScalePointer<-1, 0, ScaleType_>
112 {
113  using ScaleType = ScaleType_;
114  static constexpr int GranularityMN = -1;
115  static constexpr int GranularityK = 0;
116 
117  const ScaleType* ptr = nullptr;
118 
122 
124  {
125  return FlatmmScalePointer{};
126  }
128  {
129  return 1; // alway return 1, it doesn't change the result
130  }
131 };
132 
133 template <index_t NumDTensor = 0>
135 {
137  CK_TILE_HOST BaseFlatmmHostArgs(const void* a_ptr_,
138  const void* b_ptr_,
139  const std::array<const void*, NumDTensor>& ds_ptr_,
140  void* e_ptr_,
141  index_t k_batch_,
142  index_t M_,
143  index_t N_,
144  index_t K_,
145  index_t stride_A_,
146  index_t stride_B_,
147  const std::array<index_t, NumDTensor>& stride_Ds_,
148  index_t stride_E_)
149  : a_ptr(a_ptr_),
150  b_ptr(b_ptr_),
151  ds_ptr(ds_ptr_),
152  e_ptr(e_ptr_),
153  M(M_),
154  N(N_),
155  K(K_),
156  stride_A(stride_A_),
157  stride_B(stride_B_),
158  stride_Ds(stride_Ds_),
159  stride_E(stride_E_),
160  k_batch(k_batch_)
161  {
162  }
163 
164  const void* a_ptr;
165  const void* b_ptr;
166  const std::array<const void*, NumDTensor> ds_ptr;
167  union
168  {
169  void* e_ptr;
170  void* c_ptr;
171  };
177  const std::array<index_t, NumDTensor> stride_Ds;
178  union
179  {
182  };
183 
185 };
186 template <class ScaleM = FlatmmScalePointer<-1>,
187  class ScaleN = FlatmmScalePointer<-1>,
188  index_t NumDTensor = 0>
190 {
192  CK_TILE_HOST ScaleFlatmmHostArgs(const void* a_ptr_,
193  const void* b_shuffle_ptr_,
194  const std::array<const void*, NumDTensor>& ds_ptr_,
195  void* c_ptr_,
196  index_t k_batch_,
197  index_t M_,
198  index_t N_,
199  index_t K_,
200  index_t stride_A_,
201  index_t stride_B_,
202  const std::array<index_t, NumDTensor>& stride_Ds_,
203  index_t stride_C_,
204  ScaleM scale_m_ = nullptr,
205  ScaleN scale_n_ = nullptr)
206  : BaseFlatmmHostArgs(a_ptr_,
207  b_shuffle_ptr_,
208  ds_ptr_,
209  c_ptr_,
210  k_batch_,
211  M_,
212  N_,
213  K_,
214  stride_A_,
215  stride_B_,
216  stride_Ds_,
217  stride_C_),
218  scale_m(scale_m_),
219  scale_n(scale_n_)
220  {
221  }
222  ScaleM scale_m = nullptr;
223  ScaleN scale_n = nullptr;
224 };
225 
226 template <int NumberTensor = 0>
229 
230 template <class ScaleM, class ScaleN, index_t NumDTensor = 0>
232 {
233  const void* a_ptr;
234  // const void* b_shuffle_ptr;
235  const void* b_ptr;
236  const std::array<const void*, NumDTensor> ds_ptr;
237  void* e_ptr;
243  std::array<index_t, NumDTensor> stride_Ds;
246  ScaleM scale_m_ptr = nullptr;
247  ScaleN scale_n_ptr = nullptr;
248 };
249 
250 template <typename TilePartitioner_, typename FlatmmPipeline_, typename EpiloguePipeline_>
252 {
263  static constexpr index_t kBlockSize = FlatmmPipeline::BlockSize;
264  static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel;
265 
268  // Below type is actually accumulation data type - the output of block GEMM.
270 
271  static constexpr index_t NumDTensor = DsDataType::size();
272 
273  static constexpr auto I0 = number<0>();
274  static constexpr auto I1 = number<1>();
275  static constexpr auto I2 = number<2>();
276  static constexpr auto I3 = number<3>();
277 
278  static_assert(DsLayout::size() == DsDataType::size(),
279  "The size of DsLayout and DsDataType should be the same");
280  // using KernelArgs = FlatmmKernelArgs<DsLayout::size()>;
281 
282  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
283  {
284  // clang-format off
285  return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
286  // clang-format on
287  }
288 
289  CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
290  {
291  assert(!UsePersistentKernel);
292  return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
293  }
294 
295  template <class ScaleM, class ScaleN>
296  CK_TILE_HOST static constexpr auto
297  GridSize(const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs)
298  {
299  if constexpr(UsePersistentKernel)
300  {
301  hipDeviceProp_t prop;
302  int deviceId = 0; // default device
303 
304  constexpr int block_size = FlatmmKernel::BlockSize().x;
305  int dync_smem_size = 0;
306  int maxActiveBlocksPerCU = 0;
307 
308  [[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
309 
310  e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
311  &maxActiveBlocksPerCU,
312  reinterpret_cast<void*>(
313  kentry<1, FlatmmKernel, FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>>),
314  block_size,
315  dync_smem_size);
316 
317  const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
318  const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
319 
320  // std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
321  // << ", persistent_block_size: " << persistent_block_size
322  // << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
323 
324  assert(kargs.k_batch == 1);
325  return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
326  }
327  else
328  {
329  return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
330  }
331  }
332 
333  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
334 
335  template <class ScaleM, class ScaleN>
336  CK_TILE_HOST static constexpr FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>
337  MakeKernelArgs(const ScaleFlatmmHostArgs<ScaleM, ScaleN, DsDataType::size()>& hostArgs)
338  {
339  return {hostArgs.a_ptr,
340  hostArgs.b_ptr,
341  hostArgs.ds_ptr,
342  hostArgs.e_ptr,
343  hostArgs.M,
344  hostArgs.N,
345  hostArgs.K,
346  hostArgs.stride_A,
347  hostArgs.stride_B,
348  hostArgs.stride_Ds,
349  hostArgs.stride_E,
350  hostArgs.k_batch,
351  hostArgs.scale_m,
352  hostArgs.scale_n};
353  }
354 
356  {
357  return max(FlatmmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
358  }
360  {
361  return FlatmmPipeline::GetSmemSize();
362  }
363 
365  {
366  template <class KernelArgs>
367  __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
368  {
369  constexpr auto N1 = BlockGemmShape::WarpTile::at(number<1>{});
370  constexpr auto K1 = BlockGemmShape::WarpTile::at(number<2>{});
371  const index_t K_t = kargs.k_batch * K1;
372  const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
373 
374  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
375  {
376  a_k_split_offset = k_id * KRead;
377  }
378  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
379  {
380  a_k_split_offset = k_id * KRead * kargs.stride_A;
381  }
382 
383  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
384  {
385  b_k_split_offset = k_id * KRead * kargs.stride_B * N1;
386  }
387  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
388  {
389  b_k_split_offset = k_id * KRead * N1;
390  }
391 
392  if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
393  {
394  splitted_k = KRead;
395  }
396  else
397  {
398  splitted_k = kargs.K - KRead * (kargs.k_batch - 1);
399  }
400  }
401 
405  };
406 
407  template <class KernelArgs>
408  CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs)
409  {
410  if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
412  {
413  if(kargs.k_batch != 1)
414  {
415  std::cerr << "Conditions not met for Kbatch >1 !" << std::endl;
416  return false;
417  }
418  }
419  if constexpr(UsePersistentKernel)
420  {
421  if(kargs.k_batch != 1)
422  {
423  std::cerr << "Persistent mode doesn't support Kbatch >1 !" << std::endl;
424  return false;
425  }
426  }
427 
428  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
429  {
430  if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false)
431  {
432  std::cerr << "Can't support K that is not a multiple of KPerBlock"
433  " without padding!"
434  << std::endl;
435  return false;
436  }
437  if(kargs.K % FlatmmPipeline::GetVectorSizeA() != 0)
438  {
439  std::cerr << "K is not a multiple of vector load size for A tensor!" << std::endl;
440  return false;
441  }
442  }
443  else
444  {
445  if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
446  {
447  std::cerr << "Can't support M that is not a multiple of MPerBlock"
448  " without padding!"
449  << std::endl;
450  return false;
451  }
452  if(kargs.M % FlatmmPipeline::GetVectorSizeA() != 0)
453  {
454  std::cerr << "M is not a multiple of vector load size for A tensor!" << std::endl;
455  return false;
456  }
457  }
458 
459  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
460  {
461  if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
462  {
463  std::cerr << "Can't support N that is not a multiple of NPerBlock"
464  " without padding!"
465  << std::endl;
466  return false;
467  }
468  if(kargs.N % FlatmmPipeline::GetVectorSizeB() != 0)
469  {
470  std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl;
471  return false;
472  }
473  }
474  else
475  {
476  if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false)
477  {
478  std::cerr << "Can't support K that is not a multiple of KPerBlock"
479  " without padding!"
480  << std::endl;
481  return false;
482  }
483  if(kargs.K % FlatmmPipeline::GetVectorSizeB() != 0)
484  {
485  std::cerr << "K is not a multiple of vector load size for B tensor!" << std::endl;
486  return false;
487  }
488  }
489 
490  bool DTesnorIsValid = {true};
491  static_for<0, NumDTensor, 1>{}([&](auto index) {
492  using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
493  if(std::is_same_v<DiLayout, ELayout> == false)
494  {
495  DTesnorIsValid = false;
496  }
497  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
498  {
499  if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
500  {
501  CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of "
502  "NPerBlock without padding!");
503  DTesnorIsValid = false;
504  }
505  if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0)
506  {
507  CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!");
508  DTesnorIsValid = false;
509  }
510  }
511  else
512  {
513  if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
514  {
515  CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of "
516  "MPerBlock without padding!");
517 
518  DTesnorIsValid = false;
519  }
520  if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0)
521  {
522  CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!");
523  DTesnorIsValid = false;
524  }
525  }
526  });
527 
528  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
529  {
530  if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
531  {
532  std::cerr << "Can't support N that is not a multiple of NPerBlock"
533  " without padding!"
534  << std::endl;
535  return false;
536  }
537  if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
538  {
539  std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl;
540  return false;
541  }
542  }
543  else
544  {
545  if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
546  {
547  std::cerr << "Can't support M that is not a multiple of MPerBlock"
548  " without padding!"
549  << std::endl;
550  return false;
551  }
552  if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
553  {
554  std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl;
555  return false;
556  }
557  }
558  return DTesnorIsValid;
559  }
560 
561  template <typename KernelArgs>
562  CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr,
563  const KernelArgs& kargs,
564  const index_t k_size,
565  const index_t block_idx_m)
566  {
567  // Step 1: Create tensor view
568  const auto& a_tensor_view = [&]() {
569  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
570  {
571  return make_naive_tensor_view<address_space_enum::global>(
572  a_ptr,
573  make_tuple(kargs.M, k_size),
574  make_tuple(kargs.stride_A, 1),
575  number<FlatmmPipeline::GetVectorSizeA()>{},
576  number<1>{});
577  }
578  else
579  {
580  return make_naive_tensor_view<address_space_enum::global>(
581  a_ptr,
582  make_tuple(k_size, kargs.M),
583  make_tuple(kargs.stride_A, 1),
584  number<FlatmmPipeline::GetVectorSizeA()>{},
585  number<1>{});
586  }
587  }();
588 
589  // Step 2: Create padded view
590  const auto& a_pad_view = [&]() {
591  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
592  {
593  return pad_tensor_view(a_tensor_view,
597  }
598  else
599  {
600  return pad_tensor_view(a_tensor_view,
604  }
605  }();
606 
607  // Step 3: Create tile window
608  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
609  {
610  return make_tile_window(a_pad_view,
613  {block_idx_m, 0});
614  }
615  else
616  {
617  return make_tile_window(a_pad_view,
620  {0, block_idx_m});
621  }
622  }
623 
624  template <typename KernelArgs>
625  CK_TILE_DEVICE static auto MakeBFlatBlockWindow(const BDataType* b_flat_ptr,
626  const KernelArgs& kargs,
627  const index_t block_idx_n)
628  {
629  // Step 1: Create tensor view
630  index_t kFlatK =
631  FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(I2));
632  index_t kFlatN = kargs.N * kargs.K / kFlatK;
633 
634  const auto& b_flat_tensor_view = make_naive_tensor_view<address_space_enum::global>(
635  b_flat_ptr,
636  make_tuple(kFlatN, kFlatK),
637  make_tuple(kFlatK, 1),
638  number<FlatmmPipeline::GetVectorSizeB()>{},
639  number<1>{});
640 
641  // Step 2: No padding needed for b_flat
642  // Step 3: Create tile window
643  return make_tile_window(
644  b_flat_tensor_view,
647  {static_cast<int>(block_idx_n / BlockGemmShape::WarpTile::at(I1)), 0});
648  }
649 
650  template <typename KernelArgs>
651  CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
652  const KernelArgs& kargs,
653  const index_t block_idx_m,
654  const index_t block_idx_n)
655  {
656  // Step 1: Create tensor views
657  const auto& ds_tensor_view = generate_tuple(
658  [&](auto i) {
659  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
660  using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
661  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
662  {
663  return make_naive_tensor_view<address_space_enum::global>(
664  static_cast<const DDataType_*>(ds_ptr[i]),
665  make_tuple(kargs.M, kargs.N),
666  make_tuple(kargs.stride_Ds[i], 1),
667  number<EpiloguePipeline::GetVectorSizeD(i)>{},
668  number<1>{});
669  }
670  else
671  {
672  return make_naive_tensor_view<address_space_enum::global>(
673  static_cast<const DDataType_*>(ds_ptr[i]),
674  make_tuple(kargs.N, kargs.M),
675  make_tuple(kargs.stride_Ds[i], 1),
676  number<EpiloguePipeline::GetVectorSizeD(i)>{},
677  number<1>{});
678  }
679  },
681 
682  // Step 2: Create padded views
683  const auto& ds_pad_view = generate_tuple(
684  [&](auto i) {
685  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
686  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
687  {
688  return pad_tensor_view(ds_tensor_view[i],
692  }
693  else
694  {
695  return pad_tensor_view(ds_tensor_view[i],
699  }
700  },
702 
703  // Step 3: Create tile windows
704  return generate_tuple(
705  [&](auto i) {
706  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
707  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
708  {
709  return make_tile_window(ds_pad_view[i],
712  {block_idx_m, block_idx_n});
713  }
714  else
715  {
716  return make_tile_window(ds_pad_view[i],
719  {block_idx_n, block_idx_m});
720  }
721  },
723  }
724 
725  template <memory_operation_enum DstInMemOp = memory_operation_enum::set, typename KernelArgs>
727  const KernelArgs& kargs,
728  const index_t block_idx_m,
729  const index_t block_idx_n)
730  {
731  // Step 1: Create tensor view
732  const auto& e_tensor_view = [&]() {
733  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
734  {
735  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
736  e_ptr,
737  make_tuple(kargs.M, kargs.N),
738  make_tuple(kargs.stride_E, 1),
739  number<EpiloguePipeline::GetVectorSizeC()>{},
740  number<1>{});
741  }
742  else
743  {
744  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
745  e_ptr,
746  make_tuple(kargs.N, kargs.M),
747  make_tuple(kargs.stride_E, 1),
748  number<1>{},
749  number<1>{});
750  }
751  }();
752 
753  // Step 2: Create padded view
754  const auto& e_pad_view = [&]() {
755  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
756  {
757  return pad_tensor_view(e_tensor_view,
761  }
762  else
763  {
764  return pad_tensor_view(e_tensor_view,
768  }
769  }();
770 
771  // Step 3: Create tile window
772  return make_tile_window(
773  e_pad_view,
775  {block_idx_m, block_idx_n});
776  }
777 
778  template <typename KernelArgs>
779  CK_TILE_DEVICE static auto MakeScaleMWindow(const KernelArgs& kargs,
780  const SplitKBatchOffset& splitk_batch_offset,
781  const index_t block_idx_m)
782  {
783  constexpr int ScaleGranularityM = decltype(kargs.scale_m_ptr)::GranularityMN;
784  constexpr int ScaleGranularityKA = decltype(kargs.scale_m_ptr)::GranularityK;
785 
786  auto scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale
787  : 1; // per-token scale
788 
789  // Step 1: Create tensor view
790  const auto scale_m_view = make_naive_tensor_view<address_space_enum::global>(
791  kargs.scale_m_ptr.ptr,
792  make_tuple(kargs.M / ScaleGranularityM,
793  ScaleGranularityKA == 0
794  ? 1
795  : (splitk_batch_offset.splitted_k / ScaleGranularityKA)),
796  make_tuple(scale_stride_m, 0),
797  number < ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1 > {},
798  number<1>{});
799 
800  // Step 2: Create tile window
801  return make_tile_window(scale_m_view,
803  number < ScaleGranularityKA == 0
804  ? TilePartitioner::NPerBlock
805  : TilePartitioner::KPerBlock > {}),
806  {block_idx_m, 0});
807  }
808 
809  template <typename KernelArgs>
810  CK_TILE_DEVICE static auto MakeScaleNWindow(const KernelArgs& kargs,
811  const SplitKBatchOffset& splitk_batch_offset,
812  const index_t block_idx_n)
813  {
814  constexpr int ScaleGranularityN = decltype(kargs.scale_n_ptr)::GranularityMN;
815  constexpr int ScaleGranularityKB = decltype(kargs.scale_n_ptr)::GranularityK;
816 
817  auto scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale
818  : 1; // per-channel scale
819 
820  // Step 1: Create tensor view
821  const auto scale_n_view = make_naive_tensor_view<address_space_enum::global>(
822  kargs.scale_n_ptr.ptr,
823  make_tuple(
824  ScaleGranularityKB == 0 ? 1 : (splitk_batch_offset.splitted_k / ScaleGranularityKB),
825  kargs.N / ScaleGranularityN),
826  make_tuple(0, scale_stride_n),
827  number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
828  number<1>{});
829 
830  // Step 2: Create tile window
831  return make_tile_window(scale_n_view,
832  make_tuple(number < ScaleGranularityKB == 0
833  ? TilePartitioner::MPerBlock
834  : TilePartitioner::KPerBlock > {},
836  {0, block_idx_n});
837  }
838 
839  template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
840  CK_TILE_DEVICE static void
841  RunFlatmm(const ADataType* a_ptr,
842  const BDataType* b_flat_ptr,
843  const std::array<const void*, NumDTensor>& ds_ptr,
844  EDataType* e_ptr,
845  void* smem_ptr_ping,
846  void* smem_ptr_pong,
847  const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs,
848  const SplitKBatchOffset& splitk_batch_offset,
849  const index_t block_idx_m,
850  const index_t block_idx_n)
851  {
852  // Create block windows using specialized methods
853  const auto& a_block_window =
854  MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
855  const auto& b_flat_block_window = MakeBFlatBlockWindow(b_flat_ptr, kargs, block_idx_n);
856  const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
857  const auto& scale_m_window = MakeScaleMWindow(kargs, splitk_batch_offset, block_idx_m);
858  const auto& scale_n_window = MakeScaleNWindow(kargs, splitk_batch_offset, block_idx_n);
859 
860  const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
861 
862  // Run GEMM cooperatively by whole workgroup.
863  const auto& c_block_tile = FlatmmPipeline{}.template operator()(
864  a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong);
865 
866  // Run Epilogue Pipeline with k_batch dispatching
867  if constexpr(ScaleM::GranularityMN != -1 || ScaleN::GranularityMN != -1)
868  {
869  if(kargs.k_batch == 1)
870  {
871  auto e_block_window = MakeEBlockWindow<memory_operation_enum::set>(
872  e_ptr, kargs, block_idx_m, block_idx_n);
874  .template operator()<decltype(e_block_window),
875  decltype(c_block_tile),
876  decltype(ds_block_window)>(e_block_window,
877  c_block_tile,
878  ds_block_window,
879  smem_ptr_ping,
880  scale_m_window,
881  scale_n_window);
882  }
883  else
884  {
885  auto e_block_window = MakeEBlockWindow<memory_operation_enum::atomic_add>(
886  e_ptr, kargs, block_idx_m, block_idx_n);
888  .template operator()<decltype(e_block_window),
889  decltype(c_block_tile),
890  decltype(ds_block_window)>(e_block_window,
891  c_block_tile,
892  ds_block_window,
893  smem_ptr_ping,
894  scale_m_window,
895  scale_n_window);
896  }
897  }
898  else if(UseDefaultScheduler || (get_warp_id() == 0))
899  {
900  if(kargs.k_batch == 1)
901  {
902  auto e_block_window = MakeEBlockWindow<memory_operation_enum::set>(
903  e_ptr, kargs, block_idx_m, block_idx_n);
905  .template operator()<decltype(e_block_window),
906  decltype(c_block_tile),
907  decltype(ds_block_window)>(
908  e_block_window, c_block_tile, ds_block_window, smem_ptr_ping);
909  }
910  else
911  {
912  auto e_block_window = MakeEBlockWindow<memory_operation_enum::atomic_add>(
913  e_ptr, kargs, block_idx_m, block_idx_n);
915  .template operator()<decltype(e_block_window),
916  decltype(c_block_tile),
917  decltype(ds_block_window)>(
918  e_block_window, c_block_tile, ds_block_window, smem_ptr_ping);
919  }
920  }
921  }
922 
923  template <class ScaleM, class ScaleN>
924  CK_TILE_DEVICE void operator()(FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()> kargs,
925  int partition_idx = blockIdx.x) const
926  {
927  int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
928 
929  do
930  {
931  const auto [iM, iN] =
932  TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
933  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
934  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
935 
936  const SplitKBatchOffset splitk_batch_offset(kargs);
937  // options
938  const ADataType* a_ptr =
939  static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
940  const BDataType* b_flat_ptr =
941  static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
942  EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
943 
944  // allocate LDS
945  __shared__ char smem_ptr_ping[GetSmemPingSize()];
946  __shared__ char smem_ptr_pong[GetSmemPongSize()];
947 
948  if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
950  {
951  constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
952  RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
953  b_flat_ptr,
954  kargs.ds_ptr,
955  e_ptr,
956  smem_ptr_ping,
957  smem_ptr_pong,
958  kargs,
959  splitk_batch_offset,
960  i_m,
961  i_n);
962  }
963  partition_idx += gridDim.x;
964  } while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
965  }
966 };
967 
968 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
Definition: cluster_descriptor.hpp:13
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:36
__global__ void kentry(Args... args)
Definition: kernel_launch.hpp:22
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:206
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:157
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
unsigned int uint32_t
Definition: stdint.h:126
Definition: flatmm_kernel.hpp:135
index_t N
Definition: flatmm_kernel.hpp:173
const void * a_ptr
Definition: flatmm_kernel.hpp:164
index_t stride_B
Definition: flatmm_kernel.hpp:176
const std::array< index_t, NumDTensor > stride_Ds
Definition: flatmm_kernel.hpp:177
index_t stride_C
Definition: flatmm_kernel.hpp:181
CK_TILE_HOST BaseFlatmmHostArgs()=default
index_t K
Definition: flatmm_kernel.hpp:174
const void * b_ptr
Definition: flatmm_kernel.hpp:165
index_t k_batch
Definition: flatmm_kernel.hpp:184
index_t stride_E
Definition: flatmm_kernel.hpp:180
CK_TILE_HOST BaseFlatmmHostArgs(const void *a_ptr_, const void *b_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_E_)
Definition: flatmm_kernel.hpp:137
index_t stride_A
Definition: flatmm_kernel.hpp:175
const std::array< const void *, NumDTensor > ds_ptr
Definition: flatmm_kernel.hpp:166
void * c_ptr
Definition: flatmm_kernel.hpp:170
void * e_ptr
Definition: flatmm_kernel.hpp:169
index_t M
Definition: flatmm_kernel.hpp:172
Definition: flatmm_kernel.hpp:365
index_t b_k_split_offset
Definition: flatmm_kernel.hpp:403
index_t a_k_split_offset
Definition: flatmm_kernel.hpp:402
__device__ SplitKBatchOffset(const KernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: flatmm_kernel.hpp:367
index_t splitted_k
Definition: flatmm_kernel.hpp:404
Definition: flatmm_kernel.hpp:232
ScaleN scale_n_ptr
Definition: flatmm_kernel.hpp:247
void * e_ptr
Definition: flatmm_kernel.hpp:237
std::array< index_t, NumDTensor > stride_Ds
Definition: flatmm_kernel.hpp:243
index_t K
Definition: flatmm_kernel.hpp:240
ScaleM scale_m_ptr
Definition: flatmm_kernel.hpp:246
const void * b_ptr
Definition: flatmm_kernel.hpp:235
index_t k_batch
Definition: flatmm_kernel.hpp:245
index_t N
Definition: flatmm_kernel.hpp:239
index_t stride_B
Definition: flatmm_kernel.hpp:242
const std::array< const void *, NumDTensor > ds_ptr
Definition: flatmm_kernel.hpp:236
const void * a_ptr
Definition: flatmm_kernel.hpp:233
index_t stride_E
Definition: flatmm_kernel.hpp:244
index_t M
Definition: flatmm_kernel.hpp:238
index_t stride_A
Definition: flatmm_kernel.hpp:241
Definition: flatmm_kernel.hpp:252
static CK_TILE_DEVICE auto MakeEBlockWindow(EDataType *e_ptr, const KernelArgs &kargs, const index_t block_idx_m, const index_t block_idx_n)
Definition: flatmm_kernel.hpp:726
static constexpr CK_TILE_HOST auto BlockSize()
Definition: flatmm_kernel.hpp:333
remove_cvref_t< typename FlatmmPipeline::BlockGemmShape > BlockGemmShape
Definition: flatmm_kernel.hpp:256
static constexpr auto I0
Definition: flatmm_kernel.hpp:273
static CK_TILE_DEVICE auto MakeABlockWindow(const ADataType *a_ptr, const KernelArgs &kargs, const index_t k_size, const index_t block_idx_m)
Definition: flatmm_kernel.hpp:562
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: flatmm_kernel.hpp:253
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition: flatmm_kernel.hpp:261
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: flatmm_kernel.hpp:262
static constexpr bool UsePersistentKernel
Definition: flatmm_kernel.hpp:264
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: flatmm_kernel.hpp:269
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: flatmm_kernel.hpp:257
static constexpr auto I2
Definition: flatmm_kernel.hpp:275
static constexpr CK_TILE_HOST auto GridSize(const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs)
Definition: flatmm_kernel.hpp:297
CK_TILE_DEVICE void operator()(FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> kargs, int partition_idx=blockIdx.x) const
Definition: flatmm_kernel.hpp:924
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPingSize()
Definition: flatmm_kernel.hpp:355
remove_cvref_t< FlatmmPipeline_ > FlatmmPipeline
Definition: flatmm_kernel.hpp:254
remove_cvref_t< typename FlatmmPipeline::CLayout > ELayout
Definition: flatmm_kernel.hpp:260
remove_cvref_t< typename FlatmmPipeline::ADataType > ADataType
Definition: flatmm_kernel.hpp:266
remove_cvref_t< typename FlatmmPipeline::BLayout > BLayout
Definition: flatmm_kernel.hpp:259
static constexpr index_t NumDTensor
Definition: flatmm_kernel.hpp:271
static CK_TILE_HOST const std::string GetName()
Definition: flatmm_kernel.hpp:282
static constexpr CK_TILE_HOST FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> MakeKernelArgs(const ScaleFlatmmHostArgs< ScaleM, ScaleN, DsDataType::size()> &hostArgs)
Definition: flatmm_kernel.hpp:337
static constexpr index_t kBlockSize
Definition: flatmm_kernel.hpp:263
static CK_TILE_DEVICE auto MakeBFlatBlockWindow(const BDataType *b_flat_ptr, const KernelArgs &kargs, const index_t block_idx_n)
Definition: flatmm_kernel.hpp:625
static CK_TILE_DEVICE auto MakeScaleMWindow(const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m)
Definition: flatmm_kernel.hpp:779
static CK_TILE_DEVICE auto MakeScaleNWindow(const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_n)
Definition: flatmm_kernel.hpp:810
remove_cvref_t< typename FlatmmPipeline::ALayout > ALayout
Definition: flatmm_kernel.hpp:258
static constexpr auto I3
Definition: flatmm_kernel.hpp:276
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: flatmm_kernel.hpp:289
static constexpr auto I1
Definition: flatmm_kernel.hpp:274
static CK_TILE_DEVICE void RunFlatmm(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_ping, void *smem_ptr_pong, const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Definition: flatmm_kernel.hpp:841
remove_cvref_t< typename FlatmmPipeline::BDataType > BDataType
Definition: flatmm_kernel.hpp:267
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: flatmm_kernel.hpp:408
static CK_TILE_DEVICE auto MakeDBlockWindows(const std::array< const void *, NumDTensor > &ds_ptr, const KernelArgs &kargs, const index_t block_idx_m, const index_t block_idx_n)
Definition: flatmm_kernel.hpp:651
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPongSize()
Definition: flatmm_kernel.hpp:359
Definition: flatmm_kernel.hpp:15
index_t stride_C
Definition: flatmm_kernel.hpp:28
CK_TILE_HOST FlatmmProblem()=default
index_t M
Definition: flatmm_kernel.hpp:23
index_t stride_B
Definition: flatmm_kernel.hpp:27
CK_TILE_HOST FlatmmProblem(index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
Definition: flatmm_kernel.hpp:17
index_t stride_A
Definition: flatmm_kernel.hpp:26
index_t N
Definition: flatmm_kernel.hpp:24
index_t K
Definition: flatmm_kernel.hpp:25
constexpr CK_TILE_HOST_DEVICE FlatmmScalePointer()=default
constexpr CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType *)
Definition: flatmm_kernel.hpp:120
constexpr CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType *, index_t)
Definition: flatmm_kernel.hpp:121
constexpr CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t) const
Definition: flatmm_kernel.hpp:123
constexpr CK_TILE_HOST_DEVICE ScaleType operator[](index_t) const
Definition: flatmm_kernel.hpp:127
ScaleType_ ScaleType
Definition: flatmm_kernel.hpp:113
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType *ptr_, index_t length_)
Definition: flatmm_kernel.hpp:78
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType *ptr_)
Definition: flatmm_kernel.hpp:77
CK_TILE_HOST_DEVICE ScaleType operator[](index_t i) const
Definition: flatmm_kernel.hpp:99
ScaleType_ ScaleType
Definition: flatmm_kernel.hpp:67
const ScaleType * ptr
Definition: flatmm_kernel.hpp:71
CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const
Definition: flatmm_kernel.hpp:83
Definition: flatmm_kernel.hpp:33
CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const
Definition: flatmm_kernel.hpp:47
CK_TILE_HOST_DEVICE ScaleType operator[](index_t i) const =delete
ScaleType_ ScaleType
Definition: flatmm_kernel.hpp:34
const ScaleType * ptr
Definition: flatmm_kernel.hpp:38
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType *ptr_)
Definition: flatmm_kernel.hpp:41
static constexpr int GranularityK
Definition: flatmm_kernel.hpp:36
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType *ptr_, [[maybe_unused]] index_t length_)
Definition: flatmm_kernel.hpp:42
static constexpr int GranularityMN
Definition: flatmm_kernel.hpp:35
CK_TILE_HOST_DEVICE FlatmmScalePointer()=default
Definition: flatmm_kernel.hpp:190
CK_TILE_HOST ScaleFlatmmHostArgs()=default
ScaleM scale_m
Definition: flatmm_kernel.hpp:222
ScaleN scale_n
Definition: flatmm_kernel.hpp:223
CK_TILE_HOST ScaleFlatmmHostArgs(const void *a_ptr_, const void *b_shuffle_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *c_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_C_, ScaleM scale_m_=nullptr, ScaleN scale_n_=nullptr)
Definition: flatmm_kernel.hpp:192
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:115
Definition: coordinate_transform.hpp:1392
Definition: sequence.hpp:49
Definition: functional.hpp:43