/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Source File
gridwise_gemm_xdl_cshuffle_v3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/env.hpp"
17 
18 namespace ck {
19 
20 // Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
21 // kernel function Blockers:
22 // 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
23 // two lds chunks.
24 // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
25 // buffer when we declare __shared__ inside blkgemmpipe
26 template <typename GridwiseGemm,
27  bool HasMainKBlockLoop,
28  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
29  index_t MinimumOccupancy = 1,
30  TailNumber TailNum = TailNumber::Full>
31 __global__ void
32 #if CK_USE_LAUNCH_BOUNDS
33 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
34 #endif
35  // __attribute__((amdgpu_waves_per_eu(1, 1)))
36  kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
37 {
38 #if defined(__gfx9__) || defined(__gfx12__) || defined(__gfx11__)
39  if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
40  {
41  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
42 
43  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
44 
45  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
46  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
47  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
48  karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
49  p_shared,
50  karg);
51  }
52 #else
53  ignore = karg;
54 #endif // end of if (defined(__gfx9__))
55 }
56 
57 template <typename GridwiseGemm,
58  bool HasMainKBlockLoop,
59  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
60  index_t MinimumOccupancy = 1,
61  TailNumber TailNum = TailNumber::Full>
62 __global__ void
63 #if CK_USE_LAUNCH_BOUNDS
64 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
65 #endif
66  // __attribute__((amdgpu_waves_per_eu(1, 1)))
67  kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
68 {
69 #if defined(__gfx9__) || defined(__gfx12__) || defined(__gfx11__)
70  // Pass two lds pointer is the key to tell compiler that ds_read/write
71  // operate on different lds chunk at same time without order dependecy
72  if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
73  {
74  __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
75  __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
76 
77  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
78 
79  GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
80  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
81  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
82  karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
83  p_shared_0,
84  p_shared_1,
85  karg);
86  }
87 #else
88  ignore = karg;
89 #endif // end of if (defined(__gfx9__))
90 }
91 
197 template <typename ALayout,
198  typename BLayout,
199  typename CLayout,
200  typename ADataType,
201  typename BDataType,
202  typename AccDataType,
203  typename CShuffleDataType,
204  typename CDataType,
205  typename AElementwiseOperation,
206  typename BElementwiseOperation,
207  typename CElementwiseOperation,
209  index_t BlockSize,
210  index_t MPerBlock,
211  index_t NPerBlock,
212  index_t KPerBlock,
213  index_t AK1Value,
214  index_t BK1Value,
215  index_t MPerXdl,
216  index_t NPerXdl,
217  index_t MXdlPerWave,
218  index_t NXdlPerWave,
219  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
220  typename ABlockTransferThreadClusterArrangeOrder,
221  typename ABlockTransferSrcAccessOrder,
222  index_t ABlockTransferSrcVectorDim,
223  index_t ABlockTransferSrcScalarPerVector,
224  index_t ABlockTransferDstScalarPerVector_AK1,
225  bool AThreadTransferSrcResetCoordinateAfterRun,
226  index_t ABlockLdsExtraM,
227  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
228  typename BBlockTransferThreadClusterArrangeOrder,
229  typename BBlockTransferSrcAccessOrder,
230  index_t BBlockTransferSrcVectorDim,
231  index_t BBlockTransferSrcScalarPerVector,
232  index_t BBlockTransferDstScalarPerVector_BK1,
233  bool BThreadTransferSrcResetCoordinateAfterRun,
234  index_t BBlockLdsExtraN,
235  index_t CShuffleMXdlPerWavePerShuffle,
236  index_t CShuffleNXdlPerWavePerShuffle,
237  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
238  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
241  typename ComputeTypeA = CDataType,
242  typename ComputeTypeB = ComputeTypeA,
243  bool PermuteA = false,
244  bool PermuteB = false,
245  bool DoElementwiseBeforeCShuffle = false>
247 {
248  static constexpr auto I0 = Number<0>{};
249  static constexpr auto I1 = Number<1>{};
250  static constexpr auto I2 = Number<2>{};
251  static constexpr auto I3 = Number<3>{};
252  static constexpr auto I4 = Number<4>{};
253  static constexpr auto I5 = Number<5>{};
254  static constexpr auto I6 = Number<6>{};
255  static constexpr auto I7 = Number<7>{};
256 
257  // K1 should be Number<...>
258  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
259  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
260  static constexpr auto AK1Number = Number<AK1Value>{};
261  static constexpr auto BK1Number = Number<BK1Value>{};
262 
263  static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
264  static constexpr bool is_single_rate_mfma =
266  lcm_AK1_BK1 <= 4) ||
268  // gfx950 double rate mfma16x16 require at least 128 KPerBlock to consume
270  KPerBlock < 128 && MPerXdl == 16))
271  ? true
272  : false;
273  static constexpr auto is_scale_mfma = false;
274  static constexpr index_t KPack =
276  MfmaSelector<ComputeTypeA,
277  MPerXdl,
278  NPerXdl,
279  ComputeTypeA,
281  is_scale_mfma>::selected_mfma.k_per_blk);
282 
284 
285  static constexpr index_t APackedSize = []() {
287  return 2;
288  else
289  return 1;
290  }();
291 
292  static constexpr index_t BPackedSize = []() {
294  return 2;
295  else
296  return 1;
297  }();
298 
299  __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
300  {
301  return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
302  }
303 
304  __host__ static auto CalculateMPadded(index_t M)
305  {
306  return math::integer_least_multiple(M, MPerBlock);
307  }
308 
309  __host__ static auto CalculateNPadded(index_t N)
310  {
311  return math::integer_least_multiple(N, NPerBlock);
312  }
313 
314  __host__ static auto CalculateKPadded(index_t K)
315  {
316  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
317  }
318 
319  __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
320  {
321  auto K_t = K_Batch * KPerBlock;
322  return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
323  }
324 
325  __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
326  {
327  auto K_t = K_Batch * KPerBlock;
328  return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
329  }
330 
331  __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
332  {
333  auto K_t = K_Batch * KPerBlock;
334  return (K + K_t - 1) / K_t * KPerBlock;
335  }
336 
337  __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
338  {
339  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
340  auto K_t = K_Batch * KReadVec;
341  return (K + K_t - 1) / K_t * KReadVec;
342  }
343 
344  __host__ static auto CalculateMBlock(index_t M)
345  {
346  return math::integer_divide_ceil(M, MPerBlock);
347  }
348 
349  __host__ static auto CalculateNBlock(index_t N)
350  {
351  return math::integer_divide_ceil(N, NPerBlock);
352  }
353 
354  template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
355  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
356  {
357  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
358  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
359 
361  TileDesc_K0_MN_K1{},
367  }
368 
369  __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
370  index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
371  {
372  const auto a_grid_desc_mraw_kraw = [&]() {
373  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
374  {
375  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
376  }
377  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
378  {
379  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
380  }
381  }();
382 
384 
385  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
386  GemmSpec == GemmSpecialization::MNKPadding)
387  {
388  // pad both M and K
389  const auto a_grid_desc_m_k =
390  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
392  make_right_pad_transform(K, KPad - K)),
395 
396  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
397  a_grid_desc_m_k,
402 
403  return a_grid_desc_ak0_m_ak1;
404  }
405  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
406  GemmSpec == GemmSpecialization::MNPadding)
407  {
408  // pad M, but not K
409  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
410  a_grid_desc_mraw_kraw,
412  make_right_pad_transform(M, MPad - M)),
415 
416  return a_grid_desc_ak0_m_ak1;
417  }
418  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
419  GemmSpec == GemmSpecialization::NKPadding)
420  {
421  // pad K, but not M
422  const auto a_grid_desc_m_k = transform_tensor_descriptor(
423  a_grid_desc_mraw_kraw,
427 
428  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
429  a_grid_desc_m_k,
434 
435  return a_grid_desc_ak0_m_ak1;
436  }
437  else
438  {
439  // not pad M or K
440  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
441  a_grid_desc_mraw_kraw,
446 
447  return a_grid_desc_ak0_m_ak1;
448  }
449  }
450 
451  __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
452  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
453  {
454  const auto b_grid_desc_nraw_kraw = [&]() {
456  {
457  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
458  }
460  {
461  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
462  }
463  }();
464 
466 
467  static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
468  GemmSpec != GemmSpecialization::Default),
469  "pk_i4_t does not support padding");
470 
471  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
472  GemmSpec == GemmSpecialization::MNKPadding)
473  {
474  // pad both N and K
475  const auto b_grid_desc_n_k =
476  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
478  make_right_pad_transform(K, KPad - K)),
481 
482  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
483  b_grid_desc_n_k,
488 
489  return b_grid_desc_bk0_n_bk1;
490  }
491  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
492  GemmSpec == GemmSpecialization::MNPadding)
493  {
494  // pad N, but not K
495  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
496  b_grid_desc_nraw_kraw,
498  make_right_pad_transform(N, NPad - N)),
501 
502  return b_grid_desc_bk0_n_bk1;
503  }
504  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
505  GemmSpec == GemmSpecialization::MKPadding)
506  {
507  // pad K, but not N
508  const auto b_grid_desc_n_k = transform_tensor_descriptor(
509  b_grid_desc_nraw_kraw,
513 
514  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
515  b_grid_desc_n_k,
520 
521  return b_grid_desc_bk0_n_bk1;
522  }
523  else
524  {
525  if constexpr(!PermuteB)
526  {
527  // not pad N or K
528  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
529  b_grid_desc_nraw_kraw,
534 
535  return b_grid_desc_bk0_n_bk1;
536  }
537  else
538  {
539  // Pre-shuffled Weight
540  // BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1]
541  constexpr index_t BK01 = KPerBlock / BK1Value;
542  const index_t BK0_ = StrideB / BK1Value;
543  const index_t BK00 = BK0_ / BK01;
544 
545  const auto b_grid_desc_bk00_n_bk01_bk1_permute =
546  make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value));
547 
548  const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor(
549  b_grid_desc_bk00_n_bk01_bk1_permute,
552  make_pass_through_transform(BK1Value)),
555 
556  return b_grid_desc_bk0_n_bk1_permute;
557  }
558  }
559  }
560 
561  template <typename ABlockDesc_AK0_M_AK1>
562  __host__ __device__ static constexpr auto
563  MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
564  {
565  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
566 
567  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
568  }
569 
570  template <typename BBlockDesc_BK0_N_BK1>
571  __host__ __device__ static constexpr auto
572  MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
573  {
574  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
575 
576  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
577  }
578 
579  __host__ __device__ static auto
581  {
582  const auto c_grid_desc_mraw_nraw = [&]() {
584  {
585  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
586  }
588  {
589  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
590  }
591  }();
592 
593  // pad M and N
594  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
596  make_right_pad_transform(N, NPad - N)),
599 #if 0
601 
602  if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
603  GemmSpec == GemmSpecialization::MNKPadding)
604  {
605  // pad M and N
606  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
608  make_right_pad_transform(N, NPad - N)),
611  }
612  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
613  GemmSpec == GemmSpecialization::MKPadding)
614  {
615  // pad M, but not N
617  c_grid_desc_mraw_nraw,
621  }
622  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
623  GemmSpec == GemmSpecialization::NKPadding)
624  {
625  // pad N, but not M
627  c_grid_desc_mraw_nraw,
631  }
632  else
633  {
634  // not pad M or N
635  return c_grid_desc_mraw_nraw;
636  }
637 #endif
638  }
639 
640  struct Problem
641  {
642  __host__ Problem(index_t M_,
643  index_t N_,
644  index_t K_,
645  index_t StrideA_,
646  index_t StrideB_,
647  index_t StrideC_,
648  index_t KBatch_,
649  AElementwiseOperation a_element_op,
650  BElementwiseOperation b_element_op,
651  CElementwiseOperation c_element_op)
652  : M{M_},
653  N{N_},
654  K{K_},
655  StrideA{StrideA_},
656  StrideB{StrideB_},
657  StrideC{StrideC_},
658  KBatch{KBatch_},
661  KRead{CalculateKRead(K_, KBatch_)},
662  KPadded{CalculateKPadded(K_, KBatch_)},
663  AK0{CalculateAK0Padded(K_, KBatch_)},
664  BK0{CalculateBK0Padded(K_, KBatch_)},
665  MBlock{CalculateMBlock(M_)},
666  NBlock{CalculateNBlock(N_)},
667  a_element_op_{a_element_op},
668  b_element_op_{b_element_op},
669  c_element_op_{c_element_op}
670  {
671  }
672 
673  __host__ void Print() const
674  {
675  // clang-format off
676  std::cout << "problem {"
677  << "M:" << M << ", "
678  << "N:" << N << ", "
679  << "K:" << K << ", "
680  << "SA:" << StrideA << ", "
681  << "SB:" << StrideB << ", "
682  << "SC:" << StrideC << ", "
683  << "MP:" << MPadded << ", "
684  << "NP:" << NPadded << ", "
685  << "KRead:" << KRead << ", "
686  << "KP:" << KPadded << ", "
687  << "AK0:" << AK0 << ", "
688  << "BK0:" << BK0 << ", "
689  << "MBlock: " << MBlock << ", "
690  << "NBlock: " << NBlock << "}" << std::endl;
691  // clang-format off
692  }
693 
709  AElementwiseOperation a_element_op_;
710  BElementwiseOperation b_element_op_;
711  CElementwiseOperation c_element_op_;
712  };
713 
714  // Argument
716  {
717  __host__ Argument(const ADataType* p_a_grid_,
718  const BDataType* p_b_grid_,
719  CDataType* p_c_grid_,
720  index_t M_,
721  index_t N_,
722  index_t K_,
723  index_t StrideA_,
724  index_t StrideB_,
725  index_t StrideC_,
726  index_t k_batch_,
727  bool is_reduce_ = false,
728  AElementwiseOperation a_element_op = AElementwiseOperation{},
729  BElementwiseOperation b_element_op = BElementwiseOperation{},
730  CElementwiseOperation c_element_op = CElementwiseOperation{})
731  : Problem{M_,
732  N_,
733  K_,
734  StrideA_,
735  StrideB_,
736  StrideC_,
737  k_batch_,
738  a_element_op,
739  b_element_op,
740  c_element_op},
741  p_a_grid{p_a_grid_},
742  p_b_grid{p_b_grid_},
743  p_c_grid{p_c_grid_},
744  is_reduce(is_reduce_)
745  {
746  }
747 
748  __host__ __device__ inline bool IsReduceAdd() const
749  {
750  return (Problem::KBatch > 1) && is_reduce;
751  }
752 
753  __host__ __device__ inline bool IsAtomicAdd() const
754  {
755  return (Problem::KBatch > 1) && (!is_reduce);
756  }
757 
758  const ADataType* p_a_grid;
759  const BDataType* p_b_grid;
760  CDataType* p_c_grid;
761  bool is_reduce;
762  };
763 
765  {
766 
767  __device__ SplitKBatchOffset(Argument& karg)
768  {
769  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
770  {
771  a_k_split_offset = blockIdx.z * karg.KRead / APackedSize;
772  }
773  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
774  {
775  a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
776  }
777 
778  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
779  {
780  b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
781  }
782  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
783  {
784  if constexpr(!PermuteB)
785  {
786  b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
787  }
788  else
789  {
790  const int k0_offset = karg.KRead * karg.N;
791  b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
792  }
793  }
794 
795  if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
796  {
797  karg.K = karg.KRead;
798  }
799  else
800  {
801  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
802  }
803 
804  if(karg.IsReduceAdd())
805  {
806  c_reduce_offset = blockIdx.z * karg.M * karg.N;
807  }
808  else
809  {
810  c_reduce_offset = 0;
811  }
812  }
813 
817  };
818 
819  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
820  {
821  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
822  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
823  constexpr index_t WaveSize = BlockSize / (MWave * NWave);
824  // A matrix in LDS memory, dst of blockwise copy
825  if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
826  {
827  // bank conflict when writting the data into LDS, but don't worry, we have whole entire
828  // loop to hide it in v4. it may give you some benefit from less valu in compute address
832  }
833  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
834  // in some cases.
836  {
837  constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize;
838  constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize;
839  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
840  make_tuple(
841  AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number),
843 
844  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
845  a_lds_block_desc,
851 
852  constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor(
853  a_lds_block_desc_permuted,
859 
860  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
861  a_lds_block_desc_ak0_mldslayer_m_ak1,
868 
869  return a_lds_block_desc_ak0_m_ak1;
870  }
871  else // ColumnMajor A
872  {
873  // kfold and mpair dimension is not always required.
874  // more dimension in merge_transform increase the difficulty of generating immarg offset
875  // for compiler.
876  constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
877  constexpr auto M1 = MPerBlock / M0;
878 
879  constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
880  constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
881  constexpr auto KThreadRead = WaveSize / MPerXdl;
882  constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
883 
884  constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
885  ? 1
886  : 128 / (AK1Number * M0 * sizeof(ADataType));
887  constexpr auto KThreadReadPerm =
888  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
889  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
890  : KThreadRead;
891 
892  // 1<=mpair<=n0
893  constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
894  ? 1
895  : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
896  ? M0
897  : 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
898 
899  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
903  Number<kfold * M0 / mpair>{},
904  Number<mpair>{},
905  AK1Number));
906 
907  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
908  a_lds_block_desc,
909  make_tuple(
913  make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
916  make_tuple(
918  make_tuple(
920 
921  constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
922  a_lds_block_desc_permuted,
923  make_tuple(
931  Sequence<1>{},
932  Sequence<2>{},
933  Sequence<3>{},
934  Sequence<4>{},
935  Sequence<5>{}),
937  Sequence<2>{},
938  Sequence<0, 3>{},
939  Sequence<4, 5>{},
940  Sequence<6>{},
941  Sequence<7>{}));
942 
943  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
944  a_lds_block_desc_unmerged,
947  Number<KThreadWrite / kfold / KThreadReadPerm>{},
948  Number<kfold>{},
955 
956  return a_lds_block_desc_ak0_m_ak1;
957  }
958  }
959 
960  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
961  {
962  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
963  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
964  constexpr index_t WaveSize = BlockSize / (MWave * NWave);
965  // B matrix in LDS memory, dst of blockwise copy
966  if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
967  {
968  // bank conflict when writting the data into LDS, but don't worry, we have whole entire
969  // loop to hide it in v4. it may give you some benefit from less valu in compute address
973  }
975  {
976  // NLdsLayer * K0 as logical Bank
977  constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize;
978  constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize;
979  constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
980  make_tuple(
981  BK0Number * Number<NLdsLayer>{}, Number<NPerBlock / NLdsLayer>{}, BK1Number),
983 
984  constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
985  b_lds_block_desc,
991 
992  constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor(
993  b_lds_block_desc_permuted,
999 
1000  constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
1001  b_lds_block_desc_bk0_nldslayer_n_bk1,
1008 
1009  return b_lds_block_desc_bk0_n_bk1;
1010  }
1011  else // RowMajor B
1012  {
1013  constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
1014  constexpr auto N1 = NPerBlock / N0;
1015 
1016  constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
1017  constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
1018  constexpr auto KThreadRead = WaveSize / NPerXdl;
1019  constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
1020 
1021  constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
1022  ? 1
1023  : 128 / (BK1Number * N0 * sizeof(BDataType));
1024  constexpr auto KThreadReadPerm =
1025  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
1026  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
1027  : KThreadRead;
1028 
1029  // 1<=npair<=n0
1030  constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128)
1031  ? 1
1032  : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0
1033  ? N0
1034  : 128 / (BK1Number * NPerXdl * sizeof(BDataType)));
1035 
1036  constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
1040  Number<kfold * N0 / npair>{},
1041  Number<npair>{},
1042  BK1Number));
1043 
1044  constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
1045  b_lds_block_desc,
1046  make_tuple(
1050  make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
1053  make_tuple(
1055  make_tuple(
1057 
1058  constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
1059  b_lds_block_desc_permuted,
1060  make_tuple(
1068  Sequence<1>{},
1069  Sequence<2>{},
1070  Sequence<3>{},
1071  Sequence<4>{},
1072  Sequence<5>{}),
1074  Sequence<2>{},
1075  Sequence<0, 3>{},
1076  Sequence<4, 5>{},
1077  Sequence<6>{},
1078  Sequence<7>{}));
1079 
1080  constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
1081  b_lds_block_desc_unmerged,
1084  Number<KThreadWrite / kfold / KThreadReadPerm>{},
1085  Number<kfold>{},
1092 
1093  return b_lds_block_desc_bk0_n_bk1;
1094  }
1095  }
1096 
1098  {
1099  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1100  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1101 
1102  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1104  make_tuple(I1,
1106  I1,
1108 
1109  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1110  }
1111 
1114  BlkGemmPipelineVer,
1115  BlkGemmPipeSched,
1116  BlockSize,
1117  ADataType,
1118  BDataType,
1119  ComputeTypeA,
1120  AccDataType,
1127  ABlockTransferSrcScalarPerVector,
1128  BBlockTransferSrcScalarPerVector,
1129  MPerBlock,
1130  NPerBlock,
1131  KPerBlock,
1132  MPerXdl,
1133  NPerXdl,
1134  MXdlPerWave,
1135  NXdlPerWave,
1136  KPack>())>;
1137 
1138  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
1139  {
1140  // LDS allocation for A and B: be careful of alignment
1141  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1142  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1143 
1144  // lds max alignment
1145  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1146 
1147  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1148  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1149 
1150  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
1151  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1152 
1153  // LDS allocation for C shuffle in LDS
1154  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1156 
1157  constexpr auto c_block_size =
1158  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1159 
1160  return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize +
1161  b_block_space_size_aligned * sizeof(BDataType) / BPackedSize),
1162  c_block_size * sizeof(CShuffleDataType));
1163  }
1164 
1165  template <InMemoryDataOperationEnum CGlobalMemoryDataOperation>
1166  __device__ static bool constexpr IsValidCompilationParameter()
1167  {
1168  enum struct Arch : bool
1169  {
1170 #if defined(__gfx950__)
1171  is_gfx950_build = true,
1172 #else
1173  is_gfx950_build = false,
1174 #endif
1175  };
1176 
1177  // skip building the instances with K1>=32 && PackedSize != 2 on pre-gfx950
1178  if constexpr(static_cast<bool>(Arch::is_gfx950_build) ||
1179  (AK1Number < 32 && BK1Number < 32) ||
1180  (AK1Number >= 32 && APackedSize == 2) ||
1181  (BK1Number >= 32 && BPackedSize == 2))
1182  {
1183 
1184  }
1185  else
1186  {
1187  return false;
1188  }
1189 
1190  // Check tile size
1191 #if defined(__gfx11__) || defined(__gfx12__)
1192  if constexpr(MPerXdl != 16 || NPerXdl != 16)
1193  {
1194  return false;
1195  }
1196 #endif
1197  // Check atomic caps
1198 #if defined(__gfx11__)
1199  constexpr bool SupportMemOp = CGlobalMemoryDataOperation == InMemoryDataOperationEnum::Set;
1200 #else
1201  constexpr bool SupportMemOp = sizeof(CDataType) >= 2 || (CGlobalMemoryDataOperation ==
1203 #endif
1204  if constexpr(SupportMemOp == false)
1205  {
1206  return false;
1207  }
1208 
1209  // Check tile size
1210  if constexpr(MXdlPerWave > 0 && NXdlPerWave > 0)
1211  {
1212  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
1213  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
1214  if constexpr(MWaves > 0 && NWaves > 0)
1215  {
1216  constexpr index_t WaveSize = BlockSize / (MWaves * NWaves);
1217  if constexpr(WaveSize == get_warp_size())
1218  {
1219  return true;
1220  }
1221  else
1222  {
1223  return false;
1224  }
1225  }
1226  else
1227  {
1228  return false;
1229  }
1230  }
1231  else
1232  {
1233  return false;
1234  }
1235  }
1236  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
1237  __host__ static constexpr bool CheckValidity(const Argument& karg)
1238  {
1239  if constexpr((MPerXdl * MXdlPerWave) == 0 || (NXdlPerWave * NPerXdl) == 0)
1240  {
1241  return false;
1242  }
1243  else
1244  {
1245  if constexpr((MPerBlock % (MPerXdl * MXdlPerWave) != 0) ||
1246  (NPerBlock % (NXdlPerWave * NPerXdl) != 0))
1247  {
1248  return false;
1249  }
1250  else
1251  {
1252  if(BlockwiseGemmPipe::WaveSize != get_warp_size())
1253  {
1254  return false;
1255  }
1256  }
1257  }
1258 
1259  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
1264  {
1265  if(!(karg.M % MPerBlock == 0))
1266  {
1267  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1268  {
1269  std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
1270  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1271  << std::endl;
1272  }
1273  return false;
1274  }
1275  }
1276 
1277  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1282  {
1283  if(!(karg.N % NPerBlock == 0))
1284  {
1285  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1286  {
1287  std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1288  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1289  << std::endl;
1290  }
1291  return false;
1292  }
1293  }
1294 
1295  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1299  {
1300 
1301  auto K_t = karg.KBatch * KPerBlock;
1302  if(!(karg.K % K_t == 0))
1303  {
1304  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1305  {
1306  std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1307  << karg.K << " " << __FILE__ << ":" << __LINE__
1308  << ", in function: " << __func__ << std::endl;
1309  }
1310  return false;
1311  }
1312  }
1313  else
1314  {
1315  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1316  auto K_t = karg.KBatch * KReadVec;
1317  auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1318  if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1319  {
1320  return false;
1321  }
1322  }
1323 
1325  {
1326  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1327  {
1328  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1329  {
1330  std::cout << "Arg K (" << karg.K
1331  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1332  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1333  << __LINE__ << ", in function: " << __func__ << std::endl;
1334  }
1335  return false;
1336  }
1337  }
1338  else
1339  {
1340  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1341  {
1342  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1343  {
1344  std::cout << "Arg M (" << karg.M
1345  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1346  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1347  << __LINE__ << ", in function: " << __func__ << std::endl;
1348  }
1349  return false;
1350  }
1351  }
1352 
1354  {
1355  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1356  {
1357  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1358  {
1359  std::cout << "Arg N (" << karg.N
1360  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1361  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1362  << __LINE__ << ", in function: " << __func__ << std::endl;
1363  }
1364  return false;
1365  }
1366  }
1367  else
1368  {
1369  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1370  {
1371  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1372  {
1373  std::cout << "Arg K (" << karg.K
1374  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1375  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1376  << __LINE__ << ", in function: " << __func__ << std::endl;
1377  }
1378  return false;
1379  }
1380  }
1381 
1383  {
1384  if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1385  {
1386  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1387  {
1388  std::cout << "Arg N (" << karg.N
1389  << ") value is not a multiple of "
1390  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1391  << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
1392  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1393  << std::endl;
1394  }
1395  return false;
1396  }
1397  }
1398  else
1399  {
1400  if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1401  {
1402  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1403  {
1404  std::cout << "Arg M (" << karg.M
1405  << ") value is not a multiple of "
1406  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1407  << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
1408  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1409  << std::endl;
1410  }
1411  return false;
1412  }
1413  }
1414 
1415  if constexpr(!(is_same<remove_cvref_t<CDataType>, half_t>::value ||
1419  {
1420  if(!karg.IsReduceAdd())
1421  {
1422  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1423  {
1424  std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__
1425  << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1426  }
1427  if(karg.KBatch > 1)
1428  {
1429  return false;
1430  }
1431  }
1432  }
1433 
1434  // check gridwise gemm pipeline
1435  const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1436 
1437  if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1)
1438  {
1439  if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1440  {
1441  return false;
1442  }
1443  }
1444 
1445  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1446  return true;
1447  }
1448 
1449  __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1450  {
1451  const index_t num_loop = K / KPerBlock;
1452 
1453  return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1454  }
1455 
1456  __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1457  {
1458  const index_t num_loop = K / KPerBlock;
1459 
1460  return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1461  }
1462 
1463  template <typename CGridDesc>
1464  __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
1465  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1466  {
1467  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1468  c_grid_desc_m_n,
1473 
1474  return c_grid_desc_mblock_mperblock_nblock_nperblock;
1475  }
1476 
1477  // return block_id to C matrix tile idx (m0, n0) mapping
1478  // if arch = gfx942
1480  // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
1481 
1482  template <typename AGridDesc_AK0_M_K1,
1483  typename BGridDesc_BK0_N_K1,
1484  typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1485  bool HasMainKBlockLoop,
1486  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1487  TailNumber TailNum = TailNumber::Odd>
1488  __device__ static void Run(const ADataType* p_a_grid,
1489  const BDataType* p_b_grid,
1490  CDataType* p_c_grid,
1491  void* p_shared,
1492  const Problem& problem,
1493  const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1494  const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1495  const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1496  c_grid_desc_mblock_mperblock_nblock_nperblock)
1497  {
1498  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1499  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1500  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1501  p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1502  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1503  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1504 
1505  // divide block work by [M, N]
1506  const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
1507 
1508  const auto block_work_idx =
1509  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1510 
1511  if(!block_2_ctile_map.ValidCTileIndex(
1512  block_work_idx,
1513  make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1514  c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1515  {
1516  return;
1517  }
1518 
1519  const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1520  const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1521 
1522  // HACK: this force m/n_block_data_idx_on_grid into SGPR
1523  const index_t m_block_data_idx_on_grid =
1524  __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1525 
1526  const index_t n_block_data_idx_on_grid =
1527  __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1528 
1529  // lds max alignment
1530  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1531 
1532  // A matrix in LDS memory, dst of blockwise copy
1533  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1534 
1535  // B matrix in LDS memory, dst of blockwise copy
1536  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1537 
1538  // A matrix blockwise copy
1539  auto a_blockwise_copy =
1541  AElementwiseOperation,
1545  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1546  ABlockTransferThreadClusterArrangeOrder,
1547  ADataType,
1548  ADataType,
1549  decltype(a_grid_desc_ak0_m_ak1),
1550  decltype(a_block_desc_ak0_m_ak1),
1551  ABlockTransferSrcAccessOrder,
1553  ABlockTransferSrcVectorDim,
1554  2,
1555  ABlockTransferSrcScalarPerVector,
1556  ABlockTransferDstScalarPerVector_AK1,
1557  1,
1558  1,
1559  AThreadTransferSrcResetCoordinateAfterRun,
1560  true,
1561  BlockwiseGemmPipe::GlobalBufferNum>(
1562  a_grid_desc_ak0_m_ak1,
1563  make_multi_index(0, m_block_data_idx_on_grid, 0),
1564  problem.a_element_op_,
1565  a_block_desc_ak0_m_ak1,
1566  make_multi_index(0, 0, 0),
1568 
1569  // B matrix blockwise copy
1570  auto b_blockwise_copy =
1572  BElementwiseOperation,
1576  BBlockTransferThreadClusterLengths_BK0_N_BK1,
1577  BBlockTransferThreadClusterArrangeOrder,
1578  BDataType,
1579  BDataType,
1580  decltype(b_grid_desc_bk0_n_bk1),
1581  decltype(b_block_desc_bk0_n_bk1),
1582  BBlockTransferSrcAccessOrder,
1584  BBlockTransferSrcVectorDim,
1585  2,
1586  BBlockTransferSrcScalarPerVector,
1587  BBlockTransferDstScalarPerVector_BK1,
1588  1,
1589  1,
1590  BThreadTransferSrcResetCoordinateAfterRun,
1591  true,
1592  BlockwiseGemmPipe::GlobalBufferNum>(
1593  b_grid_desc_bk0_n_bk1,
1594  make_multi_index(0, n_block_data_idx_on_grid, 0),
1595  problem.b_element_op_,
1596  b_block_desc_bk0_n_bk1,
1597  make_multi_index(0, 0, 0),
1599 
1600  // LDS allocation for A and B: be careful of alignment
1601  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1602  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1603 
1604  // Cast after lds
1605  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1606  static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1607 
1608  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1609  reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) + a_block_space_size_aligned *
1610  sizeof(ADataType) /
1611  APackedSize),
1612  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1613 
1614  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1615  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
1616 
1617  // Blockwise GEMM pipeline
1618  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1619  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1620  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1621 
1622  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1623  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1624  KPerBlock);
1625 
1626  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
1627  a_block_desc_ak0_m_ak1,
1628  a_blockwise_copy,
1629  a_grid_buf,
1630  a_block_buf,
1631  a_block_slice_copy_step,
1632  b_grid_desc_bk0_n_bk1,
1633  b_block_desc_bk0_n_bk1,
1634  b_blockwise_copy,
1635  b_grid_buf,
1636  b_block_buf,
1637  b_block_slice_copy_step,
1638  c_thread_buf,
1639  num_k_block_main_loop);
1640 
1641  // shuffle C and write out
1642  {
1643  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1644  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1645  "wrong!");
1646 
1647  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1648  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1649 
1650  // TODO: hacky, fix it!
1651  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1652  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1653 
1654  // TODO: hacky, fix it!
1655  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1656  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1657  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1658 
1659  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1660  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1661  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1662  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1663  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1664  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1665  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1666  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1667 
1668  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1670 
1671  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1672  static_cast<CShuffleDataType*>(p_shared),
1673  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1674 
1675  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1676  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1677  make_tuple(
1680  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1681  M1, // M1 = MWave
1682  M2, // M2 * M3 * M4 = MPerXdl
1683  M3,
1684  M4)),
1687  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1688  N1, // N1 = NWave
1689  N2))), // N2 = NPerXdl
1691  make_tuple(
1693 
1694  // calculate origin of thread output tensor on global memory
1695  // blockwise GEMM c matrix starting index
1696  const auto c_thread_mtx_on_block =
1697  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1698 
1699  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1700  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1701 
1702  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1704  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1706  make_tuple(Sequence<0>{}));
1707 
1708  const auto m_thread_data_on_block_idx =
1709  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1710  make_multi_index(m_thread_data_on_block));
1711 
1712  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1716  make_tuple(Sequence<0>{}));
1717 
1718  const auto n_thread_data_on_block_idx =
1719  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1720  make_multi_index(n_thread_data_on_block));
1721 
1723  const auto& vpgr_to_lds_element_op = [&] {
1724  if constexpr(DoElementwiseBeforeCShuffle)
1725  {
1726  return problem.c_element_op_;
1727  }
1728  else
1729  {
1730  return pass_through;
1731  }
1732  };
1733  const auto& lds_to_global_element_op = [&] {
1734  if constexpr(!DoElementwiseBeforeCShuffle)
1735  {
1736  return problem.c_element_op_;
1737  }
1738  else
1739  {
1740  return pass_through;
1741  }
1742  };
1743 
1744  // shuffle: threadwise copy C from VGPR to LDS
1745  auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
1746  AccDataType,
1747  CShuffleDataType,
1748  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1749  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1750  conditional_t<DoElementwiseBeforeCShuffle,
1751  CElementwiseOperation,
1753  Sequence<CShuffleMXdlPerWavePerShuffle,
1754  CShuffleNXdlPerWavePerShuffle,
1755  I1,
1756  I1,
1757  M2,
1758  I1,
1759  M4,
1760  I1>,
1762  7,
1763  1,
1765  1,
1766  true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1767  make_multi_index(0,
1768  0,
1769  m_thread_data_on_block_idx[I1],
1770  n_thread_data_on_block_idx[I1],
1771  m_thread_data_on_block_idx[I2],
1772  m_thread_data_on_block_idx[I3],
1773  m_thread_data_on_block_idx[I4],
1774  n_thread_data_on_block_idx[I2]),
1775  vpgr_to_lds_element_op()};
1776 
1777  // shuffle: blockwise copy C from LDS to global
1778  auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1779  ThisThreadBlock, // ThreadGroup
1780  conditional_t<!DoElementwiseBeforeCShuffle,
1781  CElementwiseOperation,
1783  CGlobalMemoryDataOperation, // DstInMemOp,
1784  Sequence<1,
1785  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1786  1,
1787  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1788  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1789  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1790  CShuffleDataType, // typename SrcData,
1791  CDataType, // typename DstData,
1792  decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1793  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1794  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1795  3, // index_t VectorDim,
1796  CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1797  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1798  false> // bool ThreadTransferDstResetCoordinateAfterRun>
1799  {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1800  make_multi_index(0, 0, 0, 0),
1801  c_grid_desc_mblock_mperblock_nblock_nperblock,
1802  make_multi_index(block_m_id, 0, block_n_id, 0),
1803  lds_to_global_element_op()};
1804 
1805  // space filling curve for threadwise C in VGPR
1806  constexpr auto sfc_c_vgpr =
1809  Sequence<CShuffleMXdlPerWavePerShuffle,
1810  CShuffleNXdlPerWavePerShuffle,
1811  1,
1812  1,
1813  M2,
1814  1,
1815  M4,
1816  1>>{};
1817 
1818  // space filling curve for shuffled blockwise C in global mem
1819  constexpr auto sfc_c_global =
1822  Sequence<1,
1823  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1824  1,
1825  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1826 
1827  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1828 
1829  static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1830 
1831  static_for<0, num_access, 1>{}([&](auto access_id) {
1832  // make sure it's safe to write to LDS
1833  block_sync_lds();
1834 
1835  // each thread write its data from VGPR to LDS
1836  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1837  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1838  c_thread_buf,
1839  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1840  c_shuffle_block_buf);
1841 
1842  // make sure it's safe to read from LDS
1843  block_sync_lds();
1844 
1845  // each block copy its data from LDS to global
1846  c_shuffle_block_copy_lds_to_global.Run(
1847  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1848  c_shuffle_block_buf,
1849  c_grid_desc_mblock_mperblock_nblock_nperblock,
1850  c_grid_buf);
1851 
1852  if constexpr(access_id < num_access - 1)
1853  {
1854  constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1855 
1856  // move on C
1857  c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1858  c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1859  }
1860  });
1861  }
1862  }
1863 
1864  template <bool HasMainKBlockLoop,
1865  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1866  TailNumber TailNum = TailNumber::Odd>
1867  __device__ static void Run(const ADataType* p_a_grid,
1868  const BDataType* p_b_grid,
1869  CDataType* p_c_grid,
1870  void* p_shared,
1871  const Problem& problem)
1872  {
1873  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1874  problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
1875  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
1876  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
1877  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
1878  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
1879  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1881  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1882 
1883  Run<decltype(a_grid_desc_ak0_m_ak1),
1884  decltype(b_grid_desc_bk0_n_bk1),
1885  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1886  HasMainKBlockLoop,
1887  CGlobalMemoryDataOperation,
1888  TailNum>(p_a_grid,
1889  p_b_grid,
1890  p_c_grid,
1891  p_shared,
1892  problem,
1893  a_grid_desc_ak0_m_ak1,
1894  b_grid_desc_bk0_n_bk1,
1895  c_grid_desc_mblock_mperblock_nblock_nperblock);
1896  }
1897 
1898  template <typename AGridDesc_AK0_M_K1,
1899  typename BGridDesc_BK0_N_K1,
1900  typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1901  bool HasMainKBlockLoop,
1902  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1903  TailNumber TailNum = TailNumber::Odd>
1904  __device__ static void Run_2Lds(const ADataType* p_a_grid,
1905  const BDataType* p_b_grid,
1906  CDataType* p_c_grid,
1907  void* p_shared_0,
1908  void* p_shared_1,
1909  const Problem& problem,
1910  const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1911  const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1912  const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1913  c_grid_desc_mblock_mperblock_nblock_nperblock)
1914  {
1915  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1916  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1917  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1918  p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1919  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1920  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1921 
1922  // divide block work by [M, N]
1923  const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
1924 
1925  const auto block_work_idx =
1926  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1927 
1928  if(!block_2_ctile_map.ValidCTileIndex(
1929  block_work_idx,
1930  make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1931  c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1932  {
1933  return;
1934  }
1935 
1936  const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1937  const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1938 
1939  // HACK: this force m/n_block_data_idx_on_grid into SGPR
1940  const index_t m_block_data_idx_on_grid =
1941  __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1942 
1943  const index_t n_block_data_idx_on_grid =
1944  __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1945 
1946  // lds max alignment
1947  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1948 
1949  // A matrix in LDS memory, dst of blockwise copy
1950  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1951 
1952  // B matrix in LDS memory, dst of blockwise copy
1953  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1954 
1955  // A matrix blockwise copy
1956  auto a_blockwise_copy =
1958  AElementwiseOperation,
1962  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1963  ABlockTransferThreadClusterArrangeOrder,
1964  ADataType,
1965  ADataType,
1966  decltype(a_grid_desc_ak0_m_ak1),
1967  decltype(a_block_desc_ak0_m_ak1),
1968  ABlockTransferSrcAccessOrder,
1970  ABlockTransferSrcVectorDim,
1971  2,
1972  ABlockTransferSrcScalarPerVector,
1973  ABlockTransferDstScalarPerVector_AK1,
1974  1,
1975  1,
1976  AThreadTransferSrcResetCoordinateAfterRun,
1977  true,
1978  BlockwiseGemmPipe::GlobalBufferNum>(
1979  a_grid_desc_ak0_m_ak1,
1980  make_multi_index(0, m_block_data_idx_on_grid, 0),
1981  problem.a_element_op_,
1982  a_block_desc_ak0_m_ak1,
1983  make_multi_index(0, 0, 0),
1985 
1986  // B matrix blockwise copy
1987  auto b_blockwise_copy =
1989  BElementwiseOperation,
1993  BBlockTransferThreadClusterLengths_BK0_N_BK1,
1994  BBlockTransferThreadClusterArrangeOrder,
1995  BDataType,
1996  BDataType,
1997  decltype(b_grid_desc_bk0_n_bk1),
1998  decltype(b_block_desc_bk0_n_bk1),
1999  BBlockTransferSrcAccessOrder,
2001  BBlockTransferSrcVectorDim,
2002  2,
2003  BBlockTransferSrcScalarPerVector,
2004  BBlockTransferDstScalarPerVector_BK1,
2005  1,
2006  1,
2007  BThreadTransferSrcResetCoordinateAfterRun,
2008  true,
2009  BlockwiseGemmPipe::GlobalBufferNum>(
2010  b_grid_desc_bk0_n_bk1,
2011  make_multi_index(0, n_block_data_idx_on_grid, 0),
2012  problem.b_element_op_,
2013  b_block_desc_bk0_n_bk1,
2014  make_multi_index(0, 0, 0),
2016 
2017  // LDS allocation for A and B: be careful of alignment
2018  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
2019  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
2020 
2021  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2022  static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2023 
2024  auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2025  bit_cast<BDataType*>(static_cast<char*>(p_shared_0) +
2026  a_block_space_size_aligned * sizeof(ADataType)),
2027  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2028 
2029  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2030  static_cast<ADataType*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2031 
2032  auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2033  bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
2034  a_block_space_size_aligned * sizeof(ADataType)),
2035  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2036 
2037  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2038  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2039 
2040  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
2041  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
2042 
2043  // Blockwise GEMM pipeline
2044  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2045  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2046  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2047 
2048  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2049  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
2050  KPerBlock);
2051 
2052  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
2053  a_block_desc_ak0_m_ak1,
2054  a_blockwise_copy,
2055  a_grid_buf,
2056  a_block_bufs,
2057  a_block_slice_copy_step,
2058  b_grid_desc_bk0_n_bk1,
2059  b_block_desc_bk0_n_bk1,
2060  b_blockwise_copy,
2061  b_grid_buf,
2062  b_block_bufs,
2063  b_block_slice_copy_step,
2064  c_thread_buf,
2065  num_k_block_main_loop);
2066 
2067  // shuffle C and write out
2068  {
2069  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2070  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2071  "wrong!");
2072 
2073  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2074  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
2075 
2076  // TODO: hacky, fix it!
2077  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2078  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2079 
2080  // TODO: hacky, fix it!
2081  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
2082  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2083  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2084 
2085  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
2086  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
2087  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
2088  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
2089  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
2090  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
2091  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
2092  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
2093 
2094  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2096 
2097  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2098  static_cast<CShuffleDataType*>(p_shared_0),
2099  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2100 
2101  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
2102  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2103  make_tuple(
2106  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
2107  M1, // M1 = MWave
2108  M2, // M2 * M3 * M4 = MPerXdl
2109  M3,
2110  M4)),
2113  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
2114  N1, // N1 = NWave
2115  N2))), // N2 = NPerXdl
2117  make_tuple(
2119 
2120  // calculate origin of thread output tensor on global memory
2121  // blockwise GEMM c matrix starting index
2122  const auto c_thread_mtx_on_block =
2123  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2124 
2125  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2126  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2127 
2128  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2130  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
2132  make_tuple(Sequence<0>{}));
2133 
2134  const auto m_thread_data_on_block_idx =
2135  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2136  make_multi_index(m_thread_data_on_block));
2137 
2138  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2142  make_tuple(Sequence<0>{}));
2143 
2144  const auto n_thread_data_on_block_idx =
2145  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2146  make_multi_index(n_thread_data_on_block));
2147 
2148  // shuffle: threadwise copy C from VGPR to LDS
2149  auto c_thread_copy_vgpr_to_lds =
2151  CShuffleDataType,
2152  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2153  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2155  Sequence<CShuffleMXdlPerWavePerShuffle,
2156  CShuffleNXdlPerWavePerShuffle,
2157  I1,
2158  I1,
2159  M2,
2160  I1,
2161  M4,
2162  I1>,
2164  7,
2165  1,
2167  1,
2168  true>{
2169  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2170  make_multi_index(0,
2171  0,
2172  m_thread_data_on_block_idx[I1],
2173  n_thread_data_on_block_idx[I1],
2174  m_thread_data_on_block_idx[I2],
2175  m_thread_data_on_block_idx[I3],
2176  m_thread_data_on_block_idx[I4],
2177  n_thread_data_on_block_idx[I2]),
2179 
2180  // shuffle: blockwise copy C from LDS to global
2181  auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
2182  ThisThreadBlock, // ThreadGroup
2183  CElementwiseOperation, // ElementwiseOperation,
2184  CGlobalMemoryDataOperation, // DstInMemOp,
2185  Sequence<1,
2186  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2187  1,
2188  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2189  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
2190  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2191  CShuffleDataType, // typename SrcData,
2192  CDataType, // typename DstData,
2193  decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2194  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2195  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
2196  3, // index_t VectorDim,
2197  CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
2198  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
2199  false> // bool ThreadTransferDstResetCoordinateAfterRun>
2200  {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2201  make_multi_index(0, 0, 0, 0),
2202  c_grid_desc_mblock_mperblock_nblock_nperblock,
2203  make_multi_index(block_m_id, 0, block_n_id, 0),
2204  problem.c_element_op_};
2205 
2206  // space filling curve for threadwise C in VGPR
2207  constexpr auto sfc_c_vgpr =
2210  Sequence<CShuffleMXdlPerWavePerShuffle,
2211  CShuffleNXdlPerWavePerShuffle,
2212  1,
2213  1,
2214  M2,
2215  1,
2216  M4,
2217  1>>{};
2218 
2219  // space filling curve for shuffled blockwise C in global mem
2220  constexpr auto sfc_c_global =
2223  Sequence<1,
2224  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2225  1,
2226  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2227 
2228  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2229 
2230  static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
2231 
2232  static_for<0, num_access, 1>{}([&](auto access_id) {
2233  // make sure it's safe to write to LDS
2234  block_sync_lds();
2235 
2236  // each thread write its data from VGPR to LDS
2237  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2238  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2239  c_thread_buf,
2240  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2241  c_shuffle_block_buf);
2242 
2243  // make sure it's safe to read from LDS
2244  block_sync_lds();
2245 
2246  // each block copy its data from LDS to global
2247  c_shuffle_block_copy_lds_to_global.Run(
2248  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2249  c_shuffle_block_buf,
2250  c_grid_desc_mblock_mperblock_nblock_nperblock,
2251  c_grid_buf);
2252 
2253  if constexpr(access_id < num_access - 1)
2254  {
2255  constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
2256 
2257  // move on C
2258  c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
2259  c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
2260  }
2261  });
2262  }
2263  }
2264 
2265  template <bool HasMainKBlockLoop,
2266  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
2267  TailNumber TailNum = TailNumber::Odd>
2268  __device__ static void Run_2Lds(const ADataType* p_a_grid,
2269  const BDataType* p_b_grid,
2270  CDataType* p_c_grid,
2271  void* p_shared_0,
2272  void* p_shared_1,
2273  const Problem& problem)
2274  {
2275  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
2276  problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
2277  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
2278  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
2279  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
2280  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
2281 
2282  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2284  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2285 
2286  Run_2Lds<decltype(a_grid_desc_ak0_m_ak1),
2287  decltype(b_grid_desc_bk0_n_bk1),
2288  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2289  HasMainKBlockLoop,
2290  CGlobalMemoryDataOperation,
2291  TailNum>(p_a_grid,
2292  p_b_grid,
2293  p_c_grid,
2294  p_shared_0,
2295  p_shared_1,
2296  problem,
2297  a_grid_desc_ak0_m_ak1,
2298  b_grid_desc_bk0_n_bk1,
2299  c_grid_desc_mblock_mperblock_nblock_nperblock);
2300  }
2301 };
2302 
2303 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:276
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr auto BlockGemmPipeline_Selector()
Definition: blockwise_gemm_pipeline_wmma_selector.hpp:31
_Float16 half_t
Definition: data_type.hpp:30
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
ushort bhalf_t
Definition: data_type.hpp:29
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:59
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
constexpr bool is_same_v
Definition: type.hpp:283
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
unsigned int uint32_t
Definition: stdint.h:126
signed int int32_t
Definition: stdint.h:123
Definition: block_to_ctile_map.hpp:270
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:282
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:716
const BElementwiseOperation b_element_op
Definition: gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:639
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t k_batch_, bool is_reduce_=false, AElementwiseOperation a_element_op=AElementwiseOperation{}, BElementwiseOperation b_element_op=BElementwiseOperation{}, CElementwiseOperation c_element_op=CElementwiseOperation{})
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:717
const BDataType * p_b_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:759
CDataType * p_c_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:760
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:748
const AElementwiseOperation a_element_op
Definition: gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:638
__host__ __device__ bool IsAtomicAdd() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:753
const ADataType * p_a_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:758
const CElementwiseOperation c_element_op
Definition: gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:640
bool is_reduce
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:761
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:641
index_t N
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:695
index_t NPadded
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:702
index_t KBatch
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:700
index_t StrideA
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:697
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t KBatch_, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:642
CElementwiseOperation c_element_op_
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:711
index_t BK0
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:706
index_t M
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:694
index_t NBlock
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:708
index_t MPadded
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:701
index_t K
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:696
index_t StrideB
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:698
index_t KPadded
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:704
index_t StrideC
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:699
index_t MBlock
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:707
BElementwiseOperation b_element_op_
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:710
index_t AK0
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:705
index_t KRead
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:703
AElementwiseOperation a_element_op_
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:709
__host__ void Print() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:673
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:765
index_t a_k_split_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:814
index_t b_k_split_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:815
__device__ SplitKBatchOffset(Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:767
index_t c_reduce_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:816
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:247
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:572
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:451
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:337
static constexpr auto is_scale_mfma
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:273
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:314
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:304
static constexpr auto BK1Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:261
static constexpr index_t APackedSize
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:285
static constexpr bool is_single_rate_mfma
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:264
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1456
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:283
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1464
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:319
static constexpr auto I2
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:250
static constexpr index_t KPack
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:274
static constexpr auto lcm_AK1_BK1
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:263
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:355
static constexpr auto I7
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:255
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack >())> BlockwiseGemmPipe
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1136
static constexpr auto I5
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:253
static __device__ constexpr bool IsValidCompilationParameter()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1166
static constexpr auto AK1Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:260
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1237
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1904
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:299
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:344
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:2268
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:309
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:563
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:325
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1138
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:960
static constexpr index_t BPackedSize
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:292
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1488
static constexpr auto I6
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:254
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:819
static constexpr auto I1
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:249
static constexpr auto I0
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:248
static constexpr auto I3
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:251
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1867
static constexpr auto I4
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:252
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:331
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:369
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:349
static constexpr auto BK0Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:259
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1097
static constexpr auto AK0Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:258
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1449
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:580
Definition: xdlops_gemm.hpp:1126
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: data_type.hpp:186
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:334
#define CK_ENV(name)
Definition: env.hpp:129