/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp Source File
gridwise_gemm_wmma_cshuffle_v3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck/utility/env.hpp"
19 
20 namespace ck {
21 
28 // operations that could be applied on each tensor respectively. The CDE_op is an
29 // elementwise operation applied to the C and all D tensors.
129 template <typename ALayout,
130  typename BLayout,
131  typename DsLayout,
132  typename ELayout,
133  typename AsDataType,
134  typename BsDataType,
135  typename AccDataType,
136  typename CShuffleDataType,
137  typename DsDataType,
138  typename EDataType,
139  typename AElementwiseOperation,
140  typename BElementwiseOperation,
141  typename CDEElementwiseOperation,
143  index_t BlockSize,
144  index_t MPerBlock,
145  index_t NPerBlock,
146  index_t KPerBlock,
147  index_t AK1Value,
148  index_t BK1Value,
149  index_t MPerWmma,
150  index_t NPerWmma,
151  index_t MRepeat,
152  index_t NRepeat,
153  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
154  typename ABlockTransferThreadClusterArrangeOrder,
155  typename ABlockTransferSrcAccessOrder,
156  index_t ABlockTransferSrcVectorDim,
157  index_t ABlockTransferSrcScalarPerVector,
158  index_t ABlockTransferDstScalarPerVector_AK1,
159  bool AThreadTransferSrcResetCoordinateAfterRun,
160  index_t ABlockLdsExtraM,
161  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
162  typename BBlockTransferThreadClusterArrangeOrder,
163  typename BBlockTransferSrcAccessOrder,
164  index_t BBlockTransferSrcVectorDim,
165  index_t BBlockTransferSrcScalarPerVector,
166  index_t BBlockTransferDstScalarPerVector_BK1,
167  bool BThreadTransferSrcResetCoordinateAfterRun,
168  index_t BBlockLdsExtraN,
169  index_t CShuffleMRepeatPerShuffle,
170  index_t CShuffleNRepeatPerShuffle,
171  typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
172  typename CDEShuffleBlockTransferScalarPerVectors,
173  BlockGemmPipelineScheduler BlkGemmPipeSched,
174  BlockGemmPipelineVersion BlkGemmPipelineVer,
175  typename ComputeTypeA,
176  typename ComputeTypeB,
177  bool PermuteA,
178  bool PermuteB>
181  ALayout,
182  BLayout,
183  DsLayout,
184  ELayout,
185  AsDataType,
186  BsDataType,
187  AccDataType,
188  CShuffleDataType,
189  DsDataType,
190  EDataType,
191  AElementwiseOperation,
192  BElementwiseOperation,
193  CDEElementwiseOperation,
194  GemmSpec,
195  BlockSize,
196  MPerBlock,
197  NPerBlock,
198  KPerBlock,
199  AK1Value,
200  BK1Value,
201  MPerWmma,
202  NPerWmma,
203  MRepeat,
204  NRepeat,
205  ABlockTransferThreadClusterLengths_AK0_M_AK1,
206  ABlockTransferThreadClusterArrangeOrder,
207  ABlockTransferSrcAccessOrder,
208  ABlockTransferSrcVectorDim,
209  ABlockTransferSrcScalarPerVector,
210  ABlockTransferDstScalarPerVector_AK1,
211  AThreadTransferSrcResetCoordinateAfterRun,
212  ABlockLdsExtraM,
213  BBlockTransferThreadClusterLengths_BK0_N_BK1,
214  BBlockTransferThreadClusterArrangeOrder,
215  BBlockTransferSrcAccessOrder,
216  BBlockTransferSrcVectorDim,
217  BBlockTransferSrcScalarPerVector,
218  BBlockTransferDstScalarPerVector_BK1,
219  BThreadTransferSrcResetCoordinateAfterRun,
220  BBlockLdsExtraN,
221  CShuffleMRepeatPerShuffle,
222  CShuffleNRepeatPerShuffle,
223  CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
224  CDEShuffleBlockTransferScalarPerVectors,
225  BlkGemmPipeSched,
226  BlkGemmPipelineVer,
227  ComputeTypeA,
228  ComputeTypeB,
229  PermuteA,
230  PermuteB>
231 {
233  ALayout,
234  BLayout,
235  DsLayout,
236  ELayout,
237  AsDataType,
238  BsDataType,
239  AccDataType,
240  CShuffleDataType,
241  DsDataType,
242  EDataType,
243  AElementwiseOperation,
244  BElementwiseOperation,
245  CDEElementwiseOperation,
246  GemmSpec,
247  BlockSize,
248  MPerBlock,
249  NPerBlock,
250  KPerBlock,
251  AK1Value,
252  BK1Value,
253  MPerWmma,
254  NPerWmma,
255  MRepeat,
256  NRepeat,
257  ABlockTransferThreadClusterLengths_AK0_M_AK1,
258  ABlockTransferThreadClusterArrangeOrder,
259  ABlockTransferSrcAccessOrder,
260  ABlockTransferSrcVectorDim,
261  ABlockTransferSrcScalarPerVector,
262  ABlockTransferDstScalarPerVector_AK1,
263  AThreadTransferSrcResetCoordinateAfterRun,
264  ABlockLdsExtraM,
265  BBlockTransferThreadClusterLengths_BK0_N_BK1,
266  BBlockTransferThreadClusterArrangeOrder,
267  BBlockTransferSrcAccessOrder,
268  BBlockTransferSrcVectorDim,
269  BBlockTransferSrcScalarPerVector,
270  BBlockTransferDstScalarPerVector_BK1,
271  BThreadTransferSrcResetCoordinateAfterRun,
272  BBlockLdsExtraN,
273  CShuffleMRepeatPerShuffle,
274  CShuffleNRepeatPerShuffle,
275  CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
276  CDEShuffleBlockTransferScalarPerVectors,
277  BlkGemmPipeSched,
278  BlkGemmPipelineVer,
279  ComputeTypeA,
280  ComputeTypeB,
281  PermuteA,
282  PermuteB>;
283 
284  using Base::I0;
285  using Base::I1;
286  using Base::I2;
287  using Base::I3;
288  using Base::I4;
289  using Base::I5;
290  using Base::I6;
291  using Base::I7;
292 
293  using Base::AK0Number;
294  using Base::AK1Number;
295  using Base::BK0Number;
296  using Base::BK1Number;
297 
298  using Base::APackedSize;
299  using Base::BPackedSize;
300 
304  using Base::CalculateKRead;
305  using Base::CalculateMBlock;
307  using Base::CalculateNBlock;
314 
316 
318 
320 
323 
324  using Base::NumATensor;
325  using Base::NumBTensor;
326  using Base::NumDTensor;
327  using typename Base::AsGridPointer;
328  using typename Base::BsGridPointer;
329  using typename Base::DsGridPointer;
330  using AsDataType_ = AsDataType;
331  using BsDataType_ = BsDataType;
332 
333  struct Problem
334  {
335  __host__ Problem(index_t M_,
336  index_t N_,
337  index_t K_,
338  std::array<index_t, NumATensor> StrideAs_,
339  std::array<index_t, NumBTensor> StrideBs_,
340  std::array<index_t, NumDTensor> StrideDs_,
341  index_t StrideE_,
342  index_t KBatch_)
343  : M{M_},
344  N{N_},
345  K{K_},
346  StrideAs{StrideAs_},
347  StrideBs{StrideBs_},
348  StrideDs{StrideDs_},
349  StrideE{StrideE_},
350  KBatch{KBatch_},
353  KRead{CalculateKRead(K_, KBatch_)},
354  KPadded{CalculateKPadded(K_, KBatch_)},
355  AK0{CalculateAK0Padded(K_, KBatch_)},
356  BK0{CalculateBK0Padded(K_, KBatch_)},
357  MBlock{CalculateMBlock(M_)},
359  {
360  }
361 
362  __host__ void Print() const
363  {
364  std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
365  << "SAs: {";
366  static_for<0, NumATensor, 1>{}([&](auto i) {
367  std::cout << StrideAs[i] << (i.value < NumATensor - 1 ? ", " : "");
368  });
369  std::cout << "}, " << "SBs: {";
370  static_for<0, NumBTensor, 1>{}([&](auto i) {
371  std::cout << StrideBs[i] << (i.value < NumBTensor - 1 ? ", " : "");
372  });
373  std::cout << "}, ";
374  if constexpr(NumDTensor > 0)
375  {
376  std::cout << "SDs: { ";
377  static_for<0, NumDTensor, 1>{}([&](auto i) {
378  std::cout << StrideDs[i] << (i.value < NumDTensor - 1 ? ", " : "");
379  });
380  std::cout << " }, ";
381  }
382  std::cout << "SE:" << StrideE << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded
383  << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", "
384  << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock
385  << ", " << "NBlock: " << NBlock << "}" << std::endl;
386  }
387 
391  std::array<index_t, NumATensor> StrideAs;
392  std::array<index_t, NumBTensor> StrideBs;
393  std::array<index_t, NumDTensor> StrideDs;
404  };
405 
406  // Argument
408  {
409  __host__ Argument(std::array<const void*, NumATensor> p_as_grid_,
410  std::array<const void*, NumBTensor> p_bs_grid_,
411  std::array<const void*, NumDTensor> p_ds_grid_,
412  EDataType* p_e_grid_,
413  index_t M_,
414  index_t N_,
415  index_t K_,
416  std::array<index_t, NumATensor> StrideAs_,
417  std::array<index_t, NumBTensor> StrideBs_,
418  std::array<index_t, NumDTensor> StrideDs_,
419  index_t StrideE_,
420  index_t k_batch_,
421  AElementwiseOperation a_element_op_,
422  BElementwiseOperation b_element_op_,
423  CDEElementwiseOperation cde_element_op_,
424  bool is_reduce_ = false)
425  : Problem{M_, N_, K_, StrideAs_, StrideBs_, StrideDs_, StrideE_, k_batch_},
426  p_as_grid{},
427  p_bs_grid{},
428  p_ds_grid{},
429  p_e_grid{p_e_grid_},
430  a_element_op{a_element_op_},
431  b_element_op{b_element_op_},
432  cde_element_op{cde_element_op_},
433  is_reduce(is_reduce_)
434  {
435  // populate pointer, desc for As
436  static_for<0, NumATensor, 1>{}([&](auto i) {
437  using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
438 
439  // A pointer
440  p_as_grid(i) = static_cast<const ADataType_*>(p_as_grid_[i]);
441  });
442 
443  // populate pointer, desc for Bs
444  static_for<0, NumBTensor, 1>{}([&](auto i) {
445  using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
446 
447  // B pointer
448  p_bs_grid(i) = static_cast<const BDataType_*>(p_bs_grid_[i]);
449  });
450 
451  // populate pointer, desc for Ds
452  static_for<0, NumDTensor, 1>{}([&](auto i) {
453  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
454 
455  // D pointer
456  p_ds_grid(i) = static_cast<const DDataType*>(p_ds_grid_[i]);
457  });
458  }
459 
460  __host__ __device__ inline bool IsReduceAdd() const
461  {
462  return (Problem::KBatch > 1) && is_reduce;
463  }
464 
465  __host__ __device__ inline bool IsAtomicAdd() const
466  {
467  return (Problem::KBatch > 1) && (!is_reduce);
468  }
469 
473  EDataType* p_e_grid;
474 
475  const AElementwiseOperation a_element_op;
476  const BElementwiseOperation b_element_op;
477  const CDEElementwiseOperation cde_element_op;
478 
479  // TODO: it can be used with SplitK+reduction but currently only used with SplitK+atomicAdd
480  bool is_reduce;
481  };
482 
484  {
485 
486  __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
487  {
488  // Note: in xdl implementation multiple AB supports one layout
489  // but multiple strides, so we create an array of offsets with
490  // the same values.
491  // It should be fixed later on. Once we will have a thread transfer
492  // more flexible.
493  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
494  {
496  [&](auto i) { a_k_split_offset[i] = k_id * karg.KRead / APackedSize; });
497  }
498  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
499  {
501  [&](auto i) { a_k_split_offset[i] = k_id * karg.KRead * karg.StrideAs[i]; });
502  }
503 
504  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
505  {
507  [&](auto i) { b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i]; });
508  }
509  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
510  {
511  if constexpr(!PermuteB)
512  {
514  [&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; });
515  }
516  else
517  {
518  const int k0_offset = karg.KRead * karg.N;
520  [&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; });
521  }
522  }
523 
524  if(k_id < karg.KBatch - 1)
525  {
526  karg.K = karg.KRead;
527  }
528  else
529  {
530  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
531  }
532 
533  if(karg.IsReduceAdd())
534  {
535  c_reduce_offset = k_id * karg.M * karg.N;
536  }
537  else
538  {
539  c_reduce_offset = 0;
540  }
541  }
542 
543  std::array<index_t, NumATensor> a_k_split_offset;
544  std::array<index_t, NumBTensor> b_k_split_offset;
546  };
547 
549 
550  // return block_id to C matrix tile idx (m0, n0) mapping
551  // if arch = gfx942
553  // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
554 
555  __device__ static index_t GetKBlockPerScale() { return 1; }
556 
557  template <bool HasMainKBlockLoop,
558  InMemoryDataOperationEnum EGlobalMemoryDataOperation,
559  TailNumber TailNum>
560  __device__ static void Run(AsGridPointer& p_as_grid,
561  BsGridPointer& p_bs_grid,
562  DsGridPointer& p_ds_grid,
563  EDataType* p_e_grid,
564  void* p_shared,
565  const Problem& problem,
566  AElementwiseOperation a_element_op,
567  BElementwiseOperation b_element_op,
568  CDEElementwiseOperation cde_element_op)
569  {
570  const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
571  problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
572  const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
573  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
574  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
575  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
576  const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N<ELayout>(
577  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE);
578  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
580  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
581  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
583  e_grid_desc_m_n, problem.MBlock, problem.NBlock);
584 
585  // divide block work by [M, N]
586  const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
587 
588  const auto block_work_idx =
589  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
590 
591  if(!block_2_ctile_map.ValidCTileIndex(
592  block_work_idx,
593  make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
594  e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
595  {
596  return;
597  }
598 
599  const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
600  const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
601 
602  // BScale struct (Empty)
603  using BScale = typename BlockwiseGemmPipe::Empty;
604  auto b_scale_struct = BScale{};
605 
606  const index_t num_k_block_per_scale = GetKBlockPerScale();
607 
608  Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
609  decltype(bs_grid_desc_bk0_n_bk1),
610  decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
611  decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
612  decltype(b_scale_struct),
613  HasMainKBlockLoop,
614  EGlobalMemoryDataOperation,
615  TailNum>(p_as_grid,
616  p_bs_grid,
617  p_ds_grid,
618  p_e_grid,
619  p_shared,
620  as_grid_desc_ak0_m_ak1,
621  bs_grid_desc_bk0_n_bk1,
622  ds_grid_desc_mblock_mperblock_nblock_nperblock,
623  e_grid_desc_mblock_mperblock_nblock_nperblock,
624  a_element_op,
625  b_element_op,
626  cde_element_op,
627  block_m_id,
628  block_n_id,
629  num_k_block_per_scale,
630  b_scale_struct);
631  }
632 
633  // Wrapper function to have __global__ function in common
634  // between gemm_universal, b_scale, ab_scale, etc.
635  template <bool HasMainKBlockLoop,
636  InMemoryDataOperationEnum EGlobalMemoryDataOperation,
637  TailNumber TailNum>
638  __device__ static void
639  Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg)
640  {
641  // shift A matrices pointer for splitk
642  AsGridPointer p_as_grid_splitk;
643  static_for<0, NumATensor, 1>{}([&](auto i) {
644  using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
645  p_as_grid_splitk(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) +
646  splitk_batch_offset.a_k_split_offset[i];
647  });
648 
649  // shift B matrices pointer for splitk
650  BsGridPointer p_bs_grid_splitk;
651  static_for<0, NumBTensor, 1>{}([&](auto i) {
652  using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
653  p_bs_grid_splitk(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) +
654  splitk_batch_offset.b_k_split_offset[i];
655  });
656 
657  Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
658  p_as_grid_splitk,
659  p_bs_grid_splitk,
660  karg.p_ds_grid,
661  karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
662  p_shared,
663  karg,
664  karg.a_element_op,
665  karg.b_element_op,
666  karg.cde_element_op);
667  }
668 };
669 
670 } // namespace ck
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:268
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
InMemoryDataOperationEnum
Definition: ck.hpp:277
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:299
Definition: block_to_ctile_map.hpp:271
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:408
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:460
__host__ Argument(std::array< const void *, NumATensor > p_as_grid_, std::array< const void *, NumBTensor > p_bs_grid_, std::array< const void *, NumDTensor > p_ds_grid_, EDataType *p_e_grid_, index_t M_, index_t N_, index_t K_, std::array< index_t, NumATensor > StrideAs_, std::array< index_t, NumBTensor > StrideBs_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideE_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CDEElementwiseOperation cde_element_op_, bool is_reduce_=false)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:409
EDataType * p_e_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:473
AsGridPointer p_as_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:470
const CDEElementwiseOperation cde_element_op
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:477
__host__ __device__ bool IsAtomicAdd() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:465
DsGridPointer p_ds_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:472
const AElementwiseOperation a_element_op
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:475
BsGridPointer p_bs_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:471
bool is_reduce
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:480
const BElementwiseOperation b_element_op
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:476
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:334
index_t KPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:399
__host__ void Print() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:362
index_t KRead
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:398
index_t StrideE
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:394
std::array< index_t, NumATensor > StrideAs
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:391
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:393
index_t M
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:388
index_t AK0
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:400
std::array< index_t, NumBTensor > StrideBs
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:392
index_t NBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:403
index_t K
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:390
index_t BK0
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:401
index_t MBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:402
index_t MPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:396
__host__ Problem(index_t M_, index_t N_, index_t K_, std::array< index_t, NumATensor > StrideAs_, std::array< index_t, NumBTensor > StrideBs_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideE_, index_t KBatch_)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:335
index_t NPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:397
index_t N
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:389
index_t KBatch
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:395
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:484
std::array< index_t, NumATensor > a_k_split_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:543
index_t c_reduce_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:545
std::array< index_t, NumBTensor > b_k_split_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:544
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:486
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:112
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:220
static constexpr auto I6
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:120
static constexpr auto AK1Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:141
__host__ static __device__ auto MakeAsGridDescriptor_AK0_M_AK1(const index_t M, const index_t MPad, const index_t K, const index_t KPad, const std::array< index_t, NumATensor > &StrideAs, const index_t AK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:356
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:215
__host__ static constexpr __device__ auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:887
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:613
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:208
static constexpr auto BK0Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:140
static constexpr index_t NumDTensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:575
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:590
static constexpr index_t BPackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:158
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:175
__host__ static __device__ auto MakeDEGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideDE)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:513
static constexpr index_t NumBTensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:124
decltype(MakeAsGridPointer()) AsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:268
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:190
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:185
decltype(MakeBsGridPointer()) BsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:269
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:196
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:180
static constexpr auto I5
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:119
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:751
static constexpr index_t APackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:151
static constexpr auto I0
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:114
static constexpr auto I7
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:121
static constexpr auto I1
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:115
static constexpr index_t NumATensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:123
static constexpr auto AK0Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:139
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, LDSTypeA, LDSTypeB, ComputeTypeA, ComputeTypeB, AccDataType, decltype(MakeAWmmaTileDescriptor(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBWmmaTileDescriptor(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >())> BlockwiseGemmPipe
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:923
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:588
static constexpr __device__ auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc &de_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:926
static constexpr auto I2
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:116
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:602
static constexpr auto BK1Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:142
static constexpr auto I4
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:118
__host__ static __device__ auto MakeBsGridDescriptor_BK0_N_BK1(const index_t K, const index_t KPad, const index_t N, const index_t NPad, const std::array< index_t, NumBTensor > &StrideBs, const index_t BK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:481
static constexpr auto I3
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:117
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:231
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:220
__host__ static __device__ auto MakeAsGridDescriptor_AK0_M_AK1(const index_t M, const index_t MPad, const index_t K, const index_t KPad, const std::array< index_t, NumATensor > &StrideAs, const index_t AK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:356
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:215
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:208
static __device__ void Run(void *p_shared, const SplitKBatchOffset &splitk_batch_offset, Argument &karg)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:639
typename Base::BlockwiseGemmPipe BlockwiseGemmPipe
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:548
static constexpr index_t NumDTensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:575
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:590
static constexpr index_t BPackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:158
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:175
static constexpr index_t NumBTensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:124
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:190
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:185
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:196
BsDataType BsDataType_
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:331
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:180
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:319
AsDataType AsDataType_
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:330
static __device__ index_t GetKBlockPerScale()
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:555
static constexpr index_t APackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:151
static constexpr auto I0
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:114
static __device__ void Run(AsGridPointer &p_as_grid, BsGridPointer &p_bs_grid, DsGridPointer &p_ds_grid, EDataType *p_e_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:560
static constexpr auto I1
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:115
static constexpr index_t NumATensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:123
static constexpr __device__ auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc &de_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:926
static constexpr auto I2
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:116
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:602
__host__ static __device__ auto MakeBsGridDescriptor_BK0_N_BK1(const index_t K, const index_t KPad, const index_t N, const index_t NPad, const std::array< index_t, NumBTensor > &StrideBs, const index_t BK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:481
Definition: functional2.hpp:33
Definition: device_base.hpp:197