/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp Source File
gridwise_gemm_xdlops_v2r4r2.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/env.hpp"
19 
20 namespace ck {
21 
22 template <typename GridwiseGemm,
23  bool HasMainKBlockLoop,
24  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
25  typename Block2CTileMap,
26  typename AElementwiseOperation,
27  typename BElementwiseOperation,
28  typename CElementwiseOperation>
29 __global__ void
30 #if CK_USE_LAUNCH_BOUNDS
32 #endif
33  kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg,
34  const Block2CTileMap& b2c_map,
35  const AElementwiseOperation a_element_op,
36  const BElementwiseOperation b_element_op,
37  const CElementwiseOperation c_element_op)
38 {
39 #if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__))
40  constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
41 
42  __shared__ uint8_t p_shared[shared_size];
43 
44  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
45  karg, static_cast<void*>(p_shared), b2c_map, a_element_op, b_element_op, c_element_op);
46 #else
47  ignore = karg;
48  ignore = b2c_map;
49  ignore = a_element_op;
50  ignore = b_element_op;
51  ignore = c_element_op;
52 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
53 }
54 
55 template <index_t BlockSize,
56  typename FloatA,
57  typename FloatB,
58  typename FloatAcc,
59  typename FloatC,
60  typename ALayout,
61  typename BLayout,
62  typename CLayout,
63  typename AElementwiseOperation,
64  typename BElementwiseOperation,
65  typename CElementwiseOperation,
67  index_t NumGemmKPrefetchStage,
68  index_t MPerBlock,
69  index_t NPerBlock,
70  index_t K0PerBlock,
71  index_t MPerXDL,
72  index_t NPerXDL,
73  index_t K1Value,
74  index_t MRepeat,
75  index_t NRepeat,
76  typename ABlockTransferThreadClusterLengths_K0_M_K1,
77  typename ABlockTransferThreadClusterArrangeOrder,
78  typename ABlockTransferSrcAccessOrder,
79  index_t ABlockTransferSrcVectorDim,
80  index_t ABlockTransferSrcScalarPerVector,
81  index_t ABlockTransferDstScalarPerVector_K1,
82  bool AThreadTransferSrcResetCoordinateAfterRun,
83  bool ABlockLdsExtraM,
84  typename BBlockTransferThreadClusterLengths_K0_N_K1,
85  typename BBlockTransferThreadClusterArrangeOrder,
86  typename BBlockTransferSrcAccessOrder,
87  index_t BBlockTransferSrcVectorDim,
88  index_t BBlockTransferSrcScalarPerVector,
89  index_t BBlockTransferDstScalarPerVector_K1,
90  bool BThreadTransferSrcResetCoordinateAfterRun,
91  bool BBlockLdsExtraN,
92  index_t CShuffleMRepeatPerShuffle,
93  index_t CShuffleNRepeatPerShuffle,
94  index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
95  typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
98  typename ComputeTypeA = FloatC,
99  typename ComputeTypeB = ComputeTypeA,
100  typename LDSTypeA = ComputeTypeA,
101  typename LDSTypeB = ComputeTypeB>
103 {
104  static constexpr auto I0 = Number<0>{};
105  static constexpr auto I1 = Number<1>{};
106  static constexpr auto I2 = Number<2>{};
107  static constexpr auto I3 = Number<3>{};
108  static constexpr auto I4 = Number<4>{};
109  static constexpr auto I5 = Number<5>{};
110  static constexpr auto I6 = Number<6>{};
111  static constexpr auto I7 = Number<7>{};
112 
113  // K1 should be Number<...>
114  static constexpr auto K1 = Number<K1Value>{};
115  static constexpr auto M01 = 1;
116  static constexpr auto N01 = 1;
117 
118  static constexpr auto gemm_padder =
120  MPerBlock, NPerBlock, K1* K0PerBlock};
121 
123 
125  decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
126 
128  {
129  const FloatA* p_a_grid;
130  const FloatB* p_b_grid;
131  FloatC* p_c_grid;
143 
144  Argument(const FloatA* p_a_grid_,
145  const FloatB* p_b_grid_,
146  FloatC* p_c_grid_,
147  index_t M_,
148  index_t N_,
149  index_t K_,
150  index_t StrideA_,
151  index_t StrideB_,
152  index_t StrideC_,
153  index_t MPadded_,
154  index_t NPadded_,
155  index_t KPadded_,
156  index_t K0Padded_,
157  index_t k_batch_)
158  : p_a_grid(p_a_grid_),
159  p_b_grid(p_b_grid_),
160  p_c_grid(p_c_grid_),
161  M(M_),
162  N(N_),
163  K(K_),
164  StrideA(StrideA_),
165  StrideB(StrideB_),
166  StrideC(StrideC_),
167  MPadded(MPadded_),
168  NPadded(NPadded_),
169  KPadded(KPadded_),
170  K0Padded(K0Padded_),
171  k_batch(k_batch_)
172  {
173  }
174 
175  void Print() const
176  {
177  std::cout << "arg {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
178  << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
179  << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
180  << "KP:" << KPadded << ", " << "K0Padded:" << K0Padded << ", "
181  << "KB:" << k_batch << "}" << std::endl;
182  }
183  };
184 
185  __host__ __device__ static auto CalculateGridSize(const Argument& karg)
186  {
187  return std::make_tuple(math::integer_divide_ceil(karg.N, NPerBlock),
188  math::integer_divide_ceil(karg.M, MPerBlock),
189  karg.k_batch);
190  }
191 
192  // prefer this to be called on host
193  __host__ __device__ static auto CalculateMPadded(index_t M)
194  {
195  return math::integer_least_multiple(M, MPerBlock);
196  }
197 
198  __host__ __device__ static auto CalculateNPadded(index_t N)
199  {
200  return math::integer_least_multiple(N, NPerBlock);
201  }
202 
203  __host__ __device__ static auto CalculateK0Padded(index_t K, index_t K_Batch = 1)
204  {
205  // k_batch * k0 * k0_per_block * k1
206  auto K_t = K_Batch * K0PerBlock * K1;
207  return (K + K_t - 1) / K_t * K0PerBlock;
208  }
209 
210  __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
211  {
212  auto K0Padded = CalculateK0Padded(K, K_Batch);
213  return K_Batch * K0Padded * K1;
214  }
215 
216  __host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M,
217  index_t MPad,
218  index_t K,
219  index_t StrideA,
220  index_t KBatch,
221  index_t K0Padded,
222  index_t KPad)
223  {
224  const auto a_grid_desc_m_k = [&]() {
226  {
227  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
228  }
230  {
231  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
232  }
233  }();
234 
239  {
240 
241  const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
242  a_grid_desc_m_k,
246 
247  // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
249  a_grid_desc_m_kpad,
250  make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
251  make_right_pad_transform(M, MPad - M)),
254  }
255  else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
257  {
258  // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
260  a_grid_desc_m_k,
261  make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
262  make_right_pad_transform(M, MPad - M)),
265  }
266  else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding)
267  {
268  const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
269  a_grid_desc_m_k,
273 
275  a_grid_desc_m_kpad,
276  make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
280  }
281  else
282  {
284  a_grid_desc_m_k,
285  make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
289  }
290  }
291 
292  __host__ __device__ static auto MakeBGridDescriptor_KBatch_K0_N_K1(index_t K,
293  index_t NPad,
294  index_t N,
295  index_t StrideB,
296  index_t KBatch,
297  index_t K0Padded,
298  index_t KPad)
299  {
300  const auto b_grid_desc_k_n = [&]() {
302  {
303  return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
304  }
306  {
307  return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
308  }
309  }();
310 
315  {
316 
317  const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
318  b_grid_desc_k_n,
322 
323  // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
325  b_grid_desc_kpad_n,
326  make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
327  make_right_pad_transform(N, NPad - N)),
330  }
331  else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
333  {
334  // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
336  b_grid_desc_k_n,
337  make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
338  make_right_pad_transform(N, NPad - N)),
341  }
342  else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding)
343  {
344  const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
345  b_grid_desc_k_n,
349 
351  b_grid_desc_kpad_n,
352  make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
356  }
357  else
358  {
360  b_grid_desc_k_n,
361  make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
365  }
366  }
367 
368  __host__ __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
369  {
370  const auto c_grid_desc_m_n = [&]() {
372  {
373  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
374  }
376  {
377  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
378  }
379  }();
380 
381  return gemm_padder.PadCDescriptor_M_N(c_grid_desc_m_n);
382  }
383 
384  __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
385  {
386  constexpr auto max_lds_align = K1;
387 
388  // A matrix in LDS memory, dst of blockwise copy
389  constexpr auto a_k0_m_k1_block_desc = [&]() {
390  if constexpr(ABlockLdsExtraM)
391  {
395  }
396  else
397  {
399  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
400  }
401  }();
402 
403  // B matrix in LDS memory, dst of blockwise copy
404  constexpr auto b_k0_n_k1_block_desc = [&]() {
405  if constexpr(BBlockLdsExtraN)
406  {
410  }
411  else
412  {
414  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
415  }
416  }();
417 
418  // LDS allocation for A and B: be careful of alignment
419  constexpr auto a_block_space_size =
420  math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
421 
422  constexpr auto b_block_space_size =
423  math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
424 
425  constexpr auto c_block_size =
427 
428  return math::max(a_block_space_size * sizeof(LDSTypeA) +
429  b_block_space_size * sizeof(LDSTypeB),
430  c_block_size * sizeof(FloatC));
431  }
432 
433  __host__ __device__ static constexpr bool CheckValidity(const Argument& karg)
434  {
435  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
439  {
440  if(!(karg.M % MPerBlock == 0))
441  {
442  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
443  {
444  std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
445  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
446  << std::endl;
447  }
448  return false;
449  }
450  }
451 
452  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
456  {
457  if(!(karg.N % NPerBlock == 0))
458  {
459  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
460  {
461  std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
462  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
463  << std::endl;
464  }
465  return false;
466  }
467  }
468 
469  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
473  {
474 
475  auto K_t = karg.k_batch * K0PerBlock * K1;
476  if(!(karg.K % K_t == 0))
477  {
478  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
479  {
480  std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
481  << karg.K << " " << __FILE__ << ":" << __LINE__
482  << ", in function: " << __func__ << std::endl;
483  }
484  return false;
485  }
486  }
487 
489  {
490  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
491  {
492  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
493  {
494  std::cout << "Arg K (" << karg.K
495  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
496  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
497  << __LINE__ << ", in function: " << __func__ << std::endl;
498  }
499  return false;
500  }
501  }
502  else
503  {
504  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
505  {
506  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
507  {
508  std::cout << "Arg M (" << karg.M
509  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
510  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
511  << __LINE__ << ", in function: " << __func__ << std::endl;
512  }
513  return false;
514  }
515  }
516 
518  {
519  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
520  {
521  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
522  {
523  std::cout << "Arg N (" << karg.N
524  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
525  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
526  << __LINE__ << ", in function: " << __func__ << std::endl;
527  }
528  return false;
529  }
530  }
531  else
532  {
533  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
534  {
535  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
536  {
537  std::cout << "Arg K (" << karg.K
538  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
539  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
540  << __LINE__ << ", in function: " << __func__ << std::endl;
541  }
542  return false;
543  }
544  }
545 
547  {
548  if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
549  {
550  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
551  {
552  std::cout << "Arg N (" << karg.N
553  << ") value is not a multiple of "
554  "CBlockTransferScalarPerVector_NWaveNPerXDL ("
555  << CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__
556  << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
557  }
558  return false;
559  }
560  }
561  else
562  {
563  if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
564  {
565  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
566  {
567  std::cout << "Arg M (" << karg.M
568  << ") value is not a multiple of "
569  "CBlockTransferScalarPerVector_NWaveNPerXDL ("
570  << CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__
571  << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
572  }
573  return false;
574  }
575  }
576 
577  const auto num_k_loop = karg.K0Padded / K0PerBlock;
578  if(!GridwiseGemmPipe::IsSupported(num_k_loop))
579  {
580  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
581  {
582  std::cout << "The number of k loops (" << num_k_loop
583  << ") value is not supported by GridwiseGemm Pipeline."
584  << " K0Padded: " << karg.K0Padded << ", K0PerBlock: " << K0PerBlock << " "
585  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
586  << std::endl;
587  }
588  return false;
589  }
590 
591  return true;
592  }
593 
594  __host__ __device__ static auto GetKPad(index_t K, index_t KBatch)
595  {
596  const index_t K0Padded =
597  math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
598  const index_t KPad = KBatch * K0Padded * K1;
599  return KPad;
600  }
601 
602  __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0Padded)
603  {
604  const index_t num_loop = K0Padded / K0PerBlock;
605  return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
606  }
607 
608  template <typename CGridDesc>
609  __host__ __device__ static constexpr auto
610  MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_m_n_grid_desc)
611  {
612  const auto M = c_m_n_grid_desc.GetLength(I0);
613  const auto N = c_m_n_grid_desc.GetLength(I1);
614 
615  const auto MBlock = M / MPerBlock;
616  const auto NBlock = N / NPerBlock;
617 
619  c_m_n_grid_desc,
624  }
625 
626  // return block_id to C matrix tile idx (m0, n0) mapping
627  template <typename CGridDesc>
628  __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
629  const CGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch)
630  {
632  c_m_n_grid_desc, 8, KBatch);
633  }
634 
635  __host__ __device__ static constexpr auto
637  {
638  constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL);
639  constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL);
640 
642  make_tuple(I1,
644  I1,
646  }
647 
648  // return block_id to C matrix tile idx (m0, n0, k_split) mapping
649  __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap()
650  {
652  }
653 
656 
657  template <bool HasMainKBlockLoop,
658  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
659  typename Block2CTileMap>
660  __device__ static void Run(const Argument& karg,
661  void* __restrict__ p_shared_block,
662  const Block2CTileMap& block_2_ctile_map,
663  const AElementwiseOperation a_element_op = AElementwiseOperation{},
664  const BElementwiseOperation b_element_op = BElementwiseOperation{},
665  const CElementwiseOperation c_element_op = CElementwiseOperation{})
666  {
667  const FloatA* p_a_grid = karg.p_a_grid;
668  const FloatB* p_b_grid = karg.p_b_grid;
669  FloatC* p_c_grid = karg.p_c_grid;
670  const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1(
671  karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0Padded, karg.KPadded);
672  const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1(
673  karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0Padded, karg.KPadded);
674  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
675 
676  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
678 
679  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
680  p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize());
681  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
682  p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize());
683  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
684  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
685 
686  // divide block work by [KBatch, M, N]
687  const auto block_work_idx =
688  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
689 
690  if(!block_2_ctile_map.ValidCTileIndex(
691  block_work_idx,
692  make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
693  c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
694  {
695  return;
696  }
697 
698  const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
699  const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I2]);
700  const index_t k_batch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
701 
702  // HACK: this force m/n_block_data_idx_on_grid into SGPR
703  const index_t m_block_data_idx_on_grid =
704  __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
705 
706  const index_t n_block_data_idx_on_grid =
707  __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
708 
709  // lds max alignment
710  constexpr auto max_lds_align = K1;
711 
712  // A matrix in LDS memory, dst of blockwise copy
713  constexpr auto a_k0_m_k1_block_desc = [&]() {
714  if constexpr(ABlockLdsExtraM)
715  {
717  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
718  make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
719  }
720  else
721  {
723  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
724  }
725  }();
726 
727  constexpr auto a_b_k0_m_k1_block_desc = [&]() {
728  if constexpr(ABlockLdsExtraM)
729  {
731  make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
732  make_tuple(Number<K0PerBlock>{} * Number<MPerBlock + 1>{} * K1,
733  Number<MPerBlock + 1>{} * K1,
734  K1,
735  I1));
736  }
737  else
738  {
740  make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
741  max_lds_align);
742  }
743  }();
744  // B matrix in LDS memory, dst of blockwise copy
745  constexpr auto b_k0_n_k1_block_desc = [&]() {
746  if constexpr(BBlockLdsExtraN)
747  {
749  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
750  make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
751  }
752  else
753  {
755  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
756  }
757  }();
758 
759  constexpr auto b_b_k0_n_k1_block_desc = [&]() {
760  if constexpr(BBlockLdsExtraN)
761  {
763  make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
764  make_tuple(Number<K0PerBlock>{} * Number<NPerBlock + 1>{} * K1,
765  Number<NPerBlock + 1>{} * K1,
766  K1,
767  I1));
768  }
769  else
770  {
772  make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
773  max_lds_align);
774  }
775  }();
776  // A matrix blockwise copy
777  auto a_blockwise_copy =
778  ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
779  AElementwiseOperation,
782  Sequence<1, K0PerBlock, MPerBlock, K1>,
783  ABlockTransferThreadClusterLengths_K0_M_K1,
784  ABlockTransferThreadClusterArrangeOrder,
785  FloatA,
786  LDSTypeA,
787  decltype(a_b_k0_m_k1_grid_desc),
788  decltype(a_b_k0_m_k1_block_desc),
789  ABlockTransferSrcAccessOrder,
790  Sequence<0, 2, 1, 3>,
791  ABlockTransferSrcVectorDim,
792  3,
793  ABlockTransferSrcScalarPerVector,
794  ABlockTransferDstScalarPerVector_K1,
795  1,
796  1,
797  AThreadTransferSrcResetCoordinateAfterRun,
798  true>(
799  a_b_k0_m_k1_grid_desc,
800  make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0),
801  a_element_op,
802  a_b_k0_m_k1_block_desc,
803  make_multi_index(0, 0, 0, 0),
805 
806  // B matrix blockwise copy
807  auto b_blockwise_copy =
808  ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
809  BElementwiseOperation,
812  Sequence<1, K0PerBlock, NPerBlock, K1>,
813  BBlockTransferThreadClusterLengths_K0_N_K1,
814  BBlockTransferThreadClusterArrangeOrder,
815  FloatB,
816  LDSTypeB,
817  decltype(b_b_k0_n_k1_grid_desc),
818  decltype(b_b_k0_n_k1_block_desc),
819  BBlockTransferSrcAccessOrder,
820  Sequence<0, 2, 1, 3>,
821  BBlockTransferSrcVectorDim,
822  3,
823  BBlockTransferSrcScalarPerVector,
824  BBlockTransferDstScalarPerVector_K1,
825  1,
826  1,
827  BThreadTransferSrcResetCoordinateAfterRun,
828  true>(
829  b_b_k0_n_k1_grid_desc,
830  make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0),
831  b_element_op,
832  b_b_k0_n_k1_block_desc,
833  make_multi_index(0, 0, 0, 0),
835 
836  // GEMM definition
837  // c_mtx += transpose(a_mtx) * b_mtx
838  // a_mtx[K0PerBlock, MPerBlock] is in LDS
839  // b_mtx[K0PerBlock, NPerBlock] is in LDS
840  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
841  // register
842  // sanity check
843 
845  BlockSize,
846  LDSTypeA,
847  LDSTypeB,
848  FloatAcc,
849  decltype(a_k0_m_k1_block_desc),
850  decltype(b_k0_n_k1_block_desc),
851  MPerXDL,
852  NPerXDL,
853  MRepeat,
854  NRepeat,
855  K1,
856  LoopSched,
857  ComputeTypeA,
858  ComputeTypeB>();
859 
860  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
861 
862  // LDS allocation for A and B: be careful of alignment
863  constexpr auto a_block_space_size =
864  math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
865 
866  auto p_a_block = reinterpret_cast<LDSTypeA*>(p_shared_block);
867  auto p_b_block = reinterpret_cast<LDSTypeB*>(p_a_block + a_block_space_size);
868 
869  constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
870  constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
871 
872  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
873  p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
874  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
875  p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
876 
877  // gridwise GEMM pipeline
878  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
879  (a_b_k0_m_k1_grid_desc.GetLength(I1) * a_b_k0_m_k1_grid_desc.GetLength(I3)) /
880  (K0PerBlock * K1));
881 
882  const auto gridwise_gemm_pipeline = GridwiseGemmPipe{};
883 
884  gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_b_k0_m_k1_grid_desc,
885  a_b_k0_m_k1_block_desc,
886  a_blockwise_copy,
887  a_grid_buf,
888  a_block_buf,
889  a_block_slice_copy_step,
890  b_b_k0_n_k1_grid_desc,
891  b_b_k0_n_k1_block_desc,
892  b_blockwise_copy,
893  b_grid_buf,
894  b_block_buf,
895  b_block_slice_copy_step,
896  blockwise_gemm,
897  c_thread_buf,
898  num_k_block_main_loop);
899 
900  // output: register to global memory
901  {
902  constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL);
903  constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL);
904 
905  constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
906  blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
907 
908  constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
909  blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
910 
911  constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
912  constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
913  constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
914  constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
915  constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
916  constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
917  constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
918  constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
919 
920  constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock =
922 
923  auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
924  static_cast<FloatC*>(p_shared_block),
925  c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
926 
927  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
928  c_block_desc_mblock_mperblock_nblock_nperblock,
929  make_tuple(
930  make_freeze_transform(I0), // freeze mblock
931  make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle,
932  M1,
933  M2,
934  M3,
935  M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL
936  make_freeze_transform(I0), // freeze nblock
937  make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle,
938  N1,
939  N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL
940  make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
941  make_tuple(
942  Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
943 
944  // calculate origin of thread output tensor on global memory
945  // blockwise GEMM c matrix starting index
946  const auto c_thread_mtx_on_block =
947  blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
948 
949  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
950  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
951 
952  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
954  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
955  make_tuple(Sequence<0, 1, 2, 3, 4>{}),
956  make_tuple(Sequence<0>{}));
957 
958  const auto m_thread_data_on_block_idx =
959  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
960  make_multi_index(m_thread_data_on_block));
961 
962  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
965  make_tuple(Sequence<0, 1, 2>{}),
966  make_tuple(Sequence<0>{}));
967 
968  const auto n_thread_data_on_block_idx =
969  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
970  make_multi_index(n_thread_data_on_block));
971 
972  // VGPR to LDS
973  auto c_thread_copy_vgpr_to_lds =
974  ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
975  FloatC,
976  decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
977  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
979  Sequence<CShuffleMRepeatPerShuffle,
980  CShuffleNRepeatPerShuffle,
981  I1,
982  I1,
983  M2,
984  I1,
985  M4,
986  I1>,
987  Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
988  7,
989  1,
991  1,
992  true>{
993  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
995  0,
996  m_thread_data_on_block_idx[I1],
997  n_thread_data_on_block_idx[I1],
998  m_thread_data_on_block_idx[I2],
999  m_thread_data_on_block_idx[I3],
1000  m_thread_data_on_block_idx[I4],
1001  n_thread_data_on_block_idx[I2]),
1003 
1004  // LDS to global
1005  auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1006  ThisThreadBlock, // index_t BlockSize,
1007  CElementwiseOperation, // ElementwiseOperation,
1008  CGlobalMemoryDataOperation, // DstInMemOp,
1009  Sequence<1,
1010  CShuffleMRepeatPerShuffle * MWave * MPerXDL,
1011  1,
1012  CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths,
1013  CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1014  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1015  FloatC, // typename SrcData,
1016  FloatC, // typename DstData,
1017  decltype(c_block_desc_mblock_mperblock_nblock_nperblock),
1018  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1019  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1020  3, // index_t VectorDim,
1021  CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
1022  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1023  false> // bool ThreadTransferDstResetCoordinateAfterRun
1024  {c_block_desc_mblock_mperblock_nblock_nperblock,
1025  make_multi_index(0, 0, 0, 0),
1026  c_grid_desc_mblock_mperblock_nblock_nperblock,
1027  make_multi_index(block_m_id, 0, block_n_id, 0),
1028  c_element_op};
1029 
1030  constexpr auto mxdlperwave_forward_step =
1031  make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0);
1032  constexpr auto nxdlperwave_forward_step =
1033  make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL);
1034  constexpr auto nxdlperwave_backward_step =
1035  make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL);
1036 
1037  static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) {
1038  constexpr auto mxdlperwave = mxdlperwave_iter;
1039 
1040  static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) {
1041  constexpr bool nxdlperwave_forward_sweep =
1042  (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
1043 
1044  constexpr index_t nxdlperwave_value =
1045  nxdlperwave_forward_sweep
1046  ? nxdlperwave_iter
1047  : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle);
1048 
1049  constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
1050 
1051  // make sure it's safe to do ds_write
1052  block_sync_lds();
1053 
1054  // VGPR to LDS
1055  c_thread_copy_vgpr_to_lds.Run(
1056  c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
1057  make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
1058  c_thread_buf,
1059  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1060  c_block_buf);
1061 
1062  // make sure it's safe to do ds_read
1063  block_sync_lds();
1064 
1065  // LDS to global
1066  c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock,
1067  c_block_buf,
1068  c_grid_desc_mblock_mperblock_nblock_nperblock,
1069  c_grid_buf);
1070 
1071  // move on nxdlperwave dimension
1072  if constexpr(nxdlperwave_forward_sweep &&
1073  (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle))
1074  {
1075  c_block_copy_lds_to_global.MoveDstSliceWindow(
1076  c_grid_desc_mblock_mperblock_nblock_nperblock,
1077  nxdlperwave_forward_step);
1078  }
1079  else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
1080  {
1081  c_block_copy_lds_to_global.MoveDstSliceWindow(
1082  c_grid_desc_mblock_mperblock_nblock_nperblock,
1083  nxdlperwave_backward_step);
1084  }
1085  });
1086 
1087  // move on mxdlperwave dimension
1088  if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle)
1089  {
1090  c_block_copy_lds_to_global.MoveDstSliceWindow(
1091  c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step);
1092  }
1093  });
1094  }
1095  }
1096 
1097  static std::string GetTypeString()
1098  {
1099  auto str = std::stringstream();
1100 
1101  // clang-format off
1102  str << "GemmXdlSplitKCShuffle_"
1103  << getGemmSpecializationString(GemmSpec) << "_"
1104  << std::string(ALayout::name)[0]
1105  << std::string(BLayout::name)[0]
1106  << std::string(CLayout::name)[0]
1107  << "_"
1108  << "B" << BlockSize << "_"
1109  << "Vec" << ABlockTransferSrcScalarPerVector << "x"
1110  << BBlockTransferSrcScalarPerVector << "x"
1111  << CBlockTransferScalarPerVector_NWaveNPerXDL << "_"
1112  << MPerBlock << "x"
1113  << NPerBlock << "x"
1114  << K0PerBlock << "x"
1115  << K1 ;
1116  // clang-format on
1117 
1118  return str.str();
1119  }
1120 };
1121 
1122 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:30
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition: blockwise_gemm_xdlops.hpp:605
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__global__ void kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg, const Block2CTileMap &b2c_map, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:33
InMemoryDataOperationEnum
Definition: ck.hpp:276
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:298
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:18
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
unsigned char uint8_t
Definition: stdint.h:124
Simple tile mapping which creates 3D grid of block of threads.
Definition: block_to_ctile_map.hpp:976
Definition: block_to_ctile_map.hpp:540
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:128
index_t StrideC
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:137
index_t K0Padded
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:141
index_t MPadded
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:138
index_t k_batch
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:142
index_t N
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:133
const FloatA * p_a_grid
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:129
index_t K
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:134
index_t StrideA
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:135
index_t M
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:132
const FloatB * p_b_grid
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:130
index_t StrideB
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:136
FloatC * p_c_grid
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:131
index_t NPadded
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:139
index_t KPadded
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:140
Argument(const FloatA *p_a_grid_, const FloatB *p_b_grid_, FloatC *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t MPadded_, index_t NPadded_, index_t KPadded_, index_t K0Padded_, index_t k_batch_)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:144
void Print() const
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:175
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:103
static constexpr auto I2
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:106
static constexpr auto gemm_padder
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:118
__host__ static __device__ auto CalculateGridSize(const Argument &karg)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:185
__host__ static __device__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:193
static constexpr auto I5
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:109
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage, LoopSched >())> GridwiseGemmPipe
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:125
__host__ static __device__ auto GetKPad(index_t K, index_t KBatch)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:594
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap())> DefaultBlock2CTileMap
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:655
static constexpr auto K1
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:114
static constexpr auto I1
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:105
static __device__ void Run(const Argument &karg, void *__restrict__ p_shared_block, const Block2CTileMap &block_2_ctile_map, const AElementwiseOperation a_element_op=AElementwiseOperation{}, const BElementwiseOperation b_element_op=BElementwiseOperation{}, const CElementwiseOperation c_element_op=CElementwiseOperation{})
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:660
__host__ static constexpr __device__ auto MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_m_n_grid_desc)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:610
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:384
__host__ static constexpr __device__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:433
static constexpr auto I3
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:107
static constexpr auto I7
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:111
__host__ static constexpr __device__ auto GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:636
static constexpr auto I6
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:110
__host__ static __device__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:198
static constexpr auto N01
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:116
__host__ static __device__ auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, index_t MPad, index_t K, index_t StrideA, index_t KBatch, index_t K0Padded, index_t KPad)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:216
__host__ static __device__ auto CalculateK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:203
__host__ static constexpr __device__ bool CalculateHasMainK0BlockLoop(index_t K0Padded)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:602
__host__ static constexpr __device__ auto MakeDefaultBlock2CTileMap()
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:649
__host__ static __device__ auto MakeBGridDescriptor_KBatch_K0_N_K1(index_t K, index_t NPad, index_t N, index_t StrideB, index_t KBatch, index_t K0Padded, index_t KPad)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:292
__host__ static constexpr __device__ auto MakeCBlockClusterAdaptor(const CGridDesc &c_m_n_grid_desc, index_t, index_t, index_t KBatch)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:628
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:368
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:122
static constexpr auto M01
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:115
remove_cvref_t< decltype(MakeCGridDescriptor_M_N(1, 1, 1))> CGridDesc_M_N
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:654
__host__ static __device__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:210
static std::string GetTypeString()
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:1097
static constexpr auto I4
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:108
static constexpr auto I0
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:104
Definition: sequence.hpp:43
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: device_base.hpp:51
Definition: matrix_padder.hpp:134
Definition: unary_element_wise_operation.hpp:334
#define CK_ENV(name)
Definition: env.hpp:129