/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp Source File
grouped_convolution_backward_weight_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 #include "ck_tile/host/concat.hpp"
16 
17 namespace ck_tile {
18 
20 template <typename GroupedConvTraitsType_>
22 {
23 
25  TransformConvBwdWeightToGemm<GroupedConvTraitsType_::NDimSpatial,
26  GroupedConvTraitsType_::ConvSpecialization,
27  GroupedConvTraitsType_::VectorSizeA,
28  GroupedConvTraitsType_::VectorSizeB,
29  GroupedConvTraitsType_::VectorSizeC,
30  GroupedConvTraitsType_::NumGroupsToMerge>;
31  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
32 
33  template <
34  typename InLay = typename GroupedConvTraitsType_::InLayout,
35  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
36  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
37  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NWGC> &&
38  std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
39  std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
40  bool>::type = false>
42  {
43  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
44  static_cast<index_t>(args.N_),
45  static_cast<index_t>(args.C_),
46  static_cast<index_t>(args.input_spatial_lengths_[0])};
47  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
48  static_cast<index_t>(args.K_),
49  static_cast<index_t>(args.C_),
50  static_cast<index_t>(args.filter_spatial_lengths_[0])};
51  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
52  static_cast<index_t>(args.N_),
53  static_cast<index_t>(args.K_),
54  static_cast<index_t>(args.output_spatial_lengths_[0])};
55 
56  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0])};
57  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0])};
58  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0])};
59  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0])};
60 
61  k_batch = args.k_batch;
62 
63  in_ptr = args.in_ptr;
64  wei_ptr = args.wei_ptr;
65  for(index_t d = 0; d < NumDTensor; d++)
66  {
67  ds_ptr[d] = args.ds_ptr[d];
68  }
69  out_ptr = args.out_ptr;
70 
71  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
78 
79  // tuple
80  auto grid_descs =
81  conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
82  GroupedConvTraitsType_::NDimSpatial>();
83 
84  a_grid_desc_k_m = grid_descs.at(number<0>{});
85  b_grid_desc_k_n = grid_descs.at(number<1>{});
86  c_grid_desc_m_n = grid_descs.at(number<2>{});
87 
88  NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
89  group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NWGK
90  group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NWGC
91  group_stride_c = args.K_ * args.C_ // C: Wei GKXC
93  std::accumulate(args.filter_spatial_lengths_.begin(),
94  args.filter_spatial_lengths_.end(),
95  1,
96  std::multiplies<index_t>());
97 
98  GemmM = a_grid_desc_k_m.get_length(number<1>{});
99  GemmN = b_grid_desc_k_n.get_length(number<1>{});
100  GemmK = a_grid_desc_k_m.get_length(number<0>{});
102 
103  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
104  {
105  std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
106  << ", GemmBatch: " << GemmBatch
107  << ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
108  }
109  }
110 
111  template <
112  typename InLay = typename GroupedConvTraitsType_::InLayout,
113  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
114  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
115  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NHWGC> &&
116  std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
117  std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
118  bool>::type = false>
120  {
121  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
122  static_cast<index_t>(args.N_),
123  static_cast<index_t>(args.C_),
124  static_cast<index_t>(args.input_spatial_lengths_[0]),
125  static_cast<index_t>(args.input_spatial_lengths_[1])};
126  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
127  static_cast<index_t>(args.K_),
128  static_cast<index_t>(args.C_),
129  static_cast<index_t>(args.filter_spatial_lengths_[0]),
130  static_cast<index_t>(args.filter_spatial_lengths_[1])};
131  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
132  static_cast<index_t>(args.N_),
133  static_cast<index_t>(args.K_),
134  static_cast<index_t>(args.output_spatial_lengths_[0]),
135  static_cast<index_t>(args.output_spatial_lengths_[1])};
136 
137  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
138  static_cast<index_t>(args.conv_filter_strides_[1])};
139  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
140  static_cast<index_t>(args.conv_filter_dilations_[1])};
141  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
142  static_cast<index_t>(args.input_left_pads_[1])};
143  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
144  static_cast<index_t>(args.input_right_pads_[1])};
145 
146  k_batch = args.k_batch;
147 
148  in_ptr = args.in_ptr;
149  wei_ptr = args.wei_ptr;
150  for(index_t d = 0; d < NumDTensor; d++)
151  {
152  ds_ptr[d] = args.ds_ptr[d];
153  }
154  out_ptr = args.out_ptr;
155 
156  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
163 
164  // tuple
165  auto grid_descs =
166  conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
167  GroupedConvTraitsType_::NDimSpatial>();
168 
169  a_grid_desc_k_m = grid_descs.at(number<0>{});
170  b_grid_desc_k_n = grid_descs.at(number<1>{});
171  c_grid_desc_m_n = grid_descs.at(number<2>{});
172 
173  NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
174  group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NHWGK
175  group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NHWGC
176  group_stride_c = args.K_ * args.C_ // C: Wei GKYXC
178  std::accumulate(args.filter_spatial_lengths_.begin(),
179  args.filter_spatial_lengths_.end(),
180  1,
181  std::multiplies<index_t>());
182 
183  GemmM = a_grid_desc_k_m.get_length(number<1>{});
184  GemmN = b_grid_desc_k_n.get_length(number<1>{});
185  GemmK = a_grid_desc_k_m.get_length(number<0>{});
187 
188  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
189  {
190  std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
191  << ", GemmBatch: " << GemmBatch
192  << ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
193  }
194  }
195 
196  template <
197  typename InLay = typename GroupedConvTraitsType_::InLayout,
198  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
199  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
200  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NDHWGC> &&
201  std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
202  std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
203  bool>::type = false>
205  {
206  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
207  static_cast<index_t>(args.N_),
208  static_cast<index_t>(args.C_),
209  static_cast<index_t>(args.input_spatial_lengths_[0]),
210  static_cast<index_t>(args.input_spatial_lengths_[1]),
211  static_cast<index_t>(args.input_spatial_lengths_[2])};
212  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
213  static_cast<index_t>(args.K_),
214  static_cast<index_t>(args.C_),
215  static_cast<index_t>(args.filter_spatial_lengths_[0]),
216  static_cast<index_t>(args.filter_spatial_lengths_[1]),
217  static_cast<index_t>(args.filter_spatial_lengths_[2])};
218  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
219  static_cast<index_t>(args.N_),
220  static_cast<index_t>(args.K_),
221  static_cast<index_t>(args.output_spatial_lengths_[0]),
222  static_cast<index_t>(args.output_spatial_lengths_[1]),
223  static_cast<index_t>(args.output_spatial_lengths_[2])};
224 
225  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
226  static_cast<index_t>(args.conv_filter_strides_[1]),
227  static_cast<index_t>(args.conv_filter_strides_[2])};
228  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
229  static_cast<index_t>(args.conv_filter_dilations_[1]),
230  static_cast<index_t>(args.conv_filter_dilations_[2])};
231  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
232  static_cast<index_t>(args.input_left_pads_[1]),
233  static_cast<index_t>(args.input_left_pads_[2])};
234  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
235  static_cast<index_t>(args.input_right_pads_[1]),
236  static_cast<index_t>(args.input_right_pads_[2])};
237 
238  k_batch = args.k_batch;
239 
240  in_ptr = args.in_ptr;
241  wei_ptr = args.wei_ptr;
242  for(index_t d = 0; d < NumDTensor; d++)
243  {
244  ds_ptr[d] = args.ds_ptr[d];
245  }
246  out_ptr = args.out_ptr;
247 
248  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
255 
256  // tuple
257  auto grid_descs =
258  conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
259  GroupedConvTraitsType_::NDimSpatial>();
260 
261  a_grid_desc_k_m = grid_descs.at(number<0>{});
262  b_grid_desc_k_n = grid_descs.at(number<1>{});
263  c_grid_desc_m_n = grid_descs.at(number<2>{});
264 
265  NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
266  group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NDHWGK
267  group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NDHWGC
268  group_stride_c = args.K_ * args.C_ // C: Wei GKZYXC
270  std::accumulate(args.filter_spatial_lengths_.begin(),
271  args.filter_spatial_lengths_.end(),
272  1,
273  std::multiplies<index_t>());
274 
275  GemmM = a_grid_desc_k_m.get_length(number<1>{});
276  GemmN = b_grid_desc_k_n.get_length(number<1>{});
277  GemmK = a_grid_desc_k_m.get_length(number<0>{});
279 
280  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
281  {
282  std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
283  << ", GemmBatch: " << GemmBatch
284  << ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
285  }
286  }
287 
290 
294 
295  static constexpr index_t NonSpatialDims = 3;
299 
304 
311 
312  const void* out_ptr;
313  const void* in_ptr;
314  std::array<const void*, NumDTensor> ds_ptr;
315  void* wei_ptr;
316 
320 
324 };
325 
363 template <typename GroupedConvTraitsType_,
364  typename TilePartitioner_,
365  typename GemmPipeline_,
366  typename EpiloguePipeline_>
368 {
369  static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial_;
371  GroupedConvTraitsType_::ConvSpecialization;
378 
383 
385  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
386 
387  static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
388 
393 
396 
397  // TODO: Enable this
398  static constexpr bool IsSplitKSupported = true;
399 
400  static constexpr auto I0 = number<0>();
401  static constexpr auto I1 = number<1>();
402  static constexpr auto I2 = number<2>();
403  static constexpr auto I3 = number<3>();
404 
405  static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
406  "Not supported!");
407  static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
408  static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
409  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
410 
411  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
412  {
413  constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
414  // clang-format off
415  if (NumGroupsToMerge > 1) {
416  return concat('_', "grouped_convolution_backward_weight",
417  gemm_prec_str<InDataType, WeiDataType>(),
418  "gemm",
419  GemmPipeline::GetName(),
420  "epilogue",
421  EpiloguePipeline::GetName());
422  } else {
423  return concat('_', "grouped_convolution_backward_weight",
424  gemm_prec_str<InDataType, WeiDataType>(),
425  "gemm",
426  GemmPipeline::GetName(),
427  "epilogue",
428  EpiloguePipeline::GetName(), "merge", NumGroupsToMerge);
429  }
430  // clang-format on
431  }
432 
433  CK_TILE_HOST static constexpr auto
435  {
436  return dim3(
437  TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch);
438  }
439 
440  CK_TILE_HOST static constexpr auto BlockSize()
441  {
442  return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
443  }
444 
447  {
448  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
449  {
450  std::cout << "MPerBlock: " << number<TilePartitioner::MPerBlock>{} << std::endl;
451  std::cout << "NPerBlock: " << number<TilePartitioner::NPerBlock>{} << std::endl;
452  std::cout << "KPerBlock: " << number<TilePartitioner::KPerBlock>{} << std::endl;
453  }
455  }
456 
458  {
459  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
460  }
461 
463  {
465  const std::size_t k_id = blockIdx.z)
466  {
467  constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
468  const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
469  const index_t KRead = amd_wave_read_first_lane((kargs.GemmK + K_t - 1) / K_t * K1);
470 
473 
474  if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
475  {
477  }
478  else
479  {
480  splitted_k = amd_wave_read_first_lane(kargs.GemmK - KRead * (kargs.k_batch - 1));
481  }
482  }
483 
487  };
488 
490  const stream_config& s)
491  {
492  return [&]() {
493  if(kargs.k_batch > 1)
494  {
495  // Total number of convolution groups (ConvG) = GemmBatch * NumGroupsPerBatch
496  // since we require that ConvG % NumGroupsPerBatch == 0.
497  const auto wei_size =
498  kargs.GemmBatch * kargs.GemmM * kargs.GemmN * kargs.NumGroupsPerBatch;
499  hipGetErrorString(
500  hipMemsetAsync(kargs.wei_ptr, 0, wei_size * sizeof(WeiDataType), s.stream_id_));
501  }
502  };
503  }
504 
505  CK_TILE_HOST static bool
507  {
508  if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
511  {
512  if(kargs.k_batch != 1)
513  {
514  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
515  {
516  CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
517  }
518  return false;
519  }
520  }
521 
522  const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
523  const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
524 
525  // check ConvSpecialization
527  {
528  // check if it's 1x1, stride=1 conv
529  for(index_t i = 0; i < NDimSpatial; ++i)
530  {
531  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
532  const index_t ConvStride = kargs.conv_filter_strides[i];
533  const index_t LeftPad = kargs.input_left_pads[i];
534  const index_t RightPad = kargs.input_right_pads[i];
535 
536  if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
537  {
538  return false;
539  }
540  }
541  }
543  {
544  // check if it's 1x1 conv
545  for(index_t i = 0; i < NDimSpatial; ++i)
546  {
547  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
548  const index_t LeftPad = kargs.input_left_pads[i];
549  const index_t RightPad = kargs.input_right_pads[i];
550 
551  if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
552  {
553  return false;
554  }
555  }
556  }
558  {
559  if(ConvC != 1)
560  {
561  return false;
562  }
563  for(index_t i = 0; i < NDimSpatial; ++i)
564  {
565  const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3];
566 
567  if(filter_spatial_dim != I3)
568  {
569  return false;
570  }
571  }
572  }
573 
574  namespace ctc = tensor_layout::convolution;
575 
576  if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
577  std::is_same_v<InLayout, ctc::NDHWGC>)
578  {
579  // Check access per C
580  if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
581  {
582  CK_TILE_ERROR("Conv C is not a multiple of vector load size for "
583  "input image!");
584  return false;
585  }
586  }
587  else
588  {
589  CK_TILE_ERROR("Not supported input layout!");
590  return false;
591  }
592 
593  if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
594  std::is_same_v<WeiLayout, ctc::GKYXC> ||
595  std::is_same_v<WeiLayout, ctc::GKZYXC>)
596  {
597  if(ConvC % GroupedConvTraitsType_::VectorSizeC != 0)
598  {
599  CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
600  return false;
601  }
602  }
603  else
604  {
605  CK_TILE_ERROR("Not supported weight layout!");
606  return false;
607  }
608 
609  if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
610  std::is_same_v<OutLayout, ctc::NHWGK> ||
611  std::is_same_v<OutLayout, ctc::NDHWGK>)
612  {
613  if(ConvK % GroupedConvTraitsType_::VectorSizeA != 0)
614  {
615  CK_TILE_ERROR("Conv K is not a multiple of vector store size "
616  "for output image!");
617  return false;
618  }
619  }
620  else
621  {
622  CK_TILE_ERROR("Not supported output layout!");
623  return false;
624  }
625 
626  if constexpr(GroupedConvTraitsType_::NumGroupsToMerge > 1)
627  {
628  const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}];
629  if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0)
630  {
631  CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge!");
632  return false;
633  }
634 
635  // TODO: Should we also check that GemmM <= MPerBlock and GemmN <= NPerBlock?
636  }
637 
638  return true;
639  }
640 
641  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
642  CK_TILE_DEVICE static auto
644  const InDataType* b_ptr,
645  const std::array<const void*, NumDTensor>& ds_ptr,
646  WeiDataType* c_ptr,
648  {
649  static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
650  static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
651  const auto& a_tensor_view = [&]() {
652  return make_tensor_view<address_space_enum::global>(a_ptr,
653  kargs.a_grid_desc_k_m); // A: out
654  }();
655 
656  const auto& b_tensor_view = [&]() {
657  return make_tensor_view<address_space_enum::global>(b_ptr,
658  kargs.b_grid_desc_k_n); // B: in
659  }();
660 
661  const auto& c_tensor_view = [&]() {
662  return make_tensor_view<address_space_enum::global, DstInMemOp>(c_ptr,
663  kargs.c_grid_desc_m_n);
664  }();
665 
666  const auto& ds_tensor_view = generate_tuple(
667  [&](auto i) {
668  static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
669  "Not supported!");
670  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
671  "Not supported!");
672  static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, WeiDataType>,
673  "Not supported!");
674 
675  return make_tensor_view<address_space_enum::global>(
676  static_cast<WeiDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
677  },
679 
680  return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
681  }
682 
683  template <typename TensorView>
684  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views, const index_t k_batch)
685  {
686  const auto& a_pad_view = [&]() {
687  const auto& a_tensor_view = views.at(I0);
688  return pad_tensor_view(a_tensor_view,
692  }();
693 
694  const auto& b_pad_view = [&]() {
695  const auto& b_tensor_view = views.at(I1);
696  return pad_tensor_view(b_tensor_view,
700  }();
701 
702  const auto& ds_tensor_view = views.at(I2);
703  const auto& ds_pad_view = generate_tuple(
704  [&](auto i) {
705  return pad_tensor_view(ds_tensor_view[i],
709  },
711 
712  const auto& c_pad_view = [&]() {
713  const auto& c_tensor_view = views.at(I3);
714  return pad_tensor_view(c_tensor_view,
718  }();
719 
720  return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
721  }
722 
733  template <typename PadView>
734  CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
735  const index_t i_m,
736  const index_t i_n,
737  const index_t i_k)
738  {
739  const auto& a_pad_view = views.at(I0);
740  const auto& b_pad_view = views.at(I1);
741  const auto& ds_pad_view = views.at(I2);
742  const auto& c_pad_view = views.at(I3);
743 
744  const auto& a_block_window = [&]() {
745  return make_tile_window(a_pad_view,
748  {i_k, i_m});
749  }();
750 
751  const auto& b_block_window = [&]() {
752  return make_tile_window(b_pad_view,
755  {i_k, i_n});
756  }();
757 
758  const auto ds_block_window = generate_tuple(
759  [&](auto i) {
760  return make_tile_window(ds_pad_view[i],
763  {i_m, i_n});
764  },
766 
767  auto c_block_window = make_tile_window(
768  c_pad_view,
770  {i_m, i_n});
771 
772  return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
773  }
774 
787  CK_TILE_DEVICE static void RunGemm(const OutDataType* a_ptr,
788  const InDataType* b_ptr,
789  const std::array<const void*, NumDTensor>& ds_ptr,
790  WeiDataType* c_ptr,
791  void* smem_ptr_0,
793  const index_t num_loop,
794  const index_t block_idx_m,
795  const index_t block_idx_n,
796  const index_t block_idx_k)
797  {
798  // Create Gemm tensor views, pad views and tile windows
799  const auto& gemm_tensor_views_tuple =
800  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
801  a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
802 
803  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple, kargs.k_batch);
804  auto gemm_tile_windows =
805  MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
806 
807  // Run GEMM cooperatively by whole workgroup.
808  const auto& a_block_window = gemm_tile_windows.at(I0);
809  const auto& b_block_window = gemm_tile_windows.at(I1);
810  const auto& d_block_window = gemm_tile_windows.at(I2);
811 
812  const auto& c_block_tile = GemmPipeline{}.template operator()(
813  a_block_window, b_block_window, num_loop, smem_ptr_0);
814 
815  // Run Epilogue Pipeline
816  auto& c_block_window = gemm_tile_windows.at(I3);
817 
818  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
819  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
820  }
821 
837  CK_TILE_DEVICE static void RunGemm2LDS(const OutDataType* a_ptr,
838  const InDataType* b_ptr,
839  const std::array<const void*, NumDTensor>& ds_ptr,
840  WeiDataType* c_ptr,
841  void* __restrict__ smem_ptr_0,
842  void* __restrict__ smem_ptr_1,
844  const index_t num_loop,
845  const index_t block_idx_m,
846  const index_t block_idx_n,
847  const index_t block_idx_k)
848  {
849  // Create Gemm tensor views, pad views and tile windows
850  const auto& gemm_tensor_views_tuple =
851  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
852  a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
853  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple, kargs.k_batch);
854  auto gemm_tile_windows =
855  MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
856 
857  // Run GEMM cooperatively by whole workgroup.
858  const auto& a_block_window = gemm_tile_windows.at(I0);
859  const auto& b_block_window = gemm_tile_windows.at(I1);
860  const auto& d_block_window = gemm_tile_windows.at(I2);
861 
862  const auto& c_block_tile = GemmPipeline{}.template operator()(
863  a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
864 
865  // Run Epilogue Pipeline
866  auto& c_block_window = gemm_tile_windows.at(I3);
867 
868  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
869  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
870  }
871 
873  {
874  const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
875  const auto [iM, iN] =
876  TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX);
877  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
878  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
879 
880  const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z);
881  const index_t num_loop = amd_wave_read_first_lane(
882  ck_tile::integer_divide_ceil(kargs.GemmK, kargs.k_batch * TilePartitioner::KPerBlock));
883  const index_t i_k =
884  amd_wave_read_first_lane(blockIdZ * num_loop * TilePartitioner::KPerBlock);
885 
886  const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
887  const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
888  const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
889  const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
890 
891  // options
892  // conv_bwd_weight = Out * In = Weight
893  const OutDataType* a_ptr = static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a;
894  const InDataType* b_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_b;
895  WeiDataType* c_ptr = static_cast<WeiDataType*>(kargs.wei_ptr) + group_offset_c;
896 
897  __shared__ char smem_ptr_0[GetSmemSize()];
898 
899  if constexpr(GemmPipeline::DoubleSmemBuffer == true)
900  {
901  __shared__ char smem_ptr_1[GetSmemSize()];
902  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
903  GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
905  {
906  RunGemm2LDS(a_ptr,
907  b_ptr,
908  kargs.ds_ptr,
909  c_ptr,
910  smem_ptr_0,
911  smem_ptr_1,
912  kargs,
913  num_loop,
914  i_m,
915  i_n,
916  i_k);
917  }
918  }
919  else
920  {
921  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
922  GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
924  {
925  RunGemm(
926  a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, num_loop, i_m, i_n, i_k);
927  }
928  }
929  }
930 };
931 
932 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
ConvolutionSpecialization
Definition: convolution_specialization.hpp:11
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
constexpr bool is_same_v
Definition: type.hpp:283
__device__ X atomic_add(X *p_dst, const X &x)
unsigned int uint32_t
Definition: stdint.h:126
The Grouped Convolution kernel device arguments.
Definition: grouped_convolution_backward_weight_kernel.hpp:22
long_index_t group_stride_a
Definition: grouped_convolution_backward_weight_kernel.hpp:321
remove_cvref_t< decltype(ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N())> ABCGridDescs
Definition: grouped_convolution_backward_weight_kernel.hpp:289
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_strides
Definition: grouped_convolution_backward_weight_kernel.hpp:300
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > wei_g_k_c_xs_lengths
Definition: grouped_convolution_backward_weight_kernel.hpp:297
void * wei_ptr
Definition: grouped_convolution_backward_weight_kernel.hpp:315
long_index_t group_stride_b
Definition: grouped_convolution_backward_weight_kernel.hpp:322
CGridDescMN c_grid_desc_m_n
Definition: grouped_convolution_backward_weight_kernel.hpp:319
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > in_g_n_c_wis_lengths
Definition: grouped_convolution_backward_weight_kernel.hpp:296
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_dilations
Definition: grouped_convolution_backward_weight_kernel.hpp:301
AGridDescKM a_grid_desc_k_m
Definition: grouped_convolution_backward_weight_kernel.hpp:317
BGridDescKN b_grid_desc_k_n
Definition: grouped_convolution_backward_weight_kernel.hpp:318
index_t GemmN
Definition: grouped_convolution_backward_weight_kernel.hpp:307
index_t GemmBatch
Definition: grouped_convolution_backward_weight_kernel.hpp:309
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > out_g_n_k_wos_lengths
Definition: grouped_convolution_backward_weight_kernel.hpp:298
CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs &args)
Definition: grouped_convolution_backward_weight_kernel.hpp:41
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_left_pads
Definition: grouped_convolution_backward_weight_kernel.hpp:302
remove_cvref_t< decltype(ABCGridDescs{}[number< 1 >{}])> BGridDescKN
Definition: grouped_convolution_backward_weight_kernel.hpp:292
std::array< const void *, NumDTensor > ds_ptr
Definition: grouped_convolution_backward_weight_kernel.hpp:314
index_t GemmM
Definition: grouped_convolution_backward_weight_kernel.hpp:306
index_t NumGroupsPerBatch
Definition: grouped_convolution_backward_weight_kernel.hpp:310
remove_cvref_t< decltype(ABCGridDescs{}[number< 2 >{}])> CGridDescMN
Definition: grouped_convolution_backward_weight_kernel.hpp:293
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_right_pads
Definition: grouped_convolution_backward_weight_kernel.hpp:303
index_t GemmK
Definition: grouped_convolution_backward_weight_kernel.hpp:308
const void * in_ptr
Definition: grouped_convolution_backward_weight_kernel.hpp:313
index_t k_batch
Definition: grouped_convolution_backward_weight_kernel.hpp:305
static constexpr index_t NonSpatialDims
Definition: grouped_convolution_backward_weight_kernel.hpp:295
const void * out_ptr
Definition: grouped_convolution_backward_weight_kernel.hpp:312
remove_cvref_t< decltype(ABCGridDescs{}[number< 0 >{}])> AGridDescKM
Definition: grouped_convolution_backward_weight_kernel.hpp:291
static constexpr index_t NumDTensor
Definition: grouped_convolution_backward_weight_kernel.hpp:31
long_index_t group_stride_c
Definition: grouped_convolution_backward_weight_kernel.hpp:323
The Grouped Conv kernel host arguments.
Definition: grouped_convolution_utils.hpp:20
InPtr in_ptr
Definition: grouped_convolution_utils.hpp:39
OutPtr out_ptr
Definition: grouped_convolution_utils.hpp:42
WeiPtr wei_ptr
Definition: grouped_convolution_utils.hpp:40
index_t k_batch
Definition: grouped_convolution_utils.hpp:43
const std::vector< const void * > ds_ptr
Definition: grouped_convolution_utils.hpp:41
Definition: grouped_convolution_backward_weight_kernel.hpp:463
index_t b_k_split_offset
Definition: grouped_convolution_backward_weight_kernel.hpp:485
index_t splitted_k
Definition: grouped_convolution_backward_weight_kernel.hpp:486
__device__ SplitKBatchOffset(const GroupedConvBwdWeightKernelArgsSpecialized &kargs, const std::size_t k_id=blockIdx.z)
Definition: grouped_convolution_backward_weight_kernel.hpp:464
index_t a_k_split_offset
Definition: grouped_convolution_backward_weight_kernel.hpp:484
The Grouped Convolution Backward Weight kernel template.
Definition: grouped_convolution_backward_weight_kernel.hpp:368
remove_cvref_t< typename EpiloguePipeline::DsLayout > GemmDsLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:384
static constexpr index_t kBlockSize
Definition: grouped_convolution_backward_weight_kernel.hpp:387
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views, const index_t k_batch)
Definition: grouped_convolution_backward_weight_kernel.hpp:684
static constexpr CK_TILE_HOST auto BlockSize()
Definition: grouped_convolution_backward_weight_kernel.hpp:440
remove_cvref_t< typename GroupedConvTraitsType_::OutLayout > OutLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:381
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_convolution_backward_weight_kernel.hpp:372
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: grouped_convolution_backward_weight_kernel.hpp:373
GroupedConvBwdWeightKernelArgs< GroupedConvTraitsType_ > GroupedConvBwdWeightKernelArgsSpecialized
Definition: grouped_convolution_backward_weight_kernel.hpp:395
static CK_TILE_HOST const std::string GetName()
Definition: grouped_convolution_backward_weight_kernel.hpp:411
static CK_TILE_DEVICE void RunGemm2LDS(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const GroupedConvBwdWeightKernelArgsSpecialized &kargs, const index_t num_loop, const index_t block_idx_m, const index_t block_idx_n, const index_t block_idx_k)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_backward_weight_kernel.hpp:837
remove_cvref_t< typename GemmPipeline::CLayout > GemmCLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:377
static constexpr auto I2
Definition: grouped_convolution_backward_weight_kernel.hpp:402
static constexpr CK_TILE_HOST GroupedConvBwdWeightKernelArgsSpecialized MakeKernelArgs(const GroupedConvBwdWeightHostArgs &hostArgs)
Definition: grouped_convolution_backward_weight_kernel.hpp:446
remove_cvref_t< typename GemmPipeline::ALayout > GemmALayout
Definition: grouped_convolution_backward_weight_kernel.hpp:375
static CK_TILE_HOST bool IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_weight_kernel.hpp:506
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: grouped_convolution_backward_weight_kernel.hpp:374
static constexpr ConvolutionSpecialization ConvSpecialization
Definition: grouped_convolution_backward_weight_kernel.hpp:370
remove_cvref_t< typename GroupedConvTraitsType_::WeiLayout > WeiLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:380
static constexpr bool IsSplitKSupported
Definition: grouped_convolution_backward_weight_kernel.hpp:398
static CK_TILE_DEVICE void RunGemm(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, void *smem_ptr_0, const GroupedConvBwdWeightKernelArgsSpecialized &kargs, const index_t num_loop, const index_t block_idx_m, const index_t block_idx_n, const index_t block_idx_k)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_backward_weight_kernel.hpp:787
static constexpr index_t NDimSpatial
Definition: grouped_convolution_backward_weight_kernel.hpp:369
remove_cvref_t< typename GroupedConvTraitsType_::DsLayout > DsLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:382
remove_cvref_t< typename GroupedConvTraitsType_::InLayout > InLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:379
static CK_TILE_HOST auto Preprocess(const GroupedConvBwdWeightKernelArgsSpecialized &kargs, const stream_config &s)
Definition: grouped_convolution_backward_weight_kernel.hpp:489
remove_cvref_t< typename EpiloguePipeline::ODataType > WeiDataType
Definition: grouped_convolution_backward_weight_kernel.hpp:392
static constexpr auto I3
Definition: grouped_convolution_backward_weight_kernel.hpp:403
static constexpr auto I0
Definition: grouped_convolution_backward_weight_kernel.hpp:400
CK_TILE_DEVICE void operator()(GroupedConvBwdWeightKernelArgsSpecialized kargs) const
Definition: grouped_convolution_backward_weight_kernel.hpp:872
static constexpr auto I1
Definition: grouped_convolution_backward_weight_kernel.hpp:401
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: grouped_convolution_backward_weight_kernel.hpp:391
static constexpr index_t NumDTensor
Definition: grouped_convolution_backward_weight_kernel.hpp:385
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: grouped_convolution_backward_weight_kernel.hpp:457
remove_cvref_t< typename GemmPipeline::ADataType > OutDataType
Definition: grouped_convolution_backward_weight_kernel.hpp:389
static CK_TILE_DEVICE auto MakeGemmTensorViews(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, const GroupedConvBwdWeightKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_weight_kernel.hpp:643
remove_cvref_t< typename GemmPipeline::BDataType > InDataType
Definition: grouped_convolution_backward_weight_kernel.hpp:390
remove_cvref_t< typename GemmPipeline::BLayout > GemmBLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:376
static constexpr CK_TILE_HOST auto GridSize(const GroupedConvBwdWeightKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_weight_kernel.hpp:434
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n, const index_t i_k)
Create views to the data that each workgroup will process.
Definition: grouped_convolution_backward_weight_kernel.hpp:734
Definition: transform_conv_bwd_weight_to_gemm.hpp:22
CK_TILE_HOST auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N() const
Definition: transform_conv_bwd_weight_to_gemm.hpp:752
Definition: integral_constant.hpp:13
std::vector< ck_tile::long_index_t > input_spatial_lengths_
Definition: convolution_parameter.hpp:130
ck_tile::long_index_t K_
Definition: convolution_parameter.hpp:126
std::vector< ck_tile::long_index_t > output_spatial_lengths_
Definition: convolution_parameter.hpp:131
std::vector< ck_tile::long_index_t > input_right_pads_
Definition: convolution_parameter.hpp:137
ck_tile::long_index_t G_
Definition: convolution_parameter.hpp:124
std::vector< ck_tile::long_index_t > conv_filter_strides_
Definition: convolution_parameter.hpp:133
std::vector< ck_tile::long_index_t > filter_spatial_lengths_
Definition: convolution_parameter.hpp:129
ck_tile::long_index_t C_
Definition: convolution_parameter.hpp:127
ck_tile::long_index_t N_
Definition: convolution_parameter.hpp:125
std::vector< ck_tile::long_index_t > input_left_pads_
Definition: convolution_parameter.hpp:136
std::vector< ck_tile::long_index_t > conv_filter_dilations_
Definition: convolution_parameter.hpp:134
Definition: type_traits.hpp:115
Definition: sequence.hpp:49
Definition: stream_config.hpp:30
hipStream_t stream_id_
Definition: stream_config.hpp:31
#define CK_TILE_ENV(name)
Definition: env.hpp:145