/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp Source File
gridwise_gemm_wmma.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck/utility/env.hpp"
18 
19 namespace ck {
20 
21 template <typename GridwiseGemm,
22  typename ADataType,
23  typename BDataType,
24  typename CDataType,
25  typename AGridDesc,
26  typename BGridDesc,
27  typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
28  typename AElementwiseOperation,
29  typename BElementwiseOperation,
30  typename CElementwiseOperation,
31  typename Block2CTileMap,
32  bool HasMainKBlockLoop>
33 __global__ void
34 #if CK_USE_LAUNCH_BOUNDS
36 #endif
37  kernel_gemm_wmma(const ADataType* __restrict__ p_a_grid,
38  const BDataType* __restrict__ p_b_grid,
39  CDataType* __restrict__ p_c_grid,
40  const AGridDesc a_grid_desc,
41  const BGridDesc b_grid_desc,
42  const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
43  c_grid_desc_mblock_mperblock_nblock_nperblock,
44  const AElementwiseOperation a_element_op,
45  const BElementwiseOperation b_element_op,
46  const CElementwiseOperation c_element_op,
47  const Block2CTileMap block_2_ctile_map)
48 {
49 #if(defined(__gfx11__) || defined(__gfx12__))
50  __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size];
51 
52  GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
53  p_b_grid,
54  p_c_grid,
55  p_shared,
56  a_grid_desc,
57  b_grid_desc,
58  c_grid_desc_mblock_mperblock_nblock_nperblock,
59  a_element_op,
60  b_element_op,
61  c_element_op,
62  block_2_ctile_map);
63 #else
64  ignore = p_a_grid;
65  ignore = p_b_grid;
66  ignore = p_c_grid;
67  ignore = a_grid_desc;
68  ignore = b_grid_desc;
69  ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
70  ignore = a_element_op;
71  ignore = b_element_op;
72  ignore = c_element_op;
73  ignore = block_2_ctile_map;
74 #endif // end of if (defined(__gfx11__))
75 }
76 
77 template <index_t BlockSize,
78  typename ADataType,
79  typename BDataType,
80  typename AccDataType,
81  typename CShuffleDataType,
82  typename CDataType,
83  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
84  typename AGridDesc,
85  typename BGridDesc,
86  typename CGridDesc_M_N,
87  typename AElementwiseOperation,
88  typename BElementwiseOperation,
89  typename CElementwiseOperation,
90  index_t MPerBlock,
91  index_t NPerBlock,
92  index_t KPerBlock,
93  index_t MPerWmma,
94  index_t NPerWmma,
95  index_t K1Value,
96  index_t MRepeat,
97  index_t NRepeat,
98  typename ABlockTransferThreadClusterLengths_K0_M_K1,
99  typename ABlockTransferThreadClusterArrangeOrder,
100  typename ABlockTransferSrcAccessOrder,
101  index_t ABlockTransferSrcVectorDim,
102  index_t ABlockTransferSrcScalarPerVector,
103  index_t ABlockTransferDstScalarPerVector_K1,
104  bool AThreadTransferSrcResetCoordinateAfterRun,
105  bool AEnableLds,
106  bool ABlockLdsExtraM,
107  typename BBlockTransferThreadClusterLengths_K0_N_K1,
108  typename BBlockTransferThreadClusterArrangeOrder,
109  typename BBlockTransferSrcAccessOrder,
110  index_t BBlockTransferSrcVectorDim,
111  index_t BBlockTransferSrcScalarPerVector,
112  index_t BBlockTransferDstScalarPerVector_K1,
113  bool BThreadTransferSrcResetCoordinateAfterRun,
114  bool BEnableLds,
115  bool BBlockLdsExtraN,
116  index_t CShuffleMRepeatPerShuffle,
117  index_t CShuffleNRepeatPerShuffle,
118  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
119  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
120  index_t NumGemmKPrefetchStage = 1,
122  PipelineVersion PipelineVer = PipelineVersion::v1>
124 {
125  static constexpr auto I0 = Number<0>{};
126  static constexpr auto I1 = Number<1>{};
127  static constexpr auto I2 = Number<2>{};
128  static constexpr auto I3 = Number<3>{};
129  static constexpr auto I4 = Number<4>{};
130  static constexpr auto I5 = Number<5>{};
131  static constexpr auto I6 = Number<6>{};
132  static constexpr auto I7 = Number<7>{};
133 
134  // FIX ME: To be deprecated
135  static constexpr auto K1 = Number<K1Value>{};
136 
137  static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
138  static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
139  static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
140 
142 
144  remove_cvref_t<decltype(GridwiseGemmPipeline_Selector<PipelineVer,
145  NumGemmKPrefetchStage,
146  LoopSched,
147  AEnableLds,
148  BEnableLds>())>;
149 
150  // Describe how data store to (LDS/VGPR) buffer from Global memory
151  __host__ __device__ static constexpr auto MakeABlockDescriptor()
152  {
153  constexpr auto a_block_desc = [&]() {
154  if constexpr(AEnableLds)
155  {
156  // K0->M->K1 Per Block
157  constexpr auto K0PerBlock = KPerBlock / K1;
158  constexpr auto max_lds_align = K1;
159 
160  if constexpr(ABlockLdsExtraM)
161  {
165  }
166  else
167  {
169  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
170  }
171  }
172  else
173  {
174  constexpr auto A_KRow = I2;
175  constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
176  constexpr auto K0PerWmma = WmmaK / A_KRow / K1;
177  // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
180  Number<MRepeat>{},
181  I1,
183  I1,
184  I1,
185  K1),
187  Number<K0PerWmma>{} * K1,
188  Number<K0PerWmma>{} * K1,
189  K1,
190  K1,
191  K1,
192  I1));
193  }
194  }();
195 
196  return a_block_desc;
197  }
198 
199  __host__ __device__ static constexpr auto MakeBBlockDescriptor()
200  {
201  constexpr auto b_block_desc = [&]() {
202  if constexpr(BEnableLds)
203  {
204  // K0->N->K1 Per Block
205  constexpr auto K0PerBlock = KPerBlock / K1;
206  constexpr auto max_lds_align = K1;
207 
208  if constexpr(BBlockLdsExtraN)
209  {
213  }
214  else
215  {
217  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
218  }
219  }
220  else
221  {
222 
223  constexpr auto B_KRow = I2;
224  constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
225  constexpr auto K0PerWmma = WmmaK / B_KRow / K1;
226  // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
229  Number<NRepeat>{},
230  I1,
232  I1,
233  I1,
234  K1),
236  Number<K0PerWmma>{} * K1,
237  Number<K0PerWmma>{} * K1,
238  K1,
239  K1,
240  K1,
241  I1));
242  }
243  }();
244 
245  return b_block_desc;
246  }
247 
248  __host__ __device__ static constexpr auto MakeABlockSliceCopyStep()
249  {
250  constexpr auto a_block_copy_step = [&]() {
251  if constexpr(AEnableLds)
252  {
253  constexpr auto K0PerBlock = KPerBlock / K1;
254 
255  return make_multi_index(K0PerBlock, 0, 0);
256  }
257  else
258  {
259  constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
260 
261  return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
262  }
263  }();
264 
265  return a_block_copy_step;
266  }
267 
268  __host__ __device__ static constexpr auto MakeBBlockSliceCopyStep()
269  {
270  constexpr auto b_block_copy_step = [&]() {
271  if constexpr(BEnableLds)
272  {
273  constexpr auto K0PerBlock = KPerBlock / K1;
274 
275  return make_multi_index(K0PerBlock, 0, 0);
276  }
277  else
278  {
279  constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
280 
281  return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
282  }
283  }();
284 
285  return b_block_copy_step;
286  }
287 
288  // Describe how data read from (LDS/VGPR) buffer
289  template <typename ABlockDesc_>
290  __host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&)
291  {
292 
293  constexpr auto a_wave_desc = [&]() {
294  if constexpr(AEnableLds)
295  {
296  // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
297  constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
298  constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
299 #ifdef __gfx12__
300  constexpr auto A_KRow = I2;
301 #else
302  constexpr auto A_KRow = I1;
303 #endif
304 
306  ABlockDesc_{},
313  }
314  else
315  {
316  // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
317  constexpr auto KWmma = ABlockDesc_{}.GetLength(I0);
318  constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3);
319  constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4);
320  constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6);
321 
322  // Err: merge transform cause non-constexpr issue
323 
324  // return transform_tensor_descriptor(
325  // ABlockDesc_{},
326  // make_tuple(make_merge_transform(make_tuple(Number<KWmma>{}, I1)),
327  // make_pass_through_transform(Number<MRepeat>{}),
328  // make_pass_through_transform(I1),
329  // make_pass_through_transform(I1),
330  // make_pass_through_transform(Number<A_K1>{})),
331  // make_tuple(Sequence<0, 3>{},
332  // Sequence<1>{},
333  // Sequence<2>{},
334  // Sequence<4>{},
335  // Sequence<5>{}),
336  // make_tuple(
337  // Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{},
338  // Sequence<4>{}));
339 
340  // Workaround, Freeze transform
342  Number<MRepeat>{},
343  I1,
344  Number<A_KRow>{},
345  I1,
346  Number<A_K1>{}));
347  }
348  }();
349 
350  return a_wave_desc;
351  }
352 
353  template <typename BBlockDesc_>
354  __host__ __device__ static constexpr auto MakeBWaveDescriptor(const BBlockDesc_&)
355  {
356  constexpr auto b_wave_desc = [&]() {
357  if constexpr(BEnableLds)
358  {
359  // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
360  constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
361  constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
362 #ifdef __gfx12__
363  constexpr auto B_KRow = I2;
364 #else
365  constexpr auto B_KRow = I1;
366 #endif
368  BBlockDesc_{},
375  }
376  else
377  {
378  // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
379  constexpr auto KWmma = BBlockDesc_{}.GetLength(I0);
380  constexpr auto K0PerWmma = BBlockDesc_{}.GetLength(I3);
381  constexpr auto B_KRow = BBlockDesc_{}.GetLength(I4);
382  constexpr auto B_K1 = BBlockDesc_{}.GetLength(I6);
383 
384  // Workaround, Freeze transform
386  Number<NRepeat>{},
387  I1,
388  Number<B_KRow>{},
389  I1,
390  Number<B_K1>{}));
391  }
392  }();
393 
394  return b_wave_desc;
395  }
396 
397  __host__ __device__ static constexpr auto
398  // *Caution Here repeat is shuffle repeat
400  {
401  constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
403  make_tuple(I1,
405  I1,
407 
408  return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
409  }
410 
411  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
412  template <typename Block2CTileMap>
413  __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
414  const BGridDesc& b_grid_desc,
415  const CGridDesc_M_N& c_grid_desc_m_n,
416  const Block2CTileMap& block_2_ctile_map)
417  {
418  static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
419  "wrong! K1 need to be known at compile-time");
420 
421  static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) &&
422  (NPerBlock % (NRepeat * NPerWmma)) == 0,
423  "Invalid tuning param!");
424 
425  const auto GetAProblemsizeMK = [&]() {
426  if constexpr(AEnableLds)
427  {
428  return make_tuple(a_grid_desc.GetLength(I1),
429  a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2));
430  }
431  else
432  {
433  return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) *
434  a_grid_desc.GetLength(I5),
435  a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) *
436  a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6));
437  }
438  };
439 
440  const auto GetBProblemsizeNK = [&]() {
441  if constexpr(BEnableLds)
442  {
443  return make_tuple(b_grid_desc.GetLength(I1),
444  b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I2));
445  }
446  else
447  {
448  return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) *
449  b_grid_desc.GetLength(I5),
450  b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) *
451  b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6));
452  }
453  };
454 
455  const auto M = GetAProblemsizeMK()[I0];
456  const auto N = GetBProblemsizeNK()[I0];
457  const auto K = GetAProblemsizeMK()[I1];
458 
459  if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
460  K == GetBProblemsizeNK()[I1]))
461  {
462  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
463  {
464  printf("A: MxK = %d x %d, B: NxK = %d x %d, C: MxN = %d x %d\n",
465  GetAProblemsizeMK()[I0],
466  GetAProblemsizeMK()[I1],
467  GetBProblemsizeNK()[I0],
468  GetBProblemsizeNK()[I1],
469  c_grid_desc_m_n.GetLength(I0),
470  c_grid_desc_m_n.GetLength(I1));
471  printf("GridwiseOp err: ProblemSize check");
472  }
473  return false;
474  }
475 
476  if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
477  {
478  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
479  {
480  printf("GridwiseOp err: ProblemSize division");
481  }
482  return false;
483  }
484 
485  // check gridwise gemm pipeline
486  const auto num_k_loop = K / KPerBlock;
487 
488  if(!GridwiseGemmPipe::IsSupported(num_k_loop))
489  {
490  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
491  {
492  printf("GridwiseOp err: Pipeline not support this k_loop");
493  }
494  return false;
495  }
496 
497  if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
498  {
499  return false;
500  }
501 
502  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
503  constexpr long_index_t TwoGB = (long_index_t{1} << 31);
504 
505  if(!(a_grid_desc.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
506  b_grid_desc.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB))
507  {
508  return false;
509  }
510  return true;
511  }
512 
513  __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
514  {
515  const index_t num_loop = K / KPerBlock;
516 
517  return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
518  }
519 
520  __host__ __device__ static constexpr auto
521  MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n)
522  {
523  const auto M = c_grid_desc_m_n.GetLength(I0);
524  const auto N = c_grid_desc_m_n.GetLength(I1);
525 
526  const auto MBlock = M / MPerBlock;
527  const auto NBlock = N / NPerBlock;
528 
529  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
530  c_grid_desc_m_n,
535 
536  return c_grid_desc_mblock_mperblock_nblock_nperblock;
537  }
538 
539  // return block_id to C matrix tile idx (m0, n0) mapping
540  __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
541  const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
542  {
544  c_grid_desc_m_n);
545  }
546 
548  {
549  // LDS allocation for A and B: be careful of alignment
550 
551  static constexpr auto max_lds_align = K1;
552 
553  static constexpr auto a_block_space_size_aligned =
554  AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(),
556  : 0;
557  static constexpr auto b_block_space_size_aligned =
558  BEnableLds ? math::integer_least_multiple(MakeBBlockDescriptor().GetElementSpaceSize(),
560  : 0;
561 
562  static constexpr auto a_block_space_offset = 0;
564 
565  // LDS allocation for C shuffle in LDS
566  static constexpr auto c_shuffle_block_space_size =
568  .GetElementSpaceSize();
569 
570  static constexpr auto c_shuffle_block_space_offset = 0;
571 
572  static constexpr auto lds_size =
573  math::max(c_shuffle_block_space_size * sizeof(CShuffleDataType),
574  a_block_space_size_aligned * sizeof(ADataType) +
575  b_block_space_size_aligned * sizeof(BDataType));
576  };
577 
580  CGridDesc_M_N{}))>;
582  remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
583 
584  template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
585  __device__ static void Run(const ADataType* __restrict__ p_a_grid,
586  const BDataType* __restrict__ p_b_grid,
587  CDataType* __restrict__ p_c_grid,
588  void* __restrict__ p_shared,
589  const AGridDesc& a_grid_desc,
590  const BGridDesc& b_grid_desc,
592  c_grid_desc_mblock_mperblock_nblock_nperblock,
593  const AElementwiseOperation& a_element_op,
594  const BElementwiseOperation& b_element_op,
595  const CElementwiseOperation& c_element_op,
596  const Block2CTileMap& block_2_ctile_map)
597  {
598  // clang-format off
599 /*******************************************************************************/
600 // Memory buffer zone.
601  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
602  p_a_grid, a_grid_desc.GetElementSpaceSize());
603  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
604  p_b_grid, b_grid_desc.GetElementSpaceSize());
605  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
606  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
607 
608 /*******************************************************************************/
609 // BlockIdx.x -> [BlockId.m, BlockId.n]
610  const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
611  if(!block_2_ctile_map.ValidCTileIndex(
612  block_work_idx,
613  make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
614  c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
615  { return; }
616 
617  // Store BlockId into SGPR
618  const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
619  const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
620 
621 /*******************************************************************************/
622 // BlockLevel, A/B Matrix ThreadMapping in WMMA Source buffer, As Destinaion of BlockWise_Copy
623  const auto K = [&](){
624  if constexpr(AEnableLds){
625  return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2);
626  }
627  else{
628  return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3)
629  * a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6);
630  }
631  }();
632 
633  constexpr auto a_block_desc = MakeABlockDescriptor();
634  constexpr auto b_block_desc = MakeBBlockDescriptor();
635 
636  auto a_block_trait = [&](){
637  // A matrix blockwise copy
638  if constexpr(AEnableLds)
639  {
640  constexpr auto K0PerBlock = KPerBlock/ K1;
641  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
642  static_cast<ADataType*>(p_shared),
644 
645  auto a_blockwise_copy =
647 /* typename SrcElementwiseOperation, */ AElementwiseOperation,
648 /* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough,
649 /* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set,
650 /* typename BlockSliceLengths, */ Sequence<K0PerBlock, MPerBlock, K1>,
651 /* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1,
652 /* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder,
653 /* typename SrcData, */ ADataType,
654 /* typename DstData, */ ADataType,
655 /* typename SrcDesc, */ decltype(a_grid_desc),
656 /* typename DstDesc, */ decltype(a_block_desc),
657 /* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder,
658 /* typename DstDimAccessOrder, */ Sequence<0, 1, 2>,
659 /* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim,
660 /* index_t DstVectorDim, */ 2,
661 /* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector,
662 /* index_t DstScalarPerVector, */ ABlockTransferDstScalarPerVector_K1,
663 /* index_t SrcScalarStrideInVector, */ 1,
664 /* index_t DstScalarStrideInVector, */ 1,
665 /* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun,
666 /* bool ThreadTransferDstResetCoordinateAfterRun, */ true,
667  NumGemmKPrefetchStage>(
668  a_grid_desc,
669  make_multi_index(0, m_block_data_idx_on_grid, 0),
670  a_element_op,
671  a_block_desc,
672  make_multi_index(0, 0, 0),
674 
675  return make_tuple(a_block_buf, a_blockwise_copy);
676  }
677  else
678  {
679  // Thread-wise copy
680  // KPerBlock/WmmaK -> MRepeat -> MWaves -> K0PerWmma -> KRow -> MPerWmma -> K1
681  constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
682  constexpr auto K0PerWmma = WmmaK/2/K1Value;
683  auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
684  a_block_desc.GetElementSpaceSize());
685 
686  // Limitation: NumDim of Src and Dst descriptor should be identical
687  auto a_blockwise_copy =
689  ADataType,
690  decltype(a_grid_desc),
691  decltype(a_block_desc),
693  Number<MRepeat>{},
694  I1,
696  I1,
697  I1,
698  Number<K1Value>{}>,
700  6,
701  ABlockTransferSrcScalarPerVector,
702  AThreadTransferSrcResetCoordinateAfterRun,
703  true>(
704  a_grid_desc,
705  make_multi_index(0,
706  m_block_data_idx_on_grid/(MWaves * MPerWmma),
707  get_thread_local_1d_id() / 32,
708  0,
709  (get_thread_local_1d_id() % 32 )/ 16,
710  get_thread_local_1d_id() % 16,
711  0));
712 
713  return make_tuple(a_block_buf, a_blockwise_copy);
714  }
715  };
716 
717  auto b_block_trait = [&](){
718  if constexpr(BEnableLds)
719  {
720  constexpr auto K0PerBlock = KPerBlock/ K1;
721  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
722  static_cast<BDataType*>(p_shared) + SharedMemTrait::b_block_space_offset,
724 
725  auto b_blockwise_copy =
727  BElementwiseOperation,
731  BBlockTransferThreadClusterLengths_K0_N_K1,
732  BBlockTransferThreadClusterArrangeOrder,
733  BDataType,
734  BDataType,
735  decltype(b_grid_desc),
736  decltype(b_block_desc),
737  BBlockTransferSrcAccessOrder,
739  BBlockTransferSrcVectorDim,
740  2,
741  BBlockTransferSrcScalarPerVector,
742  BBlockTransferDstScalarPerVector_K1,
743  1,
744  1,
745  BThreadTransferSrcResetCoordinateAfterRun,
746  true,
747  NumGemmKPrefetchStage>(
748  b_grid_desc,
749  make_multi_index(0, n_block_data_idx_on_grid, 0),
750  b_element_op,
751  b_block_desc,
752  make_multi_index(0, 0, 0),
754 
755  return make_tuple(b_block_buf, b_blockwise_copy);
756  }
757  else
758  {
759  // Thread-wise copy
760  // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
761  constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
762  constexpr auto K0PerWmma = WmmaK/2/K1Value;
763  auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
764  b_block_desc.GetElementSpaceSize());
765 
766  // Limitation: NumDim of Src and Dst descriptor should be identical
767  auto b_blockwise_copy =
769  BDataType,
770  decltype(b_grid_desc),
771  decltype(b_block_desc),
773  Number<NRepeat>{},
774  I1,
776  I1,
777  I1,
778  Number<K1Value>{}>,
780  6,
781  BBlockTransferSrcScalarPerVector,
782  BThreadTransferSrcResetCoordinateAfterRun,
783  true>(
784  b_grid_desc,
785  make_multi_index(0,
786  n_block_data_idx_on_grid/(NWaves * NPerWmma),
787  get_thread_local_1d_id() / 32,
788  0,
789  (get_thread_local_1d_id() % 32 )/ 16,
790  get_thread_local_1d_id() % 16,
791  0));
792 
793  return make_tuple(b_block_buf, b_blockwise_copy);
794  }
795  };
796 
797  auto a_block_buf = a_block_trait()[I0];
798  auto a_blockwise_copy = a_block_trait()[I1];
799 
800  auto b_block_buf = b_block_trait()[I0];
801  auto b_blockwise_copy = b_block_trait()[I1];
802 /*******************************************************************************/
803  // GEMM
804  constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
805 
806  auto blockwise_gemm =
807  BlockwiseGemmWMMA<BlockSize,
808  ADataType,
809  BDataType,
810  AccDataType,
811  decltype(MakeAWaveDescriptor(a_block_desc)),
812  decltype(MakeBWaveDescriptor(b_block_desc)),
813  MPerBlock,
814  NPerBlock,
815  KPerBlock,
816  MPerWmma,
817  NPerWmma,
818  MRepeat,
819  NRepeat,
820  KPack,
821  AEnableLds,
822  BEnableLds>{};
823 
824  // Prepare Register for C matrix
825  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
826 
827 /*******************************************************************************/
828  // Shift Per SUB_K
829  constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep();
830  constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep();
831 
832  // gridwise GEMM pipeline
833  const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock);
834  GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc,
835  a_block_desc,
836  a_blockwise_copy,
837  a_grid_buf,
838  a_block_buf,
839  a_block_slice_copy_step,
840  b_grid_desc,
841  b_block_desc,
842  b_blockwise_copy,
843  b_grid_buf,
844  b_block_buf,
845  b_block_slice_copy_step,
846  blockwise_gemm,
847  c_thread_buf,
848  KBlockMainLoop);
849 /*******************************************************************************/
850  // write out to C, implement shuffle
851  {
852  // C mapping in single thread.
853  constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
854  blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
855 
856  // C mapping in single block
857  constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
858  blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
859 
860  constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I1);
861  constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I2);
862  constexpr auto NWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I4);
863  constexpr auto NThreadPerSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I5);
864  constexpr auto MAccVgprs = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I6);
865 
866  // LDS descriptor, shuffle and write out in MRepeat x NRepeat times
867  constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
869 
870  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
871  static_cast<CShuffleDataType*>(p_shared) + SharedMemTrait::c_shuffle_block_space_offset,
873 
874  constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor(
875  c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
876  make_tuple(
879  Number<CShuffleMRepeatPerShuffle>{}, // MRepeat per shuffle repeat
880  MWave, // MWave
881  MSubGroup, // MSubGroup * MAccVgprs = MPerWmma
882  MAccVgprs)),
885  Number<CShuffleNRepeatPerShuffle>{}, // NRepeat per shuffle repeat
886  NWave, // NWave
887  NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma
890 
891  // calculate origin of thread output tensor on global memory
892  // blockwise GEMM c matrix starting index
893  const auto c_thread_mtx_on_block = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0);
894 
895  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
896  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
897 
898  const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor =
900  make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))),
903 
904  const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
906  make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))),
909 
910  const auto m_thread_data_on_block_idx = m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex(
911  make_multi_index(m_thread_data_on_block));
912 
913  const auto n_thread_data_on_block_idx = n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex(
914  make_multi_index(n_thread_data_on_block));
915 
916  // shuffle: threadwise copy C from VGPR to LDS
917  auto c_thread_copy_vgpr_to_lds =
919  CShuffleDataType,
920  decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
921  decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
923  Sequence<CShuffleMRepeatPerShuffle,
924  I1,
925  I1,
926  CShuffleNRepeatPerShuffle,
927  I1,
928  I1,
929  MAccVgprs>,
931  6,
932  1, // vector write pixel
934  1,
935  true>{
936  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
938  m_thread_data_on_block_idx[I1],
939  m_thread_data_on_block_idx[I2],
940  0,
941  n_thread_data_on_block_idx[I1],
942  n_thread_data_on_block_idx[I2],
943  m_thread_data_on_block_idx[I3]),
945 
946  // shuffle: blockwise copy C from LDS to global
947  auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
948  ThisThreadBlock, // ThreadGroup
949  CElementwiseOperation, // ElementwiseOperation,
950  CGlobalMemoryDataOperation, // DstInMemOp,
951  Sequence<1,
952  CShuffleMRepeatPerShuffle * MWave * MPerWmma,
953  1,
954  CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
955  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
956  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
957  CShuffleDataType, // typename SrcData,
958  CDataType, // typename DstData,
959  decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
960  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
961  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
962  3, // index_t VectorDim,
963  CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
964  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
965  false> // bool ThreadTransferDstResetCoordinateAfterRun>
966  {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
967  make_multi_index(0, 0, 0, 0),
968  c_grid_desc_mblock_mperblock_nblock_nperblock,
969  make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
970  c_element_op};
971 
972  // space filling curve for local reg & global memory
973  // space filling curve for threadwise C in VGPR
974  constexpr auto sfc_c_vgpr =
977  Sequence<CShuffleMRepeatPerShuffle,
978  1,
979  1,
980  CShuffleNRepeatPerShuffle,
981  1,
982  1,
983  MAccVgprs>>{};
984 
985  // space filling curve for shuffled blockwise C in global mem
986  constexpr auto sfc_c_global =
989  Sequence<1,
990  CShuffleMRepeatPerShuffle * MWave * MPerWmma,
991  1,
992  CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{};
993 
994  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
995 
996  static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
997 
998  static_for<0, num_access, 1>{}([&](auto access_id) {
999  // make sure it's safe to write to LDS
1000  block_sync_lds();
1001 
1002  // each thread write its data from VGPR to LDS
1003  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
1004  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1005  c_thread_buf,
1006  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
1007  c_shuffle_block_buf);
1008 
1009  // make sure it's safe to read from LDS
1010  block_sync_lds();
1011 
1012  // each block copy its data from LDS to global
1013  c_shuffle_block_copy_lds_to_global.Run(
1014  c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
1015  c_shuffle_block_buf,
1016  c_grid_desc_mblock_mperblock_nblock_nperblock,
1017  c_grid_buf);
1018 
1019  if constexpr(access_id < num_access - 1)
1020  {
1021  constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1022 
1023  // move on C
1024  c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1025  c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1026  }
1027  });
1028  }
1029  // clang-format on
1030  }
1031 };
1032 
1033 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:268
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
constexpr auto GridwiseGemmPipeline_Selector()
Definition: gridwise_gemm_pipeline_selector.hpp:31
InMemoryDataOperationEnum
Definition: ck.hpp:277
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__global__ void kernel_gemm_wmma(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, CDataType *__restrict__ p_c_grid, const AGridDesc a_grid_desc, const BGridDesc b_grid_desc, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map)
Definition: gridwise_gemm_wmma.hpp:37
int64_t long_index_t
Definition: ck.hpp:300
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:299
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:18
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:295
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Definition: block_to_ctile_map.hpp:261
Definition: blockwise_gemm_wmma.hpp:550
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_wmma.hpp:585
Definition: gridwise_gemm_wmma.hpp:548
static constexpr auto c_shuffle_block_space_size
Definition: gridwise_gemm_wmma.hpp:566
static constexpr auto b_block_space_size_aligned
Definition: gridwise_gemm_wmma.hpp:557
static constexpr auto max_lds_align
Definition: gridwise_gemm_wmma.hpp:551
static constexpr auto c_shuffle_block_space_offset
Definition: gridwise_gemm_wmma.hpp:570
static constexpr auto lds_size
Definition: gridwise_gemm_wmma.hpp:572
static constexpr auto a_block_space_size_aligned
Definition: gridwise_gemm_wmma.hpp:553
static constexpr auto a_block_space_offset
Definition: gridwise_gemm_wmma.hpp:562
static constexpr auto b_block_space_offset
Definition: gridwise_gemm_wmma.hpp:563
Definition: gridwise_gemm_wmma.hpp:124
__host__ static constexpr __device__ auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n, index_t, index_t)
Definition: gridwise_gemm_wmma.hpp:540
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition: gridwise_gemm_wmma.hpp:580
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage, LoopSched, AEnableLds, BEnableLds >())> GridwiseGemmPipe
Definition: gridwise_gemm_wmma.hpp:148
static constexpr auto I6
Definition: gridwise_gemm_wmma.hpp:131
static __device__ void Run(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, CDataType *__restrict__ p_c_grid, void *__restrict__ p_shared, const AGridDesc &a_grid_desc, const BGridDesc &b_grid_desc, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_wmma.hpp:585
static constexpr auto I5
Definition: gridwise_gemm_wmma.hpp:130
__host__ static constexpr __device__ auto MakeBBlockDescriptor()
Definition: gridwise_gemm_wmma.hpp:199
__host__ static constexpr __device__ auto MakeBWaveDescriptor(const BBlockDesc_ &)
Definition: gridwise_gemm_wmma.hpp:354
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc &a_grid_desc, const BGridDesc &b_grid_desc, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_wmma.hpp:413
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_wmma.hpp:521
__host__ static constexpr __device__ auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition: gridwise_gemm_wmma.hpp:399
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_wmma.hpp:513
static constexpr auto I7
Definition: gridwise_gemm_wmma.hpp:132
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_wmma.hpp:141
__host__ static constexpr __device__ auto MakeBBlockSliceCopyStep()
Definition: gridwise_gemm_wmma.hpp:268
static constexpr auto K1
Definition: gridwise_gemm_wmma.hpp:135
static constexpr auto I4
Definition: gridwise_gemm_wmma.hpp:129
static constexpr auto I1
Definition: gridwise_gemm_wmma.hpp:126
static constexpr auto MWaves
Definition: gridwise_gemm_wmma.hpp:137
__host__ static constexpr __device__ auto MakeABlockSliceCopyStep()
Definition: gridwise_gemm_wmma.hpp:248
static constexpr auto I2
Definition: gridwise_gemm_wmma.hpp:127
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))> DefaultBlock2CTileMap
Definition: gridwise_gemm_wmma.hpp:582
__host__ static constexpr __device__ auto MakeAWaveDescriptor(const ABlockDesc_ &)
Definition: gridwise_gemm_wmma.hpp:290
static constexpr auto I3
Definition: gridwise_gemm_wmma.hpp:128
__host__ static constexpr __device__ auto MakeABlockDescriptor()
Definition: gridwise_gemm_wmma.hpp:151
static constexpr auto NWaves
Definition: gridwise_gemm_wmma.hpp:138
static constexpr auto I0
Definition: gridwise_gemm_wmma.hpp:125
static constexpr auto WmmaK
Definition: gridwise_gemm_wmma.hpp:139
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: integral_constant.hpp:20
Definition: is_known_at_compile_time.hpp:14
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:334
#define CK_ENV(name)
Definition: env.hpp:129