/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp Source File
device_gemm_xdl_waveletmodel_cshuffle.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
19 
20 namespace ck {
21 
22 template <typename GridwiseGemm,
23  typename ABDataType,
24  typename EDataType,
25  typename AElementwiseOperation,
26  typename BElementwiseOperation,
27  typename EElementwiseOperation,
28  typename AGridDesc_AK0_M_AK1,
29  typename BGridDesc_BK0_N_BK1,
30  typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
31  typename Block2ETileMap,
32  bool HasMainKBlockLoop>
33 __global__ void
34 #if CK_USE_LAUNCH_BOUNDS
36 #endif
37  kernel_gemm_xdl_waveletmodel_cshuffle(const ABDataType* __restrict__ p_a_grid,
38  const ABDataType* __restrict__ p_b_grid,
39  EDataType* __restrict__ p_e_grid,
40  const AElementwiseOperation a_element_op,
41  const BElementwiseOperation b_element_op,
42  const EElementwiseOperation e_element_op,
43  const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
44  const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
45  const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
46  e_grid_desc_mblock_mperblock_nblock_nperblock,
47  const Block2ETileMap block_2_etile_map)
48 {
49 #if defined(__gfx9__)
50  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
51 
52  GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
53  p_b_grid,
54  p_e_grid,
55  p_shared,
56  a_element_op,
57  b_element_op,
58  e_element_op,
59  a_grid_desc_ak0_m_ak1,
60  b_grid_desc_bk0_n_bk1,
61  e_grid_desc_mblock_mperblock_nblock_nperblock,
62  block_2_etile_map);
63 #else
64  ignore = p_a_grid;
65  ignore = p_b_grid;
66  ignore = p_e_grid;
67  ignore = a_element_op;
68  ignore = b_element_op;
69  ignore = e_element_op;
70  ignore = a_grid_desc_ak0_m_ak1;
71  ignore = b_grid_desc_bk0_n_bk1;
72  ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
73  ignore = block_2_etile_map;
74 #endif
75 }
76 
77 } // namespace ck
78 
79 namespace ck {
80 namespace tensor_operation {
81 namespace device {
82 
83 template <typename ALayout,
84  typename BLayout,
85  typename ELayout,
86  typename ADataType,
87  typename BDataType,
88  typename GemmAcEDataType,
89  typename CShuffleDataType,
90  typename EDataType,
91  typename AElementwiseOperation,
92  typename BElementwiseOperation,
93  typename CDEElementwiseOperation,
94  GemmSpecialization GemmSpec,
95  index_t NumGemmKPrefetchStage,
96  index_t TileLoadThreadGroupSize,
97  index_t TileMathThreadGroupSize,
98  index_t MPerBlock,
99  index_t NPerBlock,
100  index_t KPerBlock,
101  index_t AK1,
102  index_t BK1,
103  index_t MPerXDL,
104  index_t NPerXDL,
105  index_t MXdlPerWave,
106  index_t NXdlPerWave,
107  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
108  typename ABlockTransferThreadClusterArrangeOrder,
109  typename ABlockTransferSrcAccessOrder,
110  index_t ABlockTransferSrcVectorDim,
111  index_t ABlockTransferSrcScalarPerVector,
112  index_t ABlockTransferDstScalarPerVector_AK1,
113  bool ABlockLdsExtraM,
114  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
115  typename BBlockTransferThreadClusterArrangeOrder,
116  typename BBlockTransferSrcAccessOrder,
117  index_t BBlockTransferSrcVectorDim,
118  index_t BBlockTransferSrcScalarPerVector,
119  index_t BBlockTransferDstScalarPerVector_BK1,
120  bool BBlockLdsExtraN,
121  index_t CShuffleMXdlPerWavePerShuffle,
122  index_t CShuffleNXdlPerWavePerShuffle,
123  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
124  index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
126  BLayout,
127  ELayout,
128  ADataType,
129  BDataType,
130  EDataType,
131  AElementwiseOperation,
132  BElementwiseOperation,
133  CDEElementwiseOperation>
134 {
136 
137  static constexpr auto I0 = Number<0>{};
138  static constexpr auto I1 = Number<1>{};
139  static constexpr auto I2 = Number<2>{};
140 
141  static constexpr auto matrix_padder =
142  MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
143 
144  static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
145  {
146  const auto a_grid_desc_mraw_kraw = [&]() {
147  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
148  {
149  return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
150  make_tuple(StrideA, I1));
151  }
152  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
153  {
154  return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
155  make_tuple(I1, StrideA));
156  }
157  }();
158 
159  return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
160  }
161 
162  static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
163  {
164  const auto b_grid_desc_nraw_kraw = [&]() {
166  {
167  return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
168  make_tuple(I1, StrideB));
169  }
171  {
172  return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
173  make_tuple(StrideB, I1));
174  }
175  }();
176 
177  return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
178  }
179 
180  template <typename ELay>
181  static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
182  {
183  const auto e_grid_desc_mraw_nraw = [&]() {
185  {
186  return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
187  make_tuple(StrideE, I1));
188  }
190  {
191  return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
192  make_tuple(I1, StrideE));
193  }
194  }();
195 
196  return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
197  }
198 
199  using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
200  using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
201  using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
202 
203  // GridwiseGemm
205  ADataType, // TODO: distinguish A/B datatype
206  GemmAcEDataType,
207  CShuffleDataType,
208  EDataType,
209  AElementwiseOperation,
210  BElementwiseOperation,
211  CDEElementwiseOperation,
216  NumGemmKPrefetchStage,
217  TileLoadThreadGroupSize,
218  TileMathThreadGroupSize,
219  MPerBlock,
220  NPerBlock,
221  KPerBlock,
222  AK1,
223  BK1,
224  MPerXDL,
225  NPerXDL,
226  MXdlPerWave,
227  NXdlPerWave,
228  ABlockTransferThreadClusterLengths_AK0_M_AK1,
229  ABlockTransferThreadClusterArrangeOrder,
230  ABlockTransferSrcAccessOrder,
231  ABlockTransferSrcVectorDim,
232  ABlockTransferSrcScalarPerVector,
233  ABlockTransferDstScalarPerVector_AK1,
234  false,
235  ABlockLdsExtraM,
236  BBlockTransferThreadClusterLengths_BK0_N_BK1,
237  BBlockTransferThreadClusterArrangeOrder,
238  BBlockTransferSrcAccessOrder,
239  BBlockTransferSrcVectorDim,
240  BBlockTransferSrcScalarPerVector,
241  BBlockTransferDstScalarPerVector_BK1,
242  false,
243  BBlockLdsExtraN,
244  CShuffleMXdlPerWavePerShuffle,
245  CShuffleNXdlPerWavePerShuffle,
246  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
247  CShuffleBlockTransferScalarPerVector_NPerBlock>;
248 
251  AGridDesc_M_K{}))>;
254  BGridDesc_N_K{}))>;
255 
257 
258  // Argument
259  struct Argument : public BaseArgument
260  {
261  Argument(const ADataType* p_a_grid,
262  const BDataType* p_b_grid,
263  EDataType* p_e_grid,
264  index_t MRaw,
265  index_t NRaw,
266  index_t KRaw,
267  index_t StrideA,
268  index_t StrideB,
269  index_t StrideE,
270  AElementwiseOperation a_element_op,
271  BElementwiseOperation b_element_op,
272  CDEElementwiseOperation cde_element_op)
273  : p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
274  p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
275  p_e_grid_{static_cast<EDataType*>(p_e_grid)},
276  a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)},
277  b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
278  e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideE)},
280  GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
282  GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
284  block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
285  a_element_op_{a_element_op},
286  b_element_op_{b_element_op},
287  cde_element_op_{cde_element_op}
288  {
291  {
295  }
296  }
297 
298  void Print() const
299  {
300  std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
301  std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
302  std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
303  }
304 
305  // private:
306  // pointers
307  const ADataType* p_a_grid_;
308  const BDataType* p_b_grid_;
309  EDataType* p_e_grid_;
310 
311  // tensor descriptors for problem definiton
315 
316  // tensor descriptors for block/thread-wise copy
321 
322  // block-to-e-tile map
324 
325  // element-wise op
326  AElementwiseOperation a_element_op_;
327  BElementwiseOperation b_element_op_;
328  CDEElementwiseOperation cde_element_op_;
329  };
330 
331  // Invoker
332  struct Invoker : public BaseInvoker
333  {
335 
336  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
337  {
338 #if 0
339  {
340  std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
341  << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
342  << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
343  << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
344 
345  std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
346  << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
347  << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
348  << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
349 
350  std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", "
351  << arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
352  }
353 #endif
354 
356  arg.b_grid_desc_n_k_,
357  arg.e_grid_desc_m_n_,
358  arg.block_2_etile_map_))
359  {
360  throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
361  }
362 
364  const auto K = arg.a_grid_desc_m_k_.GetLength(I1);
365 
366  auto launch_kernel = [&](auto has_main_k_block_loop) {
367  constexpr bool has_main_loop = has_main_k_block_loop.value;
368 
369  const auto kernel = kernel_gemm_xdl_waveletmodel_cshuffle<
370  GridwiseGemm,
371  ADataType, // TODO: distiguish A/B datatype
372  EDataType,
373  AElementwiseOperation,
374  BElementwiseOperation,
375  CDEElementwiseOperation,
380  has_main_loop>;
381 
382  return launch_and_time_kernel(
383  stream_config,
384  kernel,
385  dim3(grid_size),
386  dim3(TileLoadThreadGroupSize + TileMathThreadGroupSize),
387  0,
388  arg.p_a_grid_,
389  arg.p_b_grid_,
390  arg.p_e_grid_,
391  arg.a_element_op_,
392  arg.b_element_op_,
393  arg.cde_element_op_,
397  arg.block_2_etile_map_);
398  };
399 
401  {
403  }
404  else
405  {
406  return launch_kernel(integral_constant<bool, false>{});
407  }
408  }
409 
410  // polymorphic
411  float Run(const BaseArgument* p_arg,
412  const StreamConfig& stream_config = StreamConfig{}) override
413  {
414  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
415  }
416  };
417 
418  static bool IsSupportedArgument(const Argument& arg)
419  {
420  if(!ck::is_xdl_supported())
421  {
422  return false;
423  }
424 
426  arg.b_grid_desc_n_k_,
427  arg.e_grid_desc_m_n_,
428  arg.block_2_etile_map_);
429  }
430 
431  // polymorphic
432  bool IsSupportedArgument(const BaseArgument* p_arg) override
433  {
434  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
435  }
436 
437  static auto MakeArgument(const ADataType* p_a,
438  const BDataType* p_b,
439  EDataType* p_e,
440  index_t MRaw,
441  index_t NRaw,
442  index_t KRaw,
443  index_t StrideA,
444  index_t StrideB,
445  index_t StrideE,
446  AElementwiseOperation a_element_op,
447  BElementwiseOperation b_element_op,
448  CDEElementwiseOperation cde_element_op)
449  {
450  return Argument{p_a,
451  p_b,
452  p_e,
453  MRaw,
454  NRaw,
455  KRaw,
456  StrideA,
457  StrideB,
458  StrideE,
459  a_element_op,
460  b_element_op,
461  cde_element_op};
462  }
463 
464  static auto MakeInvoker() { return Invoker{}; }
465 
466  // polymorphic
467  std::unique_ptr<BaseArgument>
468  MakeArgumentPointer(const void* p_a,
469  const void* p_b,
470  void* p_e,
471  index_t MRaw,
472  index_t NRaw,
473  index_t KRaw,
474  index_t StrideA,
475  index_t StrideB,
476  index_t StrideE,
477  AElementwiseOperation a_element_op,
478  BElementwiseOperation b_element_op,
479  CDEElementwiseOperation cde_element_op) override
480  {
481  return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
482  static_cast<const BDataType*>(p_b),
483  static_cast<EDataType*>(p_e),
484  MRaw,
485  NRaw,
486  KRaw,
487  StrideA,
488  StrideB,
489  StrideE,
490  a_element_op,
491  b_element_op,
492  cde_element_op);
493  }
494 
495  // polymorphic
496  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
497  {
498  return std::make_unique<Invoker>(Invoker{});
499  }
500 
501  // polymorphic
502  std::string GetTypeString() const override
503  {
504  auto str = std::stringstream();
505 
506  // clang-format off
507  str << "DeviceGemm_Xdl_WaveletModel_CShuffle"
508  << "<"
509  << TileLoadThreadGroupSize << ", "
510  << TileMathThreadGroupSize << ", "
511  << MPerBlock << ", "
512  << NPerBlock << ", "
513  << KPerBlock << ", "
514  << AK1 << ", "
515  << BK1
516  << ">";
517  // clang-format on
518 
519  return str.str();
520  }
521 };
522 
523 } // namespace device
524 } // namespace tensor_operation
525 } // namespace ck
#define CK_WAVELET_MIN_BLOCK_PER_CU
Definition: ck.hpp:34
#define CK_WAVELET_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:33
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
GemmSpecialization
Definition: gemm_specialization.hpp:11
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition: kernel_launch.hpp:140
Definition: ck.hpp:267
bool is_xdl_supported()
Definition: device_prop.hpp:68
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__global__ void kernel_gemm_xdl_waveletmodel_cshuffle(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, EDataType *__restrict__ p_e_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const EElementwiseOperation e_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map)
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:37
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:298
Definition: stream_config.hpp:10
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:64
__host__ static constexpr __device__ auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDescriptor_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:314
__host__ static constexpr __device__ auto MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:298
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_M_K &a_grid_desc_m_k, const BGridDesc_N_K &b_grid_desc_n_k, const EGridDesc_M_N &e_grid_desc_m_n, const Block2ETileMap &)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:177
remove_cvref_t< decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))> DefaultBlock2ETileMap
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:338
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:226
__host__ static constexpr __device__ index_t CalculateGridSize(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:270
remove_cvref_t< decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))> EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:335
__host__ static constexpr __device__ auto MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:282
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: device_base.hpp:51
Definition: device_base.hpp:62
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:260
AElementwiseOperation a_element_op_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:326
BGridDesc_N_K b_grid_desc_n_k_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:313
const BDataType * p_b_grid_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:308
Block2ETileMap block_2_etile_map_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:323
CDEElementwiseOperation cde_element_op_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:328
const ADataType * p_a_grid_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:307
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:318
BElementwiseOperation b_element_op_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:327
EDataType * p_e_grid_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:309
void Print() const
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:298
AGridDesc_M_K a_grid_desc_m_k_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:312
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, EDataType *p_e_grid, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:261
GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:320
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:317
EGridDesc_M_N e_grid_desc_m_n_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:314
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:333
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:336
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:411
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:134
static constexpr auto matrix_padder
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:141
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))> AGridDesc_AK0_M_AK1
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:251
static constexpr auto I1
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:138
static auto MakeInvoker()
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:464
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:162
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_e, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:468
static constexpr auto I2
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:139
static constexpr auto I0
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:137
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:496
typename GridwiseGemm::DefaultBlock2ETileMap Block2ETileMap
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:256
decltype(MakeEGridDescriptor_M_N< ELayout >(1, 1, 1)) EGridDesc_M_N
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:201
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))> BGridDesc_BK0_N_BK1
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:254
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:432
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:144
GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle< ADataType, GemmAcEDataType, CShuffleDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_M_K, BGridDesc_N_K, EGridDesc_M_N, NumGemmKPrefetchStage, TileLoadThreadGroupSize, TileMathThreadGroupSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock > GridwiseGemm
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:247
decltype(MakeAGridDescriptor_M_K(1, 1, 1)) AGridDesc_M_K
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:199
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:418
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:181
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, EDataType *p_e, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:437
decltype(MakeBGridDescriptor_N_K(1, 1, 1)) BGridDesc_N_K
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:200
std::string GetTypeString() const override
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:502
Definition: device_gemm.hpp:22
Definition: matrix_padder.hpp:180