/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp Source File
gridwise_gemm_wmma_cshuffle_v3.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck/utility/env.hpp"
19 
20 namespace ck {
21 
28 // operations that could be applied on each tensor respectively. The CDE_op is an
29 // elementwise operation applied to the C and all D tensors.
129 template <typename ALayout,
130  typename BLayout,
131  typename DsLayout,
132  typename ELayout,
133  typename AsDataType,
134  typename BsDataType,
135  typename AccDataType,
136  typename CShuffleDataType,
137  typename DsDataType,
138  typename EDataType,
139  typename AElementwiseOperation,
140  typename BElementwiseOperation,
141  typename CDEElementwiseOperation,
143  index_t BlockSize,
144  index_t MPerBlock,
145  index_t NPerBlock,
146  index_t KPerBlock,
147  index_t AK1Value,
148  index_t BK1Value,
149  index_t MPerWmma,
150  index_t NPerWmma,
151  index_t MRepeat,
152  index_t NRepeat,
153  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
154  typename ABlockTransferThreadClusterArrangeOrder,
155  typename ABlockTransferSrcAccessOrder,
156  index_t ABlockTransferSrcVectorDim,
157  index_t ABlockTransferSrcScalarPerVector,
158  index_t ABlockTransferDstScalarPerVector_AK1,
159  bool AThreadTransferSrcResetCoordinateAfterRun,
160  index_t ABlockLdsExtraM,
161  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
162  typename BBlockTransferThreadClusterArrangeOrder,
163  typename BBlockTransferSrcAccessOrder,
164  index_t BBlockTransferSrcVectorDim,
165  index_t BBlockTransferSrcScalarPerVector,
166  index_t BBlockTransferDstScalarPerVector_BK1,
167  bool BThreadTransferSrcResetCoordinateAfterRun,
168  index_t BBlockLdsExtraN,
169  index_t CShuffleMRepeatPerShuffle,
170  index_t CShuffleNRepeatPerShuffle,
171  typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
172  typename CDEShuffleBlockTransferScalarPerVectors,
173  BlockGemmPipelineScheduler BlkGemmPipeSched,
174  BlockGemmPipelineVersion BlkGemmPipelineVer,
175  typename ComputeTypeA,
176  typename ComputeTypeB,
177  bool PermuteA,
178  bool PermuteB,
179  bool IsBPreShuffled = false,
180  bool ForceThreadTileTransfer = false>
183  ALayout,
184  BLayout,
185  DsLayout,
186  ELayout,
187  AsDataType,
188  BsDataType,
189  AccDataType,
190  CShuffleDataType,
191  DsDataType,
192  EDataType,
193  AElementwiseOperation,
194  BElementwiseOperation,
195  CDEElementwiseOperation,
196  GemmSpec,
197  BlockSize,
198  MPerBlock,
199  NPerBlock,
200  KPerBlock,
201  AK1Value,
202  BK1Value,
203  MPerWmma,
204  NPerWmma,
205  MRepeat,
206  NRepeat,
207  ABlockTransferThreadClusterLengths_AK0_M_AK1,
208  ABlockTransferThreadClusterArrangeOrder,
209  ABlockTransferSrcAccessOrder,
210  ABlockTransferSrcVectorDim,
211  ABlockTransferSrcScalarPerVector,
212  ABlockTransferDstScalarPerVector_AK1,
213  AThreadTransferSrcResetCoordinateAfterRun,
214  ABlockLdsExtraM,
215  BBlockTransferThreadClusterLengths_BK0_N_BK1,
216  BBlockTransferThreadClusterArrangeOrder,
217  BBlockTransferSrcAccessOrder,
218  BBlockTransferSrcVectorDim,
219  BBlockTransferSrcScalarPerVector,
220  BBlockTransferDstScalarPerVector_BK1,
221  BThreadTransferSrcResetCoordinateAfterRun,
222  BBlockLdsExtraN,
223  CShuffleMRepeatPerShuffle,
224  CShuffleNRepeatPerShuffle,
225  CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
226  CDEShuffleBlockTransferScalarPerVectors,
227  BlkGemmPipeSched,
228  BlkGemmPipelineVer,
229  ComputeTypeA,
230  ComputeTypeB,
231  PermuteA,
232  PermuteB,
233  IsBPreShuffled,
234  ForceThreadTileTransfer>
235 {
237  ALayout,
238  BLayout,
239  DsLayout,
240  ELayout,
241  AsDataType,
242  BsDataType,
243  AccDataType,
244  CShuffleDataType,
245  DsDataType,
246  EDataType,
247  AElementwiseOperation,
248  BElementwiseOperation,
249  CDEElementwiseOperation,
250  GemmSpec,
251  BlockSize,
252  MPerBlock,
253  NPerBlock,
254  KPerBlock,
255  AK1Value,
256  BK1Value,
257  MPerWmma,
258  NPerWmma,
259  MRepeat,
260  NRepeat,
261  ABlockTransferThreadClusterLengths_AK0_M_AK1,
262  ABlockTransferThreadClusterArrangeOrder,
263  ABlockTransferSrcAccessOrder,
264  ABlockTransferSrcVectorDim,
265  ABlockTransferSrcScalarPerVector,
266  ABlockTransferDstScalarPerVector_AK1,
267  AThreadTransferSrcResetCoordinateAfterRun,
268  ABlockLdsExtraM,
269  BBlockTransferThreadClusterLengths_BK0_N_BK1,
270  BBlockTransferThreadClusterArrangeOrder,
271  BBlockTransferSrcAccessOrder,
272  BBlockTransferSrcVectorDim,
273  BBlockTransferSrcScalarPerVector,
274  BBlockTransferDstScalarPerVector_BK1,
275  BThreadTransferSrcResetCoordinateAfterRun,
276  BBlockLdsExtraN,
277  CShuffleMRepeatPerShuffle,
278  CShuffleNRepeatPerShuffle,
279  CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
280  CDEShuffleBlockTransferScalarPerVectors,
281  BlkGemmPipeSched,
282  BlkGemmPipelineVer,
283  ComputeTypeA,
284  ComputeTypeB,
285  PermuteA,
286  PermuteB,
287  IsBPreShuffled,
288  ForceThreadTileTransfer>;
289 
290  using Base::I0;
291  using Base::I1;
292  using Base::I2;
293  using Base::I3;
294  using Base::I4;
295  using Base::I5;
296  using Base::I6;
297  using Base::I7;
298 
299  using Base::AK0Number;
300  using Base::AK1Number;
301  using Base::BK0Number;
302  using Base::BK1Number;
303 
304  using Base::APackedSize;
305  using Base::BPackedSize;
306 
310  using Base::CalculateKRead;
311  using Base::CalculateMBlock;
313  using Base::CalculateNBlock;
320 
322 
324 
325  using Base::NumATensor;
326  using Base::NumBTensor;
327  using Base::NumDTensor;
328  using typename Base::AsGridPointer;
329  using typename Base::BsGridPointer;
330  using typename Base::DsGridPointer;
331  using AsDataType_ = AsDataType;
332  using BsDataType_ = BsDataType;
333 
334  struct Problem
335  {
336  __host__ Problem(index_t M_,
337  index_t N_,
338  index_t K_,
339  std::array<index_t, NumATensor> StrideAs_,
340  std::array<index_t, NumBTensor> StrideBs_,
341  std::array<index_t, NumDTensor> StrideDs_,
342  index_t StrideE_,
343  index_t KBatch_)
344  : M{M_},
345  N{N_},
346  K{K_},
347  StrideAs{StrideAs_},
348  StrideBs{StrideBs_},
349  StrideDs{StrideDs_},
350  StrideE{StrideE_},
351  KBatch{KBatch_},
354  KRead{CalculateKRead(K_, KBatch_)},
355  KPadded{CalculateKPadded(K_, KBatch_)},
356  AK0{CalculateAK0Padded(K_, KBatch_)},
357  BK0{CalculateBK0Padded(K_, KBatch_)},
358  MBlock{CalculateMBlock(M_)},
359  NBlock{CalculateNBlock(N_)},
360  Kt{K_}
361  {
362  }
363 
364  __host__ void Print() const
365  {
366  std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
367  << "SAs: {";
368  static_for<0, NumATensor, 1>{}([&](auto i) {
369  std::cout << StrideAs[i] << (i.value < NumATensor - 1 ? ", " : "");
370  });
371  std::cout << "}, " << "SBs: {";
372  static_for<0, NumBTensor, 1>{}([&](auto i) {
373  std::cout << StrideBs[i] << (i.value < NumBTensor - 1 ? ", " : "");
374  });
375  std::cout << "}, ";
376  if constexpr(NumDTensor > 0)
377  {
378  std::cout << "SDs: { ";
379  static_for<0, NumDTensor, 1>{}([&](auto i) {
380  std::cout << StrideDs[i] << (i.value < NumDTensor - 1 ? ", " : "");
381  });
382  std::cout << " }, ";
383  }
384  std::cout << "SE:" << StrideE << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded
385  << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", "
386  << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock
387  << ", " << "NBlock: " << NBlock << "}" << std::endl;
388  }
389 
393  std::array<index_t, NumATensor> StrideAs;
394  std::array<index_t, NumBTensor> StrideBs;
395  std::array<index_t, NumDTensor> StrideDs;
407  };
408 
409  // Argument
411  {
412  __host__ Argument(std::array<const void*, NumATensor> p_as_grid_,
413  std::array<const void*, NumBTensor> p_bs_grid_,
414  std::array<const void*, NumDTensor> p_ds_grid_,
415  EDataType* p_e_grid_,
416  index_t M_,
417  index_t N_,
418  index_t K_,
419  std::array<index_t, NumATensor> StrideAs_,
420  std::array<index_t, NumBTensor> StrideBs_,
421  std::array<index_t, NumDTensor> StrideDs_,
422  index_t StrideE_,
423  index_t k_batch_,
424  AElementwiseOperation a_element_op_,
425  BElementwiseOperation b_element_op_,
426  CDEElementwiseOperation cde_element_op_,
427  bool is_reduce_ = false)
428  : Problem{M_, N_, K_, StrideAs_, StrideBs_, StrideDs_, StrideE_, k_batch_},
429  p_as_grid{},
430  p_bs_grid{},
431  p_ds_grid{},
432  p_e_grid{p_e_grid_},
433  a_element_op{a_element_op_},
434  b_element_op{b_element_op_},
435  cde_element_op{cde_element_op_},
436  is_reduce(is_reduce_)
437  {
438  // populate pointer, desc for As
439  static_for<0, NumATensor, 1>{}([&](auto i) {
440  using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
441 
442  // A pointer
443  p_as_grid(i) = static_cast<const ADataType_*>(p_as_grid_[i]);
444  });
445 
446  // populate pointer, desc for Bs
447  static_for<0, NumBTensor, 1>{}([&](auto i) {
448  using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
449 
450  // B pointer
451  p_bs_grid(i) = static_cast<const BDataType_*>(p_bs_grid_[i]);
452  });
453 
454  // populate pointer, desc for Ds
455  static_for<0, NumDTensor, 1>{}([&](auto i) {
456  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
457 
458  // D pointer
459  p_ds_grid(i) = static_cast<const DDataType*>(p_ds_grid_[i]);
460  });
461  }
462 
463  __host__ __device__ inline bool IsReduceAdd() const
464  {
465  return (Problem::KBatch > 1) && is_reduce;
466  }
467 
468  __host__ __device__ inline bool IsAtomicAdd() const
469  {
470  return (Problem::KBatch > 1) && (!is_reduce);
471  }
472 
476  EDataType* p_e_grid;
477 
478  AElementwiseOperation a_element_op;
479  BElementwiseOperation b_element_op;
480  CDEElementwiseOperation cde_element_op;
481 
482  // TODO: it can be used with SplitK+reduction but currently only used with SplitK+atomicAdd
483  bool is_reduce;
484  };
485 
487  {
488 
489  __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
490  {
491  // Note: in xdl implementation multiple AB supports one layout
492  // but multiple strides, so we create an array of offsets with
493  // the same values.
494  // It should be fixed later on. Once we will have a thread transfer
495  // more flexible.
496  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
497  {
499  [&](auto i) { a_k_split_offset[i] = k_id * karg.KRead / APackedSize; });
500  }
501  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
502  {
504  [&](auto i) { a_k_split_offset[i] = k_id * karg.KRead * karg.StrideAs[i]; });
505  }
506 
507  if constexpr(IsBPreShuffled)
508  {
509  static_for<0, NumBTensor, 1>{}([&](auto i) { b_k_split_offset[i] = 0; });
510  }
511  else
512  {
513  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
514  {
515  static_for<0, NumBTensor, 1>{}([&](auto i) {
516  b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i];
517  });
518  }
519  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
520  {
521  if constexpr(!PermuteB)
522  {
524  [&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; });
525  }
526  else
527  {
528  const int k0_offset = karg.KRead * karg.N;
530  [&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; });
531  }
532  }
533  }
534 
535  if(k_id < karg.KBatch - 1)
536  {
537  karg.K = karg.KRead;
538  }
539  else
540  {
541  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
542  }
543 
544  if(karg.IsReduceAdd())
545  {
546  c_reduce_offset = k_id * karg.M * karg.N;
547  }
548  else
549  {
550  c_reduce_offset = 0;
551  }
552  }
553 
554  std::array<index_t, NumATensor> a_k_split_offset;
555  std::array<index_t, NumBTensor> b_k_split_offset;
557  };
558 
560 
561  // return block_id to C matrix tile idx (m0, n0) mapping
562  // if arch = gfx942
564  // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
565 
566  __device__ static index_t GetKBlockPerScale() { return 1; }
567 
568  template <bool HasMainKBlockLoop,
569  InMemoryDataOperationEnum EGlobalMemoryDataOperation,
570  TailNumber TailNum,
571  typename Block2CTileMap,
572  typename EpilogueArgument,
573  int BlockMapMBlockIndex = 0,
574  int BlockMapNBlockIndex = 1>
575  __device__ static void Run(AsGridPointer& p_as_grid,
576  BsGridPointer& p_bs_grid,
577  DsGridPointer& p_ds_grid,
578  EDataType* p_e_grid,
579  void* p_shared,
580  const Problem& problem,
581  const Block2CTileMap& block_2_ctile_map,
582  AElementwiseOperation a_element_op,
583  BElementwiseOperation b_element_op,
584  CDEElementwiseOperation cde_element_op,
585  EpilogueArgument& epilogue_args,
586  const index_t k_id = 0)
587  {
588  const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
589  problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
590  const index_t K_b = IsBPreShuffled ? problem.Kt : problem.K;
591  const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
592  K_b, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
593  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
594  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
595  const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N<ELayout>(
596  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE);
597  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
599  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
600  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
602  e_grid_desc_m_n, problem.MBlock, problem.NBlock);
603 
604  const auto block_work_idx =
606 
607  if(!block_2_ctile_map.ValidCTileIndex(
608  block_work_idx,
609  make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
610  e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
611  {
612  return;
613  }
614 
615  const index_t block_m_id =
616  __builtin_amdgcn_readfirstlane(block_work_idx[Number<BlockMapMBlockIndex>{}]);
617  const index_t block_n_id =
618  __builtin_amdgcn_readfirstlane(block_work_idx[Number<BlockMapNBlockIndex>{}]);
619 
620  // BScale struct (Empty)
621  using BScale = typename BlockwiseGemmPipe::Empty;
622  auto b_scale_struct = BScale{};
623 
624  const index_t num_k_block_per_scale = GetKBlockPerScale();
625 
626  Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
627  decltype(bs_grid_desc_bk0_n_bk1),
628  decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
629  decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
630  decltype(b_scale_struct),
631  decltype(epilogue_args),
632  HasMainKBlockLoop,
633  EGlobalMemoryDataOperation,
634  TailNum>(p_as_grid,
635  p_bs_grid,
636  p_ds_grid,
637  p_e_grid,
638  p_shared,
639  as_grid_desc_ak0_m_ak1,
640  bs_grid_desc_bk0_n_bk1,
641  ds_grid_desc_mblock_mperblock_nblock_nperblock,
642  e_grid_desc_mblock_mperblock_nblock_nperblock,
643  a_element_op,
644  b_element_op,
645  cde_element_op,
646  block_m_id,
647  block_n_id,
648  num_k_block_per_scale,
649  b_scale_struct,
650  epilogue_args,
651  k_id);
652  }
653 
654  template <bool HasMainKBlockLoop,
655  InMemoryDataOperationEnum EGlobalMemoryDataOperation,
656  TailNumber TailNum,
657  typename EpilogueArgument>
658  __device__ static void Run(AsGridPointer& p_as_grid,
659  BsGridPointer& p_bs_grid,
660  DsGridPointer& p_ds_grid,
661  EDataType* p_e_grid,
662  void* p_shared,
663  const Problem& problem,
664  AElementwiseOperation a_element_op,
665  BElementwiseOperation b_element_op,
666  CDEElementwiseOperation cde_element_op,
667  EpilogueArgument& epilogue_args)
668  {
669  Run<HasMainKBlockLoop,
670  EGlobalMemoryDataOperation,
671  TailNum,
673  EpilogueArgument>(p_as_grid,
674  p_bs_grid,
675  p_ds_grid,
676  p_e_grid,
677  p_shared,
678  problem,
679  DefaultBlock2CTileMap(problem),
680  a_element_op,
681  b_element_op,
682  cde_element_op,
683  epilogue_args);
684  }
685 
686  // Wrapper function to have __global__ function in common
687  // between gemm_universal, b_scale, ab_scale, etc.
688  template <bool HasMainKBlockLoop,
689  InMemoryDataOperationEnum EGlobalMemoryDataOperation,
690  TailNumber TailNum,
691  typename Block2CTileMap,
692  typename EpilogueArgument,
693  int BlockMapMBlockIndex = 0,
694  int BlockMapNBlockIndex = 1>
695  __device__ static void Run(void* p_shared,
696  const SplitKBatchOffset& splitk_batch_offset,
697  Argument& karg,
698  const Block2CTileMap& block_2_ctile_map,
699  EpilogueArgument& epilogue_args,
700  const index_t k_id = 0)
701  {
702  // shift A matrices pointer for splitk
703  AsGridPointer p_as_grid_splitk;
704  static_for<0, NumATensor, 1>{}([&](auto i) {
705  using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
706  p_as_grid_splitk(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) +
707  splitk_batch_offset.a_k_split_offset[i];
708  });
709 
710  // shift B matrices pointer for splitk
711  BsGridPointer p_bs_grid_splitk;
712  static_for<0, NumBTensor, 1>{}([&](auto i) {
713  using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
714  p_bs_grid_splitk(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) +
715  splitk_batch_offset.b_k_split_offset[i];
716  });
717 
718  Run<HasMainKBlockLoop,
719  EGlobalMemoryDataOperation,
720  TailNum,
722  EpilogueArgument,
723  BlockMapMBlockIndex,
724  BlockMapNBlockIndex>(p_as_grid_splitk,
725  p_bs_grid_splitk,
726  karg.p_ds_grid,
727  karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
728  p_shared,
729  karg,
730  block_2_ctile_map,
731  karg.a_element_op,
732  karg.b_element_op,
733  karg.cde_element_op,
734  epilogue_args,
735  k_id);
736  }
737 
738  // Wrapper function to have __global__ function in common
739  // between gemm_universal, b_scale, ab_scale, etc.
740  template <bool HasMainKBlockLoop,
741  InMemoryDataOperationEnum EGlobalMemoryDataOperation,
742  TailNumber TailNum,
743  typename EpilogueArgument>
744  __device__ static void Run(void* p_shared,
745  const SplitKBatchOffset& splitk_batch_offset,
746  Argument& karg,
747  EpilogueArgument& epilogue_args,
748  const index_t k_id = 0)
749  {
750  Run<HasMainKBlockLoop,
751  EGlobalMemoryDataOperation,
752  TailNum,
754  EpilogueArgument>(
755  p_shared, splitk_batch_offset, karg, DefaultBlock2CTileMap(karg), epilogue_args, k_id);
756  }
757 
758  __device__ static auto DefaultBlock2CTileMap(const Problem& problem)
759  {
760  return Block2CTileMap{problem.M, problem.N, 4};
761  }
762 };
763 
764 } // namespace ck
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:270
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
InMemoryDataOperationEnum
Definition: ck.hpp:279
BlockGemmPipelineVersion
Block GEMM pipeline version enumeration.
Definition: scheduler_enum.hpp:17
TailNumber
Tail number enumeration for pipeline buffering.
Definition: scheduler_enum.hpp:49
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
BlockGemmPipelineScheduler
Block GEMM pipeline scheduler enumeration.
Definition: scheduler_enum.hpp:33
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:301
Definition: block_to_ctile_map.hpp:271
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:298
__host__ __device__ bool ValidCTileIndex(const CTileIdx &, const CTileDim &) const
Definition: block_to_ctile_map.hpp:384
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:411
BElementwiseOperation b_element_op
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:479
bool is_reduce
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:483
__host__ __device__ bool IsAtomicAdd() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:468
AsGridPointer p_as_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:473
AElementwiseOperation a_element_op
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:478
DsGridPointer p_ds_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:475
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:463
__host__ Argument(std::array< const void *, NumATensor > p_as_grid_, std::array< const void *, NumBTensor > p_bs_grid_, std::array< const void *, NumDTensor > p_ds_grid_, EDataType *p_e_grid_, index_t M_, index_t N_, index_t K_, std::array< index_t, NumATensor > StrideAs_, std::array< index_t, NumBTensor > StrideBs_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideE_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CDEElementwiseOperation cde_element_op_, bool is_reduce_=false)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:412
CDEElementwiseOperation cde_element_op
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:480
EDataType * p_e_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:476
BsGridPointer p_bs_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:474
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:335
index_t KPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:401
index_t KRead
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:400
std::array< index_t, NumATensor > StrideAs
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:393
index_t MPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:398
index_t NBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:405
index_t Kt
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:406
index_t M
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:390
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:395
index_t NPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:399
index_t KBatch
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:397
index_t StrideE
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:396
__host__ void Print() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:364
std::array< index_t, NumBTensor > StrideBs
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:394
index_t MBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:404
index_t AK0
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:402
index_t N
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:391
index_t K
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:392
__host__ Problem(index_t M_, index_t N_, index_t K_, std::array< index_t, NumATensor > StrideAs_, std::array< index_t, NumBTensor > StrideBs_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideE_, index_t KBatch_)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:336
index_t BK0
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:403
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:487
std::array< index_t, NumATensor > a_k_split_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:554
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:489
index_t c_reduce_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:556
std::array< index_t, NumBTensor > b_k_split_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:555
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:125
static constexpr auto BK1Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:155
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:567
decltype(MakeAsGridPointer()) AsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:382
static constexpr index_t BPackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:181
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:553
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:320
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:325
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:343
static constexpr __device__ auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc &de_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:667
static constexpr auto BK0Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:153
static constexpr auto I0
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:127
static constexpr auto I1
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:128
static constexpr index_t NumBTensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:137
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:310
static constexpr index_t NumATensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:136
static constexpr auto I3
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:130
__host__ static __device__ auto MakeAsGridDescriptor_AK0_M_AK1(const index_t M, const index_t MPad, const index_t K, const index_t KPad, const std::array< index_t, NumATensor > &StrideAs, const index_t AK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:410
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:350
static constexpr auto I2
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:129
static constexpr auto I6
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:133
static constexpr auto AK1Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:154
static constexpr index_t APackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:174
decltype(MakeBsGridPointer()) BsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:383
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:315
static constexpr auto I7
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:134
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:331
__host__ static __device__ auto MakeBsGridDescriptor_BK0_N_BK1(const index_t K, const index_t KPad, const index_t N, const index_t NPad, const std::array< index_t, NumBTensor > &StrideBs, const index_t BK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:437
static constexpr auto AK0Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:152
static constexpr index_t NumDTensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:540
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:555
__host__ static __device__ auto MakeDEGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideDE)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:478
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, LDSTypeA, LDSTypeB, ComputeTypeA, ComputeTypeB, AccDataType, decltype(MakeAWmmaTileDescriptor()), decltype(MakeBWmmaTileDescriptor()), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, false, IsBPreShuffled >())> BlockwiseGemmPipe
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:601
static constexpr auto I4
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:131
static constexpr auto I5
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:132
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:355
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:235
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:567
decltype(MakeAsGridPointer()) AsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:382
static constexpr index_t BPackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:181
static __device__ void Run(AsGridPointer &p_as_grid, BsGridPointer &p_bs_grid, DsGridPointer &p_ds_grid, EDataType *p_e_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, EpilogueArgument &epilogue_args)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:658
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:553
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:320
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:325
static __device__ void Run(void *p_shared, const SplitKBatchOffset &splitk_batch_offset, Argument &karg, EpilogueArgument &epilogue_args, const index_t k_id=0)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:744
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:343
typename Base::BlockwiseGemmPipe BlockwiseGemmPipe
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:559
static constexpr __device__ auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc &de_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:667
static __device__ void Run(void *p_shared, const SplitKBatchOffset &splitk_batch_offset, Argument &karg, const Block2CTileMap &block_2_ctile_map, EpilogueArgument &epilogue_args, const index_t k_id=0)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:695
AsDataType AsDataType_
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:331
static constexpr auto I0
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:127
static constexpr index_t NumBTensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:137
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:310
static constexpr index_t NumATensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:136
__host__ static __device__ auto MakeAsGridDescriptor_AK0_M_AK1(const index_t M, const index_t MPad, const index_t K, const index_t KPad, const std::array< index_t, NumATensor > &StrideAs, const index_t AK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:410
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:350
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:323
static constexpr auto I2
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:129
BsDataType BsDataType_
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:332
static constexpr index_t APackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:174
decltype(MakeBsGridPointer()) BsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:383
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:315
static __device__ void Run(AsGridPointer &p_as_grid, BsGridPointer &p_bs_grid, DsGridPointer &p_ds_grid, EDataType *p_e_grid, void *p_shared, const Problem &problem, const Block2CTileMap &block_2_ctile_map, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, EpilogueArgument &epilogue_args, const index_t k_id=0)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:575
BlockToCTileMap_Grouped_M00_N0_M01Adapt< 8, MPerBlock, NPerBlock > Block2CTileMap
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:563
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:331
__host__ static __device__ auto MakeBsGridDescriptor_BK0_N_BK1(const index_t K, const index_t KPad, const index_t N, const index_t NPad, const std::array< index_t, NumBTensor > &StrideBs, const index_t BK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:437
static __device__ index_t GetKBlockPerScale()
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:566
static constexpr index_t NumDTensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:540
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:555
static __device__ auto DefaultBlock2CTileMap(const Problem &problem)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:758
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:355
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: device_base.hpp:197