/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp Source File
gridwise_gemm_xdl_cshuffle_v2.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
17 
18 namespace ck {
19 
20 template <typename GridwiseGemm, bool HasMainKBlockLoop, index_t TailNum = 3>
21 __global__ void
22 #if CK_USE_LAUNCH_BOUNDS
23 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1)
24 #endif
25  // __attribute__((amdgpu_waves_per_eu(1, 1)))
26  kernel_gemm_xdl_cshuffle_v2(typename GridwiseGemm::Argument karg)
27 {
28 #if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__))
29  // Pass two lds pointer is the key to tell compiler that ds_read/write
30  // operate on different lds chunk at same time without order dependecy
31  __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
32  __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
33 
34  GridwiseGemm::template Run<HasMainKBlockLoop, TailNum>(
35  karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared_0, p_shared_1, karg);
36 #else
37  ignore = karg;
38 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
39 }
40 
41 template <typename GridwiseGemm,
42  typename FloatA,
43  typename FloatB,
44  typename FloatC,
45  bool HasMainKBlockLoop>
46 __global__ void
47 #if CK_USE_LAUNCH_BOUNDS
48 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1)
49 #endif
50  kernel_gemm_xdl_cshuffle_v2(const FloatA* p_a_grid,
51  const FloatB* p_b_grid,
52  FloatC* p_c_grid,
53  typename GridwiseGemm::Problem problem)
54 {
55 #if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__))
56  __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
57  __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
58 
59  GridwiseGemm::template Run<HasMainKBlockLoop>(
60  p_a_grid, p_b_grid, p_c_grid, p_shared_0, p_shared_1, problem);
61 #else
62  ignore = p_a_grid;
63  ignore = p_b_grid;
64  ignore = p_c_grid;
65  ignore = problem;
66 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
67 }
68 
69 template <typename ALayout,
70  typename BLayout,
71  typename CLayout,
72  typename FloatA,
73  typename FloatB,
74  typename FloatGemmAcc,
75  typename FloatCShuffle,
76  typename FloatC,
77  typename AElementwiseOperation,
78  typename BElementwiseOperation,
79  typename CElementwiseOperation,
81  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
82  index_t NumGemmKPrefetchStage,
83  index_t BlockSize,
84  index_t MPerBlock,
85  index_t NPerBlock,
86  index_t KPerBlock,
87  index_t AK1Value,
88  index_t BK1Value,
89  index_t MPerXdl,
90  index_t NPerXdl,
91  index_t MXdlPerWave,
92  index_t NXdlPerWave,
93  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
94  typename ABlockTransferThreadClusterArrangeOrder,
95  typename ABlockTransferSrcAccessOrder,
96  index_t ABlockTransferSrcVectorDim,
97  index_t ABlockTransferSrcScalarPerVector,
98  index_t ABlockTransferDstScalarPerVector_AK1,
99  bool AThreadTransferSrcResetCoordinateAfterRun,
100  index_t ABlockLdsExtraM,
101  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
102  typename BBlockTransferThreadClusterArrangeOrder,
103  typename BBlockTransferSrcAccessOrder,
104  index_t BBlockTransferSrcVectorDim,
105  index_t BBlockTransferSrcScalarPerVector,
106  index_t BBlockTransferDstScalarPerVector_BK1,
107  bool BThreadTransferSrcResetCoordinateAfterRun,
108  index_t BBlockLdsExtraN,
109  index_t CShuffleMXdlPerWavePerShuffle,
110  index_t CShuffleNXdlPerWavePerShuffle,
111  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
112  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
113  LoopScheduler LoopSched,
114  PipelineVersion PipelineVer = PipelineVersion::v1,
115  typename ComputeTypeA = FloatC,
116  typename ComputeTypeB = ComputeTypeA>
118 {
119  static constexpr auto I0 = Number<0>{};
120  static constexpr auto I1 = Number<1>{};
121  static constexpr auto I2 = Number<2>{};
122  static constexpr auto I3 = Number<3>{};
123  static constexpr auto I4 = Number<4>{};
124  static constexpr auto I5 = Number<5>{};
125  static constexpr auto I6 = Number<6>{};
126  static constexpr auto I7 = Number<7>{};
127 
128  // K1 should be Number<...>
129  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
130  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
131  static constexpr auto AK1Number = Number<AK1Value>{};
132  static constexpr auto BK1Number = Number<BK1Value>{};
133 
135 
136  __host__ static auto CalculateGridSize(index_t M, index_t N)
137  {
139  }
140 
141  __host__ static auto CalculateMPadded(index_t M)
142  {
143  return math::integer_divide_ceil(M, MPerBlock) * MPerBlock;
144  }
145 
146  __host__ static auto CalculateNPadded(index_t N)
147  {
148  return math::integer_divide_ceil(N, NPerBlock) * NPerBlock;
149  }
150 
151  __host__ static auto CalculateKPadded(index_t K)
152  {
153  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
154  }
155 
156  __host__ static auto CalculateAK0(index_t K)
157  {
159 
160  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
161  GemmSpec == GemmSpecialization::MNKPadding ||
162  GemmSpec == GemmSpecialization::KPadding ||
163  GemmSpec == GemmSpecialization::NKPadding)
164  {
165  return CalculateKPadded(K) / AK1Value;
166  }
167  else
168  {
169  return K / AK1Value;
170  }
171  }
172 
173  __host__ static auto CalculateBK0(index_t K)
174  {
176 
177  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
178  GemmSpec == GemmSpecialization::MNKPadding ||
179  GemmSpec == GemmSpecialization::KPadding ||
180  GemmSpec == GemmSpecialization::MKPadding)
181  {
182  return CalculateKPadded(K) / BK1Value;
183  }
184  else
185  {
186  return K / BK1Value;
187  }
188  }
189 
190  __host__ static auto CalculateMBlock(index_t M)
191  {
192  return math::integer_divide_floor(M, MPerBlock);
193  }
194 
195  __host__ static auto CalculateNBlock(index_t N)
196  {
197  return math::integer_divide_floor(N, NPerBlock);
198  }
199 
200  template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
201  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
202  {
203  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
204  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
205 
207  TileDesc_K0_MN_K1{},
213  }
214 
215  __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
216  index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
217  {
218  const auto a_grid_desc_mraw_kraw = [&]() {
219  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
220  {
221  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
222  }
223  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
224  {
225  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
226  }
227  }();
228 
230 
231  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
232  GemmSpec == GemmSpecialization::MNKPadding)
233  {
234  // pad both M and K
235  const auto a_grid_desc_m_k =
236  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
238  make_right_pad_transform(K, KPad - K)),
241 
242  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
243  a_grid_desc_m_k,
248 
249  return a_grid_desc_ak0_m_ak1;
250  }
251  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
252  GemmSpec == GemmSpecialization::MNPadding)
253  {
254  // pad M, but not K
255  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
256  a_grid_desc_mraw_kraw,
258  make_right_pad_transform(M, MPad - M)),
261 
262  return a_grid_desc_ak0_m_ak1;
263  }
264  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
265  GemmSpec == GemmSpecialization::NKPadding)
266  {
267  // pad K, but not M
268  const auto a_grid_desc_m_k = transform_tensor_descriptor(
269  a_grid_desc_mraw_kraw,
273 
274  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
275  a_grid_desc_m_k,
280 
281  return a_grid_desc_ak0_m_ak1;
282  }
283  else
284  {
285  // not pad M or K
286  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
287  a_grid_desc_mraw_kraw,
292 
293  return a_grid_desc_ak0_m_ak1;
294  }
295  }
296 
297  __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
298  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
299  {
300  const auto b_grid_desc_nraw_kraw = [&]() {
302  {
303  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
304  }
306  {
307  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
308  }
309  }();
310 
312 
313  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
314  GemmSpec == GemmSpecialization::MNKPadding)
315  {
316  // pad both N and K
317  const auto b_grid_desc_n_k =
318  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
320  make_right_pad_transform(K, KPad - K)),
323 
324  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
325  b_grid_desc_n_k,
330 
331  return b_grid_desc_bk0_n_bk1;
332  }
333  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
334  GemmSpec == GemmSpecialization::MNPadding)
335  {
336  // pad N, but not K
337  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
338  b_grid_desc_nraw_kraw,
340  make_right_pad_transform(N, NPad - N)),
343 
344  return b_grid_desc_bk0_n_bk1;
345  }
346  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
347  GemmSpec == GemmSpecialization::MKPadding)
348  {
349  // pad K, but not N
350  const auto b_grid_desc_n_k = transform_tensor_descriptor(
351  b_grid_desc_nraw_kraw,
355 
356  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
357  b_grid_desc_n_k,
362 
363  return b_grid_desc_bk0_n_bk1;
364  }
365  else
366  {
367  // not pad N or K
368  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
369  b_grid_desc_nraw_kraw,
374 
375  return b_grid_desc_bk0_n_bk1;
376  }
377  }
378 
379  template <typename ABlockDesc_AK0_M_AK1>
380  __host__ __device__ static constexpr auto
381  MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
382  {
383  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
384 
385  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
386  }
387 
388  template <typename BBlockDesc_BK0_N_BK1>
389  __host__ __device__ static constexpr auto
390  MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
391  {
392  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
393 
394  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
395  }
396 
397  __host__ __device__ static auto
399  {
400  const auto c_grid_desc_mraw_nraw = [&]() {
402  {
403  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
404  }
406  {
407  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
408  }
409  }();
410 
412 
413  if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
414  GemmSpec == GemmSpecialization::MNKPadding)
415  {
416  // pad M and N
417  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
419  make_right_pad_transform(N, NPad - N)),
422  }
423  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
424  GemmSpec == GemmSpecialization::MKPadding)
425  {
426  // pad M, but not N
428  c_grid_desc_mraw_nraw,
432  }
433  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
434  GemmSpec == GemmSpecialization::NKPadding)
435  {
436  // pad N, but not M
438  c_grid_desc_mraw_nraw,
442  }
443  else
444  {
445  // not pad M or N
446  return c_grid_desc_mraw_nraw;
447  }
448  }
449 
450  struct Problem
451  {
452  __host__ Problem(index_t M_,
453  index_t N_,
454  index_t K_,
455  index_t StrideA_,
456  index_t StrideB_,
457  index_t StrideC_)
458  : M{M_},
459  N{N_},
460  K{K_},
461  StrideA{StrideA_},
462  StrideB{StrideB_},
463  StrideC{StrideC_},
467  AK0{CalculateAK0(K_)},
468  BK0{CalculateBK0(K_)},
469  MBlock{CalculateMBlock(M_)},
471  {
472  }
473 
474  __host__ void Print() const
475  {
476  std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
477  << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
478  << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
479  << "KP:" << KPadded << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", "
480  << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl;
481  }
482 
496  };
497 
498  // Argument
500  {
501  __host__ Argument(const FloatA* p_a_grid_,
502  const FloatB* p_b_grid_,
503  FloatC* p_c_grid_,
504  index_t M_,
505  index_t N_,
506  index_t K_,
507  index_t StrideA_,
508  index_t StrideB_,
509  index_t StrideC_)
510  : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
511  p_a_grid{p_a_grid_},
512  p_b_grid{p_b_grid_},
513  p_c_grid{p_c_grid_}
514  {
515  }
516 
517  const FloatA* p_a_grid;
518  const FloatB* p_b_grid;
519  FloatC* p_c_grid;
520  };
521 
522  // FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
524  decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
525 
526  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
527  {
528  // A matrix in LDS memory, dst of blockwise copy
532  }
533 
534  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
535  {
536  // B matrix in LDS memory, dst of blockwise copy
540  }
541 
543  {
544  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
545  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
546 
547  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
549  make_tuple(I1,
551  I1,
553 
554  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
555  }
556 
557  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
558  {
559  // LDS allocation for A and B: be careful of alignment
560  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
561  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
562 
563  // lds max alignment
564  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
565 
566  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
567  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
568 
569  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
570  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
571 
572  // LDS allocation for C shuffle in LDS
573  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
575 
576  constexpr auto c_block_size =
577  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
578 
579  return math::max((a_block_space_size_aligned * sizeof(ComputeTypeA) +
580  b_block_space_size_aligned * sizeof(ComputeTypeB)),
581  c_block_size * sizeof(FloatCShuffle));
582  }
583 
584  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
585  __host__ static constexpr bool CheckValidity(const Problem& problem)
586  {
587  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
588  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
589  "Invalid tuning param!");
590 
591  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
595  {
596  if(!(problem.M % MPerBlock == 0))
597  {
598  return false;
599  }
600  }
601 
602  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
606  {
607  if(!(problem.N % NPerBlock == 0))
608  {
609  return false;
610  }
611  }
612 
617  {
618  if(!(CalculateKPadded(problem.K) % AK1Value == 0) ||
619  !(CalculateKPadded(problem.K) % BK1Value == 0))
620  {
621  return false;
622  }
623  }
624  else
625  {
626  if(!(problem.K % AK1Value == 0) || !(problem.K % BK1Value == 0))
627  {
628  return false;
629  }
630  }
631 
633  {
634  if(problem.K % ABlockTransferSrcScalarPerVector != 0)
635  {
636  return false;
637  }
638  }
639  else
640  {
641  if(problem.M % ABlockTransferSrcScalarPerVector != 0)
642  {
643  return false;
644  }
645  }
646 
648  {
649  if(problem.N % BBlockTransferSrcScalarPerVector != 0)
650  {
651  return false;
652  }
653  }
654  else
655  {
656  if(problem.K % BBlockTransferSrcScalarPerVector != 0)
657  {
658  return false;
659  }
660  }
661 
663  {
664  if(problem.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
665  {
666  return false;
667  }
668  }
669  else
670  {
671  if(problem.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
672  {
673  return false;
674  }
675  }
676 
677  // check gridwise gemm pipeline
678  const auto num_k_loop = (CalculateAK0(problem.K) * AK1Value) / KPerBlock;
679 
680  if(num_k_loop < 4)
681  {
682  return false;
683  }
684 
685  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
686  return true;
687  }
688 
689  __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
690  {
691  const index_t num_loop = K / KPerBlock;
692 
693  return num_loop > 3;
694  }
695 
696  __host__ static constexpr index_t CalculateKBlockLoopTailNum(index_t K)
697  {
698  const index_t num_loop = K / KPerBlock;
699 
700  if(num_loop % 2 == 1)
701  return 3;
702  else
703  return 2;
704  }
705 
706  template <typename CGridDesc>
708  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
709  {
710  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
711  c_grid_desc_m_n,
716 
717  return c_grid_desc_mblock_mperblock_nblock_nperblock;
718  }
719 
720  // return block_id to C matrix tile idx (m0, n0) mapping
721  // if arch = gfx942
723 
724  template <bool HasMainKBlockLoop, index_t TailNum = 3>
725  __device__ static void Run(const FloatA* p_a_grid,
726  const FloatB* p_b_grid,
727  FloatC* p_c_grid,
728  void* p_shared_0,
729  void* p_shared_1,
730  const Problem& problem)
731  {
732  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
733  problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
734  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
735  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
736  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
737  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
738 
739  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
741  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
742 
743  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
744  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
745  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
746  p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
747  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
748  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
749 
750  const AElementwiseOperation a_element_op{};
751  const BElementwiseOperation b_element_op{};
752  const CElementwiseOperation c_element_op{};
753 
754  // divide block work by [M, N]
755  const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
756 
757  const auto block_work_idx =
758  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
759 
760  if(!block_2_ctile_map.ValidCTileIndex(
761  block_work_idx,
762  make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
763  c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
764  {
765  return;
766  }
767 #if 0
768  if(threadIdx.x == 0){
769  printf("Hardware assigned No. %03d workgroup of logical C tile (%02d, %02d) on %d th XCC Die, %d th SE, %d th CU\n",
770  get_block_1d_id(),
771  block_work_idx[I0],
772  block_work_idx[I1],
773  __smid()>>6 & 0xf,
774  __smid()>>4 & 0x3,
775  __smid() & 0xf);
776  }
777 #endif
778  // HACK: this force m/n_block_data_idx_on_grid into SGPR
779  const index_t m_block_data_idx_on_grid =
780  __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
781 
782  const index_t n_block_data_idx_on_grid =
783  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
784 
785  // lds max alignment
786  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
787 
788  // A matrix in LDS memory, dst of blockwise copy
789  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
790 
791  // B matrix in LDS memory, dst of blockwise copy
792  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
793 
794  // A matrix blockwise copy
795  auto a_blockwise_copy =
797  AElementwiseOperation,
801  ABlockTransferThreadClusterLengths_AK0_M_AK1,
802  ABlockTransferThreadClusterArrangeOrder,
803  FloatA,
804  ComputeTypeA,
805  decltype(a_grid_desc_ak0_m_ak1),
806  decltype(a_block_desc_ak0_m_ak1),
807  ABlockTransferSrcAccessOrder,
809  ABlockTransferSrcVectorDim,
810  2,
811  ABlockTransferSrcScalarPerVector,
812  ABlockTransferDstScalarPerVector_AK1,
813  1,
814  1,
815  AThreadTransferSrcResetCoordinateAfterRun,
816  true>(
817  a_grid_desc_ak0_m_ak1,
818  make_multi_index(0, m_block_data_idx_on_grid, 0),
819  a_element_op,
820  a_block_desc_ak0_m_ak1,
821  make_multi_index(0, 0, 0),
823 
824  // B matrix blockwise copy
825  auto b_blockwise_copy =
827  BElementwiseOperation,
831  BBlockTransferThreadClusterLengths_BK0_N_BK1,
832  BBlockTransferThreadClusterArrangeOrder,
833  FloatB,
834  ComputeTypeB,
835  decltype(b_grid_desc_bk0_n_bk1),
836  decltype(b_block_desc_bk0_n_bk1),
837  BBlockTransferSrcAccessOrder,
839  BBlockTransferSrcVectorDim,
840  2,
841  BBlockTransferSrcScalarPerVector,
842  BBlockTransferDstScalarPerVector_BK1,
843  1,
844  1,
845  BThreadTransferSrcResetCoordinateAfterRun,
846  true>(
847  b_grid_desc_bk0_n_bk1,
848  make_multi_index(0, n_block_data_idx_on_grid, 0),
849  b_element_op,
850  b_block_desc_bk0_n_bk1,
851  make_multi_index(0, 0, 0),
853 
854  // GEMM definition
855  // c_mtx += transpose(a_mtx) * b_mtx
856  // a_mtx[K0PerBlock, MPerBlock] is in LDS
857  // b_mtx[K0PerBlock, NPerBlock] is in LDS
858  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
859  // register
860  // sanity check
861  constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
862  constexpr bool is_single_rate_mfma =
864  lcm_AK1_BK1 <= 4) ||
865  (is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
867  lcm_AK1_BK1 < 32))
868  ? true
869  : false;
870  constexpr auto is_scale_mfma = false;
871  constexpr index_t KPack = math::max(lcm_AK1_BK1,
872  MfmaSelector<ComputeTypeA,
873  MPerXdl,
874  NPerXdl,
875  ComputeTypeA,
876  is_single_rate_mfma,
877  is_scale_mfma>::selected_mfma.k_per_blk);
878 
879  // auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
880  // BlockSize,
881  // ComputeType,
882  // FloatGemmAcc,
883  // decltype(a_block_desc_ak0_m_ak1),
884  // decltype(b_block_desc_bk0_n_bk1),
885  // MPerXdl,
886  // NPerXdl,
887  // MXdlPerWave,
888  // NXdlPerWave,
889  // KPack,
890  // LoopSched>();
891  auto blockwise_gemm_pipeline = BlockwiseGemmXdlops_pipeline_v4<
892  BlockSize,
893  ComputeTypeA,
894  FloatGemmAcc,
895  decltype(a_block_desc_ak0_m_ak1),
896  decltype(b_block_desc_bk0_n_bk1),
897  decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)),
898  decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)),
899  MPerBlock,
900  NPerBlock,
901  KPerBlock,
902  MPerXdl,
903  NPerXdl,
904  MXdlPerWave,
905  NXdlPerWave,
906  KPack>{}; // TransposeC
907 
908  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
909 
910  // LDS allocation for A and B: be careful of alignment
911  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
912  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
913 
914  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
915  static_cast<ComputeTypeA*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
916 
917  auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
918  static_cast<ComputeTypeB*>(p_shared_0) + a_block_space_size_aligned,
919  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
920 
921  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
922  static_cast<ComputeTypeA*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
923 
924  auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
925  static_cast<ComputeTypeB*>(p_shared_1) + a_block_space_size_aligned,
926  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
927 
928  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
929  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
930 
931  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
932  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
933 
934  // gridwise GEMM pipeline
935  static_assert(std::is_default_constructible_v<GridwiseGemmPipe>);
936  // const auto gridwise_gemm_pipeline = GridwiseGemmPipe{};
937 
938  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
939  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
940  KPerBlock);
941 
942  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
943  a_block_desc_ak0_m_ak1,
944  a_blockwise_copy,
945  a_grid_buf,
946  a_block_bufs,
947  a_block_slice_copy_step,
948  b_grid_desc_bk0_n_bk1,
949  b_block_desc_bk0_n_bk1,
950  b_blockwise_copy,
951  b_grid_buf,
952  b_block_bufs,
953  b_block_slice_copy_step,
954  c_thread_buf,
955  num_k_block_main_loop);
956 
957  // shuffle C and write out
958  {
959  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
960  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
961  "wrong!");
962 
963  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
964  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
965 
966  // TODO: hacky, fix it!
967  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
968  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
969 
970  // TODO: hacky, fix it!
971  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
972  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
973  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
974 
975  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
976  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
977  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
978  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
979  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
980  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
981  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
982  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
983 
984  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
986 
987  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
988  static_cast<FloatCShuffle*>(p_shared_0),
989  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
990 
991  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
992  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
993  make_tuple(
996  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
997  M1, // M1 = MWave
998  M2, // M2 * M3 * M4 = MPerXdl
999  M3,
1000  M4)),
1003  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1004  N1, // N1 = NWave
1005  N2))), // N2 = NPerXdl
1007  make_tuple(
1009 
1010  // calculate origin of thread output tensor on global memory
1011  // blockwise GEMM c matrix starting index
1012  const auto c_thread_mtx_on_block =
1013  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1014 
1015  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1016  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1017 
1018  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1020  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1022  make_tuple(Sequence<0>{}));
1023 
1024  const auto m_thread_data_on_block_idx =
1025  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1026  make_multi_index(m_thread_data_on_block));
1027 
1028  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1032  make_tuple(Sequence<0>{}));
1033 
1034  const auto n_thread_data_on_block_idx =
1035  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1036  make_multi_index(n_thread_data_on_block));
1037 
1038  // shuffle: threadwise copy C from VGPR to LDS
1039  auto c_thread_copy_vgpr_to_lds =
1041  FloatCShuffle,
1042  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1043  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1045  Sequence<CShuffleMXdlPerWavePerShuffle,
1046  CShuffleNXdlPerWavePerShuffle,
1047  I1,
1048  I1,
1049  M2,
1050  I1,
1051  M4,
1052  I1>,
1054  7,
1055  1,
1057  1,
1058  true>{
1059  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1060  make_multi_index(0,
1061  0,
1062  m_thread_data_on_block_idx[I1],
1063  n_thread_data_on_block_idx[I1],
1064  m_thread_data_on_block_idx[I2],
1065  m_thread_data_on_block_idx[I3],
1066  m_thread_data_on_block_idx[I4],
1067  n_thread_data_on_block_idx[I2]),
1069 
1070  // shuffle: blockwise copy C from LDS to global
1071  auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1072  ThisThreadBlock, // ThreadGroup
1073  CElementwiseOperation, // ElementwiseOperation,
1074  CGlobalMemoryDataOperation, // DstInMemOp,
1075  Sequence<1,
1076  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1077  1,
1078  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1079  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1080  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1081  FloatCShuffle, // typename SrcData,
1082  FloatC, // typename DstData,
1083  decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1084  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1085  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1086  3, // index_t VectorDim,
1087  CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1088  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1089  false> // bool ThreadTransferDstResetCoordinateAfterRun>
1090  {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1091  make_multi_index(0, 0, 0, 0),
1092  c_grid_desc_mblock_mperblock_nblock_nperblock,
1093  make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
1094  c_element_op};
1095 
1096  // space filling curve for threadwise C in VGPR
1097  constexpr auto sfc_c_vgpr =
1100  Sequence<CShuffleMXdlPerWavePerShuffle,
1101  CShuffleNXdlPerWavePerShuffle,
1102  1,
1103  1,
1104  M2,
1105  1,
1106  M4,
1107  1>>{};
1108 
1109  // space filling curve for shuffled blockwise C in global mem
1110  constexpr auto sfc_c_global =
1113  Sequence<1,
1114  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1115  1,
1116  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1117 
1118  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1119 
1120  static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1121 
1122  static_for<0, num_access, 1>{}([&](auto access_id) {
1123  // make sure it's safe to write to LDS
1124  block_sync_lds();
1125 
1126  // each thread write its data from VGPR to LDS
1127  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1128  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1129  c_thread_buf,
1130  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1131  c_shuffle_block_buf);
1132 
1133  // make sure it's safe to read from LDS
1134  block_sync_lds();
1135 
1136  // each block copy its data from LDS to global
1137  c_shuffle_block_copy_lds_to_global.Run(
1138  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1139  c_shuffle_block_buf,
1140  c_grid_desc_mblock_mperblock_nblock_nperblock,
1141  c_grid_buf);
1142 
1143  if constexpr(access_id < num_access - 1)
1144  {
1145  constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1146 
1147  // move on C
1148  c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1149  c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1150  }
1151  });
1152  }
1153  }
1154 };
1155 
1156 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ auto integer_divide_floor(X x, Y y)
Definition: math.hpp:66
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:276
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__global__ void kernel_gemm_xdl_cshuffle_v2(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:26
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:298
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:18
Definition: block_to_ctile_map.hpp:270
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:282
Definition: blockwise_gemm_pipeline_xdlops.hpp:103
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:500
const FloatB * p_b_grid
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:518
__host__ Argument(const FloatA *p_a_grid_, const FloatB *p_b_grid_, FloatC *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:501
FloatC * p_c_grid
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:519
const FloatA * p_a_grid
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:517
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:451
index_t StrideB
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:487
index_t StrideC
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:488
index_t KPadded
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:491
index_t N
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:484
index_t M
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:483
index_t MBlock
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:494
index_t NBlock
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:495
__host__ void Print() const
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:474
index_t BK0
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:493
index_t MPadded
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:489
index_t K
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:485
index_t StrideA
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:486
index_t AK0
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:492
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:452
index_t NPadded
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:490
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:118
static constexpr auto I6
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:125
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:557
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:141
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:136
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:689
static constexpr __host__ index_t CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:696
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage, LoopSched >())> GridwiseGemmPipe
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:524
static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:707
static __device__ void Run(const FloatA *p_a_grid, const FloatB *p_b_grid, FloatC *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:725
static constexpr auto AK1Number
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:131
static constexpr auto I1
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:120
static constexpr auto I2
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:121
static constexpr auto AK0Number
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:129
static constexpr auto I0
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:119
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:190
static constexpr auto BK0Number
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:130
static __host__ auto CalculateAK0(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:156
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:195
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:526
static constexpr auto I4
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:123
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:534
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:201
static constexpr auto I5
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:124
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:542
static __host__ auto CalculateBK0(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:173
static constexpr auto BK1Number
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:132
static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:215
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:381
static constexpr auto I3
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:122
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:151
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:398
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:146
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:390
static constexpr auto I7
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:126
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:585
static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:297
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:134
Definition: xdlops_gemm.hpp:1126
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:334