/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp Source File
gridwise_gemm_xdl_cshuffle_v1.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
17 
18 namespace ck {
19 
20 template <typename GridwiseGemm, bool HasMainKBlockLoop>
21 __global__ void
22 #if CK_USE_LAUNCH_BOUNDS
24 #endif
25  kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg)
26 {
27 #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
28  defined(__gfx12__)
29  if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
30  {
31  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
32 
33  GridwiseGemm::template Run<HasMainKBlockLoop>(
34  karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg);
35  }
36 #else
37  ignore = karg;
38 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
39 }
40 
41 template <typename GridwiseGemm,
42  typename FloatA,
43  typename FloatB,
44  typename FloatC,
45  bool HasMainKBlockLoop>
46 __global__ void
47 #if CK_USE_LAUNCH_BOUNDS
49 #endif
50  kernel_gemm_xdl_cshuffle_v1(const FloatA* __restrict__ p_a_grid,
51  const FloatB* __restrict__ p_b_grid,
52  FloatC* __restrict__ p_c_grid,
53  typename GridwiseGemm::Problem problem)
54 {
55 #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
56  defined(__gfx12__)
57  if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
58  {
59  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
60 
61  GridwiseGemm::template Run<HasMainKBlockLoop>(
62  p_a_grid, p_b_grid, p_c_grid, p_shared, problem);
63  }
64 #else
65  ignore = p_a_grid;
66  ignore = p_b_grid;
67  ignore = p_c_grid;
68  ignore = problem;
69 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
70 }
71 
72 template <typename ALayout,
73  typename BLayout,
74  typename CLayout,
75  typename FloatA,
76  typename FloatB,
77  typename FloatGemmAcc,
78  typename FloatCShuffle,
79  typename FloatC,
80  typename AElementwiseOperation,
81  typename BElementwiseOperation,
82  typename CElementwiseOperation,
84  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
85  index_t NumGemmKPrefetchStage,
86  index_t BlockSize,
87  index_t MPerBlock,
88  index_t NPerBlock,
89  index_t KPerBlock,
90  index_t AK1Value,
91  index_t BK1Value,
92  index_t MPerXdl,
93  index_t NPerXdl,
94  index_t MXdlPerWave,
95  index_t NXdlPerWave,
96  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
97  typename ABlockTransferThreadClusterArrangeOrder,
98  typename ABlockTransferSrcAccessOrder,
99  index_t ABlockTransferSrcVectorDim,
100  index_t ABlockTransferSrcScalarPerVector,
101  index_t ABlockTransferDstScalarPerVector_AK1,
102  bool AThreadTransferSrcResetCoordinateAfterRun,
103  index_t ABlockLdsExtraM,
104  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
105  typename BBlockTransferThreadClusterArrangeOrder,
106  typename BBlockTransferSrcAccessOrder,
107  index_t BBlockTransferSrcVectorDim,
108  index_t BBlockTransferSrcScalarPerVector,
109  index_t BBlockTransferDstScalarPerVector_BK1,
110  bool BThreadTransferSrcResetCoordinateAfterRun,
111  index_t BBlockLdsExtraN,
112  index_t CShuffleMXdlPerWavePerShuffle,
113  index_t CShuffleNXdlPerWavePerShuffle,
114  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
115  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
116  LoopScheduler LoopSched,
117  PipelineVersion PipelineVer = PipelineVersion::v1,
118  typename ComputeTypeA = FloatC,
119  typename ComputeTypeB = ComputeTypeA>
121 {
122  static constexpr auto I0 = Number<0>{};
123  static constexpr auto I1 = Number<1>{};
124  static constexpr auto I2 = Number<2>{};
125  static constexpr auto I3 = Number<3>{};
126  static constexpr auto I4 = Number<4>{};
127  static constexpr auto I5 = Number<5>{};
128  static constexpr auto I6 = Number<6>{};
129  static constexpr auto I7 = Number<7>{};
130 
131  // K1 should be Number<...>
132  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
133  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
134  static constexpr auto AK1Number = Number<AK1Value>{};
135  static constexpr auto BK1Number = Number<BK1Value>{};
136 
138 
139  __host__ static auto CalculateGridSize(index_t M, index_t N)
140  {
141  return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
142  }
143 
144  __host__ static auto CalculateMPadded(index_t M)
145  {
146  return math::integer_divide_ceil(M, MPerBlock) * MPerBlock;
147  }
148 
149  __host__ static auto CalculateNPadded(index_t N)
150  {
151  return math::integer_divide_ceil(N, NPerBlock) * NPerBlock;
152  }
153 
154  __host__ static auto CalculateKPadded(index_t K)
155  {
156  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
157  }
158 
159  __host__ static auto CalculateAK0(index_t K)
160  {
162 
163  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
164  GemmSpec == GemmSpecialization::MNKPadding ||
165  GemmSpec == GemmSpecialization::KPadding ||
166  GemmSpec == GemmSpecialization::NKPadding)
167  {
168  return CalculateKPadded(K) / AK1Value;
169  }
170  else
171  {
172  return K / AK1Value;
173  }
174  }
175 
176  __host__ static auto CalculateBK0(index_t K)
177  {
179 
180  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
181  GemmSpec == GemmSpecialization::MNKPadding ||
182  GemmSpec == GemmSpecialization::KPadding ||
183  GemmSpec == GemmSpecialization::MKPadding)
184  {
185  return CalculateKPadded(K) / BK1Value;
186  }
187  else
188  {
189  return K / BK1Value;
190  }
191  }
192 
193  __host__ static auto CalculateMBlock(index_t M)
194  {
195  return math::integer_divide_floor(M, MPerBlock);
196  }
197 
198  __host__ static auto CalculateNBlock(index_t N)
199  {
200  return math::integer_divide_floor(N, NPerBlock);
201  }
202 
203  __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
204  index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
205  {
206  const auto a_grid_desc_mraw_kraw = [&]() {
207  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
208  {
209  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
210  }
211  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
212  {
213  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
214  }
215  }();
216 
218 
219  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
220  GemmSpec == GemmSpecialization::MNKPadding)
221  {
222  // pad both M and K
223  const auto a_grid_desc_m_k =
224  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
226  make_right_pad_transform(K, KPad - K)),
229 
230  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
231  a_grid_desc_m_k,
236 
237  return a_grid_desc_ak0_m_ak1;
238  }
239  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
240  GemmSpec == GemmSpecialization::MNPadding)
241  {
242  // pad M, but not K
243  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
244  a_grid_desc_mraw_kraw,
246  make_right_pad_transform(M, MPad - M)),
249 
250  return a_grid_desc_ak0_m_ak1;
251  }
252  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
253  GemmSpec == GemmSpecialization::NKPadding)
254  {
255  // pad K, but not M
256  const auto a_grid_desc_m_k = transform_tensor_descriptor(
257  a_grid_desc_mraw_kraw,
261 
262  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
263  a_grid_desc_m_k,
268 
269  return a_grid_desc_ak0_m_ak1;
270  }
271  else
272  {
273  // not pad M or K
274  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
275  a_grid_desc_mraw_kraw,
280 
281  return a_grid_desc_ak0_m_ak1;
282  }
283  }
284 
285  __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
286  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
287  {
288  const auto b_grid_desc_nraw_kraw = [&]() {
290  {
291  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
292  }
294  {
295  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
296  }
297  }();
298 
300 
301  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
302  GemmSpec == GemmSpecialization::MNKPadding)
303  {
304  // pad both N and K
305  const auto b_grid_desc_n_k =
306  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
308  make_right_pad_transform(K, KPad - K)),
311 
312  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
313  b_grid_desc_n_k,
318 
319  return b_grid_desc_bk0_n_bk1;
320  }
321  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
322  GemmSpec == GemmSpecialization::MNPadding)
323  {
324  // pad N, but not K
325  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
326  b_grid_desc_nraw_kraw,
328  make_right_pad_transform(N, NPad - N)),
331 
332  return b_grid_desc_bk0_n_bk1;
333  }
334  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
335  GemmSpec == GemmSpecialization::MKPadding)
336  {
337  // pad K, but not N
338  const auto b_grid_desc_n_k = transform_tensor_descriptor(
339  b_grid_desc_nraw_kraw,
343 
344  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
345  b_grid_desc_n_k,
350 
351  return b_grid_desc_bk0_n_bk1;
352  }
353  else
354  {
355  // not pad N or K
356  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
357  b_grid_desc_nraw_kraw,
362 
363  return b_grid_desc_bk0_n_bk1;
364  }
365  }
366 
367  __host__ __device__ static auto
369  {
370  const auto c_grid_desc_mraw_nraw = [&]() {
372  {
373  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
374  }
376  {
377  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
378  }
379  }();
380 
382 
383  if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
384  GemmSpec == GemmSpecialization::MNKPadding)
385  {
386  // pad M and N
387  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
389  make_right_pad_transform(N, NPad - N)),
392  }
393  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
394  GemmSpec == GemmSpecialization::MKPadding)
395  {
396  // pad M, but not N
398  c_grid_desc_mraw_nraw,
402  }
403  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
404  GemmSpec == GemmSpecialization::NKPadding)
405  {
406  // pad N, but not M
408  c_grid_desc_mraw_nraw,
412  }
413  else
414  {
415  // not pad M or N
416  return c_grid_desc_mraw_nraw;
417  }
418  }
419 
420  struct Problem
421  {
422  __host__ Problem(index_t M_,
423  index_t N_,
424  index_t K_,
425  index_t StrideA_,
426  index_t StrideB_,
427  index_t StrideC_)
428  : M{M_},
429  N{N_},
430  K{K_},
431  StrideA{StrideA_},
432  StrideB{StrideB_},
433  StrideC{StrideC_},
437  AK0{CalculateAK0(K_)},
438  BK0{CalculateBK0(K_)},
439  MBlock{CalculateMBlock(M_)},
441  {
442  }
443 
444  __host__ void Print() const
445  {
446  std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
447  << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
448  << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
449  << "KP:" << KPadded << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", "
450  << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl;
451  }
452 
466  };
467 
468  // Argument
470  {
471  __host__ Argument(const FloatA* p_a_grid_,
472  const FloatB* p_b_grid_,
473  FloatC* p_c_grid_,
474  index_t M_,
475  index_t N_,
476  index_t K_,
477  index_t StrideA_,
478  index_t StrideB_,
479  index_t StrideC_)
480  : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
481  p_a_grid{p_a_grid_},
482  p_b_grid{p_b_grid_},
483  p_c_grid{p_c_grid_}
484  {
485  }
486 
487  const FloatA* p_a_grid;
488  const FloatB* p_b_grid;
489  FloatC* p_c_grid;
490  };
491 
492  // FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
494  decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
495 
496  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
497  {
498  // A matrix in LDS memory, dst of blockwise copy
502  }
503 
504  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
505  {
506  // B matrix in LDS memory, dst of blockwise copy
510  }
511 
513  {
514  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
515  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
516 
517  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
519  make_tuple(I1,
521  I1,
523 
524  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
525  }
526 
527  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
528  {
529  // LDS allocation for A and B: be careful of alignment
530  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
531  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
532 
533  // lds max alignment
534  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
535 
536  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
537  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
538 
539  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
540  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
541 
542  // LDS allocation for C shuffle in LDS
543  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
545 
546  constexpr auto c_block_size =
547  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
548 
549  return math::max((a_block_space_size_aligned * sizeof(ComputeTypeA) +
550  b_block_space_size_aligned * sizeof(ComputeTypeB)),
551  c_block_size * sizeof(FloatCShuffle));
552  }
553 
554  template <
555  InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set>
556  __device__ static bool constexpr IsValidCompilationParameter()
557  {
558  return ck::tensor_operation::device::IsValidGemmCompilationParameter<
559  BlockSize,
560  MPerBlock,
561  NPerBlock,
562  MPerXdl,
563  NPerXdl,
564  MXdlPerWave,
565  NXdlPerWave,
566  FloatC,
567  CGlobalMemoryDataOperation>();
568  }
569 
570  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
571  __host__ static constexpr bool CheckValidity(const Problem& problem)
572  {
573  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
574  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
575  "Invalid tuning param!");
576 
577  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
581  {
582  if(!(problem.M % MPerBlock == 0))
583  {
584  return false;
585  }
586  }
587 
588  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
592  {
593  if(!(problem.N % NPerBlock == 0))
594  {
595  return false;
596  }
597  }
598 
603  {
604  if(!(CalculateKPadded(problem.K) % AK1Value == 0) ||
605  !(CalculateKPadded(problem.K) % BK1Value == 0))
606  {
607  return false;
608  }
609  }
610  else
611  {
612  if(!(problem.K % AK1Value == 0) || !(problem.K % BK1Value == 0))
613  {
614  return false;
615  }
616  }
617 
619  {
620  if(problem.K % ABlockTransferSrcScalarPerVector != 0)
621  {
622  return false;
623  }
624  }
625  else
626  {
627  if(problem.M % ABlockTransferSrcScalarPerVector != 0)
628  {
629  return false;
630  }
631  }
632 
634  {
635  if(problem.N % BBlockTransferSrcScalarPerVector != 0)
636  {
637  return false;
638  }
639  }
640  else
641  {
642  if(problem.K % BBlockTransferSrcScalarPerVector != 0)
643  {
644  return false;
645  }
646  }
647 
649  {
650  if(problem.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
651  {
652  return false;
653  }
654  }
655  else
656  {
657  if(problem.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
658  {
659  return false;
660  }
661  }
662 
663  // check gridwise gemm pipeline
664  const auto num_k_loop = (CalculateAK0(problem.K) * AK1Value) / KPerBlock;
665 
666  if(!GridwiseGemmPipe::IsSupported(num_k_loop))
667  {
668  return false;
669  }
670 
671  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
672  return true;
673  }
674 
675  __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
676  {
677  const index_t num_loop = K / KPerBlock;
678 
679  return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
680  }
681 
682  template <typename CGridDesc>
684  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
685  {
686  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
687  c_grid_desc_m_n,
692 
693  return c_grid_desc_mblock_mperblock_nblock_nperblock;
694  }
695 
696  // return block_id to C matrix tile idx (m0, n0) mapping
698 
699  template <bool HasMainKBlockLoop>
700  __device__ static void Run(const FloatA* __restrict__ p_a_grid,
701  const FloatB* __restrict__ p_b_grid,
702  FloatC* __restrict__ p_c_grid,
703  void* __restrict__ p_shared,
704  const Problem& problem)
705  {
706  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
707  problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
708  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
709  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
710  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
711  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
712 
713  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
715  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
716 
717  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
718  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
719  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
720  p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
721  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
722  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
723 
724  const AElementwiseOperation a_element_op{};
725  const BElementwiseOperation b_element_op{};
726  const CElementwiseOperation c_element_op{};
727 
728  // divide block work by [M, N]
729  const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N};
730 
731  const auto block_work_idx =
732  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
733 
734  if(!block_2_ctile_map.ValidCTileIndex(
735  block_work_idx,
736  make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
737  c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
738  {
739  return;
740  }
741 
742  // HACK: this force m/n_block_data_idx_on_grid into SGPR
743  const index_t m_block_data_idx_on_grid =
744  __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
745 
746  const index_t n_block_data_idx_on_grid =
747  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
748 
749  // lds max alignment
750  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
751 
752  // A matrix in LDS memory, dst of blockwise copy
753  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
754 
755  // B matrix in LDS memory, dst of blockwise copy
756  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
757 
758  // A matrix blockwise copy
759  auto a_blockwise_copy =
761  AElementwiseOperation,
765  ABlockTransferThreadClusterLengths_AK0_M_AK1,
766  ABlockTransferThreadClusterArrangeOrder,
767  FloatA,
768  ComputeTypeA,
769  decltype(a_grid_desc_ak0_m_ak1),
770  decltype(a_block_desc_ak0_m_ak1),
771  ABlockTransferSrcAccessOrder,
773  ABlockTransferSrcVectorDim,
774  2,
775  ABlockTransferSrcScalarPerVector,
776  ABlockTransferDstScalarPerVector_AK1,
777  1,
778  1,
779  AThreadTransferSrcResetCoordinateAfterRun,
780  true,
781  NumGemmKPrefetchStage>(
782  a_grid_desc_ak0_m_ak1,
783  make_multi_index(0, m_block_data_idx_on_grid, 0),
784  a_element_op,
785  a_block_desc_ak0_m_ak1,
786  make_multi_index(0, 0, 0),
788 
789  // B matrix blockwise copy
790  auto b_blockwise_copy =
792  BElementwiseOperation,
796  BBlockTransferThreadClusterLengths_BK0_N_BK1,
797  BBlockTransferThreadClusterArrangeOrder,
798  FloatB,
799  ComputeTypeB,
800  decltype(b_grid_desc_bk0_n_bk1),
801  decltype(b_block_desc_bk0_n_bk1),
802  BBlockTransferSrcAccessOrder,
804  BBlockTransferSrcVectorDim,
805  2,
806  BBlockTransferSrcScalarPerVector,
807  BBlockTransferDstScalarPerVector_BK1,
808  1,
809  1,
810  BThreadTransferSrcResetCoordinateAfterRun,
811  true,
812  NumGemmKPrefetchStage>(
813  b_grid_desc_bk0_n_bk1,
814  make_multi_index(0, n_block_data_idx_on_grid, 0),
815  b_element_op,
816  b_block_desc_bk0_n_bk1,
817  make_multi_index(0, 0, 0),
819 
820  // GEMM definition
821  // c_mtx += transpose(a_mtx) * b_mtx
822  // a_mtx[K0PerBlock, MPerBlock] is in LDS
823  // b_mtx[K0PerBlock, NPerBlock] is in LDS
824  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
825  // register
826  // sanity check
827  constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
828  constexpr bool is_single_rate_mfma =
830  lcm_AK1_BK1 <= 4) ||
831  (is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
833  lcm_AK1_BK1 < 32))
834  ? true
835  : false;
836  constexpr auto is_scale_mfma = false;
837  constexpr index_t KPack = math::max(lcm_AK1_BK1,
838  MfmaSelector<ComputeTypeA,
839  MPerXdl,
840  NPerXdl,
841  ComputeTypeB,
842  is_single_rate_mfma,
843  is_scale_mfma>::selected_mfma.k_per_blk);
844 
846  BlockSize,
847  ComputeTypeA,
848  ComputeTypeB,
849  FloatGemmAcc,
850  decltype(a_block_desc_ak0_m_ak1),
851  decltype(b_block_desc_bk0_n_bk1),
852  MPerXdl,
853  NPerXdl,
854  MXdlPerWave,
855  NXdlPerWave,
856  KPack,
857  LoopSched>();
858 
859  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
860 
861  // LDS allocation for A and B: be careful of alignment
862  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
863  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
864 
865  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
866  static_cast<ComputeTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
867 
868  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
869  static_cast<ComputeTypeB*>(p_shared) + a_block_space_size_aligned,
870  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
871 
872  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
873  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
874 
875  // gridwise GEMM pipeline
876  static_assert(std::is_default_constructible_v<GridwiseGemmPipe>);
877  const auto gridwise_gemm_pipeline = GridwiseGemmPipe{};
878 
879  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
880  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
881  KPerBlock);
882 
883  gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
884  a_block_desc_ak0_m_ak1,
885  a_blockwise_copy,
886  a_grid_buf,
887  a_block_buf,
888  a_block_slice_copy_step,
889  b_grid_desc_bk0_n_bk1,
890  b_block_desc_bk0_n_bk1,
891  b_blockwise_copy,
892  b_grid_buf,
893  b_block_buf,
894  b_block_slice_copy_step,
895  blockwise_gemm,
896  c_thread_buf,
897  num_k_block_main_loop);
898 
899  // shuffle C and write out
900  {
901  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
902  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
903  "wrong!");
904 
905  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
906  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
907 
908  // TODO: hacky, fix it!
909  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
910  blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
911 
912  // TODO: hacky, fix it!
913  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
914  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
915  blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
916 
917  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
918  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
919  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
920  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
921  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
922  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
923  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
924  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
925 
926  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
928 
929  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
930  static_cast<FloatCShuffle*>(p_shared),
931  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
932 
933  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
934  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
935  make_tuple(
938  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
939  M1, // M1 = MWave
940  M2, // M2 * M3 * M4 = MPerXdl
941  M3,
942  M4)),
945  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
946  N1, // N1 = NWave
947  N2))), // N2 = NPerXdl
949  make_tuple(
951 
952  // calculate origin of thread output tensor on global memory
953  // blockwise GEMM c matrix starting index
954  const auto c_thread_mtx_on_block =
955  blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
956 
957  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
958  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
959 
960  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
962  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
965 
966  const auto m_thread_data_on_block_idx =
967  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
968  make_multi_index(m_thread_data_on_block));
969 
970  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
975 
976  const auto n_thread_data_on_block_idx =
977  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
978  make_multi_index(n_thread_data_on_block));
979 
980  // shuffle: threadwise copy C from VGPR to LDS
981  auto c_thread_copy_vgpr_to_lds =
983  FloatCShuffle,
984  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
985  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
987  Sequence<CShuffleMXdlPerWavePerShuffle,
988  CShuffleNXdlPerWavePerShuffle,
989  I1,
990  I1,
991  M2,
992  I1,
993  M4,
994  I1>,
996  7,
997  1,
999  1,
1000  true>{
1001  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1002  make_multi_index(0,
1003  0,
1004  m_thread_data_on_block_idx[I1],
1005  n_thread_data_on_block_idx[I1],
1006  m_thread_data_on_block_idx[I2],
1007  m_thread_data_on_block_idx[I3],
1008  m_thread_data_on_block_idx[I4],
1009  n_thread_data_on_block_idx[I2]),
1011 
1012  // shuffle: blockwise copy C from LDS to global
1013  auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1014  ThisThreadBlock, // ThreadGroup
1015  CElementwiseOperation, // ElementwiseOperation,
1016  CGlobalMemoryDataOperation, // DstInMemOp,
1017  Sequence<1,
1018  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1019  1,
1020  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1021  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1022  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1023  FloatCShuffle, // typename SrcData,
1024  FloatC, // typename DstData,
1025  decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1026  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1027  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1028  3, // index_t VectorDim,
1029  CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1030  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1031  false> // bool ThreadTransferDstResetCoordinateAfterRun>
1032  {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1033  make_multi_index(0, 0, 0, 0),
1034  c_grid_desc_mblock_mperblock_nblock_nperblock,
1035  make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
1036  c_element_op};
1037 
1038  // space filling curve for threadwise C in VGPR
1039  constexpr auto sfc_c_vgpr =
1042  Sequence<CShuffleMXdlPerWavePerShuffle,
1043  CShuffleNXdlPerWavePerShuffle,
1044  1,
1045  1,
1046  M2,
1047  1,
1048  M4,
1049  1>>{};
1050 
1051  // space filling curve for shuffled blockwise C in global mem
1052  constexpr auto sfc_c_global =
1055  Sequence<1,
1056  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1057  1,
1058  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1059 
1060  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1061 
1062  static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1063 
1064  static_for<0, num_access, 1>{}([&](auto access_id) {
1065  // make sure it's safe to write to LDS
1066  block_sync_lds();
1067 
1068  // each thread write its data from VGPR to LDS
1069  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1070  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1071  c_thread_buf,
1072  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1073  c_shuffle_block_buf);
1074 
1075  // make sure it's safe to read from LDS
1076  block_sync_lds();
1077 
1078  // each block copy its data from LDS to global
1079  c_shuffle_block_copy_lds_to_global.Run(
1080  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1081  c_shuffle_block_buf,
1082  c_grid_desc_mblock_mperblock_nblock_nperblock,
1083  c_grid_buf);
1084 
1085  if constexpr(access_id < num_access - 1)
1086  {
1087  constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1088 
1089  // move on C
1090  c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1091  c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1092  }
1093  });
1094  }
1095  }
1096 };
1097 
1098 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ auto integer_divide_floor(X x, Y y)
Definition: math.hpp:66
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:268
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition: blockwise_gemm_xdlops.hpp:620
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:277
__global__ void kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:25
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:299
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:18
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:470
__host__ Argument(const FloatA *p_a_grid_, const FloatB *p_b_grid_, FloatC *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:471
const FloatB * p_b_grid
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:488
const FloatA * p_a_grid
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:487
FloatC * p_c_grid
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:489
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:421
index_t NPadded
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:460
index_t N
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:454
index_t StrideA
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:456
index_t StrideB
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:457
index_t NBlock
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:465
index_t BK0
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:463
index_t KPadded
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:461
index_t MPadded
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:459
index_t K
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:455
index_t MBlock
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:464
__host__ void Print() const
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:444
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:422
index_t M
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:453
index_t AK0
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:462
index_t StrideC
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:458
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:121
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:504
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:571
static constexpr auto BK1Number
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:135
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:149
static __host__ auto CalculateAK0(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:159
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:144
static constexpr auto I6
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:128
static constexpr auto I1
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:123
static constexpr auto I2
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:124
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:675
static constexpr auto I5
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:127
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:368
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:137
static __host__ auto CalculateBK0(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:176
static __device__ void Run(const FloatA *__restrict__ p_a_grid, const FloatB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, void *__restrict__ p_shared, const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:700
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:198
static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:285
static __device__ constexpr bool IsValidCompilationParameter()
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:556
static constexpr auto BK0Number
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:133
static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:203
static constexpr auto I7
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:129
static constexpr auto I3
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:125
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:154
static constexpr auto AK1Number
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:134
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:193
static constexpr auto AK0Number
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:132
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:527
static constexpr auto I0
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:122
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:512
static constexpr auto I4
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:126
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:496
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage, LoopSched >())> GridwiseGemmPipe
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:494
static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:683
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:139
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition: xdlops_gemm.hpp:1208
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: functional2.hpp:33
Definition: device_base.hpp:197
Definition: unary_element_wise_operation.hpp:334