/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp Source File
device_gemm_xdl_waveletmodel_cshuffle.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
19 
20 namespace ck {
21 
22 template <typename GridwiseGemm,
23  typename ABDataType,
24  typename EDataType,
25  typename AElementwiseOperation,
26  typename BElementwiseOperation,
27  typename EElementwiseOperation,
28  typename AGridDesc_AK0_M_AK1,
29  typename BGridDesc_BK0_N_BK1,
30  typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
31  typename Block2ETileMap,
32  bool HasMainKBlockLoop>
33 __global__ void
34 #if CK_USE_LAUNCH_BOUNDS
36 #endif
37  kernel_gemm_xdl_waveletmodel_cshuffle(const ABDataType* __restrict__ p_a_grid,
38  const ABDataType* __restrict__ p_b_grid,
39  EDataType* __restrict__ p_e_grid,
40  const AElementwiseOperation a_element_op,
41  const BElementwiseOperation b_element_op,
42  const EElementwiseOperation e_element_op,
43  const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
44  const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
45  const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
46  e_grid_desc_mblock_mperblock_nblock_nperblock,
47  const Block2ETileMap block_2_etile_map)
48 {
49 #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
50  if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
51  {
52  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
53 
54  GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
55  p_b_grid,
56  p_e_grid,
57  p_shared,
58  a_element_op,
59  b_element_op,
60  e_element_op,
61  a_grid_desc_ak0_m_ak1,
62  b_grid_desc_bk0_n_bk1,
63  e_grid_desc_mblock_mperblock_nblock_nperblock,
64  block_2_etile_map);
65  }
66 #else
67  ignore = p_a_grid;
68  ignore = p_b_grid;
69  ignore = p_e_grid;
70  ignore = a_element_op;
71  ignore = b_element_op;
72  ignore = e_element_op;
73  ignore = a_grid_desc_ak0_m_ak1;
74  ignore = b_grid_desc_bk0_n_bk1;
75  ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
76  ignore = block_2_etile_map;
77 #endif
78 }
79 
80 } // namespace ck
81 
82 namespace ck {
83 namespace tensor_operation {
84 namespace device {
85 
86 template <typename ALayout,
87  typename BLayout,
88  typename ELayout,
89  typename ADataType,
90  typename BDataType,
91  typename GemmAcEDataType,
92  typename CShuffleDataType,
93  typename EDataType,
94  typename AElementwiseOperation,
95  typename BElementwiseOperation,
96  typename CDEElementwiseOperation,
97  GemmSpecialization GemmSpec,
98  index_t NumGemmKPrefetchStage,
99  index_t TileLoadThreadGroupSize,
100  index_t TileMathThreadGroupSize,
101  index_t MPerBlock,
102  index_t NPerBlock,
103  index_t KPerBlock,
104  index_t AK1,
105  index_t BK1,
106  index_t MPerXDL,
107  index_t NPerXDL,
108  index_t MXdlPerWave,
109  index_t NXdlPerWave,
110  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
111  typename ABlockTransferThreadClusterArrangeOrder,
112  typename ABlockTransferSrcAccessOrder,
113  index_t ABlockTransferSrcVectorDim,
114  index_t ABlockTransferSrcScalarPerVector,
115  index_t ABlockTransferDstScalarPerVector_AK1,
116  bool ABlockLdsExtraM,
117  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
118  typename BBlockTransferThreadClusterArrangeOrder,
119  typename BBlockTransferSrcAccessOrder,
120  index_t BBlockTransferSrcVectorDim,
121  index_t BBlockTransferSrcScalarPerVector,
122  index_t BBlockTransferDstScalarPerVector_BK1,
123  bool BBlockLdsExtraN,
124  index_t CShuffleMXdlPerWavePerShuffle,
125  index_t CShuffleNXdlPerWavePerShuffle,
126  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
127  index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
129  BLayout,
130  ELayout,
131  ADataType,
132  BDataType,
133  EDataType,
134  AElementwiseOperation,
135  BElementwiseOperation,
136  CDEElementwiseOperation>
137 {
138  static constexpr auto BlockSize = math::max(TileLoadThreadGroupSize, TileMathThreadGroupSize);
140  static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
141  static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
142 
144 
145  static constexpr auto I0 = Number<0>{};
146  static constexpr auto I1 = Number<1>{};
147  static constexpr auto I2 = Number<2>{};
148 
149  static constexpr auto matrix_padder =
150  MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
151 
152  static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
153  {
154  const auto a_grid_desc_mraw_kraw = [&]() {
155  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
156  {
157  return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
158  make_tuple(StrideA, I1));
159  }
160  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
161  {
162  return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
163  make_tuple(I1, StrideA));
164  }
165  }();
166 
167  return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
168  }
169 
170  static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
171  {
172  const auto b_grid_desc_nraw_kraw = [&]() {
174  {
175  return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
176  make_tuple(I1, StrideB));
177  }
179  {
180  return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
181  make_tuple(StrideB, I1));
182  }
183  }();
184 
185  return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
186  }
187 
188  template <typename ELay>
189  static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
190  {
191  const auto e_grid_desc_mraw_nraw = [&]() {
193  {
194  return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
195  make_tuple(StrideE, I1));
196  }
198  {
199  return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
200  make_tuple(I1, StrideE));
201  }
202  }();
203 
204  return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
205  }
206 
207  using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
208  using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
209  using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
210 
211  // GridwiseGemm
212  template <index_t NXdlPerWave_>
214  ADataType, // TODO: distinguish A/B datatype
215  GemmAcEDataType,
216  CShuffleDataType,
217  EDataType,
218  AElementwiseOperation,
219  BElementwiseOperation,
220  CDEElementwiseOperation,
225  NumGemmKPrefetchStage,
226  TileLoadThreadGroupSize,
227  TileMathThreadGroupSize,
228  MPerBlock,
229  NPerBlock,
230  KPerBlock,
231  AK1,
232  BK1,
233  MPerXDL,
234  NPerXDL,
235  MXdlPerWave,
236  NXdlPerWave_,
237  ABlockTransferThreadClusterLengths_AK0_M_AK1,
238  ABlockTransferThreadClusterArrangeOrder,
239  ABlockTransferSrcAccessOrder,
240  ABlockTransferSrcVectorDim,
241  ABlockTransferSrcScalarPerVector,
242  ABlockTransferDstScalarPerVector_AK1,
243  false,
244  ABlockLdsExtraM,
245  BBlockTransferThreadClusterLengths_BK0_N_BK1,
246  BBlockTransferThreadClusterArrangeOrder,
247  BBlockTransferSrcAccessOrder,
248  BBlockTransferSrcVectorDim,
249  BBlockTransferSrcScalarPerVector,
250  BBlockTransferDstScalarPerVector_BK1,
251  false,
252  BBlockLdsExtraN,
253  CShuffleMXdlPerWavePerShuffle,
254  CShuffleNXdlPerWavePerShuffle,
255  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
256  CShuffleBlockTransferScalarPerVector_NPerBlock>;
259 
262  AGridDesc_M_K{}))>;
265  BGridDesc_N_K{}))>;
266 
268 
269  // Argument
270  struct Argument : public BaseArgument
271  {
272  Argument(const ADataType* p_a_grid,
273  const BDataType* p_b_grid,
274  EDataType* p_e_grid,
275  index_t MRaw,
276  index_t NRaw,
277  index_t KRaw,
278  index_t StrideA,
279  index_t StrideB,
280  index_t StrideE,
281  AElementwiseOperation a_element_op,
282  BElementwiseOperation b_element_op,
283  CDEElementwiseOperation cde_element_op)
284  : p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
285  p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
286  p_e_grid_{static_cast<EDataType*>(p_e_grid)},
287  a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)},
288  b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
289  e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideE)},
291  GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
293  GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
294  block_2_etile_map_{GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
295  a_element_op_{a_element_op},
296  b_element_op_{b_element_op},
297  cde_element_op_{cde_element_op}
298  {
299  }
300 
301  void Print() const
302  {
303  std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
304  std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
305  std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
306  }
307 
308  // private:
309  // pointers
310  const ADataType* p_a_grid_;
311  const BDataType* p_b_grid_;
312  EDataType* p_e_grid_;
313 
314  // tensor descriptors for problem definiton
318 
319  // tensor descriptors for block/thread-wise copy
322 
323  // block-to-e-tile map
325 
326  // element-wise op
327  AElementwiseOperation a_element_op_;
328  BElementwiseOperation b_element_op_;
329  CDEElementwiseOperation cde_element_op_;
330  };
331 
332  // Invoker
333  struct Invoker : public BaseInvoker
334  {
336 
337  template <typename GridwiseGemm>
338  float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
339  {
340 #if 0
341  {
342  std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
343  << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
344  << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
345  << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
346 
347  std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
348  << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
349  << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
350  << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
351 
352  std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", "
353  << arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
354  }
355 #endif
356 
357  if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
358  arg.b_grid_desc_n_k_,
359  arg.e_grid_desc_m_n_,
360  arg.block_2_etile_map_))
361  {
362  throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
363  }
364  auto e_grid_desc_mblock_mperblock_nblock_nperblock =
365  GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
366  arg.e_grid_desc_m_n_);
367 
368  const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.e_grid_desc_m_n_);
369  const auto K = arg.a_grid_desc_m_k_.GetLength(I1);
370 
371  auto launch_kernel = [&](auto has_main_k_block_loop) {
372  constexpr bool has_main_loop = has_main_k_block_loop.value;
373 
374  const auto kernel = kernel_gemm_xdl_waveletmodel_cshuffle<
375  GridwiseGemm,
376  ADataType, // TODO: distiguish A/B datatype
377  EDataType,
378  AElementwiseOperation,
379  BElementwiseOperation,
380  CDEElementwiseOperation,
383  typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
384  typename GridwiseGemm::DefaultBlock2ETileMap,
385  has_main_loop>;
386 
387  return launch_and_time_kernel(
388  stream_config,
389  kernel,
390  dim3(grid_size),
391  dim3(TileLoadThreadGroupSize + TileMathThreadGroupSize),
392  0,
393  arg.p_a_grid_,
394  arg.p_b_grid_,
395  arg.p_e_grid_,
396  arg.a_element_op_,
397  arg.b_element_op_,
398  arg.cde_element_op_,
401  e_grid_desc_mblock_mperblock_nblock_nperblock,
402  arg.block_2_etile_map_);
403  };
404 
405  if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
406  {
408  }
409  else
410  {
411  return launch_kernel(integral_constant<bool, false>{});
412  }
413  }
414 
416 
417  // polymorphic
418  float Run(const BaseArgument* p_arg,
419  const StreamConfig& stream_config = StreamConfig{}) override
420  {
421  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
422  }
423  };
424 
425  static bool IsSupportedArgument(const Argument& arg)
426  {
427  if(!ck::is_xdl_wmma_supported<ADataType, BDataType, MPerXDL, NPerXDL>())
428  {
429  return false;
430  }
431  if(get_warp_size() == 64)
432  {
433  if constexpr(NXdlPerWave64 > 0)
434  {
436  arg.b_grid_desc_n_k_,
437  arg.e_grid_desc_m_n_,
438  arg.block_2_etile_map_);
439  }
440  }
441  else
442  {
443  if constexpr(NXdlPerWave32 > 0)
444  {
446  arg.b_grid_desc_n_k_,
447  arg.e_grid_desc_m_n_,
448  arg.block_2_etile_map_);
449  }
450  }
451  return false;
452  }
453 
454  // polymorphic
455  bool IsSupportedArgument(const BaseArgument* p_arg) override
456  {
457  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
458  }
459 
460  static auto MakeArgument(const ADataType* p_a,
461  const BDataType* p_b,
462  EDataType* p_e,
463  index_t MRaw,
464  index_t NRaw,
465  index_t KRaw,
466  index_t StrideA,
467  index_t StrideB,
468  index_t StrideE,
469  AElementwiseOperation a_element_op,
470  BElementwiseOperation b_element_op,
471  CDEElementwiseOperation cde_element_op)
472  {
473  return Argument{p_a,
474  p_b,
475  p_e,
476  MRaw,
477  NRaw,
478  KRaw,
479  StrideA,
480  StrideB,
481  StrideE,
482  a_element_op,
483  b_element_op,
484  cde_element_op};
485  }
486 
487  static auto MakeInvoker() { return Invoker{}; }
488 
489  // polymorphic
490  std::unique_ptr<BaseArgument>
491  MakeArgumentPointer(const void* p_a,
492  const void* p_b,
493  void* p_e,
494  index_t MRaw,
495  index_t NRaw,
496  index_t KRaw,
497  index_t StrideA,
498  index_t StrideB,
499  index_t StrideE,
500  AElementwiseOperation a_element_op,
501  BElementwiseOperation b_element_op,
502  CDEElementwiseOperation cde_element_op) override
503  {
504  return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
505  static_cast<const BDataType*>(p_b),
506  static_cast<EDataType*>(p_e),
507  MRaw,
508  NRaw,
509  KRaw,
510  StrideA,
511  StrideB,
512  StrideE,
513  a_element_op,
514  b_element_op,
515  cde_element_op);
516  }
517 
518  // polymorphic
519  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
520  {
521  return std::make_unique<Invoker>(Invoker{});
522  }
523 
524  // polymorphic
525  std::string GetTypeString() const override
526  {
527  auto str = std::stringstream();
528 
529  // clang-format off
530  str << "DeviceGemm_Xdl_WaveletModel_CShuffle"
531  << "<"
532  << TileLoadThreadGroupSize << ", "
533  << TileMathThreadGroupSize << ", "
534  << MPerBlock << ", "
535  << NPerBlock << ", "
536  << KPerBlock << ", "
537  << AK1 << ", "
538  << BK1
539  << ">";
540  // clang-format on
541 
542  return str.str();
543  }
544 };
545 
546 } // namespace device
547 } // namespace tensor_operation
548 } // namespace ck
#define CK_WAVELET_MIN_BLOCK_PER_CU
Definition: ck.hpp:35
#define CK_WAVELET_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:34
#define GET_NXDL_PER_WAVE_IMPL
Definition: device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition: device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition: kernel_launch.hpp:140
Definition: ck.hpp:268
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__global__ void kernel_gemm_xdl_waveletmodel_cshuffle(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, EDataType *__restrict__ p_e_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const EElementwiseOperation e_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map)
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:37
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:299
Definition: stream_config.hpp:10
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:64
__host__ static constexpr __device__ auto MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:315
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_M_K &a_grid_desc_m_k, const BGridDesc_N_K &b_grid_desc_n_k, const EGridDesc_M_N &e_grid_desc_m_n, const Block2ETileMap &)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:194
remove_cvref_t< decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))> DefaultBlock2ETileMap
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:355
__host__ static constexpr __device__ auto MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:299
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: device_base.hpp:197
Definition: device_base.hpp:208
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:271
AElementwiseOperation a_element_op_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:327
BGridDesc_N_K b_grid_desc_n_k_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:316
const BDataType * p_b_grid_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:311
Block2ETileMap block_2_etile_map_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:324
CDEElementwiseOperation cde_element_op_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:329
const ADataType * p_a_grid_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:310
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:321
BElementwiseOperation b_element_op_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:328
EDataType * p_e_grid_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:312
void Print() const
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:301
AGridDesc_M_K a_grid_desc_m_k_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:315
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, EDataType *p_e_grid, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:272
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:320
EGridDesc_M_N e_grid_desc_m_n_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:317
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:334
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:418
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:338
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:137
static constexpr auto matrix_padder
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:149
static constexpr auto I1
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:146
static constexpr GET_NXDL_PER_WAVE_IMPL auto NXdlPerWave64
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:140
static auto MakeInvoker()
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:487
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:170
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_e, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:491
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))> BGridDesc_BK0_N_BK1
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:265
static constexpr auto I2
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:147
static constexpr auto I0
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:145
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:519
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))> AGridDesc_AK0_M_AK1
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:262
static constexpr auto NXdlPerWave32
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:141
decltype(MakeEGridDescriptor_M_N< ELayout >(1, 1, 1)) EGridDesc_M_N
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:209
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:455
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:152
typename GridwiseGemm64::DefaultBlock2ETileMap Block2ETileMap
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:267
decltype(MakeAGridDescriptor_M_K(1, 1, 1)) AGridDesc_M_K
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:207
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:425
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:189
static constexpr auto BlockSize
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:138
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, EDataType *p_e, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:460
decltype(MakeBGridDescriptor_N_K(1, 1, 1)) BGridDesc_N_K
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:208
std::string GetTypeString() const override
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:525
Definition: device_gemm.hpp:22
Definition: matrix_padder.hpp:180