/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp Source File
device_gemm_reduce_xdl_cshuffle.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
18 
19 namespace ck {
20 namespace tensor_operation {
21 namespace device {
22 
23 // Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
24 // version currently has compiler issues with register spill which further causes validation
25 // failures.
26 template <typename ALayout,
27  typename BLayout,
28  typename CLayout,
29  typename ADataType,
30  typename BDataType,
31  typename CDataType,
32  typename GemmAccDataType,
33  typename CShuffleDataType,
34  typename ReduceAccDataType,
35  typename ReducePtrsGlobal,
36  typename AElementwiseOperation,
37  typename BElementwiseOperation,
38  typename CElementwiseOperation,
39  typename ReduceOperations,
40  typename ReduceInElementwiseOperations,
41  typename ReduceAccElementwiseOperations,
42  typename ReduceGlobalMemoryDataOperation,
43  GemmSpecialization GemmSpec,
44  index_t NumGemmKPrefetchStage,
45  index_t BlockSize,
46  index_t MPerBlock,
47  index_t NPerBlock,
48  index_t KPerBlock,
49  index_t AK1,
50  index_t BK1,
51  index_t MPerXDL,
52  index_t NPerXDL,
53  index_t MXdlPerWave,
54  index_t NXdlPerWave,
55  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
56  typename ABlockTransferThreadClusterArrangeOrder,
57  typename ABlockTransferSrcAccessOrder,
58  index_t ABlockTransferSrcVectorDim,
59  index_t ABlockTransferSrcScalarPerVector,
60  index_t ABlockTransferDstScalarPerVector_AK1,
61  bool ABlockLdsExtraM,
62  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
63  typename BBlockTransferThreadClusterArrangeOrder,
64  typename BBlockTransferSrcAccessOrder,
65  index_t BBlockTransferSrcVectorDim,
66  index_t BBlockTransferSrcScalarPerVector,
67  index_t BBlockTransferDstScalarPerVector_BK1,
68  bool BBlockLdsExtraN,
69  index_t CShuffleMXdlPerWavePerShuffle,
70  index_t CShuffleNXdlPerWavePerShuffle,
71  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
72  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
73  typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
74  index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
75  index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
77 struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperations::Size()>
78 {
80 
82  static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
83  static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
84 
85  static constexpr auto I0 = Number<0>{};
86  static constexpr auto I1 = Number<1>{};
87  static constexpr auto I2 = Number<2>{};
88 
89  static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
90  {
91  const auto a_grid_desc_mraw_kraw = [&]() {
92  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
93  {
94  return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
95  make_tuple(StrideA, I1));
96  }
97  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
98  {
99  return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
100  make_tuple(I1, StrideA));
101  }
102  }();
103 
104  const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
105  const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
106 
107  const auto MPad = M - MRaw;
108  const auto KPad = K - KRaw;
109 
110  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
111  GemmSpec == GemmSpecialization::MNKPadding)
112  {
113  // pad both M and K
114  assert(K % AK1 == 0);
115 
116  const auto AK0 = K / AK1;
117 
118  const auto a_grid_desc_m_k =
119  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
121  make_right_pad_transform(KRaw, KPad)),
124 
125  const auto a_grid_desc_ak0_m_ak1 =
126  transform_tensor_descriptor(a_grid_desc_m_k,
131 
132  return a_grid_desc_ak0_m_ak1;
133  }
134  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
135  GemmSpec == GemmSpecialization::MNPadding)
136  {
137  // pad M, but not K
138  assert(KRaw % AK1 == 0);
139 
140  const auto AK0 = KRaw / AK1;
141 
142  const auto a_grid_desc_ak0_m_ak1 =
143  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
145  make_right_pad_transform(MRaw, MPad)),
148 
149  return a_grid_desc_ak0_m_ak1;
150  }
151  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
152  GemmSpec == GemmSpecialization::NKPadding)
153  {
154  // pad K, but not M
155  assert(K % AK1 == 0);
156 
157  const auto AK0 = K / AK1;
158 
159  const auto a_grid_desc_m_k = transform_tensor_descriptor(
160  a_grid_desc_mraw_kraw,
164 
165  const auto a_grid_desc_ak0_m_ak1 =
166  transform_tensor_descriptor(a_grid_desc_m_k,
171 
172  return a_grid_desc_ak0_m_ak1;
173  }
174  else
175  {
176  // not pad M or K
177  assert(KRaw % AK1 == 0);
178 
179  const auto AK0 = KRaw / AK1;
180 
181  const auto a_grid_desc_ak0_m_ak1 =
182  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
187 
188  return a_grid_desc_ak0_m_ak1;
189  }
190  }
191 
192  static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
193  {
194  const auto b_grid_desc_nraw_kraw = [&]() {
196  {
197  return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
198  make_tuple(I1, StrideB));
199  }
201  {
202  return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
203  make_tuple(StrideB, I1));
204  }
205  }();
206 
207  const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
208  const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
209 
210  const auto NPad = N - NRaw;
211  const auto KPad = K - KRaw;
212 
213  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
214  GemmSpec == GemmSpecialization::MNKPadding)
215  {
216  // pad both N and K
217  assert(K % BK1 == 0);
218 
219  const auto BK0 = K / BK1;
220 
221  const auto b_grid_desc_n_k =
222  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
224  make_right_pad_transform(KRaw, KPad)),
227 
228  const auto b_grid_desc_bk0_n_bk1 =
229  transform_tensor_descriptor(b_grid_desc_n_k,
234 
235  return b_grid_desc_bk0_n_bk1;
236  }
237  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
238  GemmSpec == GemmSpecialization::MNPadding)
239  {
240  // pad N, but not K
241  assert(KRaw % BK1 == 0);
242 
243  const auto BK0 = KRaw / BK1;
244 
245  const auto b_grid_desc_bk0_n_bk1 =
246  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
248  make_right_pad_transform(NRaw, NPad)),
251 
252  return b_grid_desc_bk0_n_bk1;
253  }
254  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
255  GemmSpec == GemmSpecialization::MKPadding)
256  {
257  // pad K, but not N
258  assert(K % BK1 == 0);
259 
260  const auto BK0 = K / BK1;
261 
262  const auto b_grid_desc_n_k = transform_tensor_descriptor(
263  b_grid_desc_nraw_kraw,
267 
268  const auto b_grid_desc_bk0_n_bk1 =
269  transform_tensor_descriptor(b_grid_desc_n_k,
274 
275  return b_grid_desc_bk0_n_bk1;
276  }
277  else
278  {
279  // not pad N or K
280  assert(KRaw % BK1 == 0);
281 
282  const auto BK0 = KRaw / BK1;
283 
284  const auto b_grid_desc_bk0_n_bk1 =
285  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
290 
291  return b_grid_desc_bk0_n_bk1;
292  }
293  }
294 
295  static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
296  {
297  const auto c_grid_desc_mraw_nraw = [&]() {
299  {
300  return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
301  make_tuple(StrideC, I1));
302  }
304  {
305  return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
306  make_tuple(I1, StrideC));
307  }
308  }();
309 
310  const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
311  const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
312 
313  const auto MPad = M - MRaw;
314  const auto NPad = N - NRaw;
315 
316  if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
317  GemmSpec == GemmSpecialization::MNKPadding)
318  {
319  // pad M and N
320  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
322  make_right_pad_transform(NRaw, NPad)),
325  }
326  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
327  GemmSpec == GemmSpecialization::MKPadding)
328  {
329  // pad M, but not N
331  c_grid_desc_mraw_nraw,
335  }
336  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
337  GemmSpec == GemmSpecialization::NKPadding)
338  {
339  // pad N, but not M
341  c_grid_desc_mraw_nraw,
345  }
346  else
347  {
348  // not pad M or N
349  return c_grid_desc_mraw_nraw;
350  }
351  }
352 
353  // assume Reduce is packed tensor
355  {
356  const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
357 
358  const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
359  const auto MPad = M - MRaw;
360 
361  if constexpr(GemmSpec == GemmSpecialization::MPadding ||
362  GemmSpec == GemmSpecialization::MNPadding ||
363  GemmSpec == GemmSpecialization::MKPadding ||
364  GemmSpec == GemmSpecialization::MNKPadding)
365  {
366  // pad M
367  return transform_tensor_descriptor(d_grid_desc_mraw,
371  }
372  else
373  {
374  // not pad M
375  return d_grid_desc_mraw;
376  }
377  }
378 
381  using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
383 
384  // GridwiseGemm
385  template <index_t NXdlPerWave_>
387  ADataType, // TODO: distinguish A/B datatype
388  GemmAccDataType,
389  CShuffleDataType,
390  CDataType,
391  ReduceAccDataType,
392  ReducePtrsGlobal,
393  AElementwiseOperation,
394  BElementwiseOperation,
395  CElementwiseOperation,
396  ReduceOperations,
397  ReduceInElementwiseOperations,
398  ReduceAccElementwiseOperations,
400  ReduceGlobalMemoryDataOperation,
405  NumGemmKPrefetchStage,
406  BlockSize,
407  MPerBlock,
408  NPerBlock,
409  KPerBlock,
410  AK1,
411  BK1,
412  MPerXDL,
413  NPerXDL,
414  MXdlPerWave,
415  NXdlPerWave_,
416  ABlockTransferThreadClusterLengths_AK0_M_AK1,
417  ABlockTransferThreadClusterArrangeOrder,
418  ABlockTransferSrcAccessOrder,
419  ABlockTransferSrcVectorDim,
420  ABlockTransferSrcScalarPerVector,
421  ABlockTransferDstScalarPerVector_AK1,
422  false,
423  ABlockLdsExtraM,
424  BBlockTransferThreadClusterLengths_BK0_N_BK1,
425  BBlockTransferThreadClusterArrangeOrder,
426  BBlockTransferSrcAccessOrder,
427  BBlockTransferSrcVectorDim,
428  BBlockTransferSrcScalarPerVector,
429  BBlockTransferDstScalarPerVector_BK1,
430  false,
431  BBlockLdsExtraN,
432  CShuffleMXdlPerWavePerShuffle,
433  CShuffleNXdlPerWavePerShuffle,
434  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
435  CShuffleBlockTransferScalarPerVector_NPerBlock,
436  CReduceThreadClusterLengths_MPerBlock_NPerBlock,
437  CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
438  CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
439  LoopSched>;
442 
443  // Argument
444  struct Argument : public BaseArgument
445  {
446  Argument(const ADataType* p_a_grid,
447  const BDataType* p_b_grid,
448  CDataType* p_c_grid,
449  ReducePtrsGlobal p_reduces_grid,
450  index_t MRaw,
451  index_t NRaw,
452  index_t KRaw,
453  index_t StrideA,
454  index_t StrideB,
455  index_t StrideC,
456  AElementwiseOperation a_element_op,
457  BElementwiseOperation b_element_op,
458  CElementwiseOperation c_element_op,
459  ReduceInElementwiseOperations reduce_in_element_ops,
460  ReduceAccElementwiseOperations reduce_out_element_ops)
461  : p_a_grid_{p_a_grid},
462  p_b_grid_{p_b_grid},
463  p_c_grid_{p_c_grid},
464  p_reduces_grid_{p_reduces_grid},
467  c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
469  block_2_ctile_map_{GridwiseGemm64::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
470  a_element_op_{a_element_op},
471  b_element_op_{b_element_op},
472  c_element_op_{c_element_op},
473  reduce_in_element_ops_{reduce_in_element_ops},
474  reduce_out_element_ops_{reduce_out_element_ops}
475  {
476  }
477 
478  // private:
479  const ADataType* p_a_grid_;
480  const BDataType* p_b_grid_;
481  CDataType* p_c_grid_;
482  ReducePtrsGlobal p_reduces_grid_;
488  AElementwiseOperation a_element_op_;
489  BElementwiseOperation b_element_op_;
490  CElementwiseOperation c_element_op_;
491  ReduceInElementwiseOperations reduce_in_element_ops_;
492  ReduceAccElementwiseOperations reduce_out_element_ops_;
493  };
494 
495  // Invoker
496  struct Invoker : public BaseInvoker
497  {
499 
500  template <typename GridwiseGemm>
501  float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
502  {
503  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
504  {
505  std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
506  << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
507  << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
508  << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
509 
510  std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
511  << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
512  << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
513  << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
514 
515  std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
516  << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
517 
518  std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0)
519  << "}" << std::endl;
520  }
521 
522  if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
524  arg.c_grid_desc_m_n_,
525  arg.block_2_ctile_map_))
526  {
527  throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
528  }
529  auto c_grid_desc_mblock_mperblock_nblock_nperblock =
530  GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
531  arg.c_grid_desc_m_n_);
532 
533  auto reduce_grid_desc_mblock_mperblock =
534  GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(arg.reduce_grid_desc_m_);
535 
536  const index_t grid_size =
537  arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
538 
539  const auto K =
540  arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
541 
542  float elapsed_time = 0.0f;
543  if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
544  {
545  const auto kernel = kernel_gemm_reduce_xdl_cshuffle_v1<
546  GridwiseGemm,
547  ADataType, // TODO: distiguish A/B datatype
548  CDataType,
549  ReducePtrsGlobal,
550  AElementwiseOperation,
551  BElementwiseOperation,
552  CElementwiseOperation,
553  ReduceInElementwiseOperations,
554  ReduceAccElementwiseOperations,
557  typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
558  typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
559  typename GridwiseGemm::DefaultBlock2CTileMap,
560  true>;
561 
562  elapsed_time = launch_and_time_kernel(stream_config,
563  kernel,
564  dim3(grid_size),
565  dim3(BlockSize),
566  0,
567  arg.p_a_grid_,
568  arg.p_b_grid_,
569  arg.p_c_grid_,
570  arg.p_reduces_grid_,
571  arg.a_element_op_,
572  arg.b_element_op_,
573  arg.c_element_op_,
578  c_grid_desc_mblock_mperblock_nblock_nperblock,
579  reduce_grid_desc_mblock_mperblock,
580  arg.block_2_ctile_map_);
581  }
582  else
583  {
584  const auto kernel = kernel_gemm_reduce_xdl_cshuffle_v1<
585  GridwiseGemm,
586  ADataType, // TODO: distiguish A/B datatype
587  CDataType,
588  ReducePtrsGlobal,
589  AElementwiseOperation,
590  BElementwiseOperation,
591  CElementwiseOperation,
592  ReduceInElementwiseOperations,
593  ReduceAccElementwiseOperations,
596  typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
597  typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
598  typename GridwiseGemm::DefaultBlock2CTileMap,
599  false>;
600 
601  elapsed_time = launch_and_time_kernel(stream_config,
602  kernel,
603  dim3(grid_size),
604  dim3(BlockSize),
605  0,
606  arg.p_a_grid_,
607  arg.p_b_grid_,
608  arg.p_c_grid_,
609  arg.p_reduces_grid_,
610  arg.a_element_op_,
611  arg.b_element_op_,
612  arg.c_element_op_,
617  c_grid_desc_mblock_mperblock_nblock_nperblock,
618  reduce_grid_desc_mblock_mperblock,
619  arg.block_2_ctile_map_);
620  }
621 
622  return elapsed_time;
623  }
624 
626 
627  // polymorphic
628  float Run(const BaseArgument* p_arg,
629  const StreamConfig& stream_config = StreamConfig{}) override
630  {
631  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
632  }
633  };
634 
635  static constexpr bool IsValidCompilationParameter()
636  {
637  // TODO: properly implement this check
638  return true;
639  }
640 
641  static bool IsSupportedArgument(const Argument& arg)
642  {
643  if(!ck::is_xdl_wmma_supported<ADataType, BDataType, MPerXDL, NPerXDL>())
644  {
645  return false;
646  }
647  if(get_warp_size() == 64)
648  {
649  if constexpr(NXdlPerWave64 > 0)
650  {
653  arg.c_grid_desc_m_n_,
654  arg.block_2_ctile_map_);
655  }
656  }
657  else
658  {
659  if constexpr(NXdlPerWave32 > 0)
660  {
663  arg.c_grid_desc_m_n_,
664  arg.block_2_ctile_map_);
665  }
666  }
667  return false;
668  }
669 
670  // polymorphic
671  bool IsSupportedArgument(const BaseArgument* p_arg) override
672  {
673  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
674  }
675 
676  static constexpr int NumReduce = ReduceOperations::Size();
677  static auto MakeArgument(const void* p_a,
678  const void* p_b,
679  const void* p_bias,
680  std::array<const void*, 0> p_ds,
681  void* p_c,
682  std::array<void*, NumReduce> p_reduces,
683  ck::index_t M,
684  ck::index_t N,
685  ck::index_t K,
686  ck::index_t StrideA,
687  ck::index_t StrideB,
688  ck::index_t StrideC,
689  std::array<ck::index_t, 0> StrideDs,
690  std::array<void*, 3> gemm_element_ops,
691  std::array<void*, 0> d_element_ops,
692  std::array<void*, NumReduce> reduce_in_element_op,
693  std::array<void*, NumReduce> reduce_out_element_op)
694  {
695  (void)p_bias;
696  (void)p_ds;
697  (void)StrideDs;
698  (void)d_element_ops;
699 
700  ReducePtrsGlobal reduce_tuple = generate_tuple(
701  [&](auto I) {
702  auto tmp = ReducePtrsGlobal{}[I];
703  using T = remove_pointer_t<decltype(tmp)>;
704  return static_cast<T*>(p_reduces[I]);
705  },
707 
708  ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
709  [&](auto I) {
710  auto tmp = ReduceInElementwiseOperations{}[I];
711  using T = remove_pointer_t<decltype(tmp)>;
712  return *(static_cast<T*>(reduce_in_element_op[I]));
713  },
715  ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
716  [&](auto I) {
717  auto tmp = ReduceAccElementwiseOperations{}[I];
718  using T = remove_pointer_t<decltype(tmp)>;
719  return *(static_cast<T*>(reduce_out_element_op[I]));
720  },
722 
723  AElementwiseOperation a_element_op =
724  *(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
725  BElementwiseOperation b_element_op =
726  *(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
727  CElementwiseOperation c_element_op =
728  *(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
729 
730  return Argument{static_cast<const ADataType*>(p_a),
731  static_cast<const BDataType*>(p_b),
732  static_cast<CDataType*>(p_c),
733  reduce_tuple,
734  M,
735  N,
736  K,
737  StrideA,
738  StrideB,
739  StrideC,
740  a_element_op,
741  b_element_op,
742  c_element_op,
743  reduce_in_element_ops,
744  reduce_out_element_ops};
745  }
746 
747  static auto MakeInvoker() { return Invoker{}; }
748 
749  // polymorphic
750  std::unique_ptr<BaseArgument>
751  MakeArgumentPointer(const void* p_a,
752  const void* p_b,
753  const void* p_bias,
754  std::array<const void*, 0> p_ds,
755  void* p_c,
756  std::array<void*, NumReduce> p_reduces,
757  ck::index_t M,
758  ck::index_t N,
759  ck::index_t K,
760  ck::index_t StrideA,
761  ck::index_t StrideB,
762  ck::index_t StrideC,
763  std::array<ck::index_t, 0> StrideDs,
764  std::array<void*, 3> gemm_element_ops,
765  std::array<void*, 0> d_element_ops,
766  std::array<void*, NumReduce> reduce_in_element_op,
767  std::array<void*, NumReduce> reduce_out_element_op,
768  ck::index_t = 1) override
769  {
770  (void)p_bias;
771  (void)p_ds;
772  (void)StrideDs;
773  (void)d_element_ops;
774 
775  ReducePtrsGlobal reduce_tuple = generate_tuple(
776  [&](auto I) {
777  auto tmp = ReducePtrsGlobal{}[I];
778  using T = remove_pointer_t<decltype(tmp)>;
779  return static_cast<T*>(p_reduces[I]);
780  },
782 
783  ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
784  [&](auto I) {
785  auto tmp = ReduceInElementwiseOperations{}[I];
786  using T = remove_pointer_t<decltype(tmp)>;
787  return *(static_cast<T*>(reduce_in_element_op[I]));
788  },
790  ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
791  [&](auto I) {
792  auto tmp = ReduceAccElementwiseOperations{}[I];
793  using T = remove_pointer_t<decltype(tmp)>;
794  return *(static_cast<T*>(reduce_out_element_op[I]));
795  },
797 
798  AElementwiseOperation a_element_op =
799  *(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
800  BElementwiseOperation b_element_op =
801  *(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
802  CElementwiseOperation c_element_op =
803  *(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
804 
805  return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
806  static_cast<const BDataType*>(p_b),
807  static_cast<CDataType*>(p_c),
808  reduce_tuple,
809  M,
810  N,
811  K,
812  StrideA,
813  StrideB,
814  StrideC,
815  a_element_op,
816  b_element_op,
817  c_element_op,
818  reduce_in_element_ops,
819  reduce_out_element_ops);
820  }
821 
822  // polymorphic
823  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
824  {
825  return std::make_unique<Invoker>(Invoker{});
826  }
827 
828  // polymorphic
829  std::string GetTypeString() const override
830  {
831  auto str = std::stringstream();
832 
833  // clang-format off
834  str << "DeviceGemmReduce_Xdl_CShuffle"
835  << "<"
836  << BlockSize << ", "
837  << MPerBlock << ", "
838  << NPerBlock << ", "
839  << KPerBlock << ", "
840  << AK1 << ", "
841  << BK1 << ", "
842  << MPerXDL << ", "
843  << NPerXDL << ", "
844  << MXdlPerWave << ", "
845  << NXdlPerWave << ", "
846  << ABlockTransferSrcScalarPerVector << ", "
847  << BBlockTransferSrcScalarPerVector << ", "
848  << CShuffleMXdlPerWavePerShuffle << ", "
849  << CShuffleNXdlPerWavePerShuffle
850  << ">";
851  // clang-format on
852 
853  return str.str();
854  }
855 };
856 
857 } // namespace device
858 } // namespace tensor_operation
859 } // namespace ck
#define GET_NXDL_PER_WAVE_IMPL
Definition: device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition: device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:268
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
typename remove_pointer< T >::type remove_pointer_t
Definition: type.hpp:300
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__global__ void kernel_gemm_reduce_xdl_cshuffle_v1(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, ReducePtrsGlobal p_reduces_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const ReduceInElementwiseOperations reduce_in_element_ops, const ReduceAccElementwiseOperations reduce_out_element_ops, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, const Block2CTileMap block_2_ctile_map)
Definition: gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:40
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:299
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
Definition: stream_config.hpp:10
Definition: gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:152
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))> DefaultBlock2CTileMap
Definition: gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:348
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:251
Definition: sequence.hpp:43
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: device_base.hpp:197
Definition: device_base.hpp:208
Definition: device_gemm_reduce_xdl_cshuffle.hpp:445
CGridDesc_M_N c_grid_desc_m_n_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:485
const BDataType * p_b_grid_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:480
ReducePtrsGlobal p_reduces_grid_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:482
BElementwiseOperation b_element_op_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:489
CDataType * p_c_grid_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:481
CElementwiseOperation c_element_op_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:490
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, ReducePtrsGlobal p_reduces_grid, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, ReduceInElementwiseOperations reduce_in_element_ops, ReduceAccElementwiseOperations reduce_out_element_ops)
Definition: device_gemm_reduce_xdl_cshuffle.hpp:446
ReduceInElementwiseOperations reduce_in_element_ops_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:491
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:483
AElementwiseOperation a_element_op_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:488
ReduceAccElementwiseOperations reduce_out_element_ops_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:492
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:484
const ADataType * p_a_grid_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:479
ReduceGridDesc_M reduce_grid_desc_m_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:486
GridwiseGemm64::DefaultBlock2CTileMap block_2_ctile_map_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:487
Definition: device_gemm_reduce_xdl_cshuffle.hpp:497
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_reduce_xdl_cshuffle.hpp:501
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_reduce_xdl_cshuffle.hpp:628
Definition: device_gemm_reduce_xdl_cshuffle.hpp:78
static constexpr auto NXdlPerWave32
Definition: device_gemm_reduce_xdl_cshuffle.hpp:83
static constexpr auto I0
Definition: device_gemm_reduce_xdl_cshuffle.hpp:85
static constexpr GET_NXDL_PER_WAVE_IMPL auto NXdlPerWave64
Definition: device_gemm_reduce_xdl_cshuffle.hpp:82
static constexpr auto I2
Definition: device_gemm_reduce_xdl_cshuffle.hpp:87
static constexpr auto I1
Definition: device_gemm_reduce_xdl_cshuffle.hpp:86
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
Definition: device_gemm_reduce_xdl_cshuffle.hpp:295
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_reduce_xdl_cshuffle.hpp:635
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
Definition: device_gemm_reduce_xdl_cshuffle.hpp:192
decltype(MakeCGridDescriptor_M_N(1, 1, 1)) CGridDesc_M_N
Definition: device_gemm_reduce_xdl_cshuffle.hpp:381
decltype(MakeReduceGridDescriptor_M(1)) ReduceGridDesc_M
Definition: device_gemm_reduce_xdl_cshuffle.hpp:382
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_reduce_xdl_cshuffle.hpp:823
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_reduce_xdl_cshuffle.hpp:671
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_reduce_xdl_cshuffle.hpp:641
static auto MakeReduceGridDescriptor_M(index_t MRaw)
Definition: device_gemm_reduce_xdl_cshuffle.hpp:354
static auto MakeInvoker()
Definition: device_gemm_reduce_xdl_cshuffle.hpp:747
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
Definition: device_gemm_reduce_xdl_cshuffle.hpp:89
decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)) AGridDesc_AK0_M_AK1
Definition: device_gemm_reduce_xdl_cshuffle.hpp:379
decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)) BGridDesc_BK0_N_BK1
Definition: device_gemm_reduce_xdl_cshuffle.hpp:380
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const void *p_bias, std::array< const void *, 0 > p_ds, void *p_c, std::array< void *, NumReduce > p_reduces, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, std::array< ck::index_t, 0 > StrideDs, std::array< void *, 3 > gemm_element_ops, std::array< void *, 0 > d_element_ops, std::array< void *, NumReduce > reduce_in_element_op, std::array< void *, NumReduce > reduce_out_element_op, ck::index_t=1) override
Definition: device_gemm_reduce_xdl_cshuffle.hpp:751
static auto MakeArgument(const void *p_a, const void *p_b, const void *p_bias, std::array< const void *, 0 > p_ds, void *p_c, std::array< void *, NumReduce > p_reduces, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, std::array< ck::index_t, 0 > StrideDs, std::array< void *, 3 > gemm_element_ops, std::array< void *, 0 > d_element_ops, std::array< void *, NumReduce > reduce_in_element_op, std::array< void *, NumReduce > reduce_out_element_op)
Definition: device_gemm_reduce_xdl_cshuffle.hpp:677
static constexpr int NumReduce
Definition: device_gemm_reduce_xdl_cshuffle.hpp:676
std::string GetTypeString() const override
Definition: device_gemm_reduce_xdl_cshuffle.hpp:829
Definition: device_gemm_reduce.hpp:17
#define CK_ENV(name)
Definition: env.hpp:129