/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp Source File
gridwise_gemm_wmma_cshuffle_v3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck/utility/env.hpp"
19 
20 namespace ck {
21 
28 // operations that could be applied on each tensor respectively. The CDE_op is an
29 // elementwise operation applied to the C and all D tensors.
129 template <typename ALayout,
130  typename BLayout,
131  typename DsLayout,
132  typename ELayout,
133  typename AsDataType,
134  typename BsDataType,
135  typename AccDataType,
136  typename CShuffleDataType,
137  typename DsDataType,
138  typename EDataType,
139  typename AElementwiseOperation,
140  typename BElementwiseOperation,
141  typename CDEElementwiseOperation,
143  index_t BlockSize,
144  index_t MPerBlock,
145  index_t NPerBlock,
146  index_t KPerBlock,
147  index_t AK1Value,
148  index_t BK1Value,
149  index_t MPerWmma,
150  index_t NPerWmma,
151  index_t MRepeat,
152  index_t NRepeat,
153  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
154  typename ABlockTransferThreadClusterArrangeOrder,
155  typename ABlockTransferSrcAccessOrder,
156  index_t ABlockTransferSrcVectorDim,
157  index_t ABlockTransferSrcScalarPerVector,
158  index_t ABlockTransferDstScalarPerVector_AK1,
159  bool AThreadTransferSrcResetCoordinateAfterRun,
160  index_t ABlockLdsExtraM,
161  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
162  typename BBlockTransferThreadClusterArrangeOrder,
163  typename BBlockTransferSrcAccessOrder,
164  index_t BBlockTransferSrcVectorDim,
165  index_t BBlockTransferSrcScalarPerVector,
166  index_t BBlockTransferDstScalarPerVector_BK1,
167  bool BThreadTransferSrcResetCoordinateAfterRun,
168  index_t BBlockLdsExtraN,
169  index_t CShuffleMRepeatPerShuffle,
170  index_t CShuffleNRepeatPerShuffle,
171  typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
172  typename CDEShuffleBlockTransferScalarPerVectors,
173  BlockGemmPipelineScheduler BlkGemmPipeSched,
174  BlockGemmPipelineVersion BlkGemmPipelineVer,
175  typename ComputeTypeA,
176  typename ComputeTypeB,
177  bool PermuteA,
178  bool PermuteB,
179  bool ForceThreadTileTransfer = false>
182  ALayout,
183  BLayout,
184  DsLayout,
185  ELayout,
186  AsDataType,
187  BsDataType,
188  AccDataType,
189  CShuffleDataType,
190  DsDataType,
191  EDataType,
192  AElementwiseOperation,
193  BElementwiseOperation,
194  CDEElementwiseOperation,
195  GemmSpec,
196  BlockSize,
197  MPerBlock,
198  NPerBlock,
199  KPerBlock,
200  AK1Value,
201  BK1Value,
202  MPerWmma,
203  NPerWmma,
204  MRepeat,
205  NRepeat,
206  ABlockTransferThreadClusterLengths_AK0_M_AK1,
207  ABlockTransferThreadClusterArrangeOrder,
208  ABlockTransferSrcAccessOrder,
209  ABlockTransferSrcVectorDim,
210  ABlockTransferSrcScalarPerVector,
211  ABlockTransferDstScalarPerVector_AK1,
212  AThreadTransferSrcResetCoordinateAfterRun,
213  ABlockLdsExtraM,
214  BBlockTransferThreadClusterLengths_BK0_N_BK1,
215  BBlockTransferThreadClusterArrangeOrder,
216  BBlockTransferSrcAccessOrder,
217  BBlockTransferSrcVectorDim,
218  BBlockTransferSrcScalarPerVector,
219  BBlockTransferDstScalarPerVector_BK1,
220  BThreadTransferSrcResetCoordinateAfterRun,
221  BBlockLdsExtraN,
222  CShuffleMRepeatPerShuffle,
223  CShuffleNRepeatPerShuffle,
224  CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
225  CDEShuffleBlockTransferScalarPerVectors,
226  BlkGemmPipeSched,
227  BlkGemmPipelineVer,
228  ComputeTypeA,
229  ComputeTypeB,
230  PermuteA,
231  PermuteB,
232  ForceThreadTileTransfer>
233 {
235  ALayout,
236  BLayout,
237  DsLayout,
238  ELayout,
239  AsDataType,
240  BsDataType,
241  AccDataType,
242  CShuffleDataType,
243  DsDataType,
244  EDataType,
245  AElementwiseOperation,
246  BElementwiseOperation,
247  CDEElementwiseOperation,
248  GemmSpec,
249  BlockSize,
250  MPerBlock,
251  NPerBlock,
252  KPerBlock,
253  AK1Value,
254  BK1Value,
255  MPerWmma,
256  NPerWmma,
257  MRepeat,
258  NRepeat,
259  ABlockTransferThreadClusterLengths_AK0_M_AK1,
260  ABlockTransferThreadClusterArrangeOrder,
261  ABlockTransferSrcAccessOrder,
262  ABlockTransferSrcVectorDim,
263  ABlockTransferSrcScalarPerVector,
264  ABlockTransferDstScalarPerVector_AK1,
265  AThreadTransferSrcResetCoordinateAfterRun,
266  ABlockLdsExtraM,
267  BBlockTransferThreadClusterLengths_BK0_N_BK1,
268  BBlockTransferThreadClusterArrangeOrder,
269  BBlockTransferSrcAccessOrder,
270  BBlockTransferSrcVectorDim,
271  BBlockTransferSrcScalarPerVector,
272  BBlockTransferDstScalarPerVector_BK1,
273  BThreadTransferSrcResetCoordinateAfterRun,
274  BBlockLdsExtraN,
275  CShuffleMRepeatPerShuffle,
276  CShuffleNRepeatPerShuffle,
277  CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
278  CDEShuffleBlockTransferScalarPerVectors,
279  BlkGemmPipeSched,
280  BlkGemmPipelineVer,
281  ComputeTypeA,
282  ComputeTypeB,
283  PermuteA,
284  PermuteB,
285  ForceThreadTileTransfer>;
286 
287  using Base::I0;
288  using Base::I1;
289  using Base::I2;
290  using Base::I3;
291  using Base::I4;
292  using Base::I5;
293  using Base::I6;
294  using Base::I7;
295 
296  using Base::AK0Number;
297  using Base::AK1Number;
298  using Base::BK0Number;
299  using Base::BK1Number;
300 
301  using Base::APackedSize;
302  using Base::BPackedSize;
303 
307  using Base::CalculateKRead;
308  using Base::CalculateMBlock;
310  using Base::CalculateNBlock;
317 
319 
321 
322  using Base::NumATensor;
323  using Base::NumBTensor;
324  using Base::NumDTensor;
325  using typename Base::AsGridPointer;
326  using typename Base::BsGridPointer;
327  using typename Base::DsGridPointer;
328  using AsDataType_ = AsDataType;
329  using BsDataType_ = BsDataType;
330 
331  struct Problem
332  {
333  __host__ Problem(index_t M_,
334  index_t N_,
335  index_t K_,
336  std::array<index_t, NumATensor> StrideAs_,
337  std::array<index_t, NumBTensor> StrideBs_,
338  std::array<index_t, NumDTensor> StrideDs_,
339  index_t StrideE_,
340  index_t KBatch_)
341  : M{M_},
342  N{N_},
343  K{K_},
344  StrideAs{StrideAs_},
345  StrideBs{StrideBs_},
346  StrideDs{StrideDs_},
347  StrideE{StrideE_},
348  KBatch{KBatch_},
351  KRead{CalculateKRead(K_, KBatch_)},
352  KPadded{CalculateKPadded(K_, KBatch_)},
353  AK0{CalculateAK0Padded(K_, KBatch_)},
354  BK0{CalculateBK0Padded(K_, KBatch_)},
355  MBlock{CalculateMBlock(M_)},
357  {
358  }
359 
360  __host__ void Print() const
361  {
362  std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
363  << "SAs: {";
364  static_for<0, NumATensor, 1>{}([&](auto i) {
365  std::cout << StrideAs[i] << (i.value < NumATensor - 1 ? ", " : "");
366  });
367  std::cout << "}, " << "SBs: {";
368  static_for<0, NumBTensor, 1>{}([&](auto i) {
369  std::cout << StrideBs[i] << (i.value < NumBTensor - 1 ? ", " : "");
370  });
371  std::cout << "}, ";
372  if constexpr(NumDTensor > 0)
373  {
374  std::cout << "SDs: { ";
375  static_for<0, NumDTensor, 1>{}([&](auto i) {
376  std::cout << StrideDs[i] << (i.value < NumDTensor - 1 ? ", " : "");
377  });
378  std::cout << " }, ";
379  }
380  std::cout << "SE:" << StrideE << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded
381  << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", "
382  << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock
383  << ", " << "NBlock: " << NBlock << "}" << std::endl;
384  }
385 
389  std::array<index_t, NumATensor> StrideAs;
390  std::array<index_t, NumBTensor> StrideBs;
391  std::array<index_t, NumDTensor> StrideDs;
402  };
403 
404  // Argument
406  {
407  __host__ Argument(std::array<const void*, NumATensor> p_as_grid_,
408  std::array<const void*, NumBTensor> p_bs_grid_,
409  std::array<const void*, NumDTensor> p_ds_grid_,
410  EDataType* p_e_grid_,
411  index_t M_,
412  index_t N_,
413  index_t K_,
414  std::array<index_t, NumATensor> StrideAs_,
415  std::array<index_t, NumBTensor> StrideBs_,
416  std::array<index_t, NumDTensor> StrideDs_,
417  index_t StrideE_,
418  index_t k_batch_,
419  AElementwiseOperation a_element_op_,
420  BElementwiseOperation b_element_op_,
421  CDEElementwiseOperation cde_element_op_,
422  bool is_reduce_ = false)
423  : Problem{M_, N_, K_, StrideAs_, StrideBs_, StrideDs_, StrideE_, k_batch_},
424  p_as_grid{},
425  p_bs_grid{},
426  p_ds_grid{},
427  p_e_grid{p_e_grid_},
428  a_element_op{a_element_op_},
429  b_element_op{b_element_op_},
430  cde_element_op{cde_element_op_},
431  is_reduce(is_reduce_)
432  {
433  // populate pointer, desc for As
434  static_for<0, NumATensor, 1>{}([&](auto i) {
435  using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
436 
437  // A pointer
438  p_as_grid(i) = static_cast<const ADataType_*>(p_as_grid_[i]);
439  });
440 
441  // populate pointer, desc for Bs
442  static_for<0, NumBTensor, 1>{}([&](auto i) {
443  using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
444 
445  // B pointer
446  p_bs_grid(i) = static_cast<const BDataType_*>(p_bs_grid_[i]);
447  });
448 
449  // populate pointer, desc for Ds
450  static_for<0, NumDTensor, 1>{}([&](auto i) {
451  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
452 
453  // D pointer
454  p_ds_grid(i) = static_cast<const DDataType*>(p_ds_grid_[i]);
455  });
456  }
457 
458  __host__ __device__ inline bool IsReduceAdd() const
459  {
460  return (Problem::KBatch > 1) && is_reduce;
461  }
462 
463  __host__ __device__ inline bool IsAtomicAdd() const
464  {
465  return (Problem::KBatch > 1) && (!is_reduce);
466  }
467 
471  EDataType* p_e_grid;
472 
473  const AElementwiseOperation a_element_op;
474  const BElementwiseOperation b_element_op;
475  const CDEElementwiseOperation cde_element_op;
476 
477  // TODO: it can be used with SplitK+reduction but currently only used with SplitK+atomicAdd
478  bool is_reduce;
479  };
480 
482  {
483 
484  __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
485  {
486  // Note: in xdl implementation multiple AB supports one layout
487  // but multiple strides, so we create an array of offsets with
488  // the same values.
489  // It should be fixed later on. Once we will have a thread transfer
490  // more flexible.
491  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
492  {
494  [&](auto i) { a_k_split_offset[i] = k_id * karg.KRead / APackedSize; });
495  }
496  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
497  {
499  [&](auto i) { a_k_split_offset[i] = k_id * karg.KRead * karg.StrideAs[i]; });
500  }
501 
502  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
503  {
505  [&](auto i) { b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i]; });
506  }
507  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
508  {
509  if constexpr(!PermuteB)
510  {
512  [&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; });
513  }
514  else
515  {
516  const int k0_offset = karg.KRead * karg.N;
518  [&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; });
519  }
520  }
521 
522  if(k_id < karg.KBatch - 1)
523  {
524  karg.K = karg.KRead;
525  }
526  else
527  {
528  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
529  }
530 
531  if(karg.IsReduceAdd())
532  {
533  c_reduce_offset = k_id * karg.M * karg.N;
534  }
535  else
536  {
537  c_reduce_offset = 0;
538  }
539  }
540 
541  std::array<index_t, NumATensor> a_k_split_offset;
542  std::array<index_t, NumBTensor> b_k_split_offset;
544  };
545 
547 
548  // return block_id to C matrix tile idx (m0, n0) mapping
549  // if arch = gfx942
551  // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
552 
553  __device__ static index_t GetKBlockPerScale() { return 1; }
554 
555  template <bool HasMainKBlockLoop,
556  InMemoryDataOperationEnum EGlobalMemoryDataOperation,
557  TailNumber TailNum,
558  typename EpilogueArgument>
559  __device__ static void Run(AsGridPointer& p_as_grid,
560  BsGridPointer& p_bs_grid,
561  DsGridPointer& p_ds_grid,
562  EDataType* p_e_grid,
563  void* p_shared,
564  const Problem& problem,
565  AElementwiseOperation a_element_op,
566  BElementwiseOperation b_element_op,
567  CDEElementwiseOperation cde_element_op,
568  EpilogueArgument& epilogue_args)
569  {
570  const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
571  problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
572  const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
573  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
574  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
575  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
576  const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N<ELayout>(
577  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE);
578  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
580  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
581  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
583  e_grid_desc_m_n, problem.MBlock, problem.NBlock);
584 
585  // divide block work by [M, N]
586  const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
587 
588  const auto block_work_idx =
589  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
590 
591  if(!block_2_ctile_map.ValidCTileIndex(
592  block_work_idx,
593  make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
594  e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
595  {
596  return;
597  }
598 
599  const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
600  const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
601 
602  // BScale struct (Empty)
603  using BScale = typename BlockwiseGemmPipe::Empty;
604  auto b_scale_struct = BScale{};
605 
606  const index_t num_k_block_per_scale = GetKBlockPerScale();
607 
608  Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
609  decltype(bs_grid_desc_bk0_n_bk1),
610  decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
611  decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
612  decltype(b_scale_struct),
613  decltype(epilogue_args),
614  HasMainKBlockLoop,
615  EGlobalMemoryDataOperation,
616  TailNum>(p_as_grid,
617  p_bs_grid,
618  p_ds_grid,
619  p_e_grid,
620  p_shared,
621  as_grid_desc_ak0_m_ak1,
622  bs_grid_desc_bk0_n_bk1,
623  ds_grid_desc_mblock_mperblock_nblock_nperblock,
624  e_grid_desc_mblock_mperblock_nblock_nperblock,
625  a_element_op,
626  b_element_op,
627  cde_element_op,
628  block_m_id,
629  block_n_id,
630  num_k_block_per_scale,
631  b_scale_struct,
632  epilogue_args);
633  }
634 
635  // Wrapper function to have __global__ function in common
636  // between gemm_universal, b_scale, ab_scale, etc.
637  template <bool HasMainKBlockLoop,
638  InMemoryDataOperationEnum EGlobalMemoryDataOperation,
639  TailNumber TailNum,
640  typename EpilogueArgument>
641  __device__ static void Run(void* p_shared,
642  const SplitKBatchOffset& splitk_batch_offset,
643  Argument& karg,
644  EpilogueArgument& epilogue_args)
645  {
646  // shift A matrices pointer for splitk
647  AsGridPointer p_as_grid_splitk;
648  static_for<0, NumATensor, 1>{}([&](auto i) {
649  using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
650  p_as_grid_splitk(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) +
651  splitk_batch_offset.a_k_split_offset[i];
652  });
653 
654  // shift B matrices pointer for splitk
655  BsGridPointer p_bs_grid_splitk;
656  static_for<0, NumBTensor, 1>{}([&](auto i) {
657  using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
658  p_bs_grid_splitk(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) +
659  splitk_batch_offset.b_k_split_offset[i];
660  });
661 
662  Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
663  p_as_grid_splitk,
664  p_bs_grid_splitk,
665  karg.p_ds_grid,
666  karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
667  p_shared,
668  karg,
669  karg.a_element_op,
670  karg.b_element_op,
671  karg.cde_element_op,
672  epilogue_args);
673  }
674 };
675 
676 } // namespace ck
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:268
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
InMemoryDataOperationEnum
Definition: ck.hpp:277
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:299
Definition: block_to_ctile_map.hpp:271
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:406
const BElementwiseOperation b_element_op
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:474
__host__ Argument(std::array< const void *, NumATensor > p_as_grid_, std::array< const void *, NumBTensor > p_bs_grid_, std::array< const void *, NumDTensor > p_ds_grid_, EDataType *p_e_grid_, index_t M_, index_t N_, index_t K_, std::array< index_t, NumATensor > StrideAs_, std::array< index_t, NumBTensor > StrideBs_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideE_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CDEElementwiseOperation cde_element_op_, bool is_reduce_=false)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:407
const AElementwiseOperation a_element_op
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:473
bool is_reduce
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:478
EDataType * p_e_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:471
const CDEElementwiseOperation cde_element_op
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:475
__host__ __device__ bool IsAtomicAdd() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:463
BsGridPointer p_bs_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:469
AsGridPointer p_as_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:468
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:458
DsGridPointer p_ds_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:470
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:332
std::array< index_t, NumBTensor > StrideBs
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:390
index_t AK0
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:398
index_t N
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:387
index_t K
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:388
index_t BK0
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:399
index_t NPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:395
index_t KPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:397
index_t StrideE
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:392
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:391
index_t NBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:401
std::array< index_t, NumATensor > StrideAs
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:389
__host__ void Print() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:360
index_t KBatch
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:393
index_t M
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:386
__host__ Problem(index_t M_, index_t N_, index_t K_, std::array< index_t, NumATensor > StrideAs_, std::array< index_t, NumBTensor > StrideBs_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideE_, index_t KBatch_)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:333
index_t MBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:400
index_t MPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:394
index_t KRead
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:396
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:482
index_t c_reduce_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:543
std::array< index_t, NumATensor > a_k_split_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:541
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:484
std::array< index_t, NumBTensor > b_k_split_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:542
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:122
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:288
static constexpr auto I2
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:126
static constexpr __device__ auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc &de_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:609
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:535
static constexpr auto I3
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:127
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:323
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:293
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:318
decltype(MakeAsGridPointer()) AsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:350
static constexpr auto I1
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:125
static constexpr index_t NumATensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:133
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:311
static constexpr auto AK1Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:151
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:521
__host__ static __device__ auto MakeBsGridDescriptor_BK0_N_BK1(const index_t K, const index_t KPad, const index_t N, const index_t NPad, const std::array< index_t, NumBTensor > &StrideBs, const index_t BK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:405
__host__ static __device__ auto MakeDEGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideDE)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:446
__host__ static __device__ auto MakeAsGridDescriptor_AK0_M_AK1(const index_t M, const index_t MPad, const index_t K, const index_t KPad, const std::array< index_t, NumATensor > &StrideAs, const index_t AK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:378
static constexpr auto I6
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:130
static constexpr auto AK0Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:149
static constexpr index_t NumBTensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:134
static constexpr auto I0
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:124
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:523
static constexpr index_t APackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:161
static constexpr auto I7
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:131
static constexpr auto I4
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:128
static constexpr index_t BPackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:168
static constexpr auto BK1Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:152
static constexpr auto BK0Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:150
decltype(MakeBsGridPointer()) BsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:351
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:278
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, LDSTypeA, LDSTypeB, ComputeTypeA, ComputeTypeB, AccDataType, decltype(MakeAWmmaTileDescriptor()), decltype(MakeBWmmaTileDescriptor()), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >())> BlockwiseGemmPipe
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:566
static constexpr index_t NumDTensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:508
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:283
static constexpr auto I5
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:129
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:299
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:233
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:288
static constexpr auto I2
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:126
static constexpr __device__ auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc &de_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:609
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:535
AsDataType AsDataType_
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:328
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:323
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:293
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:318
decltype(MakeAsGridPointer()) AsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:350
static constexpr auto I1
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:125
static constexpr index_t NumATensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:133
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:311
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:521
__host__ static __device__ auto MakeBsGridDescriptor_BK0_N_BK1(const index_t K, const index_t KPad, const index_t N, const index_t NPad, const std::array< index_t, NumBTensor > &StrideBs, const index_t BK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:405
static __device__ void Run(void *p_shared, const SplitKBatchOffset &splitk_batch_offset, Argument &karg, EpilogueArgument &epilogue_args)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:641
__host__ static __device__ auto MakeAsGridDescriptor_AK0_M_AK1(const index_t M, const index_t MPad, const index_t K, const index_t KPad, const std::array< index_t, NumATensor > &StrideAs, const index_t AK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:378
static constexpr index_t NumBTensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:134
static constexpr auto I0
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:124
BsDataType BsDataType_
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:329
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:523
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:320
static constexpr index_t APackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:161
static constexpr index_t BPackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:168
static __device__ void Run(AsGridPointer &p_as_grid, BsGridPointer &p_bs_grid, DsGridPointer &p_ds_grid, EDataType *p_e_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, EpilogueArgument &epilogue_args)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:559
static __device__ index_t GetKBlockPerScale()
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:553
decltype(MakeBsGridPointer()) BsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:351
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:278
static constexpr index_t NumDTensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:508
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:283
typename Base::BlockwiseGemmPipe BlockwiseGemmPipe
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:546
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:299
Definition: functional2.hpp:33
Definition: device_base.hpp:197