/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp Source File
device_grouped_gemm_xdl.hpp
Go to the documentation of this file.
1 #pragma once
2 // SPDX-License-Identifier: MIT
3 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
4 
5 #pragma once
6 
7 #include <iostream>
8 #include <sstream>
9 
11 #include "ck/utility/env.hpp"
21 
22 namespace ck {
23 namespace tensor_operation {
24 namespace device {
25 
26 template <typename GridwiseGemm,
27  typename GemmDesc,
28  typename AElementwiseOperation,
29  typename BElementwiseOperation,
30  typename CDEElementwiseOperation,
31  bool HasMainKBlockLoop>
32 __global__ void
33 #if CK_USE_LAUNCH_BOUNDS
35 #endif
36  kernel_grouped_gemm_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
37  const index_t group_count,
38  const AElementwiseOperation a_element_op,
39  const BElementwiseOperation b_element_op,
40  const CDEElementwiseOperation c_element_op)
41 {
42 #if defined(__gfx9__)
43  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
44 
45  const index_t block_id = get_block_1d_id();
46 
47  const auto gemm_desc_ptr =
48  reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
49 
50  index_t left = 0;
51  index_t right = group_count;
52  index_t group_id = index_t((left + right) / 2);
53  while((!(block_id >= gemm_desc_ptr[group_id].BlockStart_ &&
54  block_id < gemm_desc_ptr[group_id].BlockEnd_)) &&
55  left <= right)
56  {
57  if(block_id < gemm_desc_ptr[group_id].BlockStart_)
58  {
59  right = group_id;
60  }
61  else
62  {
63  left = group_id;
64  }
65  group_id = index_t((left + right) / 2);
66  }
67 
68  GridwiseGemm::template Run<HasMainKBlockLoop, InMemoryDataOperationEnum::Set>(
69  gemm_desc_ptr[group_id].a_ptr_,
70  gemm_desc_ptr[group_id].b_ptr_,
71  gemm_desc_ptr[group_id].ds_ptr_,
72  gemm_desc_ptr[group_id].e_ptr_,
73  p_shared,
74  a_element_op,
75  b_element_op,
76  c_element_op,
77  gemm_desc_ptr[group_id].a_grid_desc_ak0_m_ak1_,
78  gemm_desc_ptr[group_id].b_grid_desc_bk0_n_bk1_,
79  gemm_desc_ptr[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
80  gemm_desc_ptr[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
81  gemm_desc_ptr[group_id].block_2_etile_map_);
82 #else
83  ignore = gemm_descs_const;
84  ignore = group_count;
85  ignore = a_element_op;
86  ignore = b_element_op;
87  ignore = c_element_op;
88 #endif
89 }
90 
91 template <typename ALayout,
92  typename BLayout,
93  typename DsLayout,
94  typename ELayout,
95  typename ADataType,
96  typename BDataType,
97  typename AccDataType,
98  typename CShuffleDataType,
99  typename DsDataType,
100  typename EDataType,
101  typename AElementwiseOperation,
102  typename BElementwiseOperation,
103  typename CDEElementwiseOperation,
104  GemmSpecialization GemmSpec,
105  ck::index_t NumPrefetch,
106  ck::index_t BlockSize,
107  ck::index_t MPerBlock,
108  ck::index_t NPerBlock,
109  ck::index_t KPerBlock,
110  ck::index_t AK1,
111  ck::index_t BK1,
112  ck::index_t MPerXDL,
113  ck::index_t NPerXDL,
114  ck::index_t MXdlPerWave,
115  ck::index_t NXdlPerWave,
116  typename ABlockTransferThreadClusterLengths_K0_M_K1,
117  typename ABlockTransferThreadClusterArrangeOrder,
118  typename ABlockTransferSrcAccessOrder,
119  ck::index_t ABlockTransferSrcVectorDim,
120  ck::index_t ABlockTransferSrcScalarPerVector,
121  ck::index_t ABlockTransferDstScalarPerVector_K1,
122  bool ABlockLdsExtraM,
123  typename BBlockTransferThreadClusterLengths_K0_N_K1,
124  typename BBlockTransferThreadClusterArrangeOrder,
125  typename BBlockTransferSrcAccessOrder,
126  ck::index_t BBlockTransferSrcVectorDim,
127  ck::index_t BBlockTransferSrcScalarPerVector,
128  ck::index_t BBlockTransferDstScalarPerVector_K1,
129  bool BBlockLdsExtraN,
130  index_t CShuffleMXdlPerWavePerShuffle,
131  index_t CShuffleNXdlPerWavePerShuffle,
132  typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
133  index_t CDEBlockTransferScalarPerVector_NPerBlock,
135 struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
136  BLayout,
137  DsLayout,
138  ELayout,
139  ADataType,
140  BDataType,
141  DsDataType,
142  EDataType,
143  AElementwiseOperation,
144  BElementwiseOperation,
145  CDEElementwiseOperation>
146 {
148 
149  static constexpr index_t NumDTensor = DsDataType::Size();
150 
151  static constexpr auto I0 = Number<0>{};
152  static constexpr auto I1 = Number<1>{};
153  static constexpr auto I2 = Number<2>{};
154 
155  static constexpr auto matrix_padder =
156  MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
157 
158  static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
159  {
160  const auto a_grid_desc_mraw_kraw = [&]() {
161  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
162  {
163  return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
164  make_tuple(StrideA, I1));
165  }
166  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
167  {
168  return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
169  make_tuple(I1, StrideA));
170  }
171  }();
172 
173  return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
174  }
175 
176  static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
177  {
178  const auto b_grid_desc_nraw_kraw = [&]() {
180  {
181  return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
182  make_tuple(I1, StrideB));
183  }
185  {
186  return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
187  make_tuple(StrideB, I1));
188  }
189  }();
190 
191  return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
192  }
193 
194  template <typename ELay>
195  static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
196  {
197  const auto e_grid_desc_mraw_nraw = [&]() {
199  {
200  return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
201  make_tuple(StrideE, I1));
202  }
204  {
205  return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
206  make_tuple(I1, StrideE));
207  }
208  }();
209 
210  return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
211  }
212 
213  static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
214  const std::array<index_t, NumDTensor>& NRaws,
215  const std::array<index_t, NumDTensor>& DsStride)
216  {
217  return generate_tuple(
218  [&](auto i) {
219  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
220 
221  return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
222  },
224  }
225 
226  using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
227  using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
229  using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
230 
231  using ComputeDataType = ADataType;
232 
233  // GridwiseGemm
235  ADataType, // TODO: distinguish A/B datatype
236  BDataType,
238  AccDataType,
239  CShuffleDataType,
240  DsDataType,
241  EDataType,
242  AElementwiseOperation,
243  BElementwiseOperation,
244  CDEElementwiseOperation,
245  NumPrefetch, // NumGemmKPrefetchStage
246  BlockSize,
247  MPerBlock,
248  NPerBlock,
249  KPerBlock,
250  AK1,
251  BK1,
252  MPerXDL,
253  NPerXDL,
254  MXdlPerWave,
255  NXdlPerWave,
256  ABlockTransferThreadClusterLengths_K0_M_K1,
257  ABlockTransferThreadClusterArrangeOrder,
258  ABlockTransferSrcAccessOrder,
259  ABlockTransferSrcVectorDim,
260  ABlockTransferSrcScalarPerVector,
261  ABlockTransferDstScalarPerVector_K1,
262  false, // AThreadTransferSrcResetCoordinateAfterRun,
263  ABlockLdsExtraM,
264  BBlockTransferThreadClusterLengths_K0_N_K1,
265  BBlockTransferThreadClusterArrangeOrder,
266  BBlockTransferSrcAccessOrder,
267  BBlockTransferSrcVectorDim,
268  BBlockTransferSrcScalarPerVector,
269  BBlockTransferDstScalarPerVector_K1,
270  false, // BThreadTransferSrcResetCoordinateAfterRun,
271  BBlockLdsExtraN,
272  CShuffleMXdlPerWavePerShuffle,
273  CShuffleNXdlPerWavePerShuffle,
274  CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
275  CDEBlockTransferScalarPerVector_NPerBlock,
276  LoopSched>;
277 
280  AGridDesc_M_K{}))>;
283  BGridDesc_N_K{}))>;
286  DsGridDesc_M_N{}))>;
289  EGridDesc_M_N{}))>;
290 
292  {
295 
297  {
299  BlockStart_ = -1;
300  }
301 
302  GroupedGemmBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n, ck::index_t BlockStart)
303  {
305  BlockStart_ = BlockStart;
306  }
307 
308  template <typename TopIdx>
309  __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
310  {
311  return block_2_etile_map_.CalculateBottomIndex(
312  make_multi_index(idx_top[I0] - BlockStart_));
313  }
314 
315  // it's actually E-Tile
316  template <typename CTileIdx, typename CTileDim>
317  __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
318  const CTileDim& c_tile_dim) const
319  {
320  return block_2_etile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
321  }
322 
323  __host__ bool CheckValidity(const EGridDesc_M_N& e_grid_desc_m_n) const
324  {
325  return block_2_etile_map_.CheckValidity(e_grid_desc_m_n);
326  }
327 
330  };
331 
333  {
334  // pointers
335  const ADataType* a_ptr_;
336  const BDataType* b_ptr_;
338  EDataType* e_ptr_;
339 
340  // tensor descriptors for problem definiton
345 
346  // tensor descriptors for block/thread-wise copy
352 
353  // block-to-e-tile map
356  };
357 
358  // Argument
359  struct Argument : public BaseArgument
360  {
361  Argument(std::vector<const void*>& p_As,
362  std::vector<const void*>& p_Bs,
363  std::vector<std::array<const void*, NumDTensor>>& p_Ds,
364  std::vector<void*>& p_Es,
365  std::vector<GemmDesc>& gemm_descs,
366  AElementwiseOperation a_element_op,
367  BElementwiseOperation b_element_op,
368  CDEElementwiseOperation c_element_op)
369  : a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op}
370  {
371  grid_size_ = 0;
372 
373  group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
374 
375  if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) &&
376  group_count_ == ck::type_convert<ck::index_t>(p_Bs.size()) &&
377  group_count_ == ck::type_convert<ck::index_t>(p_Es.size())))
378  {
379  throw std::runtime_error("wrong! group_count_ != p_As/b/c.size");
380  }
381 
383 
385 
386  for(std::size_t i = 0; i < gemm_descs.size(); i++)
387  {
388  const index_t M = gemm_descs[i].M_;
389  const index_t N = gemm_descs[i].N_;
390  const index_t K = gemm_descs[i].K_;
391 
392  a_mtx_mraw_kraw_.emplace_back(M, K);
393  b_mtx_nraw_kraw_.emplace_back(N, K);
394 
395  if(M == 0)
396  {
398  continue;
399  }
400 
401  const index_t StrideA = gemm_descs[i].stride_A_;
402  const index_t StrideB = gemm_descs[i].stride_B_;
403  const index_t StrideC = gemm_descs[i].stride_C_;
404 
405  // pointer
406  typename GridwiseGemm::DsGridPointer p_ds_grid{};
407 
408  static_for<0, NumDTensor, 1>{}([&](auto j) {
409  using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
410 
411  p_ds_grid(j) = static_cast<const DDataType*>(p_Ds[i][j]);
412  });
413 
414  // tensor descriptors for problem definiton
415  const auto a_grid_desc_m_k = DeviceOp::MakeAGridDescriptor_M_K(M, K, StrideA);
416  const auto b_grid_desc_n_k = DeviceOp::MakeBGridDescriptor_N_K(K, N, StrideB);
417 
418  DsGridDesc_M_N ds_grid_desc_m_n;
419 
420  static_for<0, NumDTensor, 1>{}([&](auto j) {
421  using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
422 
423  ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
424  M, N, gemm_descs[i].stride_Ds_[j]);
425  });
426 
427  const auto e_grid_desc_m_n =
428  DeviceOp::MakeEGridDescriptor_M_N<ELayout>(M, N, StrideC);
429 
430  // tensor descriptors for block/thread-wise copy
431  const auto a_grid_desc_ak0_m_ak1 =
433 
434  const auto b_grid_desc_bk0_n_bk1 =
436 
437  const index_t grid_size_grp =
438  GroupedGemmBlock2ETileMap(e_grid_desc_m_n, 0)
439  .block_2_etile_map_.CalculateGridSize(e_grid_desc_m_n);
440 
441  const index_t BlockStart = grid_size_;
442  const index_t BlockEnd = grid_size_ + grid_size_grp;
443 
444  grid_size_ += grid_size_grp;
445 
446  // block-to-e-tile map
447  const auto block_2_etile_map =
448  GroupedGemmBlock2ETileMap(e_grid_desc_m_n, BlockStart);
449 
450  if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
451  b_grid_desc_n_k,
452  ds_grid_desc_m_n,
453  e_grid_desc_m_n,
454  block_2_etile_map))
455  {
456  // tensor descriptors for block/thread-wise copy
458  ds_grid_desc_mblock_mperblock_nblock_nperblock;
459 
460  static_for<0, NumDTensor, 1>{}([&](auto j) {
461  ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
463  ds_grid_desc_m_n[j]);
464  });
465 
466  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
468  e_grid_desc_m_n);
469 
470  gemm_desc_kernel_arg_.push_back(
471  GemmBiasTransKernelArg{static_cast<const ADataType*>(p_As[i]),
472  static_cast<const BDataType*>(p_Bs[i]),
473  p_ds_grid,
474  static_cast<EDataType*>(p_Es[i]),
475  a_grid_desc_m_k,
476  b_grid_desc_n_k,
477  ds_grid_desc_m_n,
478  e_grid_desc_m_n,
479  a_grid_desc_ak0_m_ak1,
480  b_grid_desc_bk0_n_bk1,
481  ds_grid_desc_mblock_mperblock_nblock_nperblock,
482  e_grid_desc_mblock_mperblock_nblock_nperblock,
483  block_2_etile_map,
484  BlockStart,
485  BlockEnd});
486  }
487  }
488  }
489 
490  // private:
493 
494  AElementwiseOperation a_element_op_;
495  BElementwiseOperation b_element_op_;
496  CDEElementwiseOperation c_element_op_;
497 
498  std::vector<GemmBiasTransKernelArg> gemm_desc_kernel_arg_;
499  std::vector<Tuple<index_t, index_t>> a_mtx_mraw_kraw_;
500  std::vector<Tuple<index_t, index_t>> b_mtx_nraw_kraw_;
501 
504  };
505 
506  // Invoker
507  struct Invoker : public BaseInvoker
508  {
510 
511  float Run(const Argument& arg,
512  const StreamConfig& stream_config = StreamConfig{},
513  hipStream_t cpy_stream = nullptr,
514  hipEvent_t cpy_event = nullptr)
515  {
516  bool has_main_k_block_loop = true;
517 
518  for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
519  {
520  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
521  {
522  std::cout << "group: " << i << " arg.a_grid_desc_ak0_m_ak1_{"
523  << arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0)
524  << ", "
525  << arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I1)
526  << ", "
527  << arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2)
528  << "}";
529 
530  std::cout << ", arg.b_grid_desc_bk0_n_bk1_{"
531  << arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I0)
532  << ", "
533  << arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I1)
534  << ", "
535  << arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I2)
536  << "}";
537 
538  std::cout << ", arg.e_grid_desc_m_n_{ "
539  << arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) << ", "
540  << arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) << "}"
541  << std::endl;
542  }
543 
544  if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_m_k_,
545  arg.gemm_desc_kernel_arg_[i].b_grid_desc_n_k_,
546  arg.gemm_desc_kernel_arg_[i].ds_grid_desc_m_n_,
547  arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_,
548  arg.gemm_desc_kernel_arg_[i].block_2_etile_map_))
549  {
550  throw std::runtime_error(
551  "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
552  }
553 
554  const auto K = arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
555  arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2);
556 
557  if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop)
558  {
559  throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
560  }
561  }
562 
563  // If the user provides copy stream and copy event, we assume that they're also
564  // responsible for providing allocated host memory (eg. pinned) which
565  // would be used to copy kernel arguments to the device.
566  if(cpy_stream && cpy_event)
567  {
568  if(arg.gemm_kernel_host_args_ == nullptr)
569  {
570  std::ostringstream err;
571  err << "No memory has been allocated for gemm kernel host args "
572  << "when providing the copy stream and copy event! In " << __FILE__ << ":"
573  << __LINE__ << ", in function: " << __func__;
574  throw std::runtime_error(err.str());
575  }
576  hipGetErrorString(hipMemcpyAsync(arg.p_workspace_,
578  arg.group_count_ * sizeof(GemmBiasTransKernelArg),
579  hipMemcpyHostToDevice,
580  cpy_stream));
581  hipGetErrorString(hipEventRecord(cpy_event, cpy_stream));
582  hipGetErrorString(hipEventSynchronize(cpy_event));
583  }
584  else // In this case CK owns memory allocated on host.
585  {
586  hipGetErrorString(hipMemcpyAsync(arg.p_workspace_,
587  arg.gemm_desc_kernel_arg_.data(),
588  arg.gemm_desc_kernel_arg_.size() *
589  sizeof(GemmBiasTransKernelArg),
590  hipMemcpyHostToDevice,
591  stream_config.stream_id_));
592  }
593 
594  float ave_time = 0;
595 
596  auto launch_kernel = [&](auto has_main_k_block_loop_) {
597  const auto kernel = kernel_grouped_gemm_xdl<GridwiseGemm,
598  GemmBiasTransKernelArg,
599  AElementwiseOperation,
600  BElementwiseOperation,
601  CDEElementwiseOperation,
602  has_main_k_block_loop_>;
603 
604  return launch_and_time_kernel(
605  stream_config,
606  kernel,
607  dim3(arg.grid_size_),
608  dim3(BlockSize),
609  0,
611  arg.gemm_desc_kernel_arg_.size(),
612  arg.a_element_op_,
613  arg.b_element_op_,
614  arg.c_element_op_);
615  };
616 
617  if(has_main_k_block_loop)
618  {
619  ave_time = launch_kernel(integral_constant<bool, true>{});
620  }
621  else
622  {
623  ave_time = launch_kernel(integral_constant<bool, false>{});
624  }
625 
626  return ave_time;
627  }
628 
629  // polymorphic
630  float Run(const BaseArgument* p_arg,
631  const StreamConfig& stream_config = StreamConfig{}) override
632  {
633  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
634  }
635  };
636 
637  static bool IsSupportedArgument(const Argument& arg)
638  {
639  if(!ck::is_xdl_supported())
640  {
641  return false;
642  }
643 
644  if((ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) +
646  {
647  return false;
648  }
649 
650  bool supported = true;
651 
652  // If we use padding we do not support vector loads for dimensions not divisible by vector
653  // load size.
654  if constexpr(GemmSpec != GemmSpecialization::Default)
655  {
656  // [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout,
657  // thus we have to adapt it to the {M,K} or {N,K} layout.
658  const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0;
659  const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0;
660 
661  for(index_t i = 0; i < arg.group_count_; ++i)
662  {
663  const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number<a_raw_vector_dim>{});
664  const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number<b_raw_vector_dim>{});
665 
666  supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0);
667  supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0);
668  }
669  }
670 
671  return supported;
672  }
673 
674  // polymorphic
675  bool IsSupportedArgument(const BaseArgument* p_arg) override
676  {
677  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
678  }
679 
680  static auto MakeArgument(std::vector<const void*>& p_As,
681  std::vector<const void*>& p_Bs,
682  std::vector<std::array<const void*, NumDTensor>>& p_Ds,
683  std::vector<void*>& p_Es,
684  std::vector<GemmDesc> gemm_descs,
685  AElementwiseOperation a_element_op,
686  BElementwiseOperation b_element_op,
687  CDEElementwiseOperation c_element_op)
688  {
689  return Argument{
690  p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op};
691  }
692 
693  static auto MakeInvoker() { return Invoker{}; }
694 
695  // polymorphic
696  std::unique_ptr<BaseArgument>
697  MakeArgumentPointer(std::vector<const void*>& p_As,
698  std::vector<const void*>& p_Bs,
699  std::vector<std::array<const void*, NumDTensor>>& p_Ds,
700  std::vector<void*>& p_Es,
701  std::vector<GemmDesc>& gemm_descs,
702  AElementwiseOperation a_element_op,
703  BElementwiseOperation b_element_op,
704  CDEElementwiseOperation c_element_op) override
705  {
706  return std::make_unique<Argument>(
707  p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op);
708  }
709 
710  // polymorphic
711  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
712  {
713  return std::make_unique<Invoker>(Invoker{});
714  }
715 
716  // polymorphic
717  std::string GetTypeString() const override
718  {
719  auto str = std::stringstream();
720 
721  // clang-format off
722  str << "DeviceGroupedGemm_Xdl"
723  << "<"
724  << BlockSize << ", "
725  << MPerBlock << ", "
726  << NPerBlock << ", "
727  << KPerBlock << ", "
728  << AK1 << ", "
729  << BK1 << ", "
730  << MPerXDL << ", "
731  << NPerXDL << ", "
732  << MXdlPerWave << ", "
733  << NXdlPerWave << ", "
734  << ABlockTransferSrcScalarPerVector << ", "
735  << BBlockTransferSrcScalarPerVector << ", "
736  << CShuffleMXdlPerWavePerShuffle << ", "
737  << CShuffleNXdlPerWavePerShuffle << ", "
738  << getGemmSpecializationString(GemmSpec)
739  << ">";
740  // clang-format on
741 
742  return str.str();
743  }
744 
745  size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
746  {
747  auto p_arg_ = dynamic_cast<const Argument*>(p_arg);
748  if(p_arg_)
749  {
750  return p_arg_->group_count_ * sizeof(GemmBiasTransKernelArg);
751  }
752  else
753  throw std::runtime_error("The argument pointer is not an object of "
754  "DeviceGroupedGemmMultipleDXdlCShuffle::Argument structure!");
755  }
756 
757  size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
758  {
759  return GetWorkSpaceSize(p_arg);
760  }
761 
762  void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
763  {
764  return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args);
765  }
766 
767  size_t GetHostKernelArgSize(const BaseArgument* p_arg) const { return GetWorkSpaceSize(p_arg); }
768 
769  //----------------------------------------------------------------------------------------------
778  void SetHostKernelArgsPointer(BaseArgument* p_arg, void* p_host_kernel_args) const
779  {
780  Argument* pArg_ = dynamic_cast<Argument*>(p_arg);
781  if(!pArg_)
782  {
783  throw std::runtime_error("Failed to cast argument pointer!");
784  }
785 
786  pArg_->gemm_kernel_host_args_ = p_host_kernel_args;
787  std::copy(pArg_->gemm_desc_kernel_arg_.begin(),
788  pArg_->gemm_desc_kernel_arg_.end(),
789  static_cast<GemmBiasTransKernelArg*>(pArg_->gemm_kernel_host_args_));
790  }
791 };
792 
793 } // namespace device
794 } // namespace tensor_operation
795 } // namespace ck
#define CK_CONSTANT_ADDRESS_SPACE
Definition: ck.hpp:22
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:30
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
auto copy(InputRange &&range, OutputIterator iter) -> decltype(std::copy(std::begin(std::forward< InputRange >(range)), std::end(std::forward< InputRange >(range)), iter))
Definition: algorithm.hpp:14
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
__global__ void kernel_grouped_gemm_xdl(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation c_element_op)
Definition: device_grouped_gemm_xdl.hpp:36
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition: kernel_launch.hpp:140
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
bool is_xdl_supported()
Definition: device_prop.hpp:68
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE * cast_pointer_to_constant_address_space(T *p)
Definition: amd_address_space.hpp:35
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition: amd_address_space.hpp:24
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:298
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
Definition: stream_config.hpp:10
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:78
__host__ static constexpr __device__ auto MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:188
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_M_K &a_grid_desc_m_k, const BGridDesc_N_K &b_grid_desc_n_k, const DsGridDesc_M_N &ds_grid_desc_m_n, const EGridDesc_M_N &e_grid_desc_m_n, [[maybe_unused]] const Block2ETileMap &, index_t k_batch=1)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:330
__host__ static constexpr __device__ auto MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:205
__host__ static constexpr __device__ auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:222
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:406
__host__ static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N &ds_grid_desc_m_n)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:243
__host__ static constexpr __device__ auto MakeDefaultBlock2ETileMap(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:255
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K, index_t k_batch=1)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:398
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: functional2.hpp:33
Definition: device_base.hpp:51
void * p_workspace_
Definition: device_base.hpp:58
Definition: device_base.hpp:62
virtual void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const
Definition: device_base.hpp:102
Definition: device_grouped_gemm_xdl.hpp:360
std::vector< GemmBiasTransKernelArg > gemm_desc_kernel_arg_
Definition: device_grouped_gemm_xdl.hpp:498
std::vector< Tuple< index_t, index_t > > a_mtx_mraw_kraw_
Definition: device_grouped_gemm_xdl.hpp:499
Argument(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor >> &p_Ds, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation c_element_op)
Definition: device_grouped_gemm_xdl.hpp:361
AElementwiseOperation a_element_op_
Definition: device_grouped_gemm_xdl.hpp:494
CDEElementwiseOperation c_element_op_
Definition: device_grouped_gemm_xdl.hpp:496
std::vector< Tuple< index_t, index_t > > b_mtx_nraw_kraw_
Definition: device_grouped_gemm_xdl.hpp:500
index_t skipped_group_count_
Definition: device_grouped_gemm_xdl.hpp:492
index_t grid_size_
Definition: device_grouped_gemm_xdl.hpp:502
index_t group_count_
Definition: device_grouped_gemm_xdl.hpp:491
BElementwiseOperation b_element_op_
Definition: device_grouped_gemm_xdl.hpp:495
void * gemm_kernel_host_args_
Definition: device_grouped_gemm_xdl.hpp:503
EGridDesc_M_N e_grid_desc_m_n_
Definition: device_grouped_gemm_xdl.hpp:344
DsGridDesc_M_N ds_grid_desc_m_n_
Definition: device_grouped_gemm_xdl.hpp:343
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition: device_grouped_gemm_xdl.hpp:348
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_
Definition: device_grouped_gemm_xdl.hpp:351
const ADataType * a_ptr_
Definition: device_grouped_gemm_xdl.hpp:335
BGridDesc_N_K b_grid_desc_n_k_
Definition: device_grouped_gemm_xdl.hpp:342
AGridDesc_M_K a_grid_desc_m_k_
Definition: device_grouped_gemm_xdl.hpp:341
GroupedGemmBlock2ETileMap block_2_etile_map_
Definition: device_grouped_gemm_xdl.hpp:354
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition: device_grouped_gemm_xdl.hpp:347
EDataType * e_ptr_
Definition: device_grouped_gemm_xdl.hpp:338
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_
Definition: device_grouped_gemm_xdl.hpp:350
ck::index_t BlockStart_
Definition: device_grouped_gemm_xdl.hpp:355
GridwiseGemm::DsGridPointer ds_ptr_
Definition: device_grouped_gemm_xdl.hpp:337
ck::index_t BlockEnd_
Definition: device_grouped_gemm_xdl.hpp:355
const BDataType * b_ptr_
Definition: device_grouped_gemm_xdl.hpp:336
__host__ __device__ bool ValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim) const
Definition: device_grouped_gemm_xdl.hpp:317
GroupedGemmBlock2ETileMap()
Definition: device_grouped_gemm_xdl.hpp:296
Block2ETileMap block_2_etile_map_
Definition: device_grouped_gemm_xdl.hpp:328
GroupedGemmBlock2ETileMap(const EGridDesc_M_N &e_grid_desc_m_n, ck::index_t BlockStart)
Definition: device_grouped_gemm_xdl.hpp:302
ck::index_t BlockStart_
Definition: device_grouped_gemm_xdl.hpp:329
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: device_grouped_gemm_xdl.hpp:309
__host__ bool CheckValidity(const EGridDesc_M_N &e_grid_desc_m_n) const
Definition: device_grouped_gemm_xdl.hpp:323
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))> Block2ETileMap
Definition: device_grouped_gemm_xdl.hpp:294
Definition: device_grouped_gemm_xdl.hpp:508
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_grouped_gemm_xdl.hpp:630
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{}, hipStream_t cpy_stream=nullptr, hipEvent_t cpy_event=nullptr)
Definition: device_grouped_gemm_xdl.hpp:511
Definition: device_grouped_gemm_xdl.hpp:146
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))> BGridDesc_BK0_N_BK1
Definition: device_grouped_gemm_xdl.hpp:283
size_t GetHostKernelArgSize(const BaseArgument *p_arg) const
Definition: device_grouped_gemm_xdl.hpp:767
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
Definition: device_grouped_gemm_xdl.hpp:195
decltype(MakeBGridDescriptor_N_K(1, 1, 1)) BGridDesc_N_K
Definition: device_grouped_gemm_xdl.hpp:227
static auto MakeArgument(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor >> &p_Ds, std::vector< void * > &p_Es, std::vector< GemmDesc > gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation c_element_op)
Definition: device_grouped_gemm_xdl.hpp:680
std::string GetTypeString() const override
Definition: device_grouped_gemm_xdl.hpp:717
static auto MakeDsGridDescriptor_M_N(const std::array< index_t, NumDTensor > &MRaws, const std::array< index_t, NumDTensor > &NRaws, const std::array< index_t, NumDTensor > &DsStride)
Definition: device_grouped_gemm_xdl.hpp:213
void SetDeviceKernelArgs(BaseArgument *p_arg, void *p_dev_kernel_args) const override
Sets the device kernel arguments pointer and may copy data to device.
Definition: device_grouped_gemm_xdl.hpp:762
decltype(MakeAGridDescriptor_M_K(1, 1, 1)) AGridDesc_M_K
Definition: device_grouped_gemm_xdl.hpp:226
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
Definition: device_grouped_gemm_xdl.hpp:176
std::unique_ptr< BaseArgument > MakeArgumentPointer(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor >> &p_Ds, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation c_element_op) override
Definition: device_grouped_gemm_xdl.hpp:697
ADataType ComputeDataType
Definition: device_grouped_gemm_xdl.hpp:231
remove_cvref_t< decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))> DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition: device_grouped_gemm_xdl.hpp:286
static constexpr auto I1
Definition: device_grouped_gemm_xdl.hpp:152
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))> AGridDesc_AK0_M_AK1
Definition: device_grouped_gemm_xdl.hpp:280
remove_cvref_t< decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))> EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition: device_grouped_gemm_xdl.hpp:289
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_grouped_gemm_xdl.hpp:675
size_t GetDeviceKernelArgSize(const BaseArgument *p_arg) const override
Gets the device kernel argument size.
Definition: device_grouped_gemm_xdl.hpp:757
static constexpr auto I2
Definition: device_grouped_gemm_xdl.hpp:153
static constexpr auto matrix_padder
Definition: device_grouped_gemm_xdl.hpp:155
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N({}, {}, {}))> DsGridDesc_M_N
Definition: device_grouped_gemm_xdl.hpp:228
static bool IsSupportedArgument(const Argument &arg)
Definition: device_grouped_gemm_xdl.hpp:637
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_grouped_gemm_xdl.hpp:711
static auto MakeInvoker()
Definition: device_grouped_gemm_xdl.hpp:693
static constexpr index_t NumDTensor
Definition: device_grouped_gemm_xdl.hpp:149
GridwiseGemmMultipleD_xdl_cshuffle< ADataType, BDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, NumPrefetch, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched > GridwiseGemm
Definition: device_grouped_gemm_xdl.hpp:276
decltype(MakeEGridDescriptor_M_N< ELayout >(1, 1, 1)) EGridDesc_M_N
Definition: device_grouped_gemm_xdl.hpp:229
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
Definition: device_grouped_gemm_xdl.hpp:158
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition: device_grouped_gemm_xdl.hpp:745
void SetHostKernelArgsPointer(BaseArgument *p_arg, void *p_host_kernel_args) const
Sets the host kernel arguments pointer and copies that data on the host side. This function can be ut...
Definition: device_grouped_gemm_xdl.hpp:778
static constexpr auto I0
Definition: device_grouped_gemm_xdl.hpp:151
Definition: device_grouped_gemm.hpp:99
Definition: device_grouped_gemm.hpp:80
Definition: matrix_padder.hpp:180
#define CK_ENV(name)
Definition: env.hpp:129