/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp Source File
gridwise_gemm_xdl_cshuffle_v1.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
17 
18 namespace ck {
19 
20 template <typename GridwiseGemm, bool HasMainKBlockLoop>
21 __global__ void
22 #if CK_USE_LAUNCH_BOUNDS
24 #endif
25  kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg)
26 {
27 #if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__))
28  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
29 
30  GridwiseGemm::template Run<HasMainKBlockLoop>(
31  karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg);
32 #else
33  ignore = karg;
34 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
35 }
36 
37 template <typename GridwiseGemm,
38  typename FloatA,
39  typename FloatB,
40  typename FloatC,
41  bool HasMainKBlockLoop>
42 __global__ void
43 #if CK_USE_LAUNCH_BOUNDS
45 #endif
46  kernel_gemm_xdl_cshuffle_v1(const FloatA* __restrict__ p_a_grid,
47  const FloatB* __restrict__ p_b_grid,
48  FloatC* __restrict__ p_c_grid,
49  typename GridwiseGemm::Problem problem)
50 {
51 #if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__))
52  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
53 
54  GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, p_b_grid, p_c_grid, p_shared, problem);
55 #else
56  ignore = p_a_grid;
57  ignore = p_b_grid;
58  ignore = p_c_grid;
59  ignore = problem;
60 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
61 }
62 
63 template <typename ALayout,
64  typename BLayout,
65  typename CLayout,
66  typename FloatA,
67  typename FloatB,
68  typename FloatGemmAcc,
69  typename FloatCShuffle,
70  typename FloatC,
71  typename AElementwiseOperation,
72  typename BElementwiseOperation,
73  typename CElementwiseOperation,
75  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
76  index_t NumGemmKPrefetchStage,
77  index_t BlockSize,
78  index_t MPerBlock,
79  index_t NPerBlock,
80  index_t KPerBlock,
81  index_t AK1Value,
82  index_t BK1Value,
83  index_t MPerXdl,
84  index_t NPerXdl,
85  index_t MXdlPerWave,
86  index_t NXdlPerWave,
87  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
88  typename ABlockTransferThreadClusterArrangeOrder,
89  typename ABlockTransferSrcAccessOrder,
90  index_t ABlockTransferSrcVectorDim,
91  index_t ABlockTransferSrcScalarPerVector,
92  index_t ABlockTransferDstScalarPerVector_AK1,
93  bool AThreadTransferSrcResetCoordinateAfterRun,
94  index_t ABlockLdsExtraM,
95  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
96  typename BBlockTransferThreadClusterArrangeOrder,
97  typename BBlockTransferSrcAccessOrder,
98  index_t BBlockTransferSrcVectorDim,
99  index_t BBlockTransferSrcScalarPerVector,
100  index_t BBlockTransferDstScalarPerVector_BK1,
101  bool BThreadTransferSrcResetCoordinateAfterRun,
102  index_t BBlockLdsExtraN,
103  index_t CShuffleMXdlPerWavePerShuffle,
104  index_t CShuffleNXdlPerWavePerShuffle,
105  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
106  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
107  LoopScheduler LoopSched,
108  PipelineVersion PipelineVer = PipelineVersion::v1,
109  typename ComputeTypeA = FloatC,
110  typename ComputeTypeB = ComputeTypeA>
112 {
113  static constexpr auto I0 = Number<0>{};
114  static constexpr auto I1 = Number<1>{};
115  static constexpr auto I2 = Number<2>{};
116  static constexpr auto I3 = Number<3>{};
117  static constexpr auto I4 = Number<4>{};
118  static constexpr auto I5 = Number<5>{};
119  static constexpr auto I6 = Number<6>{};
120  static constexpr auto I7 = Number<7>{};
121 
122  // K1 should be Number<...>
123  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
124  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
125  static constexpr auto AK1Number = Number<AK1Value>{};
126  static constexpr auto BK1Number = Number<BK1Value>{};
127 
129 
130  __host__ static auto CalculateGridSize(index_t M, index_t N)
131  {
132  return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
133  }
134 
135  __host__ static auto CalculateMPadded(index_t M)
136  {
137  return math::integer_divide_ceil(M, MPerBlock) * MPerBlock;
138  }
139 
140  __host__ static auto CalculateNPadded(index_t N)
141  {
142  return math::integer_divide_ceil(N, NPerBlock) * NPerBlock;
143  }
144 
145  __host__ static auto CalculateKPadded(index_t K)
146  {
147  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
148  }
149 
150  __host__ static auto CalculateAK0(index_t K)
151  {
153 
154  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
155  GemmSpec == GemmSpecialization::MNKPadding ||
156  GemmSpec == GemmSpecialization::KPadding ||
157  GemmSpec == GemmSpecialization::NKPadding)
158  {
159  return CalculateKPadded(K) / AK1Value;
160  }
161  else
162  {
163  return K / AK1Value;
164  }
165  }
166 
167  __host__ static auto CalculateBK0(index_t K)
168  {
170 
171  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
172  GemmSpec == GemmSpecialization::MNKPadding ||
173  GemmSpec == GemmSpecialization::KPadding ||
174  GemmSpec == GemmSpecialization::MKPadding)
175  {
176  return CalculateKPadded(K) / BK1Value;
177  }
178  else
179  {
180  return K / BK1Value;
181  }
182  }
183 
184  __host__ static auto CalculateMBlock(index_t M)
185  {
186  return math::integer_divide_floor(M, MPerBlock);
187  }
188 
189  __host__ static auto CalculateNBlock(index_t N)
190  {
191  return math::integer_divide_floor(N, NPerBlock);
192  }
193 
194  __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
195  index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
196  {
197  const auto a_grid_desc_mraw_kraw = [&]() {
198  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
199  {
200  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
201  }
202  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
203  {
204  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
205  }
206  }();
207 
209 
210  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
211  GemmSpec == GemmSpecialization::MNKPadding)
212  {
213  // pad both M and K
214  const auto a_grid_desc_m_k =
215  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
217  make_right_pad_transform(K, KPad - K)),
220 
221  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
222  a_grid_desc_m_k,
227 
228  return a_grid_desc_ak0_m_ak1;
229  }
230  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
231  GemmSpec == GemmSpecialization::MNPadding)
232  {
233  // pad M, but not K
234  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
235  a_grid_desc_mraw_kraw,
237  make_right_pad_transform(M, MPad - M)),
240 
241  return a_grid_desc_ak0_m_ak1;
242  }
243  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
244  GemmSpec == GemmSpecialization::NKPadding)
245  {
246  // pad K, but not M
247  const auto a_grid_desc_m_k = transform_tensor_descriptor(
248  a_grid_desc_mraw_kraw,
252 
253  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
254  a_grid_desc_m_k,
259 
260  return a_grid_desc_ak0_m_ak1;
261  }
262  else
263  {
264  // not pad M or K
265  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
266  a_grid_desc_mraw_kraw,
271 
272  return a_grid_desc_ak0_m_ak1;
273  }
274  }
275 
276  __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
277  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
278  {
279  const auto b_grid_desc_nraw_kraw = [&]() {
281  {
282  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
283  }
285  {
286  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
287  }
288  }();
289 
291 
292  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
293  GemmSpec == GemmSpecialization::MNKPadding)
294  {
295  // pad both N and K
296  const auto b_grid_desc_n_k =
297  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
299  make_right_pad_transform(K, KPad - K)),
302 
303  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
304  b_grid_desc_n_k,
309 
310  return b_grid_desc_bk0_n_bk1;
311  }
312  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
313  GemmSpec == GemmSpecialization::MNPadding)
314  {
315  // pad N, but not K
316  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
317  b_grid_desc_nraw_kraw,
319  make_right_pad_transform(N, NPad - N)),
322 
323  return b_grid_desc_bk0_n_bk1;
324  }
325  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
326  GemmSpec == GemmSpecialization::MKPadding)
327  {
328  // pad K, but not N
329  const auto b_grid_desc_n_k = transform_tensor_descriptor(
330  b_grid_desc_nraw_kraw,
334 
335  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
336  b_grid_desc_n_k,
341 
342  return b_grid_desc_bk0_n_bk1;
343  }
344  else
345  {
346  // not pad N or K
347  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
348  b_grid_desc_nraw_kraw,
353 
354  return b_grid_desc_bk0_n_bk1;
355  }
356  }
357 
358  __host__ __device__ static auto
360  {
361  const auto c_grid_desc_mraw_nraw = [&]() {
363  {
364  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
365  }
367  {
368  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
369  }
370  }();
371 
373 
374  if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
375  GemmSpec == GemmSpecialization::MNKPadding)
376  {
377  // pad M and N
378  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
380  make_right_pad_transform(N, NPad - N)),
383  }
384  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
385  GemmSpec == GemmSpecialization::MKPadding)
386  {
387  // pad M, but not N
389  c_grid_desc_mraw_nraw,
393  }
394  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
395  GemmSpec == GemmSpecialization::NKPadding)
396  {
397  // pad N, but not M
399  c_grid_desc_mraw_nraw,
403  }
404  else
405  {
406  // not pad M or N
407  return c_grid_desc_mraw_nraw;
408  }
409  }
410 
411  struct Problem
412  {
413  __host__ Problem(index_t M_,
414  index_t N_,
415  index_t K_,
416  index_t StrideA_,
417  index_t StrideB_,
418  index_t StrideC_)
419  : M{M_},
420  N{N_},
421  K{K_},
422  StrideA{StrideA_},
423  StrideB{StrideB_},
424  StrideC{StrideC_},
428  AK0{CalculateAK0(K_)},
429  BK0{CalculateBK0(K_)},
430  MBlock{CalculateMBlock(M_)},
432  {
433  }
434 
435  __host__ void Print() const
436  {
437  std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
438  << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
439  << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
440  << "KP:" << KPadded << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", "
441  << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl;
442  }
443 
457  };
458 
459  // Argument
461  {
462  __host__ Argument(const FloatA* p_a_grid_,
463  const FloatB* p_b_grid_,
464  FloatC* p_c_grid_,
465  index_t M_,
466  index_t N_,
467  index_t K_,
468  index_t StrideA_,
469  index_t StrideB_,
470  index_t StrideC_)
471  : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
472  p_a_grid{p_a_grid_},
473  p_b_grid{p_b_grid_},
474  p_c_grid{p_c_grid_}
475  {
476  }
477 
478  const FloatA* p_a_grid;
479  const FloatB* p_b_grid;
480  FloatC* p_c_grid;
481  };
482 
483  // FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
485  decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
486 
487  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
488  {
489  // A matrix in LDS memory, dst of blockwise copy
493  }
494 
495  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
496  {
497  // B matrix in LDS memory, dst of blockwise copy
501  }
502 
504  {
505  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
506  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
507 
508  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
510  make_tuple(I1,
512  I1,
514 
515  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
516  }
517 
518  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
519  {
520  // LDS allocation for A and B: be careful of alignment
521  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
522  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
523 
524  // lds max alignment
525  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
526 
527  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
528  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
529 
530  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
531  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
532 
533  // LDS allocation for C shuffle in LDS
534  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
536 
537  constexpr auto c_block_size =
538  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
539 
540  return math::max((a_block_space_size_aligned * sizeof(ComputeTypeA) +
541  b_block_space_size_aligned * sizeof(ComputeTypeB)),
542  c_block_size * sizeof(FloatCShuffle));
543  }
544 
545  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
546  __host__ static constexpr bool CheckValidity(const Problem& problem)
547  {
548  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
549  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
550  "Invalid tuning param!");
551 
552  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
556  {
557  if(!(problem.M % MPerBlock == 0))
558  {
559  return false;
560  }
561  }
562 
563  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
567  {
568  if(!(problem.N % NPerBlock == 0))
569  {
570  return false;
571  }
572  }
573 
578  {
579  if(!(CalculateKPadded(problem.K) % AK1Value == 0) ||
580  !(CalculateKPadded(problem.K) % BK1Value == 0))
581  {
582  return false;
583  }
584  }
585  else
586  {
587  if(!(problem.K % AK1Value == 0) || !(problem.K % BK1Value == 0))
588  {
589  return false;
590  }
591  }
592 
594  {
595  if(problem.K % ABlockTransferSrcScalarPerVector != 0)
596  {
597  return false;
598  }
599  }
600  else
601  {
602  if(problem.M % ABlockTransferSrcScalarPerVector != 0)
603  {
604  return false;
605  }
606  }
607 
609  {
610  if(problem.N % BBlockTransferSrcScalarPerVector != 0)
611  {
612  return false;
613  }
614  }
615  else
616  {
617  if(problem.K % BBlockTransferSrcScalarPerVector != 0)
618  {
619  return false;
620  }
621  }
622 
624  {
625  if(problem.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
626  {
627  return false;
628  }
629  }
630  else
631  {
632  if(problem.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
633  {
634  return false;
635  }
636  }
637 
638  // check gridwise gemm pipeline
639  const auto num_k_loop = (CalculateAK0(problem.K) * AK1Value) / KPerBlock;
640 
641  if(!GridwiseGemmPipe::IsSupported(num_k_loop))
642  {
643  return false;
644  }
645 
646  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
647  return true;
648  }
649 
650  __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
651  {
652  const index_t num_loop = K / KPerBlock;
653 
654  return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
655  }
656 
657  template <typename CGridDesc>
659  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
660  {
661  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
662  c_grid_desc_m_n,
667 
668  return c_grid_desc_mblock_mperblock_nblock_nperblock;
669  }
670 
671  // return block_id to C matrix tile idx (m0, n0) mapping
673 
674  template <bool HasMainKBlockLoop>
675  __device__ static void Run(const FloatA* __restrict__ p_a_grid,
676  const FloatB* __restrict__ p_b_grid,
677  FloatC* __restrict__ p_c_grid,
678  void* __restrict__ p_shared,
679  const Problem& problem)
680  {
681  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
682  problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
683  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
684  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
685  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
686  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
687 
688  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
690  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
691 
692  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
693  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
694  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
695  p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
696  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
697  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
698 
699  const AElementwiseOperation a_element_op{};
700  const BElementwiseOperation b_element_op{};
701  const CElementwiseOperation c_element_op{};
702 
703  // divide block work by [M, N]
704  const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N};
705 
706  const auto block_work_idx =
707  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
708 
709  if(!block_2_ctile_map.ValidCTileIndex(
710  block_work_idx,
711  make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
712  c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
713  {
714  return;
715  }
716 
717  // HACK: this force m/n_block_data_idx_on_grid into SGPR
718  const index_t m_block_data_idx_on_grid =
719  __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
720 
721  const index_t n_block_data_idx_on_grid =
722  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
723 
724  // lds max alignment
725  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
726 
727  // A matrix in LDS memory, dst of blockwise copy
728  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
729 
730  // B matrix in LDS memory, dst of blockwise copy
731  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
732 
733  // A matrix blockwise copy
734  auto a_blockwise_copy =
736  AElementwiseOperation,
740  ABlockTransferThreadClusterLengths_AK0_M_AK1,
741  ABlockTransferThreadClusterArrangeOrder,
742  FloatA,
743  ComputeTypeA,
744  decltype(a_grid_desc_ak0_m_ak1),
745  decltype(a_block_desc_ak0_m_ak1),
746  ABlockTransferSrcAccessOrder,
748  ABlockTransferSrcVectorDim,
749  2,
750  ABlockTransferSrcScalarPerVector,
751  ABlockTransferDstScalarPerVector_AK1,
752  1,
753  1,
754  AThreadTransferSrcResetCoordinateAfterRun,
755  true,
756  NumGemmKPrefetchStage>(
757  a_grid_desc_ak0_m_ak1,
758  make_multi_index(0, m_block_data_idx_on_grid, 0),
759  a_element_op,
760  a_block_desc_ak0_m_ak1,
761  make_multi_index(0, 0, 0),
763 
764  // B matrix blockwise copy
765  auto b_blockwise_copy =
767  BElementwiseOperation,
771  BBlockTransferThreadClusterLengths_BK0_N_BK1,
772  BBlockTransferThreadClusterArrangeOrder,
773  FloatB,
774  ComputeTypeB,
775  decltype(b_grid_desc_bk0_n_bk1),
776  decltype(b_block_desc_bk0_n_bk1),
777  BBlockTransferSrcAccessOrder,
779  BBlockTransferSrcVectorDim,
780  2,
781  BBlockTransferSrcScalarPerVector,
782  BBlockTransferDstScalarPerVector_BK1,
783  1,
784  1,
785  BThreadTransferSrcResetCoordinateAfterRun,
786  true,
787  NumGemmKPrefetchStage>(
788  b_grid_desc_bk0_n_bk1,
789  make_multi_index(0, n_block_data_idx_on_grid, 0),
790  b_element_op,
791  b_block_desc_bk0_n_bk1,
792  make_multi_index(0, 0, 0),
794 
795  // GEMM definition
796  // c_mtx += transpose(a_mtx) * b_mtx
797  // a_mtx[K0PerBlock, MPerBlock] is in LDS
798  // b_mtx[K0PerBlock, NPerBlock] is in LDS
799  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
800  // register
801  // sanity check
802  constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
803  constexpr bool is_single_rate_mfma =
805  lcm_AK1_BK1 <= 4) ||
806  (is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
808  lcm_AK1_BK1 < 32))
809  ? true
810  : false;
811  constexpr auto is_scale_mfma = false;
812  constexpr index_t KPack = math::max(lcm_AK1_BK1,
813  MfmaSelector<ComputeTypeA,
814  MPerXdl,
815  NPerXdl,
816  ComputeTypeB,
817  is_single_rate_mfma,
818  is_scale_mfma>::selected_mfma.k_per_blk);
819 
821  BlockSize,
822  ComputeTypeA,
823  ComputeTypeB,
824  FloatGemmAcc,
825  decltype(a_block_desc_ak0_m_ak1),
826  decltype(b_block_desc_bk0_n_bk1),
827  MPerXdl,
828  NPerXdl,
829  MXdlPerWave,
830  NXdlPerWave,
831  KPack,
832  LoopSched>();
833 
834  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
835 
836  // LDS allocation for A and B: be careful of alignment
837  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
838  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
839 
840  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
841  static_cast<ComputeTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
842 
843  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
844  static_cast<ComputeTypeB*>(p_shared) + a_block_space_size_aligned,
845  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
846 
847  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
848  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
849 
850  // gridwise GEMM pipeline
851  static_assert(std::is_default_constructible_v<GridwiseGemmPipe>);
852  const auto gridwise_gemm_pipeline = GridwiseGemmPipe{};
853 
854  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
855  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
856  KPerBlock);
857 
858  gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
859  a_block_desc_ak0_m_ak1,
860  a_blockwise_copy,
861  a_grid_buf,
862  a_block_buf,
863  a_block_slice_copy_step,
864  b_grid_desc_bk0_n_bk1,
865  b_block_desc_bk0_n_bk1,
866  b_blockwise_copy,
867  b_grid_buf,
868  b_block_buf,
869  b_block_slice_copy_step,
870  blockwise_gemm,
871  c_thread_buf,
872  num_k_block_main_loop);
873 
874  // shuffle C and write out
875  {
876  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
877  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
878  "wrong!");
879 
880  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
881  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
882 
883  // TODO: hacky, fix it!
884  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
885  blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
886 
887  // TODO: hacky, fix it!
888  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
889  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
890  blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
891 
892  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
893  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
894  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
895  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
896  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
897  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
898  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
899  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
900 
901  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
903 
904  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
905  static_cast<FloatCShuffle*>(p_shared),
906  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
907 
908  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
909  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
910  make_tuple(
913  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
914  M1, // M1 = MWave
915  M2, // M2 * M3 * M4 = MPerXdl
916  M3,
917  M4)),
920  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
921  N1, // N1 = NWave
922  N2))), // N2 = NPerXdl
924  make_tuple(
926 
927  // calculate origin of thread output tensor on global memory
928  // blockwise GEMM c matrix starting index
929  const auto c_thread_mtx_on_block =
930  blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
931 
932  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
933  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
934 
935  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
937  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
940 
941  const auto m_thread_data_on_block_idx =
942  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
943  make_multi_index(m_thread_data_on_block));
944 
945  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
950 
951  const auto n_thread_data_on_block_idx =
952  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
953  make_multi_index(n_thread_data_on_block));
954 
955  // shuffle: threadwise copy C from VGPR to LDS
956  auto c_thread_copy_vgpr_to_lds =
958  FloatCShuffle,
959  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
960  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
962  Sequence<CShuffleMXdlPerWavePerShuffle,
963  CShuffleNXdlPerWavePerShuffle,
964  I1,
965  I1,
966  M2,
967  I1,
968  M4,
969  I1>,
971  7,
972  1,
974  1,
975  true>{
976  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
978  0,
979  m_thread_data_on_block_idx[I1],
980  n_thread_data_on_block_idx[I1],
981  m_thread_data_on_block_idx[I2],
982  m_thread_data_on_block_idx[I3],
983  m_thread_data_on_block_idx[I4],
984  n_thread_data_on_block_idx[I2]),
986 
987  // shuffle: blockwise copy C from LDS to global
988  auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
989  ThisThreadBlock, // ThreadGroup
990  CElementwiseOperation, // ElementwiseOperation,
991  CGlobalMemoryDataOperation, // DstInMemOp,
992  Sequence<1,
993  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
994  1,
995  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
996  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
997  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
998  FloatCShuffle, // typename SrcData,
999  FloatC, // typename DstData,
1000  decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1001  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1002  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1003  3, // index_t VectorDim,
1004  CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1005  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1006  false> // bool ThreadTransferDstResetCoordinateAfterRun>
1007  {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1008  make_multi_index(0, 0, 0, 0),
1009  c_grid_desc_mblock_mperblock_nblock_nperblock,
1010  make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
1011  c_element_op};
1012 
1013  // space filling curve for threadwise C in VGPR
1014  constexpr auto sfc_c_vgpr =
1017  Sequence<CShuffleMXdlPerWavePerShuffle,
1018  CShuffleNXdlPerWavePerShuffle,
1019  1,
1020  1,
1021  M2,
1022  1,
1023  M4,
1024  1>>{};
1025 
1026  // space filling curve for shuffled blockwise C in global mem
1027  constexpr auto sfc_c_global =
1030  Sequence<1,
1031  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1032  1,
1033  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1034 
1035  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1036 
1037  static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1038 
1039  static_for<0, num_access, 1>{}([&](auto access_id) {
1040  // make sure it's safe to write to LDS
1041  block_sync_lds();
1042 
1043  // each thread write its data from VGPR to LDS
1044  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1045  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1046  c_thread_buf,
1047  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1048  c_shuffle_block_buf);
1049 
1050  // make sure it's safe to read from LDS
1051  block_sync_lds();
1052 
1053  // each block copy its data from LDS to global
1054  c_shuffle_block_copy_lds_to_global.Run(
1055  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1056  c_shuffle_block_buf,
1057  c_grid_desc_mblock_mperblock_nblock_nperblock,
1058  c_grid_buf);
1059 
1060  if constexpr(access_id < num_access - 1)
1061  {
1062  constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1063 
1064  // move on C
1065  c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1066  c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1067  }
1068  });
1069  }
1070  }
1071 };
1072 
1073 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:30
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ auto integer_divide_floor(X x, Y y)
Definition: math.hpp:66
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition: blockwise_gemm_xdlops.hpp:605
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:276
__global__ void kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:25
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:298
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:18
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:461
__host__ Argument(const FloatA *p_a_grid_, const FloatB *p_b_grid_, FloatC *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:462
const FloatB * p_b_grid
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:479
const FloatA * p_a_grid
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:478
FloatC * p_c_grid
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:480
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:412
index_t NPadded
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:451
index_t N
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:445
index_t StrideA
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:447
index_t StrideB
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:448
index_t NBlock
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:456
index_t BK0
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:454
index_t KPadded
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:452
index_t MPadded
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:450
index_t K
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:446
index_t MBlock
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:455
__host__ void Print() const
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:435
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:413
index_t M
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:444
index_t AK0
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:453
index_t StrideC
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:449
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:112
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:495
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:546
static constexpr auto BK1Number
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:126
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:140
static __host__ auto CalculateAK0(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:150
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:135
static constexpr auto I6
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:119
static constexpr auto I1
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:114
static constexpr auto I2
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:115
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:650
static constexpr auto I5
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:118
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:359
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:128
static __host__ auto CalculateBK0(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:167
static __device__ void Run(const FloatA *__restrict__ p_a_grid, const FloatB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, void *__restrict__ p_shared, const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:675
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:189
static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:276
static constexpr auto BK0Number
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:124
static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:194
static constexpr auto I7
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:120
static constexpr auto I3
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:116
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:145
static constexpr auto AK1Number
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:125
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:184
static constexpr auto AK0Number
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:123
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:518
static constexpr auto I0
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:113
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:503
static constexpr auto I4
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:117
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:487
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage, LoopSched >())> GridwiseGemmPipe
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:485
static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:658
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:130
Definition: xdlops_gemm.hpp:1126
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:334