/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp Source File
gridwise_gemm_wmma_cshuffle_v3.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck/utility/env.hpp"
19 
20 namespace ck {
21 
28 // operations that could be applied on each tensor respectively. The CDE_op is an
29 // elementwise operation applied to the C and all D tensors.
129 template <typename ALayout,
130  typename BLayout,
131  typename DsLayout,
132  typename ELayout,
133  typename AsDataType,
134  typename BsDataType,
135  typename AccDataType,
136  typename CShuffleDataType,
137  typename DsDataType,
138  typename EDataType,
139  typename AElementwiseOperation,
140  typename BElementwiseOperation,
141  typename CDEElementwiseOperation,
143  index_t BlockSize,
144  index_t MPerBlock,
145  index_t NPerBlock,
146  index_t KPerBlock,
147  index_t AK1Value,
148  index_t BK1Value,
149  index_t MPerWmma,
150  index_t NPerWmma,
151  index_t MRepeat,
152  index_t NRepeat,
153  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
154  typename ABlockTransferThreadClusterArrangeOrder,
155  typename ABlockTransferSrcAccessOrder,
156  index_t ABlockTransferSrcVectorDim,
157  index_t ABlockTransferSrcScalarPerVector,
158  index_t ABlockTransferDstScalarPerVector_AK1,
159  bool AThreadTransferSrcResetCoordinateAfterRun,
160  index_t ABlockLdsExtraM,
161  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
162  typename BBlockTransferThreadClusterArrangeOrder,
163  typename BBlockTransferSrcAccessOrder,
164  index_t BBlockTransferSrcVectorDim,
165  index_t BBlockTransferSrcScalarPerVector,
166  index_t BBlockTransferDstScalarPerVector_BK1,
167  bool BThreadTransferSrcResetCoordinateAfterRun,
168  index_t BBlockLdsExtraN,
169  index_t CShuffleMRepeatPerShuffle,
170  index_t CShuffleNRepeatPerShuffle,
171  typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
172  typename CDEShuffleBlockTransferScalarPerVectors,
173  BlockGemmPipelineScheduler BlkGemmPipeSched,
174  BlockGemmPipelineVersion BlkGemmPipelineVer,
175  typename ComputeTypeA,
176  typename ComputeTypeB,
177  bool PermuteA,
178  bool PermuteB,
179  bool IsBPreShuffled = false,
180  bool ForceThreadTileTransfer = false>
183  ALayout,
184  BLayout,
185  DsLayout,
186  ELayout,
187  AsDataType,
188  BsDataType,
189  AccDataType,
190  CShuffleDataType,
191  DsDataType,
192  EDataType,
193  AElementwiseOperation,
194  BElementwiseOperation,
195  CDEElementwiseOperation,
196  GemmSpec,
197  BlockSize,
198  MPerBlock,
199  NPerBlock,
200  KPerBlock,
201  AK1Value,
202  BK1Value,
203  MPerWmma,
204  NPerWmma,
205  MRepeat,
206  NRepeat,
207  ABlockTransferThreadClusterLengths_AK0_M_AK1,
208  ABlockTransferThreadClusterArrangeOrder,
209  ABlockTransferSrcAccessOrder,
210  ABlockTransferSrcVectorDim,
211  ABlockTransferSrcScalarPerVector,
212  ABlockTransferDstScalarPerVector_AK1,
213  AThreadTransferSrcResetCoordinateAfterRun,
214  ABlockLdsExtraM,
215  BBlockTransferThreadClusterLengths_BK0_N_BK1,
216  BBlockTransferThreadClusterArrangeOrder,
217  BBlockTransferSrcAccessOrder,
218  BBlockTransferSrcVectorDim,
219  BBlockTransferSrcScalarPerVector,
220  BBlockTransferDstScalarPerVector_BK1,
221  BThreadTransferSrcResetCoordinateAfterRun,
222  BBlockLdsExtraN,
223  CShuffleMRepeatPerShuffle,
224  CShuffleNRepeatPerShuffle,
225  CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
226  CDEShuffleBlockTransferScalarPerVectors,
227  BlkGemmPipeSched,
228  BlkGemmPipelineVer,
229  ComputeTypeA,
230  ComputeTypeB,
231  PermuteA,
232  PermuteB,
233  IsBPreShuffled,
234  ForceThreadTileTransfer>
235 {
237  ALayout,
238  BLayout,
239  DsLayout,
240  ELayout,
241  AsDataType,
242  BsDataType,
243  AccDataType,
244  CShuffleDataType,
245  DsDataType,
246  EDataType,
247  AElementwiseOperation,
248  BElementwiseOperation,
249  CDEElementwiseOperation,
250  GemmSpec,
251  BlockSize,
252  MPerBlock,
253  NPerBlock,
254  KPerBlock,
255  AK1Value,
256  BK1Value,
257  MPerWmma,
258  NPerWmma,
259  MRepeat,
260  NRepeat,
261  ABlockTransferThreadClusterLengths_AK0_M_AK1,
262  ABlockTransferThreadClusterArrangeOrder,
263  ABlockTransferSrcAccessOrder,
264  ABlockTransferSrcVectorDim,
265  ABlockTransferSrcScalarPerVector,
266  ABlockTransferDstScalarPerVector_AK1,
267  AThreadTransferSrcResetCoordinateAfterRun,
268  ABlockLdsExtraM,
269  BBlockTransferThreadClusterLengths_BK0_N_BK1,
270  BBlockTransferThreadClusterArrangeOrder,
271  BBlockTransferSrcAccessOrder,
272  BBlockTransferSrcVectorDim,
273  BBlockTransferSrcScalarPerVector,
274  BBlockTransferDstScalarPerVector_BK1,
275  BThreadTransferSrcResetCoordinateAfterRun,
276  BBlockLdsExtraN,
277  CShuffleMRepeatPerShuffle,
278  CShuffleNRepeatPerShuffle,
279  CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
280  CDEShuffleBlockTransferScalarPerVectors,
281  BlkGemmPipeSched,
282  BlkGemmPipelineVer,
283  ComputeTypeA,
284  ComputeTypeB,
285  PermuteA,
286  PermuteB,
287  IsBPreShuffled,
288  ForceThreadTileTransfer>;
289 
290  using Base::I0;
291  using Base::I1;
292  using Base::I2;
293  using Base::I3;
294  using Base::I4;
295  using Base::I5;
296  using Base::I6;
297  using Base::I7;
298 
299  using Base::AK0Number;
300  using Base::AK1Number;
301  using Base::BK0Number;
302  using Base::BK1Number;
303 
304  using Base::APackedSize;
305  using Base::BPackedSize;
306 
310  using Base::CalculateKRead;
311  using Base::CalculateMBlock;
313  using Base::CalculateNBlock;
320 
322 
324 
325  using Base::NumATensor;
326  using Base::NumBTensor;
327  using Base::NumDTensor;
328  using typename Base::AsGridPointer;
329  using typename Base::BsGridPointer;
330  using typename Base::DsGridPointer;
331  using AsDataType_ = AsDataType;
332  using BsDataType_ = BsDataType;
333 
334  struct Problem
335  {
336  __host__ Problem() = default;
337  __host__ Problem(index_t M_,
338  index_t N_,
339  index_t K_,
340  std::array<index_t, NumATensor> StrideAs_,
341  std::array<index_t, NumBTensor> StrideBs_,
342  std::array<index_t, NumDTensor> StrideDs_,
343  index_t StrideE_,
344  index_t KBatch_)
345  : M{M_},
346  N{N_},
347  K{K_},
348  StrideAs{StrideAs_},
349  StrideBs{StrideBs_},
350  StrideDs{StrideDs_},
351  StrideE{StrideE_},
352  KBatch{KBatch_},
355  KRead{CalculateKRead(K_, KBatch_)},
356  KPadded{CalculateKPadded(K_, KBatch_)},
357  AK0{CalculateAK0Padded(K_, KBatch_)},
358  BK0{CalculateBK0Padded(K_, KBatch_)},
359  MBlock{CalculateMBlock(M_)},
360  NBlock{CalculateNBlock(N_)},
361  Kt{K_}
362  {
363  }
364 
365  __host__ void Print() const
366  {
367  std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
368  << "SAs: {";
369  static_for<0, NumATensor, 1>{}([&](auto i) {
370  std::cout << StrideAs[i] << (i.value < NumATensor - 1 ? ", " : "");
371  });
372  std::cout << "}, " << "SBs: {";
373  static_for<0, NumBTensor, 1>{}([&](auto i) {
374  std::cout << StrideBs[i] << (i.value < NumBTensor - 1 ? ", " : "");
375  });
376  std::cout << "}, ";
377  if constexpr(NumDTensor > 0)
378  {
379  std::cout << "SDs: { ";
380  static_for<0, NumDTensor, 1>{}([&](auto i) {
381  std::cout << StrideDs[i] << (i.value < NumDTensor - 1 ? ", " : "");
382  });
383  std::cout << " }, ";
384  }
385  std::cout << "SE:" << StrideE << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded
386  << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", "
387  << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock
388  << ", " << "NBlock: " << NBlock << "}" << std::endl;
389  }
390 
394  std::array<index_t, NumATensor> StrideAs;
395  std::array<index_t, NumBTensor> StrideBs;
396  std::array<index_t, NumDTensor> StrideDs;
408  };
409 
410  // Argument
412  {
413  __host__ Argument() = default;
414  __host__ Argument(std::array<const void*, NumATensor> p_as_grid_,
415  std::array<const void*, NumBTensor> p_bs_grid_,
416  std::array<const void*, NumDTensor> p_ds_grid_,
417  EDataType* p_e_grid_,
418  index_t M_,
419  index_t N_,
420  index_t K_,
421  std::array<index_t, NumATensor> StrideAs_,
422  std::array<index_t, NumBTensor> StrideBs_,
423  std::array<index_t, NumDTensor> StrideDs_,
424  index_t StrideE_,
425  index_t k_batch_,
426  AElementwiseOperation a_element_op_,
427  BElementwiseOperation b_element_op_,
428  CDEElementwiseOperation cde_element_op_,
429  bool is_reduce_ = false)
430  : Problem{M_, N_, K_, StrideAs_, StrideBs_, StrideDs_, StrideE_, k_batch_},
431  p_as_grid{},
432  p_bs_grid{},
433  p_ds_grid{},
434  p_e_grid{p_e_grid_},
435  a_element_op{a_element_op_},
436  b_element_op{b_element_op_},
437  cde_element_op{cde_element_op_},
438  is_reduce(is_reduce_)
439  {
440  // populate pointer, desc for As
441  static_for<0, NumATensor, 1>{}([&](auto i) {
442  using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
443 
444  // A pointer
445  p_as_grid(i) = static_cast<const ADataType_*>(p_as_grid_[i]);
446  });
447 
448  // populate pointer, desc for Bs
449  static_for<0, NumBTensor, 1>{}([&](auto i) {
450  using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
451 
452  // B pointer
453  p_bs_grid(i) = static_cast<const BDataType_*>(p_bs_grid_[i]);
454  });
455 
456  // populate pointer, desc for Ds
457  static_for<0, NumDTensor, 1>{}([&](auto i) {
458  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
459 
460  // D pointer
461  p_ds_grid(i) = static_cast<const DDataType*>(p_ds_grid_[i]);
462  });
463  }
464 
465  __host__ __device__ inline bool IsReduceAdd() const
466  {
467  return (Problem::KBatch > 1) && is_reduce;
468  }
469 
470  __host__ __device__ inline bool IsAtomicAdd() const
471  {
472  return (Problem::KBatch > 1) && (!is_reduce);
473  }
474 
478  EDataType* p_e_grid;
479 
480  AElementwiseOperation a_element_op;
481  BElementwiseOperation b_element_op;
482  CDEElementwiseOperation cde_element_op;
483 
484  // TODO: it can be used with SplitK+reduction but currently only used with SplitK+atomicAdd
485  bool is_reduce;
486  };
487 
489  {
490 
491  __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
492  {
493  // Note: in xdl implementation multiple AB supports one layout
494  // but multiple strides, so we create an array of offsets with
495  // the same values.
496  // It should be fixed later on. Once we will have a thread transfer
497  // more flexible.
498  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
499  {
501  [&](auto i) { a_k_split_offset[i] = k_id * karg.KRead / APackedSize; });
502  }
503  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
504  {
506  [&](auto i) { a_k_split_offset[i] = k_id * karg.KRead * karg.StrideAs[i]; });
507  }
508 
509  if constexpr(IsBPreShuffled)
510  {
511  static_for<0, NumBTensor, 1>{}([&](auto i) { b_k_split_offset[i] = 0; });
512  }
513  else
514  {
515  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
516  {
517  static_for<0, NumBTensor, 1>{}([&](auto i) {
518  b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i];
519  });
520  }
521  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
522  {
523  if constexpr(!PermuteB)
524  {
526  [&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; });
527  }
528  else
529  {
530  const int k0_offset = karg.KRead * karg.N;
532  [&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; });
533  }
534  }
535  }
536 
537  if(k_id < karg.KBatch - 1)
538  {
539  karg.K = karg.KRead;
540  }
541  else
542  {
543  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
544  }
545 
546  if(karg.IsReduceAdd())
547  {
548  c_reduce_offset = k_id * karg.M * karg.N;
549  }
550  else
551  {
552  c_reduce_offset = 0;
553  }
554  }
555 
556  std::array<index_t, NumATensor> a_k_split_offset;
557  std::array<index_t, NumBTensor> b_k_split_offset;
559  };
560 
562 
563  // return block_id to C matrix tile idx (m0, n0) mapping
564  // if arch = gfx942
566  // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
567 
568  __device__ static index_t GetKBlockPerScale() { return 1; }
569 
570  template <bool HasMainKBlockLoop,
571  InMemoryDataOperationEnum EGlobalMemoryDataOperation,
572  TailNumber TailNum,
573  typename Block2CTileMap,
574  typename EpilogueArgument,
575  int BlockMapMBlockIndex = 0,
576  int BlockMapNBlockIndex = 1>
577  __device__ static void Run(AsGridPointer& p_as_grid,
578  BsGridPointer& p_bs_grid,
579  DsGridPointer& p_ds_grid,
580  EDataType* p_e_grid,
581  void* p_shared,
582  const Problem& problem,
583  const Block2CTileMap& block_2_ctile_map,
584  AElementwiseOperation a_element_op,
585  BElementwiseOperation b_element_op,
586  CDEElementwiseOperation cde_element_op,
587  EpilogueArgument& epilogue_args,
588  const index_t A_k_id = 0,
589  const index_t B_k_id = 0)
590  {
591  const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
592  problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
593  const index_t K_b = IsBPreShuffled ? problem.Kt : problem.K;
594  const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
595  K_b, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
596  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
597  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
598  const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N<ELayout>(
599  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE);
600  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
602  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
603  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
605  e_grid_desc_m_n, problem.MBlock, problem.NBlock);
606 
607  const auto block_work_idx =
609 
610  if(!block_2_ctile_map.ValidCTileIndex(
611  block_work_idx,
612  make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
613  e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
614  {
615  return;
616  }
617 
618  const index_t block_m_id =
619  __builtin_amdgcn_readfirstlane(block_work_idx[Number<BlockMapMBlockIndex>{}]);
620  const index_t block_n_id =
621  __builtin_amdgcn_readfirstlane(block_work_idx[Number<BlockMapNBlockIndex>{}]);
622 
623  // BScale struct (Empty)
624  using Scale = typename BlockwiseGemmPipe::Empty;
625  auto a_scale_struct = Scale{};
626  auto b_scale_struct = Scale{};
627 
628  const index_t num_k_block_per_scale = GetKBlockPerScale();
629 
630  Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
631  decltype(bs_grid_desc_bk0_n_bk1),
632  decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
633  decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
634  decltype(a_scale_struct),
635  decltype(b_scale_struct),
636  decltype(epilogue_args),
637  HasMainKBlockLoop,
638  EGlobalMemoryDataOperation,
639  TailNum>(p_as_grid,
640  p_bs_grid,
641  p_ds_grid,
642  p_e_grid,
643  p_shared,
644  as_grid_desc_ak0_m_ak1,
645  bs_grid_desc_bk0_n_bk1,
646  ds_grid_desc_mblock_mperblock_nblock_nperblock,
647  e_grid_desc_mblock_mperblock_nblock_nperblock,
648  a_element_op,
649  b_element_op,
650  cde_element_op,
651  block_m_id,
652  block_n_id,
653  num_k_block_per_scale,
654  a_scale_struct,
655  b_scale_struct,
656  epilogue_args,
657  A_k_id,
658  B_k_id);
659  }
660 
661  template <bool HasMainKBlockLoop,
662  InMemoryDataOperationEnum EGlobalMemoryDataOperation,
663  TailNumber TailNum,
664  typename EpilogueArgument>
665  __device__ static void Run(AsGridPointer& p_as_grid,
666  BsGridPointer& p_bs_grid,
667  DsGridPointer& p_ds_grid,
668  EDataType* p_e_grid,
669  void* p_shared,
670  const Problem& problem,
671  AElementwiseOperation a_element_op,
672  BElementwiseOperation b_element_op,
673  CDEElementwiseOperation cde_element_op,
674  EpilogueArgument& epilogue_args)
675  {
676  Run<HasMainKBlockLoop,
677  EGlobalMemoryDataOperation,
678  TailNum,
680  EpilogueArgument>(p_as_grid,
681  p_bs_grid,
682  p_ds_grid,
683  p_e_grid,
684  p_shared,
685  problem,
686  DefaultBlock2CTileMap(problem),
687  a_element_op,
688  b_element_op,
689  cde_element_op,
690  epilogue_args);
691  }
692 
693  // Wrapper function to have __global__ function in common
694  // between gemm_universal, b_scale, ab_scale, etc.
695  template <bool HasMainKBlockLoop,
696  InMemoryDataOperationEnum EGlobalMemoryDataOperation,
697  TailNumber TailNum,
698  typename Block2CTileMap,
699  typename EpilogueArgument,
700  int BlockMapMBlockIndex = 0,
701  int BlockMapNBlockIndex = 1>
702  __device__ static void Run(void* p_shared,
703  const SplitKBatchOffset& splitk_batch_offset,
704  Argument& karg,
705  const Block2CTileMap& block_2_ctile_map,
706  EpilogueArgument& epilogue_args,
707  const index_t A_k_id = 0,
708  const index_t B_k_id = 0)
709  {
710  // shift A matrices pointer for splitk
711  AsGridPointer p_as_grid_splitk;
712  static_for<0, NumATensor, 1>{}([&](auto i) {
713  using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
714  p_as_grid_splitk(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) +
715  splitk_batch_offset.a_k_split_offset[i];
716  });
717 
718  // shift B matrices pointer for splitk
719  BsGridPointer p_bs_grid_splitk;
720  static_for<0, NumBTensor, 1>{}([&](auto i) {
721  using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
722  p_bs_grid_splitk(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) +
723  splitk_batch_offset.b_k_split_offset[i];
724  });
725 
726  Run<HasMainKBlockLoop,
727  EGlobalMemoryDataOperation,
728  TailNum,
730  EpilogueArgument,
731  BlockMapMBlockIndex,
732  BlockMapNBlockIndex>(p_as_grid_splitk,
733  p_bs_grid_splitk,
734  karg.p_ds_grid,
735  karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
736  p_shared,
737  karg,
738  block_2_ctile_map,
739  karg.a_element_op,
740  karg.b_element_op,
741  karg.cde_element_op,
742  epilogue_args,
743  A_k_id,
744  B_k_id);
745  }
746 
747  // Wrapper function to have __global__ function in common
748  // between gemm_universal, b_scale, ab_scale, etc.
749  template <bool HasMainKBlockLoop,
750  InMemoryDataOperationEnum EGlobalMemoryDataOperation,
751  TailNumber TailNum,
752  typename EpilogueArgument>
753  __device__ static void Run(void* p_shared,
754  const SplitKBatchOffset& splitk_batch_offset,
755  Argument& karg,
756  EpilogueArgument& epilogue_args,
757  const index_t A_k_id = 0,
758  const index_t B_k_id = 0)
759  {
760  Run<HasMainKBlockLoop,
761  EGlobalMemoryDataOperation,
762  TailNum,
764  EpilogueArgument>(p_shared,
765  splitk_batch_offset,
766  karg,
767  DefaultBlock2CTileMap(karg),
768  epilogue_args,
769  A_k_id,
770  B_k_id);
771  }
772 
773  __device__ static auto DefaultBlock2CTileMap(const Problem& problem)
774  {
775  return Block2CTileMap{problem.M, problem.N, 4};
776  }
777 
778  // Run method for convolution for bwd_data (grid descriptors are passed as arguments,
779  // not generated internally)
780  template <typename AGridDesc_AK0_M_K1,
781  typename BGridDesc_BK0_N_K1,
782  typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
783  typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
784  typename Block2CTileMapExt,
785  typename ComputePtrOffsetOfBatch,
786  typename ComputePtrOffsetOfN,
787  bool HasMainKBlockLoop,
788  InMemoryDataOperationEnum EGlobalMemoryDataOperation,
789  bool CTranspose,
790  TailNumber TailNum,
791  typename EpilogueArgument>
792  __device__ static void Run(void* p_shared,
793  const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
794  const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
795  const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
796  ds_grid_desc_mblock_mperblock_nblock_nperblock,
797  const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
798  e_grid_desc_mblock_mperblock_nblock_nperblock,
799  const Block2CTileMapExt& block_2_ctile_map,
800  const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
801  const ComputePtrOffsetOfN compute_ptr_offset_of_n,
802  const index_t num_k_per_block,
803  Argument& karg,
804  EpilogueArgument& epilogue_args)
805  {
806  const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
807  const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / karg.KBatch);
808  const index_t k_idx =
809  __builtin_amdgcn_readfirstlane((blockIdx.z - n_idx * karg.KBatch) * num_k_per_block);
810 
811  // offset base pointer for each work-group
812  const long_index_t a_batch_offset =
813  CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))
814  : amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
815  const long_index_t b_batch_offset =
816  CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))
817  : amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
818  const long_index_t e_batch_offset =
819  amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
820 
821  const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
822 
823  const long_index_t a_n_offset =
824  CTranspose ? 0 : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
825  const long_index_t b_n_offset =
826  CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) : 0;
827  const long_index_t e_n_offset =
828  amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
829 
830  AsGridPointer p_as_grid_;
831  static_for<0, NumATensor, 1>{}([&](auto i) {
832  using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
833  p_as_grid_(i) =
834  static_cast<const ADataType_*>(karg.p_as_grid[i]) + a_batch_offset + a_n_offset;
835  });
836 
837  BsGridPointer p_bs_grid_;
838  static_for<0, NumBTensor, 1>{}([&](auto i) {
839  using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
840  p_bs_grid_(i) =
841  static_cast<const BDataType_*>(karg.p_bs_grid[i]) + b_batch_offset + b_n_offset;
842  });
843 
844  DsGridPointer p_ds_grid_grp;
846  [&](auto i) { p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_batch_offset[i]; });
847 
848  // Currently supporting one A and one B
849  const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
850  [&](auto i) {
851  ignore = i;
852  return a_grid_desc_ak0_m_ak1;
853  },
855 
856  const auto bs_grid_desc_bk0_n_bk1 = generate_tuple(
857  [&](auto i) {
858  ignore = i;
859  return b_grid_desc_bk0_n_bk1;
860  },
862 
863  const auto block_work_idx =
864  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
865 
866  if(!block_2_ctile_map.ValidCTileIndex(
867  block_work_idx,
868  make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
869  e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
870  {
871  return;
872  }
873 
874  const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
875  const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
876 
877  // AScale struct (Empty)
878  using AScale = typename BlockwiseGemmPipe::Empty;
879  auto a_scale_struct = AScale{};
880 
881  // BScale struct (Empty)
882  using BScale = typename BlockwiseGemmPipe::Empty;
883  auto b_scale_struct = BScale{};
884 
885  const index_t num_k_block_per_scale = GetKBlockPerScale();
886 
887  Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
888  decltype(bs_grid_desc_bk0_n_bk1),
889  decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
890  decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
891  decltype(a_scale_struct),
892  decltype(b_scale_struct),
893  decltype(epilogue_args),
894  HasMainKBlockLoop,
895  EGlobalMemoryDataOperation,
896  TailNum>(p_as_grid_,
897  p_bs_grid_,
898  p_ds_grid_grp,
899  karg.p_e_grid + e_batch_offset + e_n_offset,
900  p_shared,
901  as_grid_desc_ak0_m_ak1,
902  bs_grid_desc_bk0_n_bk1,
903  ds_grid_desc_mblock_mperblock_nblock_nperblock,
904  e_grid_desc_mblock_mperblock_nblock_nperblock,
905  karg.a_element_op,
906  karg.b_element_op,
907  karg.cde_element_op,
908  block_m_id,
909  block_n_id,
910  num_k_block_per_scale,
911  a_scale_struct,
912  b_scale_struct,
913  epilogue_args,
914  k_idx,
915  k_idx,
916  karg.KBatch);
917  }
918 
919  // Run method for convolution (grid descriptors are passed as arguments,
920  // not generated internally)
921  template <typename AGridDesc_AK0_M_K1,
922  typename BGridDesc_BK0_N_K1,
923  typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
924  typename ComputePtrOffsetOfBatch,
925  index_t NumGroupsToMerge,
926  bool HasMainKBlockLoop,
927  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
928  TailNumber TailNum,
929  typename EpilogueArgument>
930  __device__ static void Run(void* p_shared,
931  const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
932  const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
933  const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
934  c_grid_desc_mblock_mperblock_nblock_nperblock,
935  const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
936  const index_t num_k_per_block,
937  Argument& karg,
938  EpilogueArgument& epilogue_args)
939  {
940  const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
941  const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
942 
943  const long_index_t a_batch_offset =
944  amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
945  const long_index_t b_batch_offset =
946  amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
947  const long_index_t e_batch_offset =
948  amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
949 
950  AsGridPointer p_as_grid_;
951  static_for<0, NumATensor, 1>{}([&](auto i) {
952  using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
953  p_as_grid_(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) + a_batch_offset;
954  });
955 
956  BsGridPointer p_bs_grid_;
957  static_for<0, NumBTensor, 1>{}([&](auto i) {
958  using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
959  p_bs_grid_(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) + b_batch_offset;
960  });
961 
962  const auto ds_grid_desc_m_n =
963  MakeDsGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideDs);
964 
965  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
967  ds_grid_desc_m_n, karg.MBlock, karg.NBlock);
968 
969  const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
970  [&](auto i) {
971  ignore = i;
972  return a_grid_desc_ak0_m_ak1;
973  },
975 
976  const auto bs_grid_desc_bk0_n_bk1 = generate_tuple(
977  [&](auto i) {
978  ignore = i;
979  return b_grid_desc_bk0_n_bk1;
980  },
982 
983  // divide block work by [M, N]
984  const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4};
985 
986  const auto block_work_idx =
987  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
988 
989  if(!block_2_ctile_map.ValidCTileIndex(
990  block_work_idx,
991  make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
992  c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
993  {
994  return;
995  }
996 
997  const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
998  const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
999 
1000  // Scale structs (Empty)
1001  using Scale = typename BlockwiseGemmPipe::Empty;
1002  auto b_scale_struct = Scale{};
1003  auto a_scale_struct = Scale{};
1004 
1005  const index_t num_k_block_per_scale = GetKBlockPerScale();
1006 
1007  Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
1008  decltype(bs_grid_desc_bk0_n_bk1),
1009  decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
1010  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1011  decltype(a_scale_struct),
1012  decltype(b_scale_struct),
1013  decltype(epilogue_args),
1014  HasMainKBlockLoop,
1015  CGlobalMemoryDataOperation,
1016  TailNum>(p_as_grid_,
1017  p_bs_grid_,
1018  karg.p_ds_grid,
1019  karg.p_e_grid + e_batch_offset,
1020  p_shared,
1021  as_grid_desc_ak0_m_ak1,
1022  bs_grid_desc_bk0_n_bk1,
1023  ds_grid_desc_mblock_mperblock_nblock_nperblock,
1024  c_grid_desc_mblock_mperblock_nblock_nperblock,
1025  karg.a_element_op,
1026  karg.b_element_op,
1027  karg.cde_element_op,
1028  block_m_id,
1029  block_n_id,
1030  num_k_block_per_scale,
1031  a_scale_struct,
1032  b_scale_struct,
1033  epilogue_args,
1034  k_idx,
1035  k_idx,
1036  karg.KBatch);
1037  }
1038 
1039  // Run method for convolution fwd (grid descriptors are passed as arguments,
1040  // not generated internally)
1041  template <typename AGridDesc_AK0_M_K1,
1042  typename BGridDesc_BK0_N_K1,
1043  typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
1044  typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1045  typename ComputePtrOffsetOfBatch,
1046  typename ComputePtrOffsetOfN,
1047  bool HasMainKBlockLoop,
1048  InMemoryDataOperationEnum EGlobalMemoryDataOperation,
1049  TailNumber TailNum,
1050  typename EpilogueArgument>
1051  __device__ static void Run(void* p_shared,
1052  const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1053  const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1054  const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
1055  ds_grid_desc_mblock_mperblock_nblock_nperblock,
1056  const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1057  e_grid_desc_mblock_mperblock_nblock_nperblock,
1058  const ComputePtrOffsetOfBatch& compute_ptr_offset_of_batch,
1059  const ComputePtrOffsetOfN& compute_ptr_offset_of_n,
1060  [[maybe_unused]] const index_t num_k_per_block,
1061  Argument& karg,
1062  EpilogueArgument& epilogue_args)
1063  {
1064  const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
1065  const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / karg.KBatch);
1066  // offset base pointer for each work-group
1067  const long_index_t a_batch_offset =
1068  amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
1069  const long_index_t b_batch_offset =
1070  amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
1071  const long_index_t e_batch_offset =
1072  amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
1073 
1074  const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
1075 
1076  const long_index_t a_n_offset =
1077  amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
1078  const long_index_t b_n_offset =
1079  amd_wave_read_first_lane(compute_ptr_offset_of_n.GetBPtrOffset(n_idx));
1080  const long_index_t e_n_offset =
1081  amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
1082 
1083  const auto ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx);
1084 
1085  AsGridPointer p_as_grid_;
1086  static_for<0, NumATensor, 1>{}([&](auto i) {
1087  using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
1088  p_as_grid_(i) =
1089  static_cast<const ADataType_*>(karg.p_as_grid[i]) + a_batch_offset + a_n_offset;
1090  });
1091 
1092  BsGridPointer p_bs_grid_;
1093  static_for<0, NumBTensor, 1>{}([&](auto i) {
1094  using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
1095  p_bs_grid_(i) =
1096  static_cast<const BDataType_*>(karg.p_bs_grid[i]) + b_batch_offset + b_n_offset;
1097  });
1098 
1099  DsGridPointer p_ds_grid_grp;
1100  static_for<0, NumDTensor, 1>{}([&](auto i) {
1101  using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
1102  p_ds_grid_grp(i) = static_cast<const DDataType_*>(karg.p_ds_grid[i]) +
1103  ds_batch_offset[i] + ds_n_offset[i];
1104  });
1105 
1106  // Currently supporting one A and one B
1107  const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
1108  [&](auto i) {
1109  ignore = i;
1110  return a_grid_desc_ak0_m_ak1;
1111  },
1112  Number<NumATensor>{});
1113 
1114  const auto bs_grid_desc_bk0_n_bk1 = generate_tuple(
1115  [&](auto i) {
1116  ignore = i;
1117  return b_grid_desc_bk0_n_bk1;
1118  },
1119  Number<NumBTensor>{});
1120 
1121  // divide block work by [M, N]
1122  const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4};
1123 
1124  const auto block_work_idx =
1125  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1126 
1127  if(!block_2_ctile_map.ValidCTileIndex(
1128  block_work_idx,
1129  make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1130  e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1131  {
1132  return;
1133  }
1134 
1135  const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1136  const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1137 
1138  // AScale struct (Empty)
1139  using AScale = typename BlockwiseGemmPipe::Empty;
1140  auto a_scale_struct = AScale{};
1141 
1142  // BScale struct (Empty)
1143  using BScale = typename BlockwiseGemmPipe::Empty;
1144  auto b_scale_struct = BScale{};
1145 
1146  const index_t num_k_block_per_scale = GetKBlockPerScale();
1147 
1148  Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
1149  decltype(bs_grid_desc_bk0_n_bk1),
1150  decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
1151  decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
1152  decltype(a_scale_struct),
1153  decltype(b_scale_struct),
1154  decltype(epilogue_args),
1155  HasMainKBlockLoop,
1156  EGlobalMemoryDataOperation,
1157  TailNum>(p_as_grid_,
1158  p_bs_grid_,
1159  p_ds_grid_grp,
1160  karg.p_e_grid + e_batch_offset + e_n_offset,
1161  p_shared,
1162  as_grid_desc_ak0_m_ak1,
1163  bs_grid_desc_bk0_n_bk1,
1164  ds_grid_desc_mblock_mperblock_nblock_nperblock,
1165  e_grid_desc_mblock_mperblock_nblock_nperblock,
1166  karg.a_element_op,
1167  karg.b_element_op,
1168  karg.cde_element_op,
1169  block_m_id,
1170  block_n_id,
1171  num_k_block_per_scale,
1172  a_scale_struct,
1173  b_scale_struct,
1174  epilogue_args);
1175  }
1176 };
1177 
1178 } // namespace ck
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:270
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:279
BlockGemmPipelineVersion
Block GEMM pipeline version enumeration.
Definition: scheduler_enum.hpp:17
int64_t long_index_t
Definition: ck.hpp:302
TailNumber
Tail number enumeration for pipeline buffering.
Definition: scheduler_enum.hpp:49
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition: amd_wave_read_first_lane.hpp:100
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
BlockGemmPipelineScheduler
Block GEMM pipeline scheduler enumeration.
Definition: scheduler_enum.hpp:33
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:301
Definition: block_to_ctile_map.hpp:271
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:298
__host__ __device__ bool ValidCTileIndex(const CTileIdx &, const CTileDim &) const
Definition: block_to_ctile_map.hpp:384
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:412
BElementwiseOperation b_element_op
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:481
bool is_reduce
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:485
__host__ __device__ bool IsAtomicAdd() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:470
AsGridPointer p_as_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:475
AElementwiseOperation a_element_op
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:480
DsGridPointer p_ds_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:477
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:465
__host__ Argument(std::array< const void *, NumATensor > p_as_grid_, std::array< const void *, NumBTensor > p_bs_grid_, std::array< const void *, NumDTensor > p_ds_grid_, EDataType *p_e_grid_, index_t M_, index_t N_, index_t K_, std::array< index_t, NumATensor > StrideAs_, std::array< index_t, NumBTensor > StrideBs_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideE_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CDEElementwiseOperation cde_element_op_, bool is_reduce_=false)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:414
CDEElementwiseOperation cde_element_op
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:482
EDataType * p_e_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:478
BsGridPointer p_bs_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:476
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:335
index_t KPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:402
index_t KRead
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:401
std::array< index_t, NumATensor > StrideAs
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:394
index_t MPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:399
index_t NBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:406
index_t Kt
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:407
index_t M
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:391
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:396
index_t NPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:400
index_t KBatch
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:398
index_t StrideE
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:397
__host__ void Print() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:365
std::array< index_t, NumBTensor > StrideBs
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:395
index_t MBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:405
index_t AK0
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:403
index_t N
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:392
index_t K
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:393
__host__ Problem(index_t M_, index_t N_, index_t K_, std::array< index_t, NumATensor > StrideAs_, std::array< index_t, NumBTensor > StrideBs_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideE_, index_t KBatch_)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:337
index_t BK0
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:404
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:489
std::array< index_t, NumATensor > a_k_split_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:556
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:491
index_t c_reduce_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:558
std::array< index_t, NumBTensor > b_k_split_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:557
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:172
static constexpr auto BK1Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:202
__host__ static constexpr __device__ auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc &de_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:726
decltype(MakeAsGridPointer()) AsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:438
static constexpr index_t BPackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:228
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:609
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:376
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:381
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:399
static constexpr auto BK0Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:200
static constexpr auto I0
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:174
static constexpr auto I1
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:175
static constexpr index_t NumBTensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:184
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:366
static constexpr index_t NumATensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:183
static constexpr auto I3
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:177
__host__ static __device__ auto MakeAsGridDescriptor_AK0_M_AK1(const index_t M, const index_t MPad, const index_t K, const index_t KPad, const std::array< index_t, NumATensor > &StrideAs, const index_t AK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:466
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:406
static constexpr auto I2
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:176
static constexpr auto I6
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:180
static constexpr auto AK1Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:201
static constexpr index_t APackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:221
decltype(MakeBsGridPointer()) BsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:439
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:371
static constexpr auto I7
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:181
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:387
__host__ static __device__ auto MakeBsGridDescriptor_BK0_N_BK1(const index_t K, const index_t KPad, const index_t N, const index_t NPad, const std::array< index_t, NumBTensor > &StrideBs, const index_t BK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:493
static constexpr auto AK0Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:199
static constexpr index_t NumDTensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:596
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:611
__host__ static __device__ auto MakeDEGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideDE)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:534
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, LDSTypeA, LDSTypeB, ComputeTypeA, ComputeTypeB, AccDataType, decltype(MakeAWmmaTileDescriptor()), decltype(MakeBWmmaTileDescriptor()), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, false, IsBPreShuffled >())> BlockwiseGemmPipe
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:659
static constexpr auto I4
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:178
__device__ static constexpr __host__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:624
static constexpr auto I5
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:179
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:411
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:235
__host__ static constexpr __device__ auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc &de_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:726
decltype(MakeAsGridPointer()) AsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:438
static constexpr index_t BPackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:228
static __device__ void Run(AsGridPointer &p_as_grid, BsGridPointer &p_bs_grid, DsGridPointer &p_ds_grid, EDataType *p_e_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, EpilogueArgument &epilogue_args)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:665
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:609
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:376
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:381
static __device__ void Run(void *p_shared, const SplitKBatchOffset &splitk_batch_offset, Argument &karg, const Block2CTileMap &block_2_ctile_map, EpilogueArgument &epilogue_args, const index_t A_k_id=0, const index_t B_k_id=0)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:702
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:399
static __device__ void Run(void *p_shared, const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const index_t num_k_per_block, Argument &karg, EpilogueArgument &epilogue_args)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:930
typename Base::BlockwiseGemmPipe BlockwiseGemmPipe
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:561
static __device__ void Run(void *p_shared, const SplitKBatchOffset &splitk_batch_offset, Argument &karg, EpilogueArgument &epilogue_args, const index_t A_k_id=0, const index_t B_k_id=0)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:753
AsDataType AsDataType_
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:331
static constexpr auto I0
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:174
static constexpr auto I1
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:175
static constexpr index_t NumBTensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:184
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:366
static __device__ void Run(void *p_shared, const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2CTileMapExt &block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const ComputePtrOffsetOfN compute_ptr_offset_of_n, const index_t num_k_per_block, Argument &karg, EpilogueArgument &epilogue_args)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:792
static constexpr index_t NumATensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:183
static __device__ void Run(void *p_shared, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, const ComputePtrOffsetOfBatch &compute_ptr_offset_of_batch, const ComputePtrOffsetOfN &compute_ptr_offset_of_n, [[maybe_unused]] const index_t num_k_per_block, Argument &karg, EpilogueArgument &epilogue_args)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:1051
__host__ static __device__ auto MakeAsGridDescriptor_AK0_M_AK1(const index_t M, const index_t MPad, const index_t K, const index_t KPad, const std::array< index_t, NumATensor > &StrideAs, const index_t AK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:466
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:406
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:323
static constexpr auto I2
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:176
BsDataType BsDataType_
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:332
static constexpr index_t APackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:221
decltype(MakeBsGridPointer()) BsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:439
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:371
BlockToCTileMap_Grouped_M00_N0_M01Adapt< 8, MPerBlock, NPerBlock > Block2CTileMap
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:565
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:387
__host__ static __device__ auto MakeBsGridDescriptor_BK0_N_BK1(const index_t K, const index_t KPad, const index_t N, const index_t NPad, const std::array< index_t, NumBTensor > &StrideBs, const index_t BK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:493
static __device__ index_t GetKBlockPerScale()
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:568
static constexpr index_t NumDTensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:596
static __device__ void Run(AsGridPointer &p_as_grid, BsGridPointer &p_bs_grid, DsGridPointer &p_ds_grid, EDataType *p_e_grid, void *p_shared, const Problem &problem, const Block2CTileMap &block_2_ctile_map, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, EpilogueArgument &epilogue_args, const index_t A_k_id=0, const index_t B_k_id=0)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:577
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:611
__device__ static constexpr __host__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:624
static __device__ auto DefaultBlock2CTileMap(const Problem &problem)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:773
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:411
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: device_base.hpp:270