/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp Source File
gridwise_gemm_xdl_cshuffle_v2.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
17 
18 namespace ck {
19 
20 template <typename GridwiseGemm, bool HasMainKBlockLoop, index_t TailNum = 3>
21 __global__ void
22 #if CK_USE_LAUNCH_BOUNDS
23 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1)
24 #endif
25  // __attribute__((amdgpu_waves_per_eu(1, 1)))
26  kernel_gemm_xdl_cshuffle_v2(typename GridwiseGemm::Argument karg)
27 {
28 #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
29  defined(__gfx12__)
30  if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
31  {
32  // Pass two lds pointer is the key to tell compiler that ds_read/write
33  // operate on different lds chunk at same time without order dependecy
34  __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
35  __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
36 
37  GridwiseGemm::template Run<HasMainKBlockLoop, TailNum>(
38  karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared_0, p_shared_1, karg);
39  }
40 #else
41  ignore = karg;
42 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
43 }
44 
45 template <typename GridwiseGemm,
46  typename FloatA,
47  typename FloatB,
48  typename FloatC,
49  bool HasMainKBlockLoop>
50 __global__ void
51 #if CK_USE_LAUNCH_BOUNDS
52 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1)
53 #endif
54  kernel_gemm_xdl_cshuffle_v2(const FloatA* p_a_grid,
55  const FloatB* p_b_grid,
56  FloatC* p_c_grid,
57  typename GridwiseGemm::Problem problem)
58 {
59 #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
60  defined(__gfx12__)
61  if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
62  {
63  __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
64  __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
65 
66  GridwiseGemm::template Run<HasMainKBlockLoop>(
67  p_a_grid, p_b_grid, p_c_grid, p_shared_0, p_shared_1, problem);
68  }
69 #else
70  ignore = p_a_grid;
71  ignore = p_b_grid;
72  ignore = p_c_grid;
73  ignore = problem;
74 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
75 }
76 
77 template <typename ALayout,
78  typename BLayout,
79  typename CLayout,
80  typename FloatA,
81  typename FloatB,
82  typename FloatGemmAcc,
83  typename FloatCShuffle,
84  typename FloatC,
85  typename AElementwiseOperation,
86  typename BElementwiseOperation,
87  typename CElementwiseOperation,
89  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
90  index_t NumGemmKPrefetchStage,
91  index_t BlockSize,
92  index_t MPerBlock,
93  index_t NPerBlock,
94  index_t KPerBlock,
95  index_t AK1Value,
96  index_t BK1Value,
97  index_t MPerXdl,
98  index_t NPerXdl,
99  index_t MXdlPerWave,
100  index_t NXdlPerWave,
101  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
102  typename ABlockTransferThreadClusterArrangeOrder,
103  typename ABlockTransferSrcAccessOrder,
104  index_t ABlockTransferSrcVectorDim,
105  index_t ABlockTransferSrcScalarPerVector,
106  index_t ABlockTransferDstScalarPerVector_AK1,
107  bool AThreadTransferSrcResetCoordinateAfterRun,
108  index_t ABlockLdsExtraM,
109  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
110  typename BBlockTransferThreadClusterArrangeOrder,
111  typename BBlockTransferSrcAccessOrder,
112  index_t BBlockTransferSrcVectorDim,
113  index_t BBlockTransferSrcScalarPerVector,
114  index_t BBlockTransferDstScalarPerVector_BK1,
115  bool BThreadTransferSrcResetCoordinateAfterRun,
116  index_t BBlockLdsExtraN,
117  index_t CShuffleMXdlPerWavePerShuffle,
118  index_t CShuffleNXdlPerWavePerShuffle,
119  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
120  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
121  LoopScheduler LoopSched,
122  PipelineVersion PipelineVer = PipelineVersion::v1,
123  typename ComputeTypeA = FloatC,
124  typename ComputeTypeB = ComputeTypeA>
126 {
127  static constexpr auto I0 = Number<0>{};
128  static constexpr auto I1 = Number<1>{};
129  static constexpr auto I2 = Number<2>{};
130  static constexpr auto I3 = Number<3>{};
131  static constexpr auto I4 = Number<4>{};
132  static constexpr auto I5 = Number<5>{};
133  static constexpr auto I6 = Number<6>{};
134  static constexpr auto I7 = Number<7>{};
135 
136  // K1 should be Number<...>
137  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
138  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
139  static constexpr auto AK1Number = Number<AK1Value>{};
140  static constexpr auto BK1Number = Number<BK1Value>{};
141 
143 
144  __host__ static auto CalculateGridSize(index_t M, index_t N)
145  {
147  }
148 
149  __host__ static auto CalculateMPadded(index_t M)
150  {
151  return math::integer_divide_ceil(M, MPerBlock) * MPerBlock;
152  }
153 
154  __host__ static auto CalculateNPadded(index_t N)
155  {
156  return math::integer_divide_ceil(N, NPerBlock) * NPerBlock;
157  }
158 
159  __host__ static auto CalculateKPadded(index_t K)
160  {
161  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
162  }
163 
164  __host__ static auto CalculateAK0(index_t K)
165  {
167 
168  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
169  GemmSpec == GemmSpecialization::MNKPadding ||
170  GemmSpec == GemmSpecialization::KPadding ||
171  GemmSpec == GemmSpecialization::NKPadding)
172  {
173  return CalculateKPadded(K) / AK1Value;
174  }
175  else
176  {
177  return K / AK1Value;
178  }
179  }
180 
181  __host__ static auto CalculateBK0(index_t K)
182  {
184 
185  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
186  GemmSpec == GemmSpecialization::MNKPadding ||
187  GemmSpec == GemmSpecialization::KPadding ||
188  GemmSpec == GemmSpecialization::MKPadding)
189  {
190  return CalculateKPadded(K) / BK1Value;
191  }
192  else
193  {
194  return K / BK1Value;
195  }
196  }
197 
198  __host__ static auto CalculateMBlock(index_t M)
199  {
200  return math::integer_divide_floor(M, MPerBlock);
201  }
202 
203  __host__ static auto CalculateNBlock(index_t N)
204  {
205  return math::integer_divide_floor(N, NPerBlock);
206  }
207 
208  template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
209  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
210  {
211  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
212  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
213 
215  TileDesc_K0_MN_K1{},
221  }
222 
223  __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
224  index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
225  {
226  const auto a_grid_desc_mraw_kraw = [&]() {
227  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
228  {
229  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
230  }
231  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
232  {
233  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
234  }
235  }();
236 
238 
239  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
240  GemmSpec == GemmSpecialization::MNKPadding)
241  {
242  // pad both M and K
243  const auto a_grid_desc_m_k =
244  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
246  make_right_pad_transform(K, KPad - K)),
249 
250  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
251  a_grid_desc_m_k,
256 
257  return a_grid_desc_ak0_m_ak1;
258  }
259  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
260  GemmSpec == GemmSpecialization::MNPadding)
261  {
262  // pad M, but not K
263  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
264  a_grid_desc_mraw_kraw,
266  make_right_pad_transform(M, MPad - M)),
269 
270  return a_grid_desc_ak0_m_ak1;
271  }
272  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
273  GemmSpec == GemmSpecialization::NKPadding)
274  {
275  // pad K, but not M
276  const auto a_grid_desc_m_k = transform_tensor_descriptor(
277  a_grid_desc_mraw_kraw,
281 
282  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
283  a_grid_desc_m_k,
288 
289  return a_grid_desc_ak0_m_ak1;
290  }
291  else
292  {
293  // not pad M or K
294  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
295  a_grid_desc_mraw_kraw,
300 
301  return a_grid_desc_ak0_m_ak1;
302  }
303  }
304 
305  __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
306  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
307  {
308  const auto b_grid_desc_nraw_kraw = [&]() {
310  {
311  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
312  }
314  {
315  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
316  }
317  }();
318 
320 
321  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
322  GemmSpec == GemmSpecialization::MNKPadding)
323  {
324  // pad both N and K
325  const auto b_grid_desc_n_k =
326  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
328  make_right_pad_transform(K, KPad - K)),
331 
332  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
333  b_grid_desc_n_k,
338 
339  return b_grid_desc_bk0_n_bk1;
340  }
341  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
342  GemmSpec == GemmSpecialization::MNPadding)
343  {
344  // pad N, but not K
345  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
346  b_grid_desc_nraw_kraw,
348  make_right_pad_transform(N, NPad - N)),
351 
352  return b_grid_desc_bk0_n_bk1;
353  }
354  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
355  GemmSpec == GemmSpecialization::MKPadding)
356  {
357  // pad K, but not N
358  const auto b_grid_desc_n_k = transform_tensor_descriptor(
359  b_grid_desc_nraw_kraw,
363 
364  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
365  b_grid_desc_n_k,
370 
371  return b_grid_desc_bk0_n_bk1;
372  }
373  else
374  {
375  // not pad N or K
376  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
377  b_grid_desc_nraw_kraw,
382 
383  return b_grid_desc_bk0_n_bk1;
384  }
385  }
386 
387  template <typename ABlockDesc_AK0_M_AK1>
388  __host__ __device__ static constexpr auto
389  MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
390  {
391  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
392 
393  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
394  }
395 
396  template <typename BBlockDesc_BK0_N_BK1>
397  __host__ __device__ static constexpr auto
398  MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
399  {
400  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
401 
402  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
403  }
404 
405  __host__ __device__ static auto
407  {
408  const auto c_grid_desc_mraw_nraw = [&]() {
410  {
411  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
412  }
414  {
415  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
416  }
417  }();
418 
420 
421  if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
422  GemmSpec == GemmSpecialization::MNKPadding)
423  {
424  // pad M and N
425  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
427  make_right_pad_transform(N, NPad - N)),
430  }
431  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
432  GemmSpec == GemmSpecialization::MKPadding)
433  {
434  // pad M, but not N
436  c_grid_desc_mraw_nraw,
440  }
441  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
442  GemmSpec == GemmSpecialization::NKPadding)
443  {
444  // pad N, but not M
446  c_grid_desc_mraw_nraw,
450  }
451  else
452  {
453  // not pad M or N
454  return c_grid_desc_mraw_nraw;
455  }
456  }
457 
458  struct Problem
459  {
460  __host__ Problem(index_t M_,
461  index_t N_,
462  index_t K_,
463  index_t StrideA_,
464  index_t StrideB_,
465  index_t StrideC_)
466  : M{M_},
467  N{N_},
468  K{K_},
469  StrideA{StrideA_},
470  StrideB{StrideB_},
471  StrideC{StrideC_},
475  AK0{CalculateAK0(K_)},
476  BK0{CalculateBK0(K_)},
477  MBlock{CalculateMBlock(M_)},
479  {
480  }
481 
482  __host__ void Print() const
483  {
484  std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
485  << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
486  << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
487  << "KP:" << KPadded << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", "
488  << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl;
489  }
490 
504  };
505 
506  // Argument
508  {
509  __host__ Argument(const FloatA* p_a_grid_,
510  const FloatB* p_b_grid_,
511  FloatC* p_c_grid_,
512  index_t M_,
513  index_t N_,
514  index_t K_,
515  index_t StrideA_,
516  index_t StrideB_,
517  index_t StrideC_)
518  : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
519  p_a_grid{p_a_grid_},
520  p_b_grid{p_b_grid_},
521  p_c_grid{p_c_grid_}
522  {
523  }
524 
525  const FloatA* p_a_grid;
526  const FloatB* p_b_grid;
527  FloatC* p_c_grid;
528  };
529 
530  // FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
532  decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
533 
534  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
535  {
536  // A matrix in LDS memory, dst of blockwise copy
540  }
541 
542  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
543  {
544  // B matrix in LDS memory, dst of blockwise copy
548  }
549 
551  {
552  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
553  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
554 
555  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
557  make_tuple(I1,
559  I1,
561 
562  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
563  }
564 
565  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
566  {
567  // LDS allocation for A and B: be careful of alignment
568  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
569  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
570 
571  // lds max alignment
572  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
573 
574  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
575  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
576 
577  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
578  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
579 
580  // LDS allocation for C shuffle in LDS
581  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
583 
584  constexpr auto c_block_size =
585  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
586 
587  return math::max((a_block_space_size_aligned * sizeof(ComputeTypeA) +
588  b_block_space_size_aligned * sizeof(ComputeTypeB)),
589  c_block_size * sizeof(FloatCShuffle));
590  }
591 
592  template <
593  InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set>
594  __device__ static bool constexpr IsValidCompilationParameter()
595  {
596  return ck::tensor_operation::device::IsValidGemmCompilationParameter<
597  BlockSize,
598  MPerBlock,
599  NPerBlock,
600  MPerXdl,
601  NPerXdl,
602  MXdlPerWave,
603  NXdlPerWave,
604  FloatC,
605  CGlobalMemoryDataOperation>();
606  }
607 
608  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
609  __host__ static constexpr bool CheckValidity(const Problem& problem)
610  {
611  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
612  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
613  "Invalid tuning param!");
614 
615  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
619  {
620  if(!(problem.M % MPerBlock == 0))
621  {
622  return false;
623  }
624  }
625 
626  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
630  {
631  if(!(problem.N % NPerBlock == 0))
632  {
633  return false;
634  }
635  }
636 
641  {
642  if(!(CalculateKPadded(problem.K) % AK1Value == 0) ||
643  !(CalculateKPadded(problem.K) % BK1Value == 0))
644  {
645  return false;
646  }
647  }
648  else
649  {
650  if(!(problem.K % AK1Value == 0) || !(problem.K % BK1Value == 0))
651  {
652  return false;
653  }
654  }
655 
657  {
658  if(problem.K % ABlockTransferSrcScalarPerVector != 0)
659  {
660  return false;
661  }
662  }
663  else
664  {
665  if(problem.M % ABlockTransferSrcScalarPerVector != 0)
666  {
667  return false;
668  }
669  }
670 
672  {
673  if(problem.N % BBlockTransferSrcScalarPerVector != 0)
674  {
675  return false;
676  }
677  }
678  else
679  {
680  if(problem.K % BBlockTransferSrcScalarPerVector != 0)
681  {
682  return false;
683  }
684  }
685 
687  {
688  if(problem.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
689  {
690  return false;
691  }
692  }
693  else
694  {
695  if(problem.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
696  {
697  return false;
698  }
699  }
700 
701  // check gridwise gemm pipeline
702  const auto num_k_loop = (CalculateAK0(problem.K) * AK1Value) / KPerBlock;
703 
704  if(num_k_loop < 4)
705  {
706  return false;
707  }
708 
709  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
710  return true;
711  }
712 
713  __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
714  {
715  const index_t num_loop = K / KPerBlock;
716 
717  return num_loop > 3;
718  }
719 
720  __host__ static constexpr index_t CalculateKBlockLoopTailNum(index_t K)
721  {
722  const index_t num_loop = K / KPerBlock;
723 
724  if(num_loop % 2 == 1)
725  return 3;
726  else
727  return 2;
728  }
729 
730  template <typename CGridDesc>
732  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
733  {
734  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
735  c_grid_desc_m_n,
740 
741  return c_grid_desc_mblock_mperblock_nblock_nperblock;
742  }
743 
744  // return block_id to C matrix tile idx (m0, n0) mapping
745  // if arch = gfx942
747 
748  template <bool HasMainKBlockLoop, index_t TailNum = 3>
749  __device__ static void Run(const FloatA* p_a_grid,
750  const FloatB* p_b_grid,
751  FloatC* p_c_grid,
752  void* p_shared_0,
753  void* p_shared_1,
754  const Problem& problem)
755  {
756  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
757  problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
758  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
759  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
760  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
761  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
762 
763  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
765  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
766 
767  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
768  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
769  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
770  p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
771  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
772  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
773 
774  const AElementwiseOperation a_element_op{};
775  const BElementwiseOperation b_element_op{};
776  const CElementwiseOperation c_element_op{};
777 
778  // divide block work by [M, N]
779  const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
780 
781  const auto block_work_idx =
782  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
783 
784  if(!block_2_ctile_map.ValidCTileIndex(
785  block_work_idx,
786  make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
787  c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
788  {
789  return;
790  }
791 #if 0
792  if(threadIdx.x == 0){
793  printf("Hardware assigned No. %03d workgroup of logical C tile (%02d, %02d) on %d th XCC Die, %d th SE, %d th CU\n",
794  get_block_1d_id(),
795  block_work_idx[I0],
796  block_work_idx[I1],
797  __smid()>>6 & 0xf,
798  __smid()>>4 & 0x3,
799  __smid() & 0xf);
800  }
801 #endif
802  // HACK: this force m/n_block_data_idx_on_grid into SGPR
803  const index_t m_block_data_idx_on_grid =
804  __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
805 
806  const index_t n_block_data_idx_on_grid =
807  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
808 
809  // lds max alignment
810  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
811 
812  // A matrix in LDS memory, dst of blockwise copy
813  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
814 
815  // B matrix in LDS memory, dst of blockwise copy
816  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
817 
818  // A matrix blockwise copy
819  auto a_blockwise_copy =
821  AElementwiseOperation,
825  ABlockTransferThreadClusterLengths_AK0_M_AK1,
826  ABlockTransferThreadClusterArrangeOrder,
827  FloatA,
828  ComputeTypeA,
829  decltype(a_grid_desc_ak0_m_ak1),
830  decltype(a_block_desc_ak0_m_ak1),
831  ABlockTransferSrcAccessOrder,
833  ABlockTransferSrcVectorDim,
834  2,
835  ABlockTransferSrcScalarPerVector,
836  ABlockTransferDstScalarPerVector_AK1,
837  1,
838  1,
839  AThreadTransferSrcResetCoordinateAfterRun,
840  true>(
841  a_grid_desc_ak0_m_ak1,
842  make_multi_index(0, m_block_data_idx_on_grid, 0),
843  a_element_op,
844  a_block_desc_ak0_m_ak1,
845  make_multi_index(0, 0, 0),
847 
848  // B matrix blockwise copy
849  auto b_blockwise_copy =
851  BElementwiseOperation,
855  BBlockTransferThreadClusterLengths_BK0_N_BK1,
856  BBlockTransferThreadClusterArrangeOrder,
857  FloatB,
858  ComputeTypeB,
859  decltype(b_grid_desc_bk0_n_bk1),
860  decltype(b_block_desc_bk0_n_bk1),
861  BBlockTransferSrcAccessOrder,
863  BBlockTransferSrcVectorDim,
864  2,
865  BBlockTransferSrcScalarPerVector,
866  BBlockTransferDstScalarPerVector_BK1,
867  1,
868  1,
869  BThreadTransferSrcResetCoordinateAfterRun,
870  true>(
871  b_grid_desc_bk0_n_bk1,
872  make_multi_index(0, n_block_data_idx_on_grid, 0),
873  b_element_op,
874  b_block_desc_bk0_n_bk1,
875  make_multi_index(0, 0, 0),
877 
878  // GEMM definition
879  // c_mtx += transpose(a_mtx) * b_mtx
880  // a_mtx[K0PerBlock, MPerBlock] is in LDS
881  // b_mtx[K0PerBlock, NPerBlock] is in LDS
882  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
883  // register
884  // sanity check
885  constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
886  constexpr bool is_single_rate_mfma =
888  lcm_AK1_BK1 <= 4) ||
889  (is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
891  lcm_AK1_BK1 < 32))
892  ? true
893  : false;
894  constexpr auto is_scale_mfma = false;
895  constexpr index_t KPack = math::max(lcm_AK1_BK1,
896  MfmaSelector<ComputeTypeA,
897  MPerXdl,
898  NPerXdl,
899  ComputeTypeA,
900  is_single_rate_mfma,
901  is_scale_mfma>::selected_mfma.k_per_blk);
902 
903  // auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
904  // BlockSize,
905  // ComputeType,
906  // FloatGemmAcc,
907  // decltype(a_block_desc_ak0_m_ak1),
908  // decltype(b_block_desc_bk0_n_bk1),
909  // MPerXdl,
910  // NPerXdl,
911  // MXdlPerWave,
912  // NXdlPerWave,
913  // KPack,
914  // LoopSched>();
915  auto blockwise_gemm_pipeline = BlockwiseGemmXdlops_pipeline_v4<
916  BlockSize,
917  ComputeTypeA,
918  FloatGemmAcc,
919  decltype(a_block_desc_ak0_m_ak1),
920  decltype(b_block_desc_bk0_n_bk1),
921  decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)),
922  decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)),
923  MPerBlock,
924  NPerBlock,
925  KPerBlock,
926  MPerXdl,
927  NPerXdl,
928  MXdlPerWave,
929  NXdlPerWave,
930  KPack>{}; // TransposeC
931 
932  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
933 
934  // LDS allocation for A and B: be careful of alignment
935  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
936  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
937 
938  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
939  static_cast<ComputeTypeA*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
940 
941  auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
942  static_cast<ComputeTypeB*>(p_shared_0) + a_block_space_size_aligned,
943  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
944 
945  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
946  static_cast<ComputeTypeA*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
947 
948  auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
949  static_cast<ComputeTypeB*>(p_shared_1) + a_block_space_size_aligned,
950  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
951 
952  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
953  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
954 
955  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
956  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
957 
958  // gridwise GEMM pipeline
959  static_assert(std::is_default_constructible_v<GridwiseGemmPipe>);
960  // const auto gridwise_gemm_pipeline = GridwiseGemmPipe{};
961 
962  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
963  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
964  KPerBlock);
965 
966  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
967  a_block_desc_ak0_m_ak1,
968  a_blockwise_copy,
969  a_grid_buf,
970  a_block_bufs,
971  a_block_slice_copy_step,
972  b_grid_desc_bk0_n_bk1,
973  b_block_desc_bk0_n_bk1,
974  b_blockwise_copy,
975  b_grid_buf,
976  b_block_bufs,
977  b_block_slice_copy_step,
978  c_thread_buf,
979  num_k_block_main_loop);
980 
981  // shuffle C and write out
982  {
983  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
984  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
985  "wrong!");
986 
987  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
988  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
989 
990  // TODO: hacky, fix it!
991  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
992  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
993 
994  // TODO: hacky, fix it!
995  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
996  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
997  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
998 
999  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1000  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1001  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1002  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1003  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1004  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1005  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1006  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1007 
1008  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1010 
1011  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1012  static_cast<FloatCShuffle*>(p_shared_0),
1013  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1014 
1015  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1016  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1017  make_tuple(
1020  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1021  M1, // M1 = MWave
1022  M2, // M2 * M3 * M4 = MPerXdl
1023  M3,
1024  M4)),
1027  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1028  N1, // N1 = NWave
1029  N2))), // N2 = NPerXdl
1031  make_tuple(
1033 
1034  // calculate origin of thread output tensor on global memory
1035  // blockwise GEMM c matrix starting index
1036  const auto c_thread_mtx_on_block =
1037  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1038 
1039  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1040  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1041 
1042  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1044  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1046  make_tuple(Sequence<0>{}));
1047 
1048  const auto m_thread_data_on_block_idx =
1049  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1050  make_multi_index(m_thread_data_on_block));
1051 
1052  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1056  make_tuple(Sequence<0>{}));
1057 
1058  const auto n_thread_data_on_block_idx =
1059  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1060  make_multi_index(n_thread_data_on_block));
1061 
1062  // shuffle: threadwise copy C from VGPR to LDS
1063  auto c_thread_copy_vgpr_to_lds =
1065  FloatCShuffle,
1066  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1067  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1069  Sequence<CShuffleMXdlPerWavePerShuffle,
1070  CShuffleNXdlPerWavePerShuffle,
1071  I1,
1072  I1,
1073  M2,
1074  I1,
1075  M4,
1076  I1>,
1078  7,
1079  1,
1081  1,
1082  true>{
1083  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1084  make_multi_index(0,
1085  0,
1086  m_thread_data_on_block_idx[I1],
1087  n_thread_data_on_block_idx[I1],
1088  m_thread_data_on_block_idx[I2],
1089  m_thread_data_on_block_idx[I3],
1090  m_thread_data_on_block_idx[I4],
1091  n_thread_data_on_block_idx[I2]),
1093 
1094  // shuffle: blockwise copy C from LDS to global
1095  auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1096  ThisThreadBlock, // ThreadGroup
1097  CElementwiseOperation, // ElementwiseOperation,
1098  CGlobalMemoryDataOperation, // DstInMemOp,
1099  Sequence<1,
1100  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1101  1,
1102  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1103  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1104  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1105  FloatCShuffle, // typename SrcData,
1106  FloatC, // typename DstData,
1107  decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1108  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1109  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1110  3, // index_t VectorDim,
1111  CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1112  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1113  false> // bool ThreadTransferDstResetCoordinateAfterRun>
1114  {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1115  make_multi_index(0, 0, 0, 0),
1116  c_grid_desc_mblock_mperblock_nblock_nperblock,
1117  make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
1118  c_element_op};
1119 
1120  // space filling curve for threadwise C in VGPR
1121  constexpr auto sfc_c_vgpr =
1124  Sequence<CShuffleMXdlPerWavePerShuffle,
1125  CShuffleNXdlPerWavePerShuffle,
1126  1,
1127  1,
1128  M2,
1129  1,
1130  M4,
1131  1>>{};
1132 
1133  // space filling curve for shuffled blockwise C in global mem
1134  constexpr auto sfc_c_global =
1137  Sequence<1,
1138  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1139  1,
1140  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1141 
1142  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1143 
1144  static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1145 
1146  static_for<0, num_access, 1>{}([&](auto access_id) {
1147  // make sure it's safe to write to LDS
1148  block_sync_lds();
1149 
1150  // each thread write its data from VGPR to LDS
1151  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1152  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1153  c_thread_buf,
1154  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1155  c_shuffle_block_buf);
1156 
1157  // make sure it's safe to read from LDS
1158  block_sync_lds();
1159 
1160  // each block copy its data from LDS to global
1161  c_shuffle_block_copy_lds_to_global.Run(
1162  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1163  c_shuffle_block_buf,
1164  c_grid_desc_mblock_mperblock_nblock_nperblock,
1165  c_grid_buf);
1166 
1167  if constexpr(access_id < num_access - 1)
1168  {
1169  constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1170 
1171  // move on C
1172  c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1173  c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1174  }
1175  });
1176  }
1177  }
1178 };
1179 
1180 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ auto integer_divide_floor(X x, Y y)
Definition: math.hpp:66
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:268
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:277
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__global__ void kernel_gemm_xdl_cshuffle_v2(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:26
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:299
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:18
Definition: block_to_ctile_map.hpp:271
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:283
Definition: blockwise_gemm_pipeline_xdlops.hpp:103
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:508
const FloatB * p_b_grid
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:526
__host__ Argument(const FloatA *p_a_grid_, const FloatB *p_b_grid_, FloatC *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:509
FloatC * p_c_grid
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:527
const FloatA * p_a_grid
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:525
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:459
index_t StrideB
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:495
index_t StrideC
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:496
index_t KPadded
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:499
index_t N
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:492
index_t M
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:491
index_t MBlock
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:502
index_t NBlock
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:503
__host__ void Print() const
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:482
index_t BK0
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:501
index_t MPadded
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:497
index_t K
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:493
index_t StrideA
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:494
index_t AK0
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:500
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:460
index_t NPadded
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:498
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:126
static constexpr auto I6
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:133
static __device__ constexpr bool IsValidCompilationParameter()
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:594
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:565
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:149
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:144
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:713
static constexpr __host__ index_t CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:720
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage, LoopSched >())> GridwiseGemmPipe
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:532
static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:731
static __device__ void Run(const FloatA *p_a_grid, const FloatB *p_b_grid, FloatC *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:749
static constexpr auto AK1Number
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:139
static constexpr auto I1
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:128
static constexpr auto I2
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:129
static constexpr auto AK0Number
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:137
static constexpr auto I0
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:127
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:198
static constexpr auto BK0Number
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:138
static __host__ auto CalculateAK0(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:164
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:203
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:534
static constexpr auto I4
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:131
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:542
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:209
static constexpr auto I5
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:132
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:550
static __host__ auto CalculateBK0(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:181
static constexpr auto BK1Number
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:140
static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:223
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:389
static constexpr auto I3
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:130
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:159
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:406
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:154
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:398
static constexpr auto I7
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:134
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:609
static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:305
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:142
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition: xdlops_gemm.hpp:1208
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: functional2.hpp:33
Definition: device_base.hpp:197
Definition: unary_element_wise_operation.hpp:334