/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp Source File
gridwise_gemm_xdlops_v2r3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
17 
18 namespace ck {
19 
20 template <typename GridwiseGemm,
21  typename FloatAB,
22  typename FloatC,
23  typename AGridDesc_K0_M_K1,
24  typename BGridDesc_K0_N_K1,
25  typename CGridDesc_M_N,
26  bool HasMainKBlockLoop>
27 __global__ void
28 #if CK_USE_LAUNCH_BOUNDS
30 #endif
31 #if CK_USE_WAVES_PER_EU
32  __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU)))
33 #endif
34  kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
35  const FloatAB* __restrict__ p_b_grid,
36  FloatC* __restrict__ p_c_grid,
37  const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
38  const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
39  const CGridDesc_M_N c_grid_desc_m_n)
40 {
41 #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
42  defined(__gfx12__)
43  if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
44  {
45  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
46 
47  GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
48  p_b_grid,
49  p_c_grid,
50  p_shared,
51  a_grid_desc_k0_m_k1,
52  b_grid_desc_k0_n_k1,
53  c_grid_desc_m_n);
54  }
55 #else
56  ignore = p_a_grid;
57  ignore = p_b_grid;
58  ignore = p_c_grid;
59  ignore = a_grid_desc_k0_m_k1;
60  ignore = b_grid_desc_k0_n_k1;
61  ignore = c_grid_desc_m_n;
62 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
63 }
64 
65 template <typename GridwiseGemm, bool HasMainKBlockLoop>
66 __global__ void
67 #if CK_USE_LAUNCH_BOUNDS
69 #endif
70 #if CK_USE_WAVES_PER_EU
71  __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU)))
72 #endif
73  kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg)
74 {
75 #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
76  defined(__gfx12__)
77  if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
78  {
79  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
80 
81  const auto a_grid_desc_k0_m_k1 =
82  amd_wave_read_first_lane(GridwiseGemm::MakeAGridDescriptor_K0_M_K1(
83  karg.M, karg.MPadded, karg.K, karg.K0, karg.StrideA));
84  const auto b_grid_desc_k0_n_k1 =
85  amd_wave_read_first_lane(GridwiseGemm::MakeBGridDescriptor_K0_N_K1(
86  karg.K, karg.N, karg.NPadded, karg.K0, karg.StrideB));
87  const auto c_grid_desc_m_n = amd_wave_read_first_lane(GridwiseGemm::MakeCGridDescriptor_M_N(
88  karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC));
89 
90  GridwiseGemm::template Run<HasMainKBlockLoop>(karg.p_a_grid,
91  karg.p_b_grid,
92  karg.p_c_grid,
93  p_shared,
94  a_grid_desc_k0_m_k1,
95  b_grid_desc_k0_n_k1,
96  c_grid_desc_m_n);
97  }
98 #else
99  ignore = karg;
100 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
101 }
102 
103 template <index_t BlockSize,
104  typename FloatAB,
105  typename FloatAcc,
106  typename FloatC,
107  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
108  typename AElementwiseOperation,
109  typename BElementwiseOperation,
110  typename CElementwiseOperation,
111  index_t MPerBlock,
112  index_t NPerBlock,
113  index_t K0PerBlock,
114  index_t MPerXdl,
115  index_t NPerXdl,
116  index_t K1Value,
117  index_t MXdlPerWave,
118  index_t NXdlPerWave,
119  typename ABlockTransferThreadClusterLengths_K0_M_K1,
120  typename ABlockTransferThreadClusterArrangeOrder,
121  typename ABlockTransferSrcAccessOrder,
122  index_t ABlockTransferSrcVectorDim,
123  index_t ABlockTransferSrcScalarPerVector,
124  index_t ABlockTransferDstScalarPerVector_K1,
125  bool AThreadTransferSrcResetCoordinateAfterRun,
126  bool ABlockLdsExtraM,
127  typename BBlockTransferThreadClusterLengths_K0_N_K1,
128  typename BBlockTransferThreadClusterArrangeOrder,
129  typename BBlockTransferSrcAccessOrder,
130  index_t BBlockTransferSrcVectorDim,
131  index_t BBlockTransferSrcScalarPerVector,
132  index_t BBlockTransferDstScalarPerVector_K1,
133  bool BThreadTransferSrcResetCoordinateAfterRun,
134  bool BBlockLdsExtraN,
135  typename CThreadTransferSrcDstAccessOrder,
136  index_t CThreadTransferSrcDstVectorDim,
137  index_t CThreadTransferDstScalarPerVector,
138  index_t NumGemmKPrefetchStage = 1,
140  PipelineVersion PipelineVer = PipelineVersion::v1>
142 {
143  static constexpr auto I0 = Number<0>{};
144  static constexpr auto I1 = Number<1>{};
145  static constexpr auto I2 = Number<2>{};
146  static constexpr auto I3 = Number<3>{};
147  static constexpr auto I4 = Number<4>{};
148  static constexpr auto I5 = Number<5>{};
149  static constexpr auto I6 = Number<6>{};
150  static constexpr auto I7 = Number<7>{};
151 
152  // K1 should be Number<...>
153  static constexpr bool is_single_rate_mfma =
155  (is_same<FloatAB, int8_t>::value && K1Value <= 8) ||
157  ? true
158  : false;
159  static constexpr auto is_scale_mfma = false;
160  static constexpr auto K1 = Number<math::max(
161  K1Value,
163  selected_mfma.k_per_blk)>{};
164 
166 
168 
169  __host__ static auto CalculateGridSize(index_t M, index_t N)
170  {
171  return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
172  }
173 
174  template <typename CGridDesc_M_N>
175  __host__ static auto CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
176  {
177  return std::make_tuple(Block2CTileMap::CalculateGridSize(c_grid_desc_m_n), 1, 1);
178  }
179 
180  template <typename>
181  __host__ static auto CalculateGridSize(index_t M, index_t N)
182  {
183  return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
184  }
185 
186  __host__ static auto CalculateMPadded(index_t M)
187  {
188  return math::integer_divide_ceil(M, MPerBlock) * MPerBlock;
189  }
190 
191  __host__ static auto CalculateNPadded(index_t N)
192  {
193  return math::integer_divide_ceil(N, NPerBlock) * NPerBlock;
194  }
195 
196  __host__ static auto CalculateK0(index_t K) { return math::integer_divide_ceil(K, K1Value); }
197 
198  // Argument
199  struct Problem
200  {
201  __host__ Problem(index_t M_,
202  index_t N_,
203  index_t K_,
204  index_t StrideA_,
205  index_t StrideB_,
206  index_t StrideC_)
207  : M{M_},
208  N{N_},
209  K{K_},
210  StrideA{StrideA_},
211  StrideB{StrideB_},
212  StrideC{StrideC_},
215  K0{CalculateK0(K_)}
216  {
217  }
218 
219  __host__ void Print() const
220  {
221  std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
222  << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
223  << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " << "K0:" << K0
224  << "}" << std::endl;
225  }
226 
236  };
237 
238  // Argument
240  {
241  __host__ Argument(const ElementDataTypeAB* p_a_grid_,
242  const ElementDataTypeAB* p_b_grid_,
243  FloatC* p_c_grid_,
244  index_t M_,
245  index_t N_,
246  index_t K_,
247  index_t StrideA_,
248  index_t StrideB_,
249  index_t StrideC_)
250  : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
251  p_a_grid{p_a_grid_},
252  p_b_grid{p_b_grid_},
253  p_c_grid{p_c_grid_}
254  {
255  }
256 
259  FloatC* p_c_grid;
260  };
261 
263  decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
264 
265  // denorm test fix, required to work around fp16 mfma issue
266  // we convert fp16->fp32->bf16 and execute bf16 mfma instruction
267  // when mfma if fixed, remove this section and update
268  // FloatABAdjusted -> FloatAB throughout this file
269 #if CK_GFX90A_DENORM_WORKAROUND
271 #else
272  using FloatABAdjusted = FloatAB;
273 #endif
274 
275  __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
276  {
277  constexpr auto max_lds_align = K1;
278 
279  // A matrix in LDS memory, dst of blockwise copy
280  constexpr auto a_block_desc_k0_m_k1 = [&]() {
281  if constexpr(ABlockLdsExtraM)
282  {
286  }
287  else
288  {
290  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
291  }
292  }();
293 
294  return a_block_desc_k0_m_k1;
295  }
296 
297  __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
298  {
299  constexpr auto max_lds_align = K1;
300 
301  // B matrix in LDS memory, dst of blockwise copy
302  constexpr auto b_block_desc_k0_n_k1 = [&]() {
303  if constexpr(BBlockLdsExtraN)
304  {
308  }
309  else
310  {
312  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
313  }
314  }();
315 
316  return b_block_desc_k0_n_k1;
317  }
318 
319  __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
320  {
321  // LDS allocation for A and B: be careful of alignment
322  constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
323 
324  constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
325 
326  constexpr auto max_lds_align = K1;
327 
328  constexpr auto a_block_space_size_aligned =
329  math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
330 
331  constexpr auto b_block_space_size_aligned =
332  math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
333 
334  return (a_block_space_size_aligned + b_block_space_size_aligned) *
335  sizeof(ElementDataTypeAB);
336  }
337 
338  template <
339  InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set>
340  __device__ static bool constexpr IsValidCompilationParameter()
341  {
342  return ck::tensor_operation::device::IsValidGemmCompilationParameter<
343  BlockSize,
344  MPerBlock,
345  NPerBlock,
346  MPerXdl,
347  NPerXdl,
348  MXdlPerWave,
349  NXdlPerWave,
350  FloatC,
351  CGlobalMemoryDataOperation>();
352  }
353 
354  template <typename AGridDesc_K0_M_K1, typename BGridDesc_K0_N_K1, typename CGridDesc_M_N>
355  __host__ __device__ static constexpr bool
356  CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
357  const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
358  const CGridDesc_M_N& c_grid_desc_m_n)
359  {
360  static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
361  "wrong! K1 need to be known at compile-time");
362 
363  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
364  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
365  "Invalid tuning param!");
366 
367  const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
368  const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
369  const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
370 
371  if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
372  K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
373  K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
374  return false;
375 
376  if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
377  return false;
378 
379  // check gridwise gemm pipeline
380  const auto num_k_loop = K0 / K0PerBlock;
381 
382  if(!GridwiseGemmPipe::IsSupported(num_k_loop))
383  {
384  return false;
385  }
386 
387  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
388  return true;
389  }
390 
391  __host__ static constexpr bool CheckValidity(const Problem& problem)
392  {
393  static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
394  "wrong! K1 need to be known at compile-time");
395 
396  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
397  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
398  "Invalid tuning param!");
399 
400  // check gridwise gemm pipeline
401  const auto num_k_loop = math::integer_divide_ceil(problem.K0, K0PerBlock);
402  if(!GridwiseGemmPipe::IsSupported(num_k_loop))
403  {
404  return false;
405  }
406 
407  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
408  return true;
409  }
410 
411  __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
412  {
413  const index_t num_loop = math::integer_divide_ceil(K, K0PerBlock * K1);
414 
415  return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
416  }
417 
418  template <typename CGridDesc>
419  __host__ __device__ static constexpr auto
420  MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc& c_grid_desc_m_n)
421  {
422  constexpr auto max_lds_align = K1;
423 
424  // A matrix in LDS memory, dst of blockwise copy
425  constexpr auto a_block_desc_k0_m_k1 = [&]() {
426  if constexpr(ABlockLdsExtraM)
427  {
431  }
432  else
433  {
435  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
436  }
437  }();
438 
439  // B matrix in LDS memory, dst of blockwise copy
440  constexpr auto b_block_desc_k0_n_k1 = [&]() {
441  if constexpr(BBlockLdsExtraN)
442  {
446  }
447  else
448  {
450  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
451  }
452  }();
453 
454  using BlockwiseGemm =
458  FloatAcc,
459  decltype(a_block_desc_k0_m_k1),
460  decltype(b_block_desc_k0_n_k1),
461  MPerXdl,
462  NPerXdl,
463  MXdlPerWave,
464  NXdlPerWave,
465  K1,
468 
469  return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
470  }
471 
472  // return block_id to C matrix tile idx (m0, n0) mapping
474 
475  template <bool HasMainKBlockLoop,
476  typename AGridDesc_K0_M_K1,
477  typename BGridDesc_K0_N_K1,
478  typename CGridDesc_M_N>
479  __device__ static void Run(const ElementDataTypeAB* p_a_grid,
480  const ElementDataTypeAB* p_b_grid,
481  FloatC* p_c_grid,
482  void* __restrict__ p_shared,
483  const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
484  const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
485  const CGridDesc_M_N& c_grid_desc_m_n)
486  {
487  const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
489 
490  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
491  p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
492  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
493  p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
494  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
495  p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
496 
497  const AElementwiseOperation a_element_op{};
498  const BElementwiseOperation b_element_op{};
499  const CElementwiseOperation c_element_op{};
500 
501  const auto block_2_ctile_map =
502  Block2CTileMap{c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)};
503 
504  // divide block work by [M, N]
505  const auto block_work_idx =
506  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
507 
508  if(!block_2_ctile_map.ValidCTileIndex(
509  block_work_idx,
510  make_tuple(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0),
511  c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1))))
512  {
513  return;
514  }
515 
516  // HACK: this force m/n_block_data_idx_on_grid into SGPR
517  const index_t m_block_data_idx_on_grid =
518  __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
519 
520  const index_t n_block_data_idx_on_grid =
521  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
522 
523  // lds max alignment
524  constexpr auto max_lds_align = K1;
525 
526  // A matrix in LDS memory, dst of blockwise copy
527  constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
528 
529  // B matrix in LDS memory, dst of blockwise copy
530  constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
531 
532  // A matrix blockwise copy
533  auto a_blockwise_copy =
535  AElementwiseOperation,
539  ABlockTransferThreadClusterLengths_K0_M_K1,
540  ABlockTransferThreadClusterArrangeOrder,
543  decltype(a_grid_desc_k0_m_k1),
544  decltype(a_block_desc_k0_m_k1),
545  ABlockTransferSrcAccessOrder,
547  ABlockTransferSrcVectorDim,
548  2,
549  ABlockTransferSrcScalarPerVector,
550  ABlockTransferDstScalarPerVector_K1,
551  1,
552  1,
553  AThreadTransferSrcResetCoordinateAfterRun,
554  true,
555  NumGemmKPrefetchStage>(
556  a_grid_desc_k0_m_k1,
557  make_multi_index(0, m_block_data_idx_on_grid, 0),
558  a_element_op,
559  a_block_desc_k0_m_k1,
560  make_multi_index(0, 0, 0),
562 
563  // B matrix blockwise copy
564  auto b_blockwise_copy =
566  BElementwiseOperation,
570  BBlockTransferThreadClusterLengths_K0_N_K1,
571  BBlockTransferThreadClusterArrangeOrder,
574  decltype(b_grid_desc_k0_n_k1),
575  decltype(b_block_desc_k0_n_k1),
576  BBlockTransferSrcAccessOrder,
578  BBlockTransferSrcVectorDim,
579  2,
580  BBlockTransferSrcScalarPerVector,
581  BBlockTransferDstScalarPerVector_K1,
582  1,
583  1,
584  BThreadTransferSrcResetCoordinateAfterRun,
585  true,
586  NumGemmKPrefetchStage>(
587  b_grid_desc_k0_n_k1,
588  make_multi_index(0, n_block_data_idx_on_grid, 0),
589  b_element_op,
590  b_block_desc_k0_n_k1,
591  make_multi_index(0, 0, 0),
593 
594  // GEMM definition
595  // c_mtx += transpose(a_mtx) * b_mtx
596  // a_mtx[K0PerBlock, MPerBlock] is in LDS
597  // b_mtx[K0PerBlock, NPerBlock] is in LDS
598  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
599  // register
600  // sanity check
602  BlockSize,
605  FloatAcc,
606  decltype(a_block_desc_k0_m_k1),
607  decltype(b_block_desc_k0_n_k1),
608  MPerXdl,
609  NPerXdl,
610  MXdlPerWave,
611  NXdlPerWave,
612  K1,
613  LoopSched,
615  FloatABAdjusted>();
616 
617  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
618 
619  // LDS allocation for A and B: be careful of alignment
620  constexpr auto a_block_space_size_aligned =
621  math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
622 
623  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
624  static_cast<ElementDataTypeAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
625 
626  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
627  static_cast<ElementDataTypeAB*>(p_shared) + a_block_space_size_aligned,
628  b_block_desc_k0_n_k1.GetElementSpaceSize());
629 
630  constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
631  constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
632 
633  // gridwise GEMM pipeline
634  const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
635  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
636 
637  GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1,
638  a_block_desc_k0_m_k1,
639  a_blockwise_copy,
640  a_grid_buf,
641  a_block_buf,
642  a_block_slice_copy_step,
643  b_grid_desc_k0_n_k1,
644  b_block_desc_k0_n_k1,
645  b_blockwise_copy,
646  b_grid_buf,
647  b_block_buf,
648  b_block_slice_copy_step,
649  blockwise_gemm,
650  c_thread_buf,
651  num_k_block_main_loop);
652 
653  // output: register to global memory
654  {
655  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
656  blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
657 
658  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
659  blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
660 
661  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0);
662  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
663  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2);
664  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
665  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4);
666  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
667  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
668  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
669 
670  // calculate origin of thread output tensor on global memory
671  // blockwise GEMM c matrix starting index
672  const auto c_thread_mtx_on_block =
673  blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
674 
675  const index_t m_thread_data_on_grid =
676  m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
677 
678  const index_t n_thread_data_on_grid =
679  n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
680 
681  const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
683  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
686 
687  const auto m_thread_data_on_grid_idx =
688  m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
689  make_multi_index(m_thread_data_on_grid));
690 
691  const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
695 
696  const auto n_thread_data_on_grid_idx =
697  n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
698  make_multi_index(n_thread_data_on_grid));
699 
700  auto c_thread_copy =
702  FloatC,
703  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
704  decltype(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2),
705  CElementwiseOperation,
707  CThreadTransferSrcDstAccessOrder,
708  CThreadTransferSrcDstVectorDim,
709  CThreadTransferDstScalarPerVector,
710  CGlobalMemoryDataOperation,
711  1,
712  true>{
713  c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
714  make_multi_index(m_thread_data_on_grid_idx[I0],
715  n_thread_data_on_grid_idx[I0],
716  m_thread_data_on_grid_idx[I1],
717  n_thread_data_on_grid_idx[I1],
718  m_thread_data_on_grid_idx[I2],
719  m_thread_data_on_grid_idx[I3],
720  m_thread_data_on_grid_idx[I4],
721  n_thread_data_on_grid_idx[I2]),
722  c_element_op};
723 
724  c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
725  make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
726  c_thread_buf,
727  c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
728  c_grid_buf);
729  }
730  }
731 };
732 
733 template <index_t BlockSize,
734  typename FloatAB,
735  typename FloatAcc,
736  typename FloatC,
737  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
738  typename ALayout,
739  typename BLayout,
740  typename CLayout,
741  typename AElementwiseOperation,
742  typename BElementwiseOperation,
743  typename CElementwiseOperation,
745  index_t MPerBlock,
746  index_t NPerBlock,
747  index_t K0PerBlock,
748  index_t MPerXdl,
749  index_t NPerXdl,
750  index_t K1Value,
751  index_t MXdlPerWave,
752  index_t NXdlPerWave,
753  typename ABlockTransferThreadClusterLengths_K0_M_K1,
754  typename ABlockTransferThreadClusterArrangeOrder,
755  typename ABlockTransferSrcAccessOrder,
756  index_t ABlockTransferSrcVectorDim,
757  index_t ABlockTransferSrcScalarPerVector,
758  index_t ABlockTransferDstScalarPerVector_K1,
759  bool AThreadTransferSrcResetCoordinateAfterRun,
760  bool ABlockLdsExtraM,
761  typename BBlockTransferThreadClusterLengths_K0_N_K1,
762  typename BBlockTransferThreadClusterArrangeOrder,
763  typename BBlockTransferSrcAccessOrder,
764  index_t BBlockTransferSrcVectorDim,
765  index_t BBlockTransferSrcScalarPerVector,
766  index_t BBlockTransferDstScalarPerVector_K1,
767  bool BThreadTransferSrcResetCoordinateAfterRun,
768  bool BBlockLdsExtraN,
769  typename CThreadTransferSrcDstAccessOrder,
770  index_t CThreadTransferSrcDstVectorDim,
771  index_t CThreadTransferDstScalarPerVector,
772  index_t NumGemmKPrefetchStage = 1,
774  PipelineVersion PipelineVer = PipelineVersion::v1>
777  FloatAB,
778  FloatAcc,
779  FloatC,
780  CGlobalMemoryDataOperation,
781  AElementwiseOperation,
782  BElementwiseOperation,
783  CElementwiseOperation,
784  MPerBlock,
785  NPerBlock,
786  K0PerBlock,
787  MPerXdl,
788  NPerXdl,
789  K1Value,
790  MXdlPerWave,
791  NXdlPerWave,
792  ABlockTransferThreadClusterLengths_K0_M_K1,
793  ABlockTransferThreadClusterArrangeOrder,
794  ABlockTransferSrcAccessOrder,
795  ABlockTransferSrcVectorDim,
796  ABlockTransferSrcScalarPerVector,
797  ABlockTransferDstScalarPerVector_K1,
798  AThreadTransferSrcResetCoordinateAfterRun,
799  ABlockLdsExtraM,
800  BBlockTransferThreadClusterLengths_K0_N_K1,
801  BBlockTransferThreadClusterArrangeOrder,
802  BBlockTransferSrcAccessOrder,
803  BBlockTransferSrcVectorDim,
804  BBlockTransferSrcScalarPerVector,
805  BBlockTransferDstScalarPerVector_K1,
806  BThreadTransferSrcResetCoordinateAfterRun,
807  BBlockLdsExtraN,
808  CThreadTransferSrcDstAccessOrder,
809  CThreadTransferSrcDstVectorDim,
810  CThreadTransferDstScalarPerVector,
811  NumGemmKPrefetchStage,
812  LoopSched,
813  PipelineVer>
814 {
815  using Parent =
817  FloatAB,
818  FloatAcc,
819  FloatC,
820  CGlobalMemoryDataOperation,
821  AElementwiseOperation,
822  BElementwiseOperation,
823  CElementwiseOperation,
824  MPerBlock,
825  NPerBlock,
826  K0PerBlock,
827  MPerXdl,
828  NPerXdl,
829  K1Value,
830  MXdlPerWave,
831  NXdlPerWave,
832  ABlockTransferThreadClusterLengths_K0_M_K1,
833  ABlockTransferThreadClusterArrangeOrder,
834  ABlockTransferSrcAccessOrder,
835  ABlockTransferSrcVectorDim,
836  ABlockTransferSrcScalarPerVector,
837  ABlockTransferDstScalarPerVector_K1,
838  AThreadTransferSrcResetCoordinateAfterRun,
839  ABlockLdsExtraM,
840  BBlockTransferThreadClusterLengths_K0_N_K1,
841  BBlockTransferThreadClusterArrangeOrder,
842  BBlockTransferSrcAccessOrder,
843  BBlockTransferSrcVectorDim,
844  BBlockTransferSrcScalarPerVector,
845  BBlockTransferDstScalarPerVector_K1,
846  BThreadTransferSrcResetCoordinateAfterRun,
847  BBlockLdsExtraN,
848  CThreadTransferSrcDstAccessOrder,
849  CThreadTransferSrcDstVectorDim,
850  CThreadTransferDstScalarPerVector,
851  NumGemmKPrefetchStage,
852  LoopSched,
853  PipelineVer>;
854 
855  using typename Parent::GridwiseGemmPipe;
856  using typename Parent::Problem;
857 
858  using Parent::I1;
859 
860  using Parent::K1;
861 
862  __device__ static auto
864  {
865  const auto a_grid_desc_m_k = [&]() {
867  {
868  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
869  }
871  {
872  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
873  }
874  }();
875 
877  {
878  const auto K0Pad = math::integer_divide_ceil(K0, K0PerBlock) * K0PerBlock;
879  const auto KPad = K0Pad * K1Value;
880 
881  const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
882  a_grid_desc_m_k,
886 
888  a_grid_desc_m_kpad,
890  make_right_pad_transform(M, MPad - M)),
893  }
894  else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
895  {
897  a_grid_desc_m_k,
899  make_right_pad_transform(M, MPad - M)),
902  }
903  else
904  {
906  a_grid_desc_m_k,
911  }
912  }
913 
914  __device__ static auto
916  {
917  const auto b_grid_desc_k_n = [&]() {
919  {
920  return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
921  }
923  {
924  return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
925  }
926  }();
927 
929  {
930  const auto K0Pad = math::integer_divide_ceil(K0, K0PerBlock) * K0PerBlock;
931  const auto KPad = K0Pad * K1Value;
932 
933  const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
934  b_grid_desc_k_n,
938 
940  b_grid_desc_kpad_n,
942  make_right_pad_transform(N, NPad - N)),
945  }
946 
947  else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
948  {
950  b_grid_desc_k_n,
952  make_right_pad_transform(N, NPad - N)),
955  }
956  else
957  {
959  b_grid_desc_k_n,
964  }
965  }
966 
967  __device__ static auto
969  {
970  const auto c_grid_desc_m_n = [&]() {
972  {
973  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
974  }
976  {
977  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
978  }
979  }();
980 
983  {
984  return transform_tensor_descriptor(c_grid_desc_m_n,
986  make_right_pad_transform(N, NPad - N)),
989  }
990  else
991  {
992 
994  c_grid_desc_m_n,
998  }
999  }
1000 
1002 
1003  __host__ static constexpr bool CheckValidity(const Problem& problem)
1004  {
1005  static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
1006  "wrong! K1 need to be known at compile-time");
1007 
1008  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1009  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1010  "Invalid tuning param!");
1011 
1012  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
1016  {
1017  if(!(problem.M % MPerBlock == 0))
1018  {
1019  return false;
1020  }
1021  }
1022 
1023  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1027  {
1028  if(!(problem.N % NPerBlock == 0))
1029  {
1030  return false;
1031  }
1032  }
1033 
1034  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1038  {
1039  if(!(problem.K0 % K0PerBlock == 0))
1040  {
1041  return false;
1042  }
1043  }
1044 
1046  {
1047  if(problem.K % ABlockTransferSrcScalarPerVector != 0)
1048  {
1049  return false;
1050  }
1051  }
1052  else
1053  {
1054  if(problem.M % ABlockTransferSrcScalarPerVector != 0)
1055  {
1056  return false;
1057  }
1058  }
1059 
1061  {
1062  if(problem.N % BBlockTransferSrcScalarPerVector != 0)
1063  {
1064  return false;
1065  }
1066  }
1067  else
1068  {
1069  if(problem.K % BBlockTransferSrcScalarPerVector != 0)
1070  {
1071  return false;
1072  }
1073  }
1074 
1075  // check gridwise gemm pipeline
1076  const auto num_k_loop = math::integer_divide_ceil(problem.K0, K0PerBlock);
1077 
1078  if(!GridwiseGemmPipe::IsSupported(num_k_loop))
1079  {
1080  return false;
1081  }
1082 
1083  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1084  return true;
1085  }
1086 };
1087 
1088 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition: device_base.hpp:178
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:268
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition: blockwise_gemm_xdlops.hpp:620
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:277
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
ushort bhalf_t
Definition: data_type.hpp:30
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition: amd_wave_read_first_lane.hpp:100
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__global__ void kernel_gemm_xdlops_v2r3(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const CGridDesc_M_N c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v2r3.hpp:34
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:299
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:18
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:295
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
Definition: gridwise_gemm_xdlops_v2r3.hpp:240
__host__ Argument(const ElementDataTypeAB *p_a_grid_, const ElementDataTypeAB *p_b_grid_, FloatC *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_xdlops_v2r3.hpp:241
const ElementDataTypeAB * p_b_grid
Definition: gridwise_gemm_xdlops_v2r3.hpp:258
const ElementDataTypeAB * p_a_grid
Definition: gridwise_gemm_xdlops_v2r3.hpp:257
FloatC * p_c_grid
Definition: gridwise_gemm_xdlops_v2r3.hpp:259
Definition: gridwise_gemm_xdlops_v2r3.hpp:200
index_t NPadded
Definition: gridwise_gemm_xdlops_v2r3.hpp:234
index_t K
Definition: gridwise_gemm_xdlops_v2r3.hpp:229
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_xdlops_v2r3.hpp:201
index_t StrideB
Definition: gridwise_gemm_xdlops_v2r3.hpp:231
index_t N
Definition: gridwise_gemm_xdlops_v2r3.hpp:228
index_t StrideC
Definition: gridwise_gemm_xdlops_v2r3.hpp:232
__host__ void Print() const
Definition: gridwise_gemm_xdlops_v2r3.hpp:219
index_t StrideA
Definition: gridwise_gemm_xdlops_v2r3.hpp:230
index_t MPadded
Definition: gridwise_gemm_xdlops_v2r3.hpp:233
index_t M
Definition: gridwise_gemm_xdlops_v2r3.hpp:227
index_t K0
Definition: gridwise_gemm_xdlops_v2r3.hpp:235
Definition: gridwise_gemm_xdlops_v2r3.hpp:814
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_xdlops_v2r3.hpp:1003
static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdlops_v2r3.hpp:968
static constexpr auto I1
Definition: gridwise_gemm_xdlops_v2r3.hpp:144
static __device__ auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t NPad, index_t K0, index_t StrideB)
Definition: gridwise_gemm_xdlops_v2r3.hpp:915
static __device__ auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t MPad, index_t K, index_t K0, index_t StrideA)
Definition: gridwise_gemm_xdlops_v2r3.hpp:863
static constexpr auto K1
Definition: gridwise_gemm_xdlops_v2r3.hpp:160
Definition: gridwise_gemm_xdlops_v2r3.hpp:142
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_xdlops_v2r3.hpp:391
static __host__ auto CalculateK0(index_t K)
Definition: gridwise_gemm_xdlops_v2r3.hpp:196
static constexpr auto I5
Definition: gridwise_gemm_xdlops_v2r3.hpp:148
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_v2r3.hpp:319
static constexpr auto I2
Definition: gridwise_gemm_xdlops_v2r3.hpp:145
static __device__ constexpr bool IsValidCompilationParameter()
Definition: gridwise_gemm_xdlops_v2r3.hpp:340
static __device__ void Run(const ElementDataTypeAB *p_a_grid, const ElementDataTypeAB *p_b_grid, FloatC *p_c_grid, void *__restrict__ p_shared, const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v2r3.hpp:479
static constexpr auto is_scale_mfma
Definition: gridwise_gemm_xdlops_v2r3.hpp:159
static constexpr auto I6
Definition: gridwise_gemm_xdlops_v2r3.hpp:149
conditional_t< is_same_v< FloatAB, ck::tf32_t >, float, FloatAB > ElementDataTypeAB
Definition: gridwise_gemm_xdlops_v2r3.hpp:167
static constexpr auto I1
Definition: gridwise_gemm_xdlops_v2r3.hpp:144
static __host__ auto CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v2r3.hpp:175
__host__ static constexpr __device__ auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
Definition: gridwise_gemm_xdlops_v2r3.hpp:297
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdlops_v2r3.hpp:186
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdlops_v2r3.hpp:411
static constexpr auto I7
Definition: gridwise_gemm_xdlops_v2r3.hpp:150
FloatAB FloatABAdjusted
Definition: gridwise_gemm_xdlops_v2r3.hpp:272
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_xdlops_v2r3.hpp:169
static constexpr auto I3
Definition: gridwise_gemm_xdlops_v2r3.hpp:146
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage, LoopSched >())> GridwiseGemmPipe
Definition: gridwise_gemm_xdlops_v2r3.hpp:263
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdlops_v2r3.hpp:165
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdlops_v2r3.hpp:191
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc &c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v2r3.hpp:420
__host__ static constexpr __device__ auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
Definition: gridwise_gemm_xdlops_v2r3.hpp:275
static constexpr auto I4
Definition: gridwise_gemm_xdlops_v2r3.hpp:147
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_xdlops_v2r3.hpp:181
static constexpr auto I0
Definition: gridwise_gemm_xdlops_v2r3.hpp:143
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v2r3.hpp:356
static constexpr bool is_single_rate_mfma
Definition: gridwise_gemm_xdlops_v2r3.hpp:153
static constexpr auto K1
Definition: gridwise_gemm_xdlops_v2r3.hpp:160
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition: xdlops_gemm.hpp:1208
Definition: sequence.hpp:43
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: is_known_at_compile_time.hpp:14
Definition: type.hpp:177
Definition: device_base.hpp:197
Definition: unary_element_wise_operation.hpp:334