/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp Source File
device_gemm_dl.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
10 #include "ck/utility/env.hpp"
19 
20 namespace ck {
21 namespace tensor_operation {
22 namespace device {
23 
24 template <
25  typename ADataType,
26  typename BDataType,
27  typename CDataType,
28  typename AccDataType,
29  typename ALayout,
30  typename BLayout,
31  typename CLayout,
32  typename AElementwiseOperation,
33  typename BElementwiseOperation,
34  typename CElementwiseOperation,
35  GemmSpecialization GemmSpec,
36  index_t BlockSize,
37  index_t MPerBlock,
38  index_t NPerBlock,
39  index_t K0PerBlock,
40  index_t K1,
41  index_t M1PerThread,
42  index_t N1PerThread,
43  index_t KPerThread,
44  typename M1N1ThreadClusterM1Xs,
45  typename M1N1ThreadClusterN1Xs,
46  typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
47  typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
48  typename ABlockTransferThreadClusterArrangeOrder,
49  typename ABlockTransferSrcAccessOrder,
50  typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
51  typename ABlockTransferSrcVectorTensorContiguousDimOrder,
52  typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
53  typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
54  typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
55  typename BBlockTransferThreadClusterArrangeOrder,
56  typename BBlockTransferSrcAccessOrder,
57  typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
58  typename BBlockTransferSrcVectorTensorContiguousDimOrder,
59  typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
60  typename CThreadTransferSrcDstAccessOrder,
61  index_t CThreadTransferSrcDstVectorDim,
62  index_t CThreadTransferDstScalarPerVector,
64  is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
65  is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
66  is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
67  bool> = false>
68 struct DeviceGemmDl : public DeviceGemm<ALayout,
69  BLayout,
70  CLayout,
71  ADataType,
72  BDataType,
73  CDataType,
74  AElementwiseOperation,
75  BElementwiseOperation,
76  CElementwiseOperation>
77 
78 {
79  static constexpr auto I0 = Number<0>{};
80  static constexpr auto I1 = Number<1>{};
81  static constexpr auto I2 = Number<2>{};
82  static constexpr auto I3 = Number<3>{};
83  static constexpr auto I4 = Number<4>{};
84  static constexpr auto I5 = Number<5>{};
85 
86  static constexpr auto K1Number = Number<K1>{};
87 
89  {
90  assert(K % K1 == 0);
91 
92  const index_t K0 = K / K1;
93 
94  const auto a_grid_desc_m_k = [&]() {
96  {
97  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
98  }
100  {
101  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
102  }
103  }();
104 
105  if constexpr(GemmSpec == GemmSpecialization::MNPadding)
106  {
107  const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
108 
110  a_grid_desc_m_k,
112  make_right_pad_transform(M, PadM)),
115  }
116  else
117  {
119  a_grid_desc_m_k,
124  }
125  }
126 
128  {
129  assert(K % K1 == 0);
130 
131  const index_t K0 = K / K1;
132 
133  const auto b_grid_desc_k_n = [&]() {
135  {
136  return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
137  }
139  {
140  return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
141  }
142  }();
143 
144  if constexpr(GemmSpec == GemmSpecialization::MNPadding)
145  {
146  const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
147 
149  b_grid_desc_k_n,
151  make_right_pad_transform(N, PadN)),
154  }
155  else
156  {
158  b_grid_desc_k_n,
163  }
164  }
165 
166  static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
167  {
168  const auto c_grid_desc_m_n = [&]() {
170  {
171  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
172  }
174  {
175  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
176  }
177  }();
178 
179  if constexpr(GemmSpec == GemmSpecialization::MNPadding)
180  {
181  const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
182  const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
183 
185  c_grid_desc_m_n,
189  }
190  else
191  {
192 
194  c_grid_desc_m_n,
198  }
199  }
200 
201  using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
202  using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
203  using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
204 
205  // GridwiseGemm
206  using GridwiseGemm =
208  ADataType,
209  AccDataType,
210  CDataType,
215  MPerBlock,
216  NPerBlock,
217  K0PerBlock,
218  K1,
219  M1PerThread,
220  N1PerThread,
221  KPerThread,
222  M1N1ThreadClusterM1Xs,
223  M1N1ThreadClusterN1Xs,
224  ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
225  ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
226  ABlockTransferThreadClusterArrangeOrder,
227  ABlockTransferSrcAccessOrder,
228  ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
229  ABlockTransferSrcVectorTensorContiguousDimOrder,
230  ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
231  BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
232  BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
233  BBlockTransferThreadClusterArrangeOrder,
234  BBlockTransferSrcAccessOrder,
235  BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
236  BBlockTransferSrcVectorTensorContiguousDimOrder,
237  BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
238  CThreadTransferSrcDstAccessOrder,
239  CThreadTransferSrcDstVectorDim,
240  CThreadTransferDstScalarPerVector>;
241 
250 
251  // Argument
252  struct Argument : public BaseArgument
253  {
254  Argument(const ADataType* p_a_grid,
255  const BDataType* p_b_grid,
256  CDataType* p_c_grid,
257  index_t M,
258  index_t N,
259  index_t K,
260  index_t StrideA,
261  index_t StrideB,
262  index_t StrideC,
263  index_t M01,
264  index_t N01,
265  AElementwiseOperation a_element_op,
266  BElementwiseOperation b_element_op,
267  CElementwiseOperation c_element_op)
268  : p_a_grid_{p_a_grid},
269  p_b_grid_{p_b_grid},
270  p_c_grid_{p_c_grid},
275  M01_{M01},
276  N01_{N01},
277  M_raw_{M},
278  N_raw_{N},
279  K_raw_{K},
280  a_element_op_{a_element_op},
281  b_element_op_{b_element_op},
282  c_element_op_{c_element_op}
283  {
287 
290  {
297 
299  }
300  }
301 
302  // private:
303  const ADataType* p_a_grid_;
304  const BDataType* p_b_grid_;
305  CDataType* p_c_grid_;
306 
310 
314 
316 
317  // TODO: unused, but may be useful in future.
320 
324 
325  // TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being.
326  AElementwiseOperation a_element_op_;
327  BElementwiseOperation b_element_op_;
328  CElementwiseOperation c_element_op_;
329  };
330 
331  // Invoker
332  struct Invoker : public BaseInvoker
333  {
335 
336  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
337  {
338  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
339  {
340  std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{"
341  << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
342  << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
343  << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
344 
345  std::cout << "arg.b_grid_desc_k0_n0_n1_k1_{"
346  << arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
347  << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
348  << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
349 
350  std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
351  << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
352  }
353 
356  {
357  throw std::runtime_error(
358  "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdl_v2r3 has invalid setting");
359  }
360 
361  const index_t grid_size = GridwiseGemm::CalculateGridSize(
362  arg.c_grid_desc_m_n_.GetLength(I0), arg.c_grid_desc_m_n_.GetLength(I1));
363 
364  const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0);
365  const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
366  const bool has_double_tail_k_block_loop =
368 
369  float ave_time = 0;
370 
371  if(has_main_k_block_loop && has_double_tail_k_block_loop)
372  {
373  const auto kernel =
375  ADataType,
376  CDataType,
381  true,
382  true>;
383 
384  ave_time = launch_and_time_kernel(stream_config,
385  kernel,
386  dim3(grid_size),
387  dim3(BlockSize),
388  0,
389  arg.p_a_grid_,
390  arg.p_b_grid_,
391  arg.p_c_grid_,
395  arg.block_2_ctile_map_);
396  }
397  else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
398  {
399  const auto kernel =
401  ADataType,
402  CDataType,
407  true,
408  false>;
409 
410  ave_time = launch_and_time_kernel(stream_config,
411  kernel,
412  dim3(grid_size),
413  dim3(BlockSize),
414  0,
415  arg.p_a_grid_,
416  arg.p_b_grid_,
417  arg.p_c_grid_,
421  arg.block_2_ctile_map_);
422  }
423  else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
424  {
425  const auto kernel =
427  ADataType,
428  CDataType,
429  remove_reference_t<AGridDesc_K0_M0_M1_K1>,
430  remove_reference_t<BGridDesc_K0_N0_N1_K1>,
431  remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
432  remove_reference_t<DefaultBlock2CTileMap>,
433  false,
434  true>;
435 
436  ave_time = launch_and_time_kernel(stream_config,
437  kernel,
438  dim3(grid_size),
439  dim3(BlockSize),
440  0,
441  arg.p_a_grid_,
442  arg.p_b_grid_,
443  arg.p_c_grid_,
447  arg.block_2_ctile_map_);
448  }
449  else
450  {
451  const auto kernel =
453  ADataType,
454  CDataType,
455  remove_reference_t<AGridDesc_K0_M0_M1_K1>,
456  remove_reference_t<BGridDesc_K0_N0_N1_K1>,
457  remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
458  remove_reference_t<DefaultBlock2CTileMap>,
459  false,
460  false>;
461 
462  ave_time = launch_and_time_kernel(stream_config,
463  kernel,
464  dim3(grid_size),
465  dim3(BlockSize),
466  0,
467  arg.p_a_grid_,
468  arg.p_b_grid_,
469  arg.p_c_grid_,
473  arg.block_2_ctile_map_);
474  }
475 
476  return ave_time;
477  }
478 
479  // polymorphic
480  float Run(const BaseArgument* p_arg,
481  const StreamConfig& stream_config = StreamConfig{}) override
482  {
483  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
484  }
485  };
486 
487  static constexpr bool IsValidCompilationParameter()
488  {
489  // TODO: properly implement this check
490  return true;
491  }
492 
493  static bool IsSupportedArgument(const Argument& arg)
494  {
495  // Make sure that the M, N, K dimensions before padding are divisible by respective vector
496  // lengths.
498  {
499  constexpr auto A_K_vec_length =
500  ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I0) *
501  ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I3);
502  if(arg.K_raw_ % A_K_vec_length != 0)
503  {
504  return false;
505  }
506  }
507  else
508  {
509  constexpr auto A_M_vec_lenght =
510  ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I1) *
511  ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I2);
512  if(arg.M_raw_ % A_M_vec_lenght != 0)
513  {
514  return false;
515  }
516  }
517 
519  {
520  constexpr auto B_N_vec_lenght =
521  BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I1) *
522  BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I2);
523  if(arg.N_raw_ % B_N_vec_lenght != 0)
524  {
525  return false;
526  }
527  }
528  else
529  {
530  constexpr auto B_K_vec_length =
531  BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I0) *
532  BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I3);
533  if(arg.K_raw_ % B_K_vec_length != 0)
534  {
535  return false;
536  }
537  }
538 
539  if(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
541  {
544  }
545  return false;
546  }
547 
548  // polymorphic
549  bool IsSupportedArgument(const BaseArgument* p_arg) override
550  {
551  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
552  }
553 
554  static auto MakeArgument(const ADataType* p_a,
555  const BDataType* p_b,
556  CDataType* p_c,
557  index_t M,
558  index_t N,
559  index_t K,
560  index_t StrideA,
561  index_t StrideB,
562  index_t StrideC,
563  AElementwiseOperation a_element_op,
564  BElementwiseOperation b_element_op,
565  CElementwiseOperation c_element_op)
566  {
567  return Argument{p_a,
568  p_b,
569  p_c,
570  M,
571  N,
572  K,
573  StrideA,
574  StrideB,
575  StrideC,
576  1,
577  1,
578  a_element_op,
579  b_element_op,
580  c_element_op};
581  }
582 
583  static auto MakeInvoker() { return Invoker{}; }
584 
585  // polymorphic
586  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
587  const void* p_b,
588  void* p_c,
589  index_t M,
590  index_t N,
591  index_t K,
592  index_t StrideA,
593  index_t StrideB,
594  index_t StrideC,
595  AElementwiseOperation a_element_op,
596  BElementwiseOperation b_element_op,
597  CElementwiseOperation c_element_op) override
598  {
599  return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
600  static_cast<const BDataType*>(p_b),
601  static_cast<CDataType*>(p_c),
602  M,
603  N,
604  K,
605  StrideA,
606  StrideB,
607  StrideC,
608  1,
609  1,
610  a_element_op,
611  b_element_op,
612  c_element_op);
613  }
614 
615  // polymorphic
616  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
617  {
618  return std::make_unique<Invoker>(Invoker{});
619  }
620 
621  // polymorphic
622  virtual std::string GetTypeString() const override
623  {
624  auto str = std::stringstream();
625 
626  // clang-format off
627  str << "DeviceGemmDl"
628  << "<"
629  << BlockSize << ", "
630  << MPerBlock << ", "
631  << NPerBlock << ", "
632  << K0PerBlock << ", "
633  << K1 << ", "
634  << M1PerThread << ", "
635  << N1PerThread << ", "
636  << KPerThread
637  << ">";
638  // clang-format on
639 
640  return str.str();
641  }
642 };
643 
644 } // namespace device
645 } // namespace tensor_operation
646 } // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:267
typename remove_reference< T >::type remove_reference_t
Definition: type.hpp:292
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
std::string get_device_name()
Definition: device_prop.hpp:19
bool is_gfx12_supported()
Definition: device_prop.hpp:55
__global__ void kernel_gemm_dl_v1r3(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map)
Definition: gridwise_gemm_dl_v1r3.hpp:33
bool is_gfx103_supported()
Definition: device_prop.hpp:97
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
bool is_gfx11_supported()
Definition: device_prop.hpp:60
Definition: stream_config.hpp:10
Definition: gridwise_gemm_dl_v1r3.hpp:93
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_v1r3.hpp:129
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_v1r3.hpp:208
__host__ static constexpr __device__ auto MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1)
Definition: gridwise_gemm_dl_v1r3.hpp:188
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_dl_v1r3.hpp:146
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K0)
Definition: gridwise_gemm_dl_v1r3.hpp:153
__host__ static constexpr __device__ auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_v1r3.hpp:241
__host__ static constexpr __device__ bool CalculateHasDoubleTailKBlockLoop(index_t K0)
Definition: gridwise_gemm_dl_v1r3.hpp:160
__host__ static constexpr __device__ auto MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1)
Definition: gridwise_gemm_dl_v1r3.hpp:168
Definition: sequence.hpp:43
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: device_base.hpp:51
Definition: device_base.hpp:62
index_t M_raw_
Definition: device_gemm_dl.hpp:321
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_
Definition: device_gemm_dl.hpp:307
CGridDesc_M_N c_grid_desc_m_n_
Definition: device_gemm_dl.hpp:309
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_
Definition: device_gemm_dl.hpp:312
index_t M01_
Definition: device_gemm_dl.hpp:318
index_t N01_
Definition: device_gemm_dl.hpp:319
CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_
Definition: device_gemm_dl.hpp:313
index_t K_raw_
Definition: device_gemm_dl.hpp:323
CDataType * p_c_grid_
Definition: device_gemm_dl.hpp:305
index_t N_raw_
Definition: device_gemm_dl.hpp:322
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_
Definition: device_gemm_dl.hpp:308
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t M01, index_t N01, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: device_gemm_dl.hpp:254
const BDataType * p_b_grid_
Definition: device_gemm_dl.hpp:304
AElementwiseOperation a_element_op_
Definition: device_gemm_dl.hpp:326
BElementwiseOperation b_element_op_
Definition: device_gemm_dl.hpp:327
DefaultBlock2CTileMap block_2_ctile_map_
Definition: device_gemm_dl.hpp:315
AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_
Definition: device_gemm_dl.hpp:311
CElementwiseOperation c_element_op_
Definition: device_gemm_dl.hpp:328
const ADataType * p_a_grid_
Definition: device_gemm_dl.hpp:303
Definition: device_gemm_dl.hpp:333
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_dl.hpp:336
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_dl.hpp:480
Definition: device_gemm_dl.hpp:78
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{})) DefaultBlock2CTileMap
Definition: device_gemm_dl.hpp:249
static constexpr auto I0
Definition: device_gemm_dl.hpp:79
static constexpr auto I2
Definition: device_gemm_dl.hpp:81
GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector > GridwiseGemm
Definition: device_gemm_dl.hpp:240
virtual std::string GetTypeString() const override
Definition: device_gemm_dl.hpp:622
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_dl.hpp:549
decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)) BGridDesc_K0_N_K1
Definition: device_gemm_dl.hpp:202
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_dl.hpp:493
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})) AGridDesc_K0_M0_M1_K1
Definition: device_gemm_dl.hpp:243
static constexpr auto I3
Definition: device_gemm_dl.hpp:82
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_dl.hpp:616
static auto MakeInvoker()
Definition: device_gemm_dl.hpp:583
decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)) AGridDesc_K0_M_K1
Definition: device_gemm_dl.hpp:201
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
Definition: device_gemm_dl.hpp:127
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})) BGridDesc_K0_N0_N1_K1
Definition: device_gemm_dl.hpp:245
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition: device_gemm_dl.hpp:586
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})) CGridDesc_M0_M10_M11_N0_N10_N11
Definition: device_gemm_dl.hpp:247
static constexpr auto I5
Definition: device_gemm_dl.hpp:84
decltype(MakeCGridDescriptor_M_N(1, 1, 1)) CGridDesc_M_N
Definition: device_gemm_dl.hpp:203
static constexpr auto I1
Definition: device_gemm_dl.hpp:80
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_dl.hpp:487
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: device_gemm_dl.hpp:554
static constexpr auto I4
Definition: device_gemm_dl.hpp:83
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
Definition: device_gemm_dl.hpp:166
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
Definition: device_gemm_dl.hpp:88
static constexpr auto K1Number
Definition: device_gemm_dl.hpp:86
Definition: device_gemm.hpp:22
#define CK_ENV(name)
Definition: env.hpp:129