/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp Source File
gridwise_gemm_wmma_cshuffle_v3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck/utility/env.hpp"
18 
19 namespace ck {
20 
124 template <typename ALayout,
125  typename BLayout,
126  typename CLayout,
127  typename ADataType,
128  typename BDataType,
129  typename AccDataType,
130  typename CShuffleDataType,
131  typename CDataType,
132  typename AElementwiseOperation,
133  typename BElementwiseOperation,
134  typename CElementwiseOperation,
136  index_t BlockSize,
137  index_t MPerBlock,
138  index_t NPerBlock,
139  index_t KPerBlock,
140  index_t AK1Value,
141  index_t BK1Value,
142  index_t MPerWmma,
143  index_t NPerWmma,
144  index_t MRepeat,
145  index_t NRepeat,
146  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
147  typename ABlockTransferThreadClusterArrangeOrder,
148  typename ABlockTransferSrcAccessOrder,
149  index_t ABlockTransferSrcVectorDim,
150  index_t ABlockTransferSrcScalarPerVector,
151  index_t ABlockTransferDstScalarPerVector_AK1,
152  bool AThreadTransferSrcResetCoordinateAfterRun,
153  index_t ABlockLdsExtraM,
154  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
155  typename BBlockTransferThreadClusterArrangeOrder,
156  typename BBlockTransferSrcAccessOrder,
157  index_t BBlockTransferSrcVectorDim,
158  index_t BBlockTransferSrcScalarPerVector,
159  index_t BBlockTransferDstScalarPerVector_BK1,
160  bool BThreadTransferSrcResetCoordinateAfterRun,
161  index_t BBlockLdsExtraN,
162  index_t CShuffleMRepeatPerShuffle,
163  index_t CShuffleNRepeatPerShuffle,
164  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
165  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
166  BlockGemmPipelineScheduler BlkGemmPipeSched,
167  BlockGemmPipelineVersion BlkGemmPipelineVer,
168  typename ComputeTypeA,
169  typename ComputeTypeB,
170  bool PermuteA,
171  bool PermuteB>
174  ALayout,
175  BLayout,
176  CLayout,
177  ADataType,
178  BDataType,
179  AccDataType,
180  CShuffleDataType,
181  CDataType,
182  AElementwiseOperation,
183  BElementwiseOperation,
184  CElementwiseOperation,
185  GemmSpec,
186  BlockSize,
187  MPerBlock,
188  NPerBlock,
189  KPerBlock,
190  AK1Value,
191  BK1Value,
192  MPerWmma,
193  NPerWmma,
194  MRepeat,
195  NRepeat,
196  ABlockTransferThreadClusterLengths_AK0_M_AK1,
197  ABlockTransferThreadClusterArrangeOrder,
198  ABlockTransferSrcAccessOrder,
199  ABlockTransferSrcVectorDim,
200  ABlockTransferSrcScalarPerVector,
201  ABlockTransferDstScalarPerVector_AK1,
202  AThreadTransferSrcResetCoordinateAfterRun,
203  ABlockLdsExtraM,
204  BBlockTransferThreadClusterLengths_BK0_N_BK1,
205  BBlockTransferThreadClusterArrangeOrder,
206  BBlockTransferSrcAccessOrder,
207  BBlockTransferSrcVectorDim,
208  BBlockTransferSrcScalarPerVector,
209  BBlockTransferDstScalarPerVector_BK1,
210  BThreadTransferSrcResetCoordinateAfterRun,
211  BBlockLdsExtraN,
212  CShuffleMRepeatPerShuffle,
213  CShuffleNRepeatPerShuffle,
214  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
215  CShuffleBlockTransferScalarPerVector_NPerBlock,
216  BlkGemmPipeSched,
217  BlkGemmPipelineVer,
218  ComputeTypeA,
219  ComputeTypeB,
220  PermuteA,
221  PermuteB>
222 {
224  ALayout,
225  BLayout,
226  CLayout,
227  ADataType,
228  BDataType,
229  AccDataType,
230  CShuffleDataType,
231  CDataType,
232  AElementwiseOperation,
233  BElementwiseOperation,
234  CElementwiseOperation,
235  GemmSpec,
236  BlockSize,
237  MPerBlock,
238  NPerBlock,
239  KPerBlock,
240  AK1Value,
241  BK1Value,
242  MPerWmma,
243  NPerWmma,
244  MRepeat,
245  NRepeat,
246  ABlockTransferThreadClusterLengths_AK0_M_AK1,
247  ABlockTransferThreadClusterArrangeOrder,
248  ABlockTransferSrcAccessOrder,
249  ABlockTransferSrcVectorDim,
250  ABlockTransferSrcScalarPerVector,
251  ABlockTransferDstScalarPerVector_AK1,
252  AThreadTransferSrcResetCoordinateAfterRun,
253  ABlockLdsExtraM,
254  BBlockTransferThreadClusterLengths_BK0_N_BK1,
255  BBlockTransferThreadClusterArrangeOrder,
256  BBlockTransferSrcAccessOrder,
257  BBlockTransferSrcVectorDim,
258  BBlockTransferSrcScalarPerVector,
259  BBlockTransferDstScalarPerVector_BK1,
260  BThreadTransferSrcResetCoordinateAfterRun,
261  BBlockLdsExtraN,
262  CShuffleMRepeatPerShuffle,
263  CShuffleNRepeatPerShuffle,
264  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
265  CShuffleBlockTransferScalarPerVector_NPerBlock,
266  BlkGemmPipeSched,
267  BlkGemmPipelineVer,
268  ComputeTypeA,
269  ComputeTypeB,
270  PermuteA,
271  PermuteB>;
272 
273  using Base::I0;
274  using Base::I1;
275  using Base::I2;
276  using Base::I3;
277  using Base::I4;
278  using Base::I5;
279  using Base::I6;
280  using Base::I7;
281 
282  using Base::AK0Number;
283  using Base::AK1Number;
284  using Base::BK0Number;
285  using Base::BK1Number;
286 
287  using Base::APackedSize;
288  using Base::BPackedSize;
289 
293  using Base::CalculateKRead;
294  using Base::CalculateMBlock;
296  using Base::CalculateNBlock;
301 
303 
305 
307 
310 
311  struct Problem
312  {
313  __host__ Problem(index_t M_,
314  index_t N_,
315  index_t K_,
316  index_t StrideA_,
317  index_t StrideB_,
318  index_t StrideC_,
319  index_t KBatch_)
320  : M{M_},
321  N{N_},
322  K{K_},
323  StrideA{StrideA_},
324  StrideB{StrideB_},
325  StrideC{StrideC_},
326  KBatch{KBatch_},
329  KRead{CalculateKRead(K_, KBatch_)},
330  KPadded{CalculateKPadded(K_, KBatch_)},
331  AK0{CalculateAK0Padded(K_, KBatch_)},
332  BK0{CalculateBK0Padded(K_, KBatch_)},
333  MBlock{CalculateMBlock(M_)},
335  {
336  }
337 
338  __host__ void Print() const
339  {
340  std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
341  << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
342  << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
343  << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0
344  << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", "
345  << "NBlock: " << NBlock << "}" << std::endl;
346  }
347 
363  };
364 
365  // Argument
367  {
368  __host__ Argument(const ADataType* p_a_grid_,
369  const BDataType* p_b_grid_,
370  CDataType* p_c_grid_,
371  index_t M_,
372  index_t N_,
373  index_t K_,
374  index_t StrideA_,
375  index_t StrideB_,
376  index_t StrideC_,
377  index_t k_batch_,
378  bool is_reduce_ = false)
379  : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_},
380  p_a_grid{p_a_grid_},
381  p_b_grid{p_b_grid_},
382  p_c_grid{p_c_grid_},
383  is_reduce(is_reduce_)
384  {
385  }
386 
387  __host__ __device__ inline bool IsReduceAdd() const
388  {
389  return (Problem::KBatch > 1) && is_reduce;
390  }
391 
392  __host__ __device__ inline bool IsAtomicAdd() const
393  {
394  return (Problem::KBatch > 1) && (!is_reduce);
395  }
396 
397  const ADataType* p_a_grid;
398  const BDataType* p_b_grid;
399  CDataType* p_c_grid;
400  bool is_reduce;
401  };
402 
404  {
405 
406  __device__ SplitKBatchOffset(Argument& karg)
407  {
408  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
409  {
410  a_k_split_offset = blockIdx.z * karg.KRead / APackedSize;
411  }
412  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
413  {
414  a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
415  }
416 
417  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
418  {
419  b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
420  }
421  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
422  {
423  if constexpr(!PermuteB)
424  {
425  b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
426  }
427  else
428  {
429  const int k0_offset = karg.KRead * karg.N;
430  b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
431  }
432  }
433 
434  if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
435  {
436  karg.K = karg.KRead;
437  }
438  else
439  {
440  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
441  }
442 
443  if(karg.IsReduceAdd())
444  {
445  c_reduce_offset = blockIdx.z * karg.M * karg.N;
446  }
447  else
448  {
449  c_reduce_offset = 0;
450  }
451  }
452 
456  };
457 
459 
460  // return block_id to C matrix tile idx (m0, n0) mapping
461  // if arch = gfx942
463  // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
464 
465  __device__ static index_t GetKBlockPerScale() { return 1; }
466 
467  template <bool HasMainKBlockLoop,
468  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
469  TailNumber TailNum = TailNumber::Odd>
470  __device__ static void Run(const ADataType* p_a_grid,
471  const BDataType* p_b_grid,
472  CDataType* p_c_grid,
473  void* p_shared,
474  const Problem& problem)
475  {
476  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
477  problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
478  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
479  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
480  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
481  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
482  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
484  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
485 
486  // divide block work by [M, N]
487  const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
488 
489  const auto block_work_idx =
490  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
491 
492  if(!block_2_ctile_map.ValidCTileIndex(
493  block_work_idx,
494  make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
495  c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
496  {
497  return;
498  }
499 
500  const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
501  const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
502 
503  // BScale struct (Empty)
504  using BScale = typename BlockwiseGemmPipe::Empty;
505  auto b_scale_struct = BScale{};
506 
507  const index_t num_k_block_per_scale = GetKBlockPerScale();
508 
509  Base::template Run<decltype(a_grid_desc_ak0_m_ak1),
510  decltype(b_grid_desc_bk0_n_bk1),
511  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
512  decltype(b_scale_struct),
513  HasMainKBlockLoop,
514  CGlobalMemoryDataOperation,
515  TailNum>(p_a_grid,
516  p_b_grid,
517  p_c_grid,
518  p_shared,
519  a_grid_desc_ak0_m_ak1,
520  b_grid_desc_bk0_n_bk1,
521  c_grid_desc_mblock_mperblock_nblock_nperblock,
522  block_m_id,
523  block_n_id,
524  num_k_block_per_scale,
525  b_scale_struct);
526  }
527 
528  // Wrapper function to have __global__ function in common
529  // between gemm_universal, b_scale, ab_scale, etc.
530  template <bool HasMainKBlockLoop,
531  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
532  TailNumber TailNum = TailNumber::Odd>
533  __device__ static void
534  Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, const Argument& karg)
535  {
536  Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
537  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
538  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
539  karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
540  p_shared,
541  karg);
542  }
543 };
544 
545 } // namespace ck
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
InMemoryDataOperationEnum
Definition: ck.hpp:276
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
unsigned int uint32_t
Definition: stdint.h:126
Definition: block_to_ctile_map.hpp:270
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:367
CDataType * p_c_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:399
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t k_batch_, bool is_reduce_=false)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:368
bool is_reduce
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:400
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:387
const ADataType * p_a_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:397
const BDataType * p_b_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:398
__host__ __device__ bool IsAtomicAdd() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:392
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:312
index_t M
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:348
index_t KPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:358
index_t NPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:356
index_t NBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:362
index_t K
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:350
__host__ void Print() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:338
index_t N
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:349
index_t AK0
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:359
index_t BK0
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:360
index_t KBatch
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:354
index_t MPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:355
index_t MBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:361
index_t StrideA
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:351
index_t StrideB
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:352
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:313
index_t StrideC
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:353
index_t KRead
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:357
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:404
index_t c_reduce_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:455
index_t b_k_split_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:454
index_t a_k_split_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:453
__device__ SplitKBatchOffset(Argument &karg)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:406
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:104
static constexpr auto I3
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:109
static constexpr auto BK0Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:117
static constexpr index_t APackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:128
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, decltype(MakeAWmmaTileDescriptor(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBWmmaTileDescriptor(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >())> BlockwiseGemmPipe
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:806
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:634
static constexpr auto I5
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:111
static constexpr index_t BPackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:135
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:496
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:185
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:173
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:197
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:162
__host__ static constexpr __device__ auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:770
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:152
static constexpr auto I0
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:106
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:167
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:223
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:157
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:307
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:192
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:809
static constexpr auto I1
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:107
static constexpr auto I4
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:110
static constexpr auto I7
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:113
static constexpr auto AK1Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:118
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:434
static constexpr auto BK1Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:119
static constexpr auto I2
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:108
static constexpr auto AK0Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:116
static constexpr auto I6
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:112
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:222
static constexpr index_t APackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:128
static constexpr index_t BPackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:135
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:185
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:173
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:197
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:162
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:152
static constexpr auto I0
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:106
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:167
typename Base::BlockwiseGemmPipe BlockwiseGemmPipe
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:458
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:223
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:157
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:306
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:307
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:192
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:809
static constexpr auto I1
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:107
static __device__ void Run(void *p_shared, const SplitKBatchOffset &splitk_batch_offset, const Argument &karg)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:534
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:434
static constexpr auto I2
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:108
static __device__ index_t GetKBlockPerScale()
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:465
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:470
Definition: device_base.hpp:51