/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Source File
gridwise_gemm_xdl_cshuffle_v3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/env.hpp"
17 
18 namespace ck {
19 
20 // Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
21 // kernel function Blockers:
22 // 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
23 // two lds chunks.
24 // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
25 // buffer when we declare __shared__ inside blkgemmpipe
26 template <typename GridwiseGemm,
27  bool HasMainKBlockLoop,
28  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
29  index_t MinimumOccupancy = 1,
30  TailNumber TailNum = TailNumber::Full>
31 __global__ void
32 #if CK_USE_LAUNCH_BOUNDS
33 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
34 #endif
35  // __attribute__((amdgpu_waves_per_eu(1, 1)))
36  kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
37 {
38 #if defined(__gfx9__) || defined(__gfx12__) || defined(__gfx11__)
39  if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
40  {
41  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
42 
43  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
44 
45  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
46  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
47  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
48  karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
49  p_shared,
50  karg);
51  }
52 #else
53  ignore = karg;
54 #endif // end of if (defined(__gfx9__))
55 }
56 
57 template <typename GridwiseGemm,
58  bool HasMainKBlockLoop,
59  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
60  index_t MinimumOccupancy = 1,
61  TailNumber TailNum = TailNumber::Full>
62 __global__ void
63 #if CK_USE_LAUNCH_BOUNDS
64 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
65 #endif
66  // __attribute__((amdgpu_waves_per_eu(1, 1)))
67  kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
68 {
69 #if defined(__gfx9__) || defined(__gfx12__) || defined(__gfx11__)
70  // Pass two lds pointer is the key to tell compiler that ds_read/write
71  // operate on different lds chunk at same time without order dependecy
72  if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
73  {
74  __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
75  __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
76 
77  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
78 
79  GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
80  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
81  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
82  karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
83  p_shared_0,
84  p_shared_1,
85  karg);
86  }
87 #else
88  ignore = karg;
89 #endif // end of if (defined(__gfx9__))
90 }
91 
197 template <typename ALayout,
198  typename BLayout,
199  typename CLayout,
200  typename ADataType,
201  typename BDataType,
202  typename AccDataType,
203  typename CShuffleDataType,
204  typename CDataType,
205  typename AElementwiseOperation,
206  typename BElementwiseOperation,
207  typename CElementwiseOperation,
209  index_t BlockSize,
210  index_t MPerBlock,
211  index_t NPerBlock,
212  index_t KPerBlock,
213  index_t AK1Value,
214  index_t BK1Value,
215  index_t MPerXdl,
216  index_t NPerXdl,
217  index_t MXdlPerWave,
218  index_t NXdlPerWave,
219  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
220  typename ABlockTransferThreadClusterArrangeOrder,
221  typename ABlockTransferSrcAccessOrder,
222  index_t ABlockTransferSrcVectorDim,
223  index_t ABlockTransferSrcScalarPerVector,
224  index_t ABlockTransferDstScalarPerVector_AK1,
225  bool AThreadTransferSrcResetCoordinateAfterRun,
226  index_t ABlockLdsExtraM,
227  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
228  typename BBlockTransferThreadClusterArrangeOrder,
229  typename BBlockTransferSrcAccessOrder,
230  index_t BBlockTransferSrcVectorDim,
231  index_t BBlockTransferSrcScalarPerVector,
232  index_t BBlockTransferDstScalarPerVector_BK1,
233  bool BThreadTransferSrcResetCoordinateAfterRun,
234  index_t BBlockLdsExtraN,
235  index_t CShuffleMXdlPerWavePerShuffle,
236  index_t CShuffleNXdlPerWavePerShuffle,
237  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
238  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
241  typename ComputeTypeA = CDataType,
242  typename ComputeTypeB = ComputeTypeA,
243  bool PermuteA = false,
244  bool PermuteB = false,
245  bool DoElementwiseBeforeCShuffle = false>
247 {
248  static constexpr auto I0 = Number<0>{};
249  static constexpr auto I1 = Number<1>{};
250  static constexpr auto I2 = Number<2>{};
251  static constexpr auto I3 = Number<3>{};
252  static constexpr auto I4 = Number<4>{};
253  static constexpr auto I5 = Number<5>{};
254  static constexpr auto I6 = Number<6>{};
255  static constexpr auto I7 = Number<7>{};
256 
257  // K1 should be Number<...>
258  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
259  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
260  static constexpr auto AK1Number = Number<AK1Value>{};
261  static constexpr auto BK1Number = Number<BK1Value>{};
262 
263  static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
264  static constexpr bool is_single_rate_mfma =
266  lcm_AK1_BK1 <= 4) ||
268  // gfx950 double rate mfma16x16 require at least 128 KPerBlock to consume
270  KPerBlock < 128 && MPerXdl == 16))
271  ? true
272  : false;
273  static constexpr auto is_scale_mfma = false;
274  static constexpr index_t KPack =
276  MfmaSelector<ComputeTypeA,
277  MPerXdl,
278  NPerXdl,
279  ComputeTypeA,
281  is_scale_mfma>::selected_mfma.k_per_blk);
282 
284 
285  static constexpr index_t APackedSize = []() {
287  return 2;
288  else
289  return 1;
290  }();
291 
292  static constexpr index_t BPackedSize = []() {
294  return 2;
295  else
296  return 1;
297  }();
298 
299  __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
300  {
301  return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
302  }
303 
304  __host__ static auto CalculateMPadded(index_t M)
305  {
306  return math::integer_least_multiple(M, MPerBlock);
307  }
308 
309  __host__ static auto CalculateNPadded(index_t N)
310  {
311  return math::integer_least_multiple(N, NPerBlock);
312  }
313 
314  __host__ static auto CalculateKPadded(index_t K)
315  {
316  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
317  }
318 
319  __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
320  {
321  auto K_t = K_Batch * KPerBlock;
322  return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
323  }
324 
325  __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
326  {
327  auto K_t = K_Batch * KPerBlock;
328  return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
329  }
330 
331  __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
332  {
333  auto K_t = K_Batch * KPerBlock;
334  return (K + K_t - 1) / K_t * KPerBlock;
335  }
336 
337  __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
338  {
339  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
340  auto K_t = K_Batch * KReadVec;
341  return (K + K_t - 1) / K_t * KReadVec;
342  }
343 
344  __host__ static auto CalculateMBlock(index_t M)
345  {
346  return math::integer_divide_ceil(M, MPerBlock);
347  }
348 
349  __host__ static auto CalculateNBlock(index_t N)
350  {
351  return math::integer_divide_ceil(N, NPerBlock);
352  }
353 
354  template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
355  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
356  {
357  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
358  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
359 
361  TileDesc_K0_MN_K1{},
367  }
368 
369  __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
370  index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
371  {
372  const auto a_grid_desc_mraw_kraw = [&]() {
373  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
374  {
375  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
376  }
377  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
378  {
379  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
380  }
381  }();
382 
384 
385  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
386  GemmSpec == GemmSpecialization::MNKPadding)
387  {
388  // pad both M and K
389  const auto a_grid_desc_m_k =
390  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
392  make_right_pad_transform(K, KPad - K)),
395 
396  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
397  a_grid_desc_m_k,
402 
403  return a_grid_desc_ak0_m_ak1;
404  }
405  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
406  GemmSpec == GemmSpecialization::MNPadding)
407  {
408  // pad M, but not K
409  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
410  a_grid_desc_mraw_kraw,
412  make_right_pad_transform(M, MPad - M)),
415 
416  return a_grid_desc_ak0_m_ak1;
417  }
418  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
419  GemmSpec == GemmSpecialization::NKPadding)
420  {
421  // pad K, but not M
422  const auto a_grid_desc_m_k = transform_tensor_descriptor(
423  a_grid_desc_mraw_kraw,
427 
428  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
429  a_grid_desc_m_k,
434 
435  return a_grid_desc_ak0_m_ak1;
436  }
437  else
438  {
439  // not pad M or K
440  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
441  a_grid_desc_mraw_kraw,
446 
447  return a_grid_desc_ak0_m_ak1;
448  }
449  }
450 
451  __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
452  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
453  {
454  const auto b_grid_desc_nraw_kraw = [&]() {
456  {
457  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
458  }
460  {
461  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
462  }
463  }();
464 
466 
467  static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
468  GemmSpec != GemmSpecialization::Default),
469  "pk_i4_t does not support padding");
470 
471  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
472  GemmSpec == GemmSpecialization::MNKPadding)
473  {
474  // pad both N and K
475  const auto b_grid_desc_n_k =
476  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
478  make_right_pad_transform(K, KPad - K)),
481 
482  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
483  b_grid_desc_n_k,
488 
489  return b_grid_desc_bk0_n_bk1;
490  }
491  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
492  GemmSpec == GemmSpecialization::MNPadding)
493  {
494  // pad N, but not K
495  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
496  b_grid_desc_nraw_kraw,
498  make_right_pad_transform(N, NPad - N)),
501 
502  return b_grid_desc_bk0_n_bk1;
503  }
504  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
505  GemmSpec == GemmSpecialization::MKPadding)
506  {
507  // pad K, but not N
508  const auto b_grid_desc_n_k = transform_tensor_descriptor(
509  b_grid_desc_nraw_kraw,
513 
514  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
515  b_grid_desc_n_k,
520 
521  return b_grid_desc_bk0_n_bk1;
522  }
523  else
524  {
525  if constexpr(!PermuteB)
526  {
527  // not pad N or K
528  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
529  b_grid_desc_nraw_kraw,
534 
535  return b_grid_desc_bk0_n_bk1;
536  }
537  else
538  {
539  // Pre-shuffled Weight
540  // BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1]
541  constexpr index_t BK01 = KPerBlock / BK1Value;
542  const index_t BK0_ = StrideB / BK1Value;
543  const index_t BK00 = BK0_ / BK01;
544 
545  const auto b_grid_desc_bk00_n_bk01_bk1_permute =
546  make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value));
547 
548  const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor(
549  b_grid_desc_bk00_n_bk01_bk1_permute,
552  make_pass_through_transform(BK1Value)),
555 
556  return b_grid_desc_bk0_n_bk1_permute;
557  }
558  }
559  }
560 
561  template <typename ABlockDesc_AK0_M_AK1>
562  __host__ __device__ static constexpr auto
563  MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
564  {
565  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
566 
567  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
568  }
569 
570  template <typename BBlockDesc_BK0_N_BK1>
571  __host__ __device__ static constexpr auto
572  MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
573  {
574  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
575 
576  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
577  }
578 
579  __host__ __device__ static auto
581  {
582  const auto c_grid_desc_mraw_nraw = [&]() {
584  {
585  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
586  }
588  {
589  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
590  }
591  }();
592 
593  // pad M and N
594  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
596  make_right_pad_transform(N, NPad - N)),
599 #if 0
601 
602  if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
603  GemmSpec == GemmSpecialization::MNKPadding)
604  {
605  // pad M and N
606  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
608  make_right_pad_transform(N, NPad - N)),
611  }
612  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
613  GemmSpec == GemmSpecialization::MKPadding)
614  {
615  // pad M, but not N
617  c_grid_desc_mraw_nraw,
621  }
622  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
623  GemmSpec == GemmSpecialization::NKPadding)
624  {
625  // pad N, but not M
627  c_grid_desc_mraw_nraw,
631  }
632  else
633  {
634  // not pad M or N
635  return c_grid_desc_mraw_nraw;
636  }
637 #endif
638  }
639 
640  struct Problem
641  {
642  __host__ Problem(index_t M_,
643  index_t N_,
644  index_t K_,
645  index_t StrideA_,
646  index_t StrideB_,
647  index_t StrideC_,
648  index_t KBatch_,
649  AElementwiseOperation a_element_op,
650  BElementwiseOperation b_element_op,
651  CElementwiseOperation c_element_op)
652  : M{M_},
653  N{N_},
654  K{K_},
655  StrideA{StrideA_},
656  StrideB{StrideB_},
657  StrideC{StrideC_},
658  KBatch{KBatch_},
661  KRead{CalculateKRead(K_, KBatch_)},
662  KPadded{CalculateKPadded(K_, KBatch_)},
663  AK0{CalculateAK0Padded(K_, KBatch_)},
664  BK0{CalculateBK0Padded(K_, KBatch_)},
665  MBlock{CalculateMBlock(M_)},
666  NBlock{CalculateNBlock(N_)},
667  a_element_op_{a_element_op},
668  b_element_op_{b_element_op},
669  c_element_op_{c_element_op}
670  {
671  }
672 
673  __host__ void Print() const
674  {
675  // clang-format off
676  std::cout << "problem {"
677  << "M:" << M << ", "
678  << "N:" << N << ", "
679  << "K:" << K << ", "
680  << "SA:" << StrideA << ", "
681  << "SB:" << StrideB << ", "
682  << "SC:" << StrideC << ", "
683  << "MP:" << MPadded << ", "
684  << "NP:" << NPadded << ", "
685  << "KRead:" << KRead << ", "
686  << "KP:" << KPadded << ", "
687  << "AK0:" << AK0 << ", "
688  << "BK0:" << BK0 << ", "
689  << "MBlock: " << MBlock << ", "
690  << "NBlock: " << NBlock << "}" << std::endl;
691  // clang-format off
692  }
693 
709  AElementwiseOperation a_element_op_;
710  BElementwiseOperation b_element_op_;
711  CElementwiseOperation c_element_op_;
712  };
713 
714  // Argument
716  {
717  __host__ Argument(const ADataType* p_a_grid_,
718  const BDataType* p_b_grid_,
719  CDataType* p_c_grid_,
720  index_t M_,
721  index_t N_,
722  index_t K_,
723  index_t StrideA_,
724  index_t StrideB_,
725  index_t StrideC_,
726  index_t k_batch_,
727  bool is_reduce_ = false,
728  AElementwiseOperation a_element_op = AElementwiseOperation{},
729  BElementwiseOperation b_element_op = BElementwiseOperation{},
730  CElementwiseOperation c_element_op = CElementwiseOperation{})
731  : Problem{M_,
732  N_,
733  K_,
734  StrideA_,
735  StrideB_,
736  StrideC_,
737  k_batch_,
738  a_element_op,
739  b_element_op,
740  c_element_op},
741  p_a_grid{p_a_grid_},
742  p_b_grid{p_b_grid_},
743  p_c_grid{p_c_grid_},
744  is_reduce(is_reduce_)
745  {
746  }
747 
748  __host__ __device__ inline bool IsReduceAdd() const
749  {
750  return (Problem::KBatch > 1) && is_reduce;
751  }
752 
753  __host__ __device__ inline bool IsAtomicAdd() const
754  {
755  return (Problem::KBatch > 1) && (!is_reduce);
756  }
757 
758  const ADataType* p_a_grid;
759  const BDataType* p_b_grid;
760  CDataType* p_c_grid;
761  bool is_reduce;
762  };
763 
765  {
766 
767  __device__ SplitKBatchOffset(Argument& karg)
768  {
769  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
770  {
771  a_k_split_offset = blockIdx.z * karg.KRead / APackedSize;
772  }
773  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
774  {
775  a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
776  }
777 
778  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
779  {
780  b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
781  }
782  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
783  {
784  if constexpr(!PermuteB)
785  {
786  b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
787  }
788  else
789  {
790  const int k0_offset = karg.KRead * karg.N;
791  b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
792  }
793  }
794 
795  if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
796  {
797  karg.K = karg.KRead;
798  }
799  else
800  {
801  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
802  }
803 
804  if(karg.IsReduceAdd())
805  {
806  c_reduce_offset = blockIdx.z * karg.M * karg.N;
807  }
808  else
809  {
810  c_reduce_offset = 0;
811  }
812  }
813 
817  };
818 
819  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
820  {
821  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
822  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
823  constexpr index_t WaveSize = BlockSize / (MWave * NWave);
824  // A matrix in LDS memory, dst of blockwise copy
825  if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
826  {
827  // bank conflict when writting the data into LDS, but don't worry, we have whole entire
828  // loop to hide it in v4. it may give you some benefit from less valu in compute address
832  }
833  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
834  // in some cases.
836  {
837  constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize;
838  constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize;
839  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
840  make_tuple(
841  AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number),
843 
844  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
845  a_lds_block_desc,
851 
852  constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor(
853  a_lds_block_desc_permuted,
859 
860  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
861  a_lds_block_desc_ak0_mldslayer_m_ak1,
868 
869  return a_lds_block_desc_ak0_m_ak1;
870  }
871  else // ColumnMajor A
872  {
873  // kfold and mpair dimension is not always required.
874  // more dimension in merge_transform increase the difficulty of generating immarg offset
875  // for compiler.
876  constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
877  constexpr auto M1 = MPerBlock / M0;
878 
879  constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
880  constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
881  constexpr auto KThreadRead = WaveSize / MPerXdl;
882  constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
883 
884  constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
885  ? 1
886  : 128 / (AK1Number * M0 * sizeof(ADataType));
887  constexpr auto KThreadReadPerm =
888  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
889  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
890  : KThreadRead;
891 
892  // 1<=mpair<=n0
893  constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
894  ? 1
895  : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
896  ? M0
897  : 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
898 
899  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
903  Number<kfold * M0 / mpair>{},
904  Number<mpair>{},
905  AK1Number));
906 
907  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
908  a_lds_block_desc,
909  make_tuple(
913  make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
916  make_tuple(
918  make_tuple(
920 
921  constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
922  a_lds_block_desc_permuted,
923  make_tuple(
931  Sequence<1>{},
932  Sequence<2>{},
933  Sequence<3>{},
934  Sequence<4>{},
935  Sequence<5>{}),
937  Sequence<2>{},
938  Sequence<0, 3>{},
939  Sequence<4, 5>{},
940  Sequence<6>{},
941  Sequence<7>{}));
942 
943  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
944  a_lds_block_desc_unmerged,
947  Number<KThreadWrite / kfold / KThreadReadPerm>{},
948  Number<kfold>{},
955 
956  return a_lds_block_desc_ak0_m_ak1;
957  }
958  }
959 
960  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
961  {
962  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
963  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
964  constexpr index_t WaveSize = BlockSize / (MWave * NWave);
965  // B matrix in LDS memory, dst of blockwise copy
966  if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
967  {
968  // bank conflict when writting the data into LDS, but don't worry, we have whole entire
969  // loop to hide it in v4. it may give you some benefit from less valu in compute address
973  }
975  {
976  // NLdsLayer * K0 as logical Bank
977  constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize;
978  constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize;
979  constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
980  make_tuple(
981  BK0Number * Number<NLdsLayer>{}, Number<NPerBlock / NLdsLayer>{}, BK1Number),
983 
984  constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
985  b_lds_block_desc,
991 
992  constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor(
993  b_lds_block_desc_permuted,
999 
1000  constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
1001  b_lds_block_desc_bk0_nldslayer_n_bk1,
1008 
1009  return b_lds_block_desc_bk0_n_bk1;
1010  }
1011  else // RowMajor B
1012  {
1013  constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
1014  constexpr auto N1 = NPerBlock / N0;
1015 
1016  constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
1017  constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
1018  constexpr auto KThreadRead = WaveSize / NPerXdl;
1019  constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
1020 
1021  constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
1022  ? 1
1023  : 128 / (BK1Number * N0 * sizeof(BDataType));
1024  constexpr auto KThreadReadPerm =
1025  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
1026  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
1027  : KThreadRead;
1028 
1029  // 1<=npair<=n0
1030  constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128)
1031  ? 1
1032  : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0
1033  ? N0
1034  : 128 / (BK1Number * NPerXdl * sizeof(BDataType)));
1035 
1036  constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
1040  Number<kfold * N0 / npair>{},
1041  Number<npair>{},
1042  BK1Number));
1043 
1044  constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
1045  b_lds_block_desc,
1046  make_tuple(
1050  make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
1053  make_tuple(
1055  make_tuple(
1057 
1058  constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
1059  b_lds_block_desc_permuted,
1060  make_tuple(
1068  Sequence<1>{},
1069  Sequence<2>{},
1070  Sequence<3>{},
1071  Sequence<4>{},
1072  Sequence<5>{}),
1074  Sequence<2>{},
1075  Sequence<0, 3>{},
1076  Sequence<4, 5>{},
1077  Sequence<6>{},
1078  Sequence<7>{}));
1079 
1080  constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
1081  b_lds_block_desc_unmerged,
1084  Number<KThreadWrite / kfold / KThreadReadPerm>{},
1085  Number<kfold>{},
1092 
1093  return b_lds_block_desc_bk0_n_bk1;
1094  }
1095  }
1096 
1098  {
1099  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1100  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1101 
1102  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1104  make_tuple(I1,
1106  I1,
1108 
1109  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1110  }
1111 
1114  BlkGemmPipelineVer,
1115  BlkGemmPipeSched,
1116  BlockSize,
1117  ADataType,
1118  BDataType,
1119  ComputeTypeA,
1120  AccDataType,
1127  ABlockTransferSrcScalarPerVector,
1128  BBlockTransferSrcScalarPerVector,
1129  MPerBlock,
1130  NPerBlock,
1131  KPerBlock,
1132  MPerXdl,
1133  NPerXdl,
1134  MXdlPerWave,
1135  NXdlPerWave,
1136  KPack>())>;
1137 
1138  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
1139  {
1140  // LDS allocation for A and B: be careful of alignment
1141  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1142  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1143 
1144  // lds max alignment
1145  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1146 
1147  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1148  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1149 
1150  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
1151  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1152 
1153  // LDS allocation for C shuffle in LDS
1154  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1156 
1157  constexpr auto c_block_size =
1158  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1159 
1160  return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize +
1161  b_block_space_size_aligned * sizeof(BDataType) / BPackedSize),
1162  c_block_size * sizeof(CShuffleDataType));
1163  }
1164 
1165  template <InMemoryDataOperationEnum CGlobalMemoryDataOperation>
1166  __device__ static bool constexpr IsValidCompilationParameter()
1167  {
1168  enum struct Arch : bool
1169  {
1170 #if defined(__gfx950__)
1171  is_gfx950_build = true,
1172 #else
1173  is_gfx950_build = false,
1174 #endif
1175  };
1176 
1177  // skip building the instances with K1>=32 && PackedSize != 2 on pre-gfx950
1178  if constexpr(static_cast<bool>(Arch::is_gfx950_build) ||
1179  (AK1Number < 32 && BK1Number < 32) ||
1180  (AK1Number >= 32 && APackedSize == 2) ||
1181  (BK1Number >= 32 && BPackedSize == 2))
1182  {
1183 
1184  }
1185  else
1186  {
1187  return false;
1188  }
1189 
1190  return ck::tensor_operation::device::IsValidGemmCompilationParameter<
1191  BlockSize,
1192  MPerBlock,
1193  NPerBlock,
1194  MPerXdl,
1195  NPerXdl,
1196  MXdlPerWave,
1197  NXdlPerWave,
1198  CDataType,
1199  CGlobalMemoryDataOperation>();
1200  }
1201  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
1202  __host__ static constexpr bool CheckValidity(const Argument& karg)
1203  {
1204  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1205  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1206  "Invalid tuning param!");
1207 
1208  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
1213  {
1214  if(!(karg.M % MPerBlock == 0))
1215  {
1216  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1217  {
1218  std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
1219  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1220  << std::endl;
1221  }
1222  return false;
1223  }
1224  }
1225 
1226  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1231  {
1232  if(!(karg.N % NPerBlock == 0))
1233  {
1234  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1235  {
1236  std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1237  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1238  << std::endl;
1239  }
1240  return false;
1241  }
1242  }
1243 
1244  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1248  {
1249 
1250  auto K_t = karg.KBatch * KPerBlock;
1251  if(!(karg.K % K_t == 0))
1252  {
1253  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1254  {
1255  std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1256  << karg.K << " " << __FILE__ << ":" << __LINE__
1257  << ", in function: " << __func__ << std::endl;
1258  }
1259  return false;
1260  }
1261  }
1262  else
1263  {
1264  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1265  auto K_t = karg.KBatch * KReadVec;
1266  auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1267  if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1268  {
1269  return false;
1270  }
1271  }
1272 
1274  {
1275  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1276  {
1277  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1278  {
1279  std::cout << "Arg K (" << karg.K
1280  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1281  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1282  << __LINE__ << ", in function: " << __func__ << std::endl;
1283  }
1284  return false;
1285  }
1286  }
1287  else
1288  {
1289  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1290  {
1291  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1292  {
1293  std::cout << "Arg M (" << karg.M
1294  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1295  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1296  << __LINE__ << ", in function: " << __func__ << std::endl;
1297  }
1298  return false;
1299  }
1300  }
1301 
1303  {
1304  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1305  {
1306  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1307  {
1308  std::cout << "Arg N (" << karg.N
1309  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1310  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1311  << __LINE__ << ", in function: " << __func__ << std::endl;
1312  }
1313  return false;
1314  }
1315  }
1316  else
1317  {
1318  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1319  {
1320  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1321  {
1322  std::cout << "Arg K (" << karg.K
1323  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1324  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1325  << __LINE__ << ", in function: " << __func__ << std::endl;
1326  }
1327  return false;
1328  }
1329  }
1330 
1332  {
1333  if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1334  {
1335  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1336  {
1337  std::cout << "Arg N (" << karg.N
1338  << ") value is not a multiple of "
1339  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1340  << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
1341  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1342  << std::endl;
1343  }
1344  return false;
1345  }
1346  }
1347  else
1348  {
1349  if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1350  {
1351  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1352  {
1353  std::cout << "Arg M (" << karg.M
1354  << ") value is not a multiple of "
1355  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1356  << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
1357  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1358  << std::endl;
1359  }
1360  return false;
1361  }
1362  }
1363 
1364  if constexpr(!(is_same<remove_cvref_t<CDataType>, half_t>::value ||
1368  {
1369  if(!karg.IsReduceAdd())
1370  {
1371  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1372  {
1373  std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__
1374  << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1375  }
1376  if(karg.KBatch > 1)
1377  {
1378  return false;
1379  }
1380  }
1381  }
1382 
1383  // check gridwise gemm pipeline
1384  const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1385 
1386  if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1)
1387  {
1388  if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1389  {
1390  return false;
1391  }
1392  }
1393 
1394  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1395  return true;
1396  }
1397 
1398  __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1399  {
1400  const index_t num_loop = K / KPerBlock;
1401 
1402  return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1403  }
1404 
1405  __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1406  {
1407  const index_t num_loop = K / KPerBlock;
1408 
1409  return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1410  }
1411 
1412  template <typename CGridDesc>
1413  __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
1414  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1415  {
1416  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1417  c_grid_desc_m_n,
1422 
1423  return c_grid_desc_mblock_mperblock_nblock_nperblock;
1424  }
1425 
1426  // return block_id to C matrix tile idx (m0, n0) mapping
1427  // if arch = gfx942
1429  // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
1430 
1431  template <typename AGridDesc_AK0_M_K1,
1432  typename BGridDesc_BK0_N_K1,
1433  typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1434  bool HasMainKBlockLoop,
1435  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1436  TailNumber TailNum = TailNumber::Odd>
1437  __device__ static void Run(const ADataType* p_a_grid,
1438  const BDataType* p_b_grid,
1439  CDataType* p_c_grid,
1440  void* p_shared,
1441  const Problem& problem,
1442  const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1443  const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1444  const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1445  c_grid_desc_mblock_mperblock_nblock_nperblock)
1446  {
1447  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1448  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1449  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1450  p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1451  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1452  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1453 
1454  // divide block work by [M, N]
1455  const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
1456 
1457  const auto block_work_idx =
1458  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1459 
1460  if(!block_2_ctile_map.ValidCTileIndex(
1461  block_work_idx,
1462  make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1463  c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1464  {
1465  return;
1466  }
1467 
1468  const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1469  const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1470 
1471  // HACK: this force m/n_block_data_idx_on_grid into SGPR
1472  const index_t m_block_data_idx_on_grid =
1473  __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1474 
1475  const index_t n_block_data_idx_on_grid =
1476  __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1477 
1478  // lds max alignment
1479  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1480 
1481  // A matrix in LDS memory, dst of blockwise copy
1482  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1483 
1484  // B matrix in LDS memory, dst of blockwise copy
1485  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1486 
1487  // A matrix blockwise copy
1488  auto a_blockwise_copy =
1490  AElementwiseOperation,
1494  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1495  ABlockTransferThreadClusterArrangeOrder,
1496  ADataType,
1497  ADataType,
1498  decltype(a_grid_desc_ak0_m_ak1),
1499  decltype(a_block_desc_ak0_m_ak1),
1500  ABlockTransferSrcAccessOrder,
1502  ABlockTransferSrcVectorDim,
1503  2,
1504  ABlockTransferSrcScalarPerVector,
1505  ABlockTransferDstScalarPerVector_AK1,
1506  1,
1507  1,
1508  AThreadTransferSrcResetCoordinateAfterRun,
1509  true,
1510  BlockwiseGemmPipe::GlobalBufferNum>(
1511  a_grid_desc_ak0_m_ak1,
1512  make_multi_index(0, m_block_data_idx_on_grid, 0),
1513  problem.a_element_op_,
1514  a_block_desc_ak0_m_ak1,
1515  make_multi_index(0, 0, 0),
1517 
1518  // B matrix blockwise copy
1519  auto b_blockwise_copy =
1521  BElementwiseOperation,
1525  BBlockTransferThreadClusterLengths_BK0_N_BK1,
1526  BBlockTransferThreadClusterArrangeOrder,
1527  BDataType,
1528  BDataType,
1529  decltype(b_grid_desc_bk0_n_bk1),
1530  decltype(b_block_desc_bk0_n_bk1),
1531  BBlockTransferSrcAccessOrder,
1533  BBlockTransferSrcVectorDim,
1534  2,
1535  BBlockTransferSrcScalarPerVector,
1536  BBlockTransferDstScalarPerVector_BK1,
1537  1,
1538  1,
1539  BThreadTransferSrcResetCoordinateAfterRun,
1540  true,
1541  BlockwiseGemmPipe::GlobalBufferNum>(
1542  b_grid_desc_bk0_n_bk1,
1543  make_multi_index(0, n_block_data_idx_on_grid, 0),
1544  problem.b_element_op_,
1545  b_block_desc_bk0_n_bk1,
1546  make_multi_index(0, 0, 0),
1548 
1549  // LDS allocation for A and B: be careful of alignment
1550  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1551  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1552 
1553  // Cast after lds
1554  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1555  static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1556 
1557  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1558  reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) + a_block_space_size_aligned *
1559  sizeof(ADataType) /
1560  APackedSize),
1561  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1562 
1563  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1564  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
1565 
1566  // Blockwise GEMM pipeline
1567  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1568  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1569  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1570 
1571  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1572  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1573  KPerBlock);
1574 
1575  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
1576  a_block_desc_ak0_m_ak1,
1577  a_blockwise_copy,
1578  a_grid_buf,
1579  a_block_buf,
1580  a_block_slice_copy_step,
1581  b_grid_desc_bk0_n_bk1,
1582  b_block_desc_bk0_n_bk1,
1583  b_blockwise_copy,
1584  b_grid_buf,
1585  b_block_buf,
1586  b_block_slice_copy_step,
1587  c_thread_buf,
1588  num_k_block_main_loop);
1589 
1590  // shuffle C and write out
1591  {
1592  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1593  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1594  "wrong!");
1595 
1596  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1597  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1598 
1599  // TODO: hacky, fix it!
1600  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1601  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1602 
1603  // TODO: hacky, fix it!
1604  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1605  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1606  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1607 
1608  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1609  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1610  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1611  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1612  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1613  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1614  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1615  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1616 
1617  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1619 
1620  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1621  static_cast<CShuffleDataType*>(p_shared),
1622  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1623 
1624  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1625  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1626  make_tuple(
1629  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1630  M1, // M1 = MWave
1631  M2, // M2 * M3 * M4 = MPerXdl
1632  M3,
1633  M4)),
1636  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1637  N1, // N1 = NWave
1638  N2))), // N2 = NPerXdl
1640  make_tuple(
1642 
1643  // calculate origin of thread output tensor on global memory
1644  // blockwise GEMM c matrix starting index
1645  const auto c_thread_mtx_on_block =
1646  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1647 
1648  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1649  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1650 
1651  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1653  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1655  make_tuple(Sequence<0>{}));
1656 
1657  const auto m_thread_data_on_block_idx =
1658  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1659  make_multi_index(m_thread_data_on_block));
1660 
1661  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1665  make_tuple(Sequence<0>{}));
1666 
1667  const auto n_thread_data_on_block_idx =
1668  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1669  make_multi_index(n_thread_data_on_block));
1670 
1672  const auto& vpgr_to_lds_element_op = [&] {
1673  if constexpr(DoElementwiseBeforeCShuffle)
1674  {
1675  return problem.c_element_op_;
1676  }
1677  else
1678  {
1679  return pass_through;
1680  }
1681  };
1682  const auto& lds_to_global_element_op = [&] {
1683  if constexpr(!DoElementwiseBeforeCShuffle)
1684  {
1685  return problem.c_element_op_;
1686  }
1687  else
1688  {
1689  return pass_through;
1690  }
1691  };
1692 
1693  // shuffle: threadwise copy C from VGPR to LDS
1694  auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
1695  AccDataType,
1696  CShuffleDataType,
1697  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1698  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1699  conditional_t<DoElementwiseBeforeCShuffle,
1700  CElementwiseOperation,
1702  Sequence<CShuffleMXdlPerWavePerShuffle,
1703  CShuffleNXdlPerWavePerShuffle,
1704  I1,
1705  I1,
1706  M2,
1707  I1,
1708  M4,
1709  I1>,
1711  7,
1712  1,
1714  1,
1715  true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1716  make_multi_index(0,
1717  0,
1718  m_thread_data_on_block_idx[I1],
1719  n_thread_data_on_block_idx[I1],
1720  m_thread_data_on_block_idx[I2],
1721  m_thread_data_on_block_idx[I3],
1722  m_thread_data_on_block_idx[I4],
1723  n_thread_data_on_block_idx[I2]),
1724  vpgr_to_lds_element_op()};
1725 
1726  // shuffle: blockwise copy C from LDS to global
1727  auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1728  ThisThreadBlock, // ThreadGroup
1729  conditional_t<!DoElementwiseBeforeCShuffle,
1730  CElementwiseOperation,
1732  CGlobalMemoryDataOperation, // DstInMemOp,
1733  Sequence<1,
1734  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1735  1,
1736  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1737  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1738  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1739  CShuffleDataType, // typename SrcData,
1740  CDataType, // typename DstData,
1741  decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1742  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1743  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1744  3, // index_t VectorDim,
1745  CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1746  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1747  false> // bool ThreadTransferDstResetCoordinateAfterRun>
1748  {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1749  make_multi_index(0, 0, 0, 0),
1750  c_grid_desc_mblock_mperblock_nblock_nperblock,
1751  make_multi_index(block_m_id, 0, block_n_id, 0),
1752  lds_to_global_element_op()};
1753 
1754  // space filling curve for threadwise C in VGPR
1755  constexpr auto sfc_c_vgpr =
1758  Sequence<CShuffleMXdlPerWavePerShuffle,
1759  CShuffleNXdlPerWavePerShuffle,
1760  1,
1761  1,
1762  M2,
1763  1,
1764  M4,
1765  1>>{};
1766 
1767  // space filling curve for shuffled blockwise C in global mem
1768  constexpr auto sfc_c_global =
1771  Sequence<1,
1772  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1773  1,
1774  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1775 
1776  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1777 
1778  static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1779 
1780  static_for<0, num_access, 1>{}([&](auto access_id) {
1781  // make sure it's safe to write to LDS
1782  block_sync_lds();
1783 
1784  // each thread write its data from VGPR to LDS
1785  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1786  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1787  c_thread_buf,
1788  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1789  c_shuffle_block_buf);
1790 
1791  // make sure it's safe to read from LDS
1792  block_sync_lds();
1793 
1794  // each block copy its data from LDS to global
1795  c_shuffle_block_copy_lds_to_global.Run(
1796  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1797  c_shuffle_block_buf,
1798  c_grid_desc_mblock_mperblock_nblock_nperblock,
1799  c_grid_buf);
1800 
1801  if constexpr(access_id < num_access - 1)
1802  {
1803  constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1804 
1805  // move on C
1806  c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1807  c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1808  }
1809  });
1810  }
1811  }
1812 
1813  template <bool HasMainKBlockLoop,
1814  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1815  TailNumber TailNum = TailNumber::Odd>
1816  __device__ static void Run(const ADataType* p_a_grid,
1817  const BDataType* p_b_grid,
1818  CDataType* p_c_grid,
1819  void* p_shared,
1820  const Problem& problem)
1821  {
1822  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1823  problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
1824  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
1825  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
1826  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
1827  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
1828  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1830  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1831 
1832  Run<decltype(a_grid_desc_ak0_m_ak1),
1833  decltype(b_grid_desc_bk0_n_bk1),
1834  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1835  HasMainKBlockLoop,
1836  CGlobalMemoryDataOperation,
1837  TailNum>(p_a_grid,
1838  p_b_grid,
1839  p_c_grid,
1840  p_shared,
1841  problem,
1842  a_grid_desc_ak0_m_ak1,
1843  b_grid_desc_bk0_n_bk1,
1844  c_grid_desc_mblock_mperblock_nblock_nperblock);
1845  }
1846 
1847  template <typename AGridDesc_AK0_M_K1,
1848  typename BGridDesc_BK0_N_K1,
1849  typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1850  bool HasMainKBlockLoop,
1851  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1852  TailNumber TailNum = TailNumber::Odd>
1853  __device__ static void Run_2Lds(const ADataType* p_a_grid,
1854  const BDataType* p_b_grid,
1855  CDataType* p_c_grid,
1856  void* p_shared_0,
1857  void* p_shared_1,
1858  const Problem& problem,
1859  const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1860  const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1861  const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1862  c_grid_desc_mblock_mperblock_nblock_nperblock)
1863  {
1864  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1865  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1866  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1867  p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1868  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1869  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1870 
1871  // divide block work by [M, N]
1872  const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
1873 
1874  const auto block_work_idx =
1875  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1876 
1877  if(!block_2_ctile_map.ValidCTileIndex(
1878  block_work_idx,
1879  make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1880  c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1881  {
1882  return;
1883  }
1884 
1885  const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1886  const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1887 
1888  // HACK: this force m/n_block_data_idx_on_grid into SGPR
1889  const index_t m_block_data_idx_on_grid =
1890  __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1891 
1892  const index_t n_block_data_idx_on_grid =
1893  __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1894 
1895  // lds max alignment
1896  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1897 
1898  // A matrix in LDS memory, dst of blockwise copy
1899  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1900 
1901  // B matrix in LDS memory, dst of blockwise copy
1902  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1903 
1904  // A matrix blockwise copy
1905  auto a_blockwise_copy =
1907  AElementwiseOperation,
1911  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1912  ABlockTransferThreadClusterArrangeOrder,
1913  ADataType,
1914  ADataType,
1915  decltype(a_grid_desc_ak0_m_ak1),
1916  decltype(a_block_desc_ak0_m_ak1),
1917  ABlockTransferSrcAccessOrder,
1919  ABlockTransferSrcVectorDim,
1920  2,
1921  ABlockTransferSrcScalarPerVector,
1922  ABlockTransferDstScalarPerVector_AK1,
1923  1,
1924  1,
1925  AThreadTransferSrcResetCoordinateAfterRun,
1926  true,
1927  BlockwiseGemmPipe::GlobalBufferNum>(
1928  a_grid_desc_ak0_m_ak1,
1929  make_multi_index(0, m_block_data_idx_on_grid, 0),
1930  problem.a_element_op_,
1931  a_block_desc_ak0_m_ak1,
1932  make_multi_index(0, 0, 0),
1934 
1935  // B matrix blockwise copy
1936  auto b_blockwise_copy =
1938  BElementwiseOperation,
1942  BBlockTransferThreadClusterLengths_BK0_N_BK1,
1943  BBlockTransferThreadClusterArrangeOrder,
1944  BDataType,
1945  BDataType,
1946  decltype(b_grid_desc_bk0_n_bk1),
1947  decltype(b_block_desc_bk0_n_bk1),
1948  BBlockTransferSrcAccessOrder,
1950  BBlockTransferSrcVectorDim,
1951  2,
1952  BBlockTransferSrcScalarPerVector,
1953  BBlockTransferDstScalarPerVector_BK1,
1954  1,
1955  1,
1956  BThreadTransferSrcResetCoordinateAfterRun,
1957  true,
1958  BlockwiseGemmPipe::GlobalBufferNum>(
1959  b_grid_desc_bk0_n_bk1,
1960  make_multi_index(0, n_block_data_idx_on_grid, 0),
1961  problem.b_element_op_,
1962  b_block_desc_bk0_n_bk1,
1963  make_multi_index(0, 0, 0),
1965 
1966  // LDS allocation for A and B: be careful of alignment
1967  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1968  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1969 
1970  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1971  static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1972 
1973  auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1974  bit_cast<BDataType*>(static_cast<char*>(p_shared_0) +
1975  a_block_space_size_aligned * sizeof(ADataType)),
1976  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1977 
1978  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1979  static_cast<ADataType*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1980 
1981  auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1982  bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
1983  a_block_space_size_aligned * sizeof(ADataType)),
1984  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1985 
1986  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
1987  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
1988 
1989  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1990  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
1991 
1992  // Blockwise GEMM pipeline
1993  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1994  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1995  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1996 
1997  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1998  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1999  KPerBlock);
2000 
2001  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
2002  a_block_desc_ak0_m_ak1,
2003  a_blockwise_copy,
2004  a_grid_buf,
2005  a_block_bufs,
2006  a_block_slice_copy_step,
2007  b_grid_desc_bk0_n_bk1,
2008  b_block_desc_bk0_n_bk1,
2009  b_blockwise_copy,
2010  b_grid_buf,
2011  b_block_bufs,
2012  b_block_slice_copy_step,
2013  c_thread_buf,
2014  num_k_block_main_loop);
2015 
2016  // shuffle C and write out
2017  {
2018  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2019  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2020  "wrong!");
2021 
2022  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2023  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
2024 
2025  // TODO: hacky, fix it!
2026  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2027  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2028 
2029  // TODO: hacky, fix it!
2030  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
2031  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2032  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2033 
2034  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
2035  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
2036  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
2037  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
2038  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
2039  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
2040  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
2041  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
2042 
2043  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2045 
2046  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2047  static_cast<CShuffleDataType*>(p_shared_0),
2048  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2049 
2050  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
2051  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2052  make_tuple(
2055  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
2056  M1, // M1 = MWave
2057  M2, // M2 * M3 * M4 = MPerXdl
2058  M3,
2059  M4)),
2062  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
2063  N1, // N1 = NWave
2064  N2))), // N2 = NPerXdl
2066  make_tuple(
2068 
2069  // calculate origin of thread output tensor on global memory
2070  // blockwise GEMM c matrix starting index
2071  const auto c_thread_mtx_on_block =
2072  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2073 
2074  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2075  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2076 
2077  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2079  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
2081  make_tuple(Sequence<0>{}));
2082 
2083  const auto m_thread_data_on_block_idx =
2084  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2085  make_multi_index(m_thread_data_on_block));
2086 
2087  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2091  make_tuple(Sequence<0>{}));
2092 
2093  const auto n_thread_data_on_block_idx =
2094  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2095  make_multi_index(n_thread_data_on_block));
2096 
2097  // shuffle: threadwise copy C from VGPR to LDS
2098  auto c_thread_copy_vgpr_to_lds =
2100  CShuffleDataType,
2101  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2102  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2104  Sequence<CShuffleMXdlPerWavePerShuffle,
2105  CShuffleNXdlPerWavePerShuffle,
2106  I1,
2107  I1,
2108  M2,
2109  I1,
2110  M4,
2111  I1>,
2113  7,
2114  1,
2116  1,
2117  true>{
2118  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2119  make_multi_index(0,
2120  0,
2121  m_thread_data_on_block_idx[I1],
2122  n_thread_data_on_block_idx[I1],
2123  m_thread_data_on_block_idx[I2],
2124  m_thread_data_on_block_idx[I3],
2125  m_thread_data_on_block_idx[I4],
2126  n_thread_data_on_block_idx[I2]),
2128 
2129  // shuffle: blockwise copy C from LDS to global
2130  auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
2131  ThisThreadBlock, // ThreadGroup
2132  CElementwiseOperation, // ElementwiseOperation,
2133  CGlobalMemoryDataOperation, // DstInMemOp,
2134  Sequence<1,
2135  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2136  1,
2137  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2138  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
2139  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2140  CShuffleDataType, // typename SrcData,
2141  CDataType, // typename DstData,
2142  decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2143  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2144  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
2145  3, // index_t VectorDim,
2146  CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
2147  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
2148  false> // bool ThreadTransferDstResetCoordinateAfterRun>
2149  {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2150  make_multi_index(0, 0, 0, 0),
2151  c_grid_desc_mblock_mperblock_nblock_nperblock,
2152  make_multi_index(block_m_id, 0, block_n_id, 0),
2153  problem.c_element_op_};
2154 
2155  // space filling curve for threadwise C in VGPR
2156  constexpr auto sfc_c_vgpr =
2159  Sequence<CShuffleMXdlPerWavePerShuffle,
2160  CShuffleNXdlPerWavePerShuffle,
2161  1,
2162  1,
2163  M2,
2164  1,
2165  M4,
2166  1>>{};
2167 
2168  // space filling curve for shuffled blockwise C in global mem
2169  constexpr auto sfc_c_global =
2172  Sequence<1,
2173  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2174  1,
2175  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2176 
2177  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2178 
2179  static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
2180 
2181  static_for<0, num_access, 1>{}([&](auto access_id) {
2182  // make sure it's safe to write to LDS
2183  block_sync_lds();
2184 
2185  // each thread write its data from VGPR to LDS
2186  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2187  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2188  c_thread_buf,
2189  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2190  c_shuffle_block_buf);
2191 
2192  // make sure it's safe to read from LDS
2193  block_sync_lds();
2194 
2195  // each block copy its data from LDS to global
2196  c_shuffle_block_copy_lds_to_global.Run(
2197  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2198  c_shuffle_block_buf,
2199  c_grid_desc_mblock_mperblock_nblock_nperblock,
2200  c_grid_buf);
2201 
2202  if constexpr(access_id < num_access - 1)
2203  {
2204  constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
2205 
2206  // move on C
2207  c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
2208  c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
2209  }
2210  });
2211  }
2212  }
2213 
2214  template <bool HasMainKBlockLoop,
2215  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
2216  TailNumber TailNum = TailNumber::Odd>
2217  __device__ static void Run_2Lds(const ADataType* p_a_grid,
2218  const BDataType* p_b_grid,
2219  CDataType* p_c_grid,
2220  void* p_shared_0,
2221  void* p_shared_1,
2222  const Problem& problem)
2223  {
2224  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
2225  problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
2226  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
2227  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
2228  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
2229  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
2230 
2231  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2233  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2234 
2235  Run_2Lds<decltype(a_grid_desc_ak0_m_ak1),
2236  decltype(b_grid_desc_bk0_n_bk1),
2237  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2238  HasMainKBlockLoop,
2239  CGlobalMemoryDataOperation,
2240  TailNum>(p_a_grid,
2241  p_b_grid,
2242  p_c_grid,
2243  p_shared_0,
2244  p_shared_1,
2245  problem,
2246  a_grid_desc_ak0_m_ak1,
2247  b_grid_desc_bk0_n_bk1,
2248  c_grid_desc_mblock_mperblock_nblock_nperblock);
2249  }
2250 };
2251 
2252 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:268
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:277
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
_Float16 half_t
Definition: data_type.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
ushort bhalf_t
Definition: data_type.hpp:30
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:62
constexpr auto BlockGemmPipeline_Selector()
Definition: blockwise_gemm_pipeline_wmma_selector.hpp:32
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:185
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
constexpr bool is_same_v
Definition: type.hpp:283
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:299
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
unsigned int uint32_t
Definition: stdint.h:126
signed int int32_t
Definition: stdint.h:123
Definition: block_to_ctile_map.hpp:271
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:283
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:716
const BElementwiseOperation b_element_op
Definition: gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:643
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t k_batch_, bool is_reduce_=false, AElementwiseOperation a_element_op=AElementwiseOperation{}, BElementwiseOperation b_element_op=BElementwiseOperation{}, CElementwiseOperation c_element_op=CElementwiseOperation{})
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:717
const BDataType * p_b_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:759
CDataType * p_c_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:760
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:748
const AElementwiseOperation a_element_op
Definition: gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:642
__host__ __device__ bool IsAtomicAdd() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:753
const ADataType * p_a_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:758
const CElementwiseOperation c_element_op
Definition: gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:644
bool is_reduce
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:761
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:641
index_t N
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:695
index_t NPadded
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:702
index_t KBatch
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:700
index_t StrideA
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:697
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t KBatch_, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:642
CElementwiseOperation c_element_op_
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:711
index_t BK0
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:706
index_t M
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:694
index_t NBlock
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:708
index_t MPadded
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:701
index_t K
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:696
index_t StrideB
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:698
index_t KPadded
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:704
index_t StrideC
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:699
index_t MBlock
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:707
BElementwiseOperation b_element_op_
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:710
index_t AK0
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:705
index_t KRead
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:703
AElementwiseOperation a_element_op_
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:709
__host__ void Print() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:673
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:765
index_t a_k_split_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:814
index_t b_k_split_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:815
__device__ SplitKBatchOffset(Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:767
index_t c_reduce_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:816
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:247
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:572
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:451
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:337
static constexpr auto is_scale_mfma
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:273
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:314
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:304
static constexpr auto BK1Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:261
static constexpr index_t APackedSize
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:285
static constexpr bool is_single_rate_mfma
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:264
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1405
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:283
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1413
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:319
static constexpr auto I2
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:250
static constexpr index_t KPack
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:274
static constexpr auto lcm_AK1_BK1
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:263
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:355
static constexpr auto I7
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:255
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack >())> BlockwiseGemmPipe
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1136
static constexpr auto I5
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:253
static __device__ constexpr bool IsValidCompilationParameter()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1166
static constexpr auto AK1Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:260
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1202
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1853
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:299
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:344
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:2217
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:309
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:563
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:325
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1138
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:960
static constexpr index_t BPackedSize
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:292
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1437
static constexpr auto I6
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:254
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:819
static constexpr auto I1
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:249
static constexpr auto I0
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:248
static constexpr auto I3
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:251
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1816
static constexpr auto I4
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:252
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:331
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:369
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:349
static constexpr auto BK0Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:259
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1097
static constexpr auto AK0Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:258
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1398
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:580
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition: xdlops_gemm.hpp:1208
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: data_type.hpp:187
Definition: functional2.hpp:33
Definition: device_base.hpp:197
Definition: unary_element_wise_operation.hpp:334
#define CK_ENV(name)
Definition: env.hpp:129