/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp Source File
device_grouped_gemm_xdl.hpp
Go to the documentation of this file.
1 #pragma once
2 // SPDX-License-Identifier: MIT
3 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
4 
5 #pragma once
6 
7 #include <iostream>
8 #include <sstream>
9 
11 #include "ck/utility/env.hpp"
21 
22 namespace ck {
23 namespace tensor_operation {
24 namespace device {
25 
26 template <typename GridwiseGemm,
27  typename GemmDesc,
28  typename AElementwiseOperation,
29  typename BElementwiseOperation,
30  typename CDEElementwiseOperation,
31  bool HasMainKBlockLoop>
32 __global__ void
33 #if CK_USE_LAUNCH_BOUNDS
35 #endif
36  kernel_grouped_gemm_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
37  const index_t group_count,
38  const AElementwiseOperation a_element_op,
39  const BElementwiseOperation b_element_op,
40  const CDEElementwiseOperation c_element_op)
41 {
42 #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
43  if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
44  {
45  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
46 
47  const index_t block_id = get_block_1d_id();
48 
49  const auto gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(
50  cast_pointer_to_generic_address_space(gemm_descs_const));
51 
52  index_t left = 0;
53  index_t right = group_count;
54  index_t group_id = index_t((left + right) / 2);
55  while((!(block_id >= gemm_desc_ptr[group_id].BlockStart_ &&
56  block_id < gemm_desc_ptr[group_id].BlockEnd_)) &&
57  left <= right)
58  {
59  if(block_id < gemm_desc_ptr[group_id].BlockStart_)
60  {
61  right = group_id;
62  }
63  else
64  {
65  left = group_id;
66  }
67  group_id = index_t((left + right) / 2);
68  }
69 
70  GridwiseGemm::template Run<HasMainKBlockLoop, InMemoryDataOperationEnum::Set>(
71  gemm_desc_ptr[group_id].a_ptr_,
72  gemm_desc_ptr[group_id].b_ptr_,
73  gemm_desc_ptr[group_id].ds_ptr_,
74  gemm_desc_ptr[group_id].e_ptr_,
75  p_shared,
76  a_element_op,
77  b_element_op,
78  c_element_op,
79  gemm_desc_ptr[group_id].a_grid_desc_ak0_m_ak1_,
80  gemm_desc_ptr[group_id].b_grid_desc_bk0_n_bk1_,
81  gemm_desc_ptr[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
82  gemm_desc_ptr[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
83  gemm_desc_ptr[group_id].block_2_etile_map_);
84  }
85 #else
86  ignore = gemm_descs_const;
87  ignore = group_count;
88  ignore = a_element_op;
89  ignore = b_element_op;
90  ignore = c_element_op;
91 #endif
92 }
93 
94 template <typename ALayout,
95  typename BLayout,
96  typename DsLayout,
97  typename ELayout,
98  typename ADataType,
99  typename BDataType,
100  typename AccDataType,
101  typename CShuffleDataType,
102  typename DsDataType,
103  typename EDataType,
104  typename AElementwiseOperation,
105  typename BElementwiseOperation,
106  typename CDEElementwiseOperation,
107  GemmSpecialization GemmSpec,
108  ck::index_t NumPrefetch,
109  ck::index_t BlockSize,
110  ck::index_t MPerBlock,
111  ck::index_t NPerBlock,
112  ck::index_t KPerBlock,
113  ck::index_t AK1,
114  ck::index_t BK1,
115  ck::index_t MPerXDL,
116  ck::index_t NPerXDL,
117  ck::index_t MXdlPerWave,
118  ck::index_t NXdlPerWave,
119  typename ABlockTransferThreadClusterLengths_K0_M_K1,
120  typename ABlockTransferThreadClusterArrangeOrder,
121  typename ABlockTransferSrcAccessOrder,
122  ck::index_t ABlockTransferSrcVectorDim,
123  ck::index_t ABlockTransferSrcScalarPerVector,
124  ck::index_t ABlockTransferDstScalarPerVector_K1,
125  bool ABlockLdsExtraM,
126  typename BBlockTransferThreadClusterLengths_K0_N_K1,
127  typename BBlockTransferThreadClusterArrangeOrder,
128  typename BBlockTransferSrcAccessOrder,
129  ck::index_t BBlockTransferSrcVectorDim,
130  ck::index_t BBlockTransferSrcScalarPerVector,
131  ck::index_t BBlockTransferDstScalarPerVector_K1,
132  bool BBlockLdsExtraN,
133  index_t CShuffleMXdlPerWavePerShuffle,
134  index_t CShuffleNXdlPerWavePerShuffle,
135  typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
136  index_t CDEBlockTransferScalarPerVector_NPerBlock,
138 struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
139  BLayout,
140  DsLayout,
141  ELayout,
142  ADataType,
143  BDataType,
144  DsDataType,
145  EDataType,
146  AElementwiseOperation,
147  BElementwiseOperation,
148  CDEElementwiseOperation>
149 {
152  static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
153  static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
154  static constexpr index_t NumDTensor = DsDataType::Size();
155 
156  static constexpr auto I0 = Number<0>{};
157  static constexpr auto I1 = Number<1>{};
158  static constexpr auto I2 = Number<2>{};
159 
160  static constexpr auto matrix_padder =
161  MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
162 
163  static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
164  {
165  const auto a_grid_desc_mraw_kraw = [&]() {
166  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
167  {
168  return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
169  make_tuple(StrideA, I1));
170  }
171  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
172  {
173  return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
174  make_tuple(I1, StrideA));
175  }
176  }();
177 
178  return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
179  }
180 
181  static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
182  {
183  const auto b_grid_desc_nraw_kraw = [&]() {
185  {
186  return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
187  make_tuple(I1, StrideB));
188  }
190  {
191  return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
192  make_tuple(StrideB, I1));
193  }
194  }();
195 
196  return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
197  }
198 
199  template <typename ELay>
200  static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
201  {
202  const auto e_grid_desc_mraw_nraw = [&]() {
204  {
205  return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
206  make_tuple(StrideE, I1));
207  }
209  {
210  return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
211  make_tuple(I1, StrideE));
212  }
213  }();
214 
215  return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
216  }
217 
218  static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
219  const std::array<index_t, NumDTensor>& NRaws,
220  const std::array<index_t, NumDTensor>& DsStride)
221  {
222  return generate_tuple(
223  [&](auto i) {
224  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
225 
226  return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
227  },
229  }
230 
231  using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
232  using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
234  using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
235 
236  using ComputeDataType = ADataType;
237 
238  // GridwiseGemm
239  template <index_t NXdlPerWave_>
241  ADataType, // TODO: distinguish A/B datatype
242  BDataType,
244  AccDataType,
245  CShuffleDataType,
246  DsDataType,
247  EDataType,
248  AElementwiseOperation,
249  BElementwiseOperation,
250  CDEElementwiseOperation,
251  NumPrefetch, // NumGemmKPrefetchStage
252  BlockSize,
253  MPerBlock,
254  NPerBlock,
255  KPerBlock,
256  AK1,
257  BK1,
258  MPerXDL,
259  NPerXDL,
260  MXdlPerWave,
261  NXdlPerWave_,
262  ABlockTransferThreadClusterLengths_K0_M_K1,
263  ABlockTransferThreadClusterArrangeOrder,
264  ABlockTransferSrcAccessOrder,
265  ABlockTransferSrcVectorDim,
266  ABlockTransferSrcScalarPerVector,
267  ABlockTransferDstScalarPerVector_K1,
268  false, // AThreadTransferSrcResetCoordinateAfterRun,
269  ABlockLdsExtraM,
270  BBlockTransferThreadClusterLengths_K0_N_K1,
271  BBlockTransferThreadClusterArrangeOrder,
272  BBlockTransferSrcAccessOrder,
273  BBlockTransferSrcVectorDim,
274  BBlockTransferSrcScalarPerVector,
275  BBlockTransferDstScalarPerVector_K1,
276  false, // BThreadTransferSrcResetCoordinateAfterRun,
277  BBlockLdsExtraN,
278  CShuffleMXdlPerWavePerShuffle,
279  CShuffleNXdlPerWavePerShuffle,
280  CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
281  CDEBlockTransferScalarPerVector_NPerBlock,
282  LoopSched>;
285 
288  AGridDesc_M_K{}))>;
291  BGridDesc_N_K{}))>;
294  DsGridDesc_M_N{}))>;
297  EGridDesc_M_N{}))>;
298 
300  {
303 
305  {
307  BlockStart_ = -1;
308  }
309 
310  GroupedGemmBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n, ck::index_t BlockStart)
311  {
313  BlockStart_ = BlockStart;
314  }
315 
316  template <typename TopIdx>
317  __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
318  {
319  return block_2_etile_map_.CalculateBottomIndex(
320  make_multi_index(idx_top[I0] - BlockStart_));
321  }
322 
323  // it's actually E-Tile
324  template <typename CTileIdx, typename CTileDim>
325  __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
326  const CTileDim& c_tile_dim) const
327  {
328  return block_2_etile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
329  }
330 
331  __host__ bool CheckValidity(const EGridDesc_M_N& e_grid_desc_m_n) const
332  {
333  return block_2_etile_map_.CheckValidity(e_grid_desc_m_n);
334  }
335 
338  };
339 
341  {
342  // pointers
343  const ADataType* a_ptr_;
344  const BDataType* b_ptr_;
346  EDataType* e_ptr_;
347 
348  // tensor descriptors for problem definiton
353 
354  // tensor descriptors for block/thread-wise copy
360 
361  // block-to-e-tile map
364  };
365 
366  // Argument
367  struct Argument : public BaseArgument
368  {
369  template <typename GridwiseGemm, typename DsPointer, typename Block2ETileMap>
370  void init_gridwise_gemm_desc(const ADataType* a_ptr,
371  const BDataType* b_ptr,
372  DsPointer ds_ptr,
373  EDataType* e_ptr,
374  const AGridDesc_M_K& a_grid_desc_m_k,
375  const BGridDesc_N_K& b_grid_desc_n_k,
376  const DsGridDesc_M_N& ds_grid_desc_m_n,
377  const EGridDesc_M_N& e_grid_desc_m_n,
378  const Block2ETileMap& block_2_etile_map,
379  index_t BlockStart,
380  index_t BlockEnd)
381  {
382  // tensor descriptors for block/thread-wise copy
383  const auto a_grid_desc_ak0_m_ak1 =
385 
386  const auto b_grid_desc_bk0_n_bk1 =
388 
389  if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
390  b_grid_desc_n_k,
391  ds_grid_desc_m_n,
392  e_grid_desc_m_n,
393  block_2_etile_map))
394  {
395  // tensor descriptors for block/thread-wise copy
397  ds_grid_desc_mblock_mperblock_nblock_nperblock;
398 
399  static_for<0, NumDTensor, 1>{}([&](auto j) {
400  ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
401  GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
402  ds_grid_desc_m_n[j]);
403  });
404 
405  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
406  GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
407  e_grid_desc_m_n);
408 
409  gemm_desc_kernel_arg_.push_back(
411  b_ptr,
412  ds_ptr,
413  e_ptr,
414  a_grid_desc_m_k,
415  b_grid_desc_n_k,
416  ds_grid_desc_m_n,
417  e_grid_desc_m_n,
418  a_grid_desc_ak0_m_ak1,
419  b_grid_desc_bk0_n_bk1,
420  ds_grid_desc_mblock_mperblock_nblock_nperblock,
421  e_grid_desc_mblock_mperblock_nblock_nperblock,
422  block_2_etile_map,
423  BlockStart,
424  BlockEnd});
425  }
426  };
427  Argument(std::vector<const void*>& p_As,
428  std::vector<const void*>& p_Bs,
429  std::vector<std::array<const void*, NumDTensor>>& p_Ds,
430  std::vector<void*>& p_Es,
431  std::vector<GemmDesc>& gemm_descs,
432  AElementwiseOperation a_element_op,
433  BElementwiseOperation b_element_op,
434  CDEElementwiseOperation c_element_op)
435  : a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op}
436  {
437  grid_size_ = 0;
438 
439  group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
440 
441  if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) &&
442  group_count_ == ck::type_convert<ck::index_t>(p_Bs.size()) &&
443  group_count_ == ck::type_convert<ck::index_t>(p_Es.size())))
444  {
445  throw std::runtime_error("wrong! group_count_ != p_As/b/c.size");
446  }
447 
449 
451 
452  for(std::size_t i = 0; i < gemm_descs.size(); i++)
453  {
454  const index_t M = gemm_descs[i].M_;
455  const index_t N = gemm_descs[i].N_;
456  const index_t K = gemm_descs[i].K_;
457 
458  a_mtx_mraw_kraw_.emplace_back(M, K);
459  b_mtx_nraw_kraw_.emplace_back(N, K);
460 
461  if(M == 0)
462  {
464  continue;
465  }
466 
467  const index_t StrideA = gemm_descs[i].stride_A_;
468  const index_t StrideB = gemm_descs[i].stride_B_;
469  const index_t StrideC = gemm_descs[i].stride_C_;
470 
471  // pointer
472  typename GridwiseGemm64::DsGridPointer p_ds_grid{};
473 
474  static_for<0, NumDTensor, 1>{}([&](auto j) {
475  using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
476 
477  p_ds_grid(j) = static_cast<const DDataType*>(p_Ds[i][j]);
478  });
479 
480  // tensor descriptors for problem definiton
481  const auto a_grid_desc_m_k = DeviceOp::MakeAGridDescriptor_M_K(M, K, StrideA);
482  const auto b_grid_desc_n_k = DeviceOp::MakeBGridDescriptor_N_K(K, N, StrideB);
483 
484  DsGridDesc_M_N ds_grid_desc_m_n;
485 
486  static_for<0, NumDTensor, 1>{}([&](auto j) {
487  using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
488 
489  ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
490  M, N, gemm_descs[i].stride_Ds_[j]);
491  });
492 
493  const auto e_grid_desc_m_n =
494  DeviceOp::MakeEGridDescriptor_M_N<ELayout>(M, N, StrideC);
495 
496  const index_t grid_size_grp =
497  GroupedGemmBlock2ETileMap(e_grid_desc_m_n, 0)
498  .block_2_etile_map_.CalculateGridSize(e_grid_desc_m_n);
499 
500  const index_t BlockStart = grid_size_;
501  const index_t BlockEnd = grid_size_ + grid_size_grp;
502 
503  grid_size_ += grid_size_grp;
504 
505  // block-to-e-tile map
506  const auto block_2_etile_map =
507  GroupedGemmBlock2ETileMap(e_grid_desc_m_n, BlockStart);
508 
509  if(get_warp_size() == 64)
510  {
511  if constexpr(NXdlPerWave64 > 0)
512  {
513  init_gridwise_gemm_desc<GridwiseGemm64>(
514  static_cast<const ADataType*>(p_As[i]),
515  static_cast<const BDataType*>(p_Bs[i]),
516  p_ds_grid,
517  static_cast<EDataType*>(p_Es[i]),
518  a_grid_desc_m_k,
519  b_grid_desc_n_k,
520  ds_grid_desc_m_n,
521  e_grid_desc_m_n,
522  block_2_etile_map,
523  BlockStart,
524  BlockEnd);
525  }
526  }
527  else
528  {
529  if constexpr(NXdlPerWave32 > 0)
530  {
531  init_gridwise_gemm_desc<GridwiseGemm32>(
532  static_cast<const ADataType*>(p_As[i]),
533  static_cast<const BDataType*>(p_Bs[i]),
534  p_ds_grid,
535  static_cast<EDataType*>(p_Es[i]),
536  a_grid_desc_m_k,
537  b_grid_desc_n_k,
538  ds_grid_desc_m_n,
539  e_grid_desc_m_n,
540  block_2_etile_map,
541  BlockStart,
542  BlockEnd);
543  }
544  }
545  }
546  }
547 
548  // private:
551 
552  AElementwiseOperation a_element_op_;
553  BElementwiseOperation b_element_op_;
554  CDEElementwiseOperation c_element_op_;
555 
556  std::vector<GemmBiasTransKernelArg> gemm_desc_kernel_arg_;
557  std::vector<Tuple<index_t, index_t>> a_mtx_mraw_kraw_;
558  std::vector<Tuple<index_t, index_t>> b_mtx_nraw_kraw_;
559 
562  };
563 
564  // Invoker
565  struct Invoker : public BaseInvoker
566  {
568 
569  template <typename GridwiseGemm>
570  float RunImp(const Argument& arg,
571  const StreamConfig& stream_config = StreamConfig{},
572  hipStream_t cpy_stream = nullptr,
573  hipEvent_t cpy_event = nullptr)
574  {
575  bool has_main_k_block_loop = true;
576 
577  for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
578  {
579  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
580  {
581  std::cout << "group: " << i << " arg.a_grid_desc_ak0_m_ak1_{"
582  << arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0)
583  << ", "
584  << arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I1)
585  << ", "
586  << arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2)
587  << "}";
588 
589  std::cout << ", arg.b_grid_desc_bk0_n_bk1_{"
590  << arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I0)
591  << ", "
592  << arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I1)
593  << ", "
594  << arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I2)
595  << "}";
596 
597  std::cout << ", arg.e_grid_desc_m_n_{ "
598  << arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) << ", "
599  << arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) << "}"
600  << std::endl;
601  }
602 
603  if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_m_k_,
604  arg.gemm_desc_kernel_arg_[i].b_grid_desc_n_k_,
605  arg.gemm_desc_kernel_arg_[i].ds_grid_desc_m_n_,
606  arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_,
607  arg.gemm_desc_kernel_arg_[i].block_2_etile_map_))
608  {
609  throw std::runtime_error(
610  "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
611  }
612 
613  const auto K = arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
614  arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2);
615 
616  if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop)
617  {
618  throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
619  }
620  }
621 
622  // If the user provides copy stream and copy event, we assume that they're also
623  // responsible for providing allocated host memory (eg. pinned) which
624  // would be used to copy kernel arguments to the device.
625  if(cpy_stream && cpy_event)
626  {
627  if(arg.gemm_kernel_host_args_ == nullptr)
628  {
629  std::ostringstream err;
630  err << "No memory has been allocated for gemm kernel host args "
631  << "when providing the copy stream and copy event! In " << __FILE__ << ":"
632  << __LINE__ << ", in function: " << __func__;
633  throw std::runtime_error(err.str());
634  }
635  hipGetErrorString(hipMemcpyAsync(arg.p_workspace_,
637  arg.group_count_ * sizeof(GemmBiasTransKernelArg),
638  hipMemcpyHostToDevice,
639  cpy_stream));
640  hipGetErrorString(hipEventRecord(cpy_event, cpy_stream));
641  hipGetErrorString(hipEventSynchronize(cpy_event));
642  }
643  else // In this case CK owns memory allocated on host.
644  {
645  hipGetErrorString(hipMemcpyAsync(arg.p_workspace_,
646  arg.gemm_desc_kernel_arg_.data(),
647  arg.gemm_desc_kernel_arg_.size() *
648  sizeof(GemmBiasTransKernelArg),
649  hipMemcpyHostToDevice,
650  stream_config.stream_id_));
651  }
652 
653  float ave_time = 0;
654 
655  auto launch_kernel = [&](auto has_main_k_block_loop_) {
656  const auto kernel = kernel_grouped_gemm_xdl<GridwiseGemm,
657  GemmBiasTransKernelArg,
658  AElementwiseOperation,
659  BElementwiseOperation,
660  CDEElementwiseOperation,
661  has_main_k_block_loop_>;
662 
663  return launch_and_time_kernel(
664  stream_config,
665  kernel,
666  dim3(arg.grid_size_),
667  dim3(BlockSize),
668  0,
670  arg.gemm_desc_kernel_arg_.size(),
671  arg.a_element_op_,
672  arg.b_element_op_,
673  arg.c_element_op_);
674  };
675 
676  if(has_main_k_block_loop)
677  {
678  ave_time = launch_kernel(integral_constant<bool, true>{});
679  }
680  else
681  {
682  ave_time = launch_kernel(integral_constant<bool, false>{});
683  }
684 
685  return ave_time;
686  }
687 
688  float Run(const Argument& arg,
689  const StreamConfig& stream_config = StreamConfig{},
690  hipStream_t cpy_stream = nullptr,
691  hipEvent_t cpy_event = nullptr)
692  {
693  if(get_warp_size() == 64)
694  {
695  if constexpr(NXdlPerWave64 > 0)
696  {
697  return RunImp<GridwiseGemm64>(arg, stream_config, cpy_stream, cpy_event);
698  }
699  }
700  else
701  {
702  if constexpr(NXdlPerWave32 > 0)
703  {
704  return RunImp<GridwiseGemm32>(arg, stream_config, cpy_stream, cpy_event);
705  }
706  }
707  return 0;
708  }
709 
710  // polymorphic
711  float Run(const BaseArgument* p_arg,
712  const StreamConfig& stream_config = StreamConfig{}) override
713  {
714  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
715  }
716  };
717 
718  static bool IsSupportedArgument(const Argument& arg)
719  {
720  if(!ck::is_xdl_wmma_supported<ADataType, BDataType, MPerXDL, NPerXDL>())
721  {
722  return false;
723  }
724  if((ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) +
726  {
727  return false;
728  }
729 
730  bool supported = true;
731 
732  // If we use padding we do not support vector loads for dimensions not divisible by
733  // vector load size.
734  if constexpr(GemmSpec != GemmSpecialization::Default)
735  {
736  // [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1}
737  // layout, thus we have to adapt it to the {M,K} or {N,K} layout.
738  const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0;
739  const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0;
740 
741  for(index_t i = 0; i < arg.group_count_; ++i)
742  {
743  const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number<a_raw_vector_dim>{});
744  const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number<b_raw_vector_dim>{});
745 
746  supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0);
747  supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0);
748  }
749  }
750 
751  return supported;
752  }
753 
754  // polymorphic
755  bool IsSupportedArgument(const BaseArgument* p_arg) override
756  {
757  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
758  }
759 
760  static auto MakeArgument(std::vector<const void*>& p_As,
761  std::vector<const void*>& p_Bs,
762  std::vector<std::array<const void*, NumDTensor>>& p_Ds,
763  std::vector<void*>& p_Es,
764  std::vector<GemmDesc> gemm_descs,
765  AElementwiseOperation a_element_op,
766  BElementwiseOperation b_element_op,
767  CDEElementwiseOperation c_element_op)
768  {
769  return Argument{
770  p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op};
771  }
772 
773  static auto MakeInvoker() { return Invoker{}; }
774 
775  // polymorphic
776  std::unique_ptr<BaseArgument>
777  MakeArgumentPointer(std::vector<const void*>& p_As,
778  std::vector<const void*>& p_Bs,
779  std::vector<std::array<const void*, NumDTensor>>& p_Ds,
780  std::vector<void*>& p_Es,
781  std::vector<GemmDesc>& gemm_descs,
782  AElementwiseOperation a_element_op,
783  BElementwiseOperation b_element_op,
784  CDEElementwiseOperation c_element_op) override
785  {
786  return std::make_unique<Argument>(
787  p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op);
788  }
789 
790  // polymorphic
791  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
792  {
793  return std::make_unique<Invoker>(Invoker{});
794  }
795 
796  // polymorphic
797  std::string GetTypeString() const override
798  {
799  auto str = std::stringstream();
800 
801  // clang-format off
802  str << "DeviceGroupedGemm_Xdl"
803  << "<"
804  << BlockSize << ", "
805  << MPerBlock << ", "
806  << NPerBlock << ", "
807  << KPerBlock << ", "
808  << AK1 << ", "
809  << BK1 << ", "
810  << MPerXDL << ", "
811  << NPerXDL << ", "
812  << MXdlPerWave << ", "
813  << NXdlPerWave << ", "
814  << ABlockTransferSrcScalarPerVector << ", "
815  << BBlockTransferSrcScalarPerVector << ", "
816  << CShuffleMXdlPerWavePerShuffle << ", "
817  << CShuffleNXdlPerWavePerShuffle << ", "
818  << getGemmSpecializationString(GemmSpec)
819  << ">";
820  // clang-format on
821 
822  return str.str();
823  }
824 
825  size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
826  {
827  auto p_arg_ = dynamic_cast<const Argument*>(p_arg);
828  if(p_arg_)
829  {
830  return p_arg_->group_count_ * sizeof(GemmBiasTransKernelArg);
831  }
832  else
833  throw std::runtime_error("The argument pointer is not an object of "
834  "DeviceGroupedGemmMultipleDXdlCShuffle::Argument structure!");
835  }
836 
837  size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
838  {
839  return GetWorkSpaceSize(p_arg);
840  }
841 
842  void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
843  {
844  return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args);
845  }
846 
847  size_t GetHostKernelArgSize(const BaseArgument* p_arg) const { return GetWorkSpaceSize(p_arg); }
848 
849  //----------------------------------------------------------------------------------------------
859  void SetHostKernelArgsPointer(BaseArgument* p_arg, void* p_host_kernel_args) const
860  {
861  Argument* pArg_ = dynamic_cast<Argument*>(p_arg);
862  if(!pArg_)
863  {
864  throw std::runtime_error("Failed to cast argument pointer!");
865  }
866 
867  pArg_->gemm_kernel_host_args_ = p_host_kernel_args;
868  std::copy(pArg_->gemm_desc_kernel_arg_.begin(),
869  pArg_->gemm_desc_kernel_arg_.end(),
870  static_cast<GemmBiasTransKernelArg*>(pArg_->gemm_kernel_host_args_));
871  }
872 };
873 
874 } // namespace device
875 } // namespace tensor_operation
876 } // namespace ck
#define CK_CONSTANT_ADDRESS_SPACE
Definition: ck.hpp:23
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition: device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
auto copy(InputRange &&range, OutputIterator iter) -> decltype(std::copy(std::begin(std::forward< InputRange >(range)), std::end(std::forward< InputRange >(range)), iter))
Definition: algorithm.hpp:14
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
__global__ void kernel_grouped_gemm_xdl(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation c_element_op)
Definition: device_grouped_gemm_xdl.hpp:36
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition: kernel_launch.hpp:140
Definition: ck.hpp:268
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE * cast_pointer_to_constant_address_space(T *p)
Definition: amd_address_space.hpp:35
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition: amd_address_space.hpp:24
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:299
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
Definition: stream_config.hpp:10
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:78
__host__ static constexpr __device__ auto MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:190
__host__ static constexpr __device__ auto MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:207
__host__ static constexpr __device__ auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:224
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:411
__host__ static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N &ds_grid_desc_m_n)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:245
__host__ static constexpr __device__ auto MakeDefaultBlock2ETileMap(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:257
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: functional2.hpp:33
Definition: device_base.hpp:197
void * p_workspace_
Definition: device_base.hpp:204
Definition: device_base.hpp:208
virtual void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const
Definition: device_base.hpp:248
Definition: device_grouped_gemm_xdl.hpp:368
std::vector< GemmBiasTransKernelArg > gemm_desc_kernel_arg_
Definition: device_grouped_gemm_xdl.hpp:556
std::vector< Tuple< index_t, index_t > > a_mtx_mraw_kraw_
Definition: device_grouped_gemm_xdl.hpp:557
Argument(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor >> &p_Ds, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation c_element_op)
Definition: device_grouped_gemm_xdl.hpp:427
AElementwiseOperation a_element_op_
Definition: device_grouped_gemm_xdl.hpp:552
CDEElementwiseOperation c_element_op_
Definition: device_grouped_gemm_xdl.hpp:554
std::vector< Tuple< index_t, index_t > > b_mtx_nraw_kraw_
Definition: device_grouped_gemm_xdl.hpp:558
index_t skipped_group_count_
Definition: device_grouped_gemm_xdl.hpp:550
index_t grid_size_
Definition: device_grouped_gemm_xdl.hpp:560
index_t group_count_
Definition: device_grouped_gemm_xdl.hpp:549
void init_gridwise_gemm_desc(const ADataType *a_ptr, const BDataType *b_ptr, DsPointer ds_ptr, EDataType *e_ptr, const AGridDesc_M_K &a_grid_desc_m_k, const BGridDesc_N_K &b_grid_desc_n_k, const DsGridDesc_M_N &ds_grid_desc_m_n, const EGridDesc_M_N &e_grid_desc_m_n, const Block2ETileMap &block_2_etile_map, index_t BlockStart, index_t BlockEnd)
Definition: device_grouped_gemm_xdl.hpp:370
BElementwiseOperation b_element_op_
Definition: device_grouped_gemm_xdl.hpp:553
void * gemm_kernel_host_args_
Definition: device_grouped_gemm_xdl.hpp:561
EGridDesc_M_N e_grid_desc_m_n_
Definition: device_grouped_gemm_xdl.hpp:352
DsGridDesc_M_N ds_grid_desc_m_n_
Definition: device_grouped_gemm_xdl.hpp:351
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition: device_grouped_gemm_xdl.hpp:356
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_
Definition: device_grouped_gemm_xdl.hpp:359
GridwiseGemm64::DsGridPointer ds_ptr_
Definition: device_grouped_gemm_xdl.hpp:345
const ADataType * a_ptr_
Definition: device_grouped_gemm_xdl.hpp:343
BGridDesc_N_K b_grid_desc_n_k_
Definition: device_grouped_gemm_xdl.hpp:350
AGridDesc_M_K a_grid_desc_m_k_
Definition: device_grouped_gemm_xdl.hpp:349
GroupedGemmBlock2ETileMap block_2_etile_map_
Definition: device_grouped_gemm_xdl.hpp:362
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition: device_grouped_gemm_xdl.hpp:355
EDataType * e_ptr_
Definition: device_grouped_gemm_xdl.hpp:346
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_
Definition: device_grouped_gemm_xdl.hpp:358
ck::index_t BlockStart_
Definition: device_grouped_gemm_xdl.hpp:363
ck::index_t BlockEnd_
Definition: device_grouped_gemm_xdl.hpp:363
const BDataType * b_ptr_
Definition: device_grouped_gemm_xdl.hpp:344
__host__ __device__ bool ValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim) const
Definition: device_grouped_gemm_xdl.hpp:325
GroupedGemmBlock2ETileMap()
Definition: device_grouped_gemm_xdl.hpp:304
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))> Block2ETileMap
Definition: device_grouped_gemm_xdl.hpp:302
Block2ETileMap block_2_etile_map_
Definition: device_grouped_gemm_xdl.hpp:336
GroupedGemmBlock2ETileMap(const EGridDesc_M_N &e_grid_desc_m_n, ck::index_t BlockStart)
Definition: device_grouped_gemm_xdl.hpp:310
ck::index_t BlockStart_
Definition: device_grouped_gemm_xdl.hpp:337
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: device_grouped_gemm_xdl.hpp:317
__host__ bool CheckValidity(const EGridDesc_M_N &e_grid_desc_m_n) const
Definition: device_grouped_gemm_xdl.hpp:331
Definition: device_grouped_gemm_xdl.hpp:566
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{}, hipStream_t cpy_stream=nullptr, hipEvent_t cpy_event=nullptr)
Definition: device_grouped_gemm_xdl.hpp:570
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_grouped_gemm_xdl.hpp:711
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{}, hipStream_t cpy_stream=nullptr, hipEvent_t cpy_event=nullptr)
Definition: device_grouped_gemm_xdl.hpp:688
Definition: device_grouped_gemm_xdl.hpp:149
size_t GetHostKernelArgSize(const BaseArgument *p_arg) const
Definition: device_grouped_gemm_xdl.hpp:847
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
Definition: device_grouped_gemm_xdl.hpp:200
decltype(MakeBGridDescriptor_N_K(1, 1, 1)) BGridDesc_N_K
Definition: device_grouped_gemm_xdl.hpp:232
static auto MakeArgument(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor >> &p_Ds, std::vector< void * > &p_Es, std::vector< GemmDesc > gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation c_element_op)
Definition: device_grouped_gemm_xdl.hpp:760
std::string GetTypeString() const override
Definition: device_grouped_gemm_xdl.hpp:797
static auto MakeDsGridDescriptor_M_N(const std::array< index_t, NumDTensor > &MRaws, const std::array< index_t, NumDTensor > &NRaws, const std::array< index_t, NumDTensor > &DsStride)
Definition: device_grouped_gemm_xdl.hpp:218
void SetDeviceKernelArgs(BaseArgument *p_arg, void *p_dev_kernel_args) const override
Sets the device kernel arguments pointer and may copy data to device.
Definition: device_grouped_gemm_xdl.hpp:842
decltype(MakeAGridDescriptor_M_K(1, 1, 1)) AGridDesc_M_K
Definition: device_grouped_gemm_xdl.hpp:231
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
Definition: device_grouped_gemm_xdl.hpp:181
std::unique_ptr< BaseArgument > MakeArgumentPointer(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor >> &p_Ds, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation c_element_op) override
Definition: device_grouped_gemm_xdl.hpp:777
ADataType ComputeDataType
Definition: device_grouped_gemm_xdl.hpp:236
static constexpr auto I1
Definition: device_grouped_gemm_xdl.hpp:157
static constexpr GET_NXDL_PER_WAVE_IMPL auto NXdlPerWave64
Definition: device_grouped_gemm_xdl.hpp:152
static constexpr auto NXdlPerWave32
Definition: device_grouped_gemm_xdl.hpp:153
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_grouped_gemm_xdl.hpp:755
size_t GetDeviceKernelArgSize(const BaseArgument *p_arg) const override
Gets the device kernel argument size.
Definition: device_grouped_gemm_xdl.hpp:837
static constexpr auto I2
Definition: device_grouped_gemm_xdl.hpp:158
static constexpr auto matrix_padder
Definition: device_grouped_gemm_xdl.hpp:160
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N({}, {}, {}))> DsGridDesc_M_N
Definition: device_grouped_gemm_xdl.hpp:233
static bool IsSupportedArgument(const Argument &arg)
Definition: device_grouped_gemm_xdl.hpp:718
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_grouped_gemm_xdl.hpp:791
static auto MakeInvoker()
Definition: device_grouped_gemm_xdl.hpp:773
remove_cvref_t< decltype(GridwiseGemm64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))> EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition: device_grouped_gemm_xdl.hpp:297
static constexpr index_t NumDTensor
Definition: device_grouped_gemm_xdl.hpp:154
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))> AGridDesc_AK0_M_AK1
Definition: device_grouped_gemm_xdl.hpp:288
decltype(MakeEGridDescriptor_M_N< ELayout >(1, 1, 1)) EGridDesc_M_N
Definition: device_grouped_gemm_xdl.hpp:234
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
Definition: device_grouped_gemm_xdl.hpp:163
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition: device_grouped_gemm_xdl.hpp:825
remove_cvref_t< decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))> DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition: device_grouped_gemm_xdl.hpp:294
void SetHostKernelArgsPointer(BaseArgument *p_arg, void *p_host_kernel_args) const
Sets the host kernel arguments pointer and copies that data on the host side. This function can be ut...
Definition: device_grouped_gemm_xdl.hpp:859
static constexpr auto I0
Definition: device_grouped_gemm_xdl.hpp:156
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))> BGridDesc_BK0_N_BK1
Definition: device_grouped_gemm_xdl.hpp:291
Definition: device_grouped_gemm.hpp:99
Definition: device_grouped_gemm.hpp:80
Definition: matrix_padder.hpp:180
#define CK_ENV(name)
Definition: env.hpp:129