/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp Source File
grouped_convolution_backward_weight_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 #include "ck_tile/host/concat.hpp"
16 
18 
19 #ifdef CK_EXPERIMENTAL_BUILDER
20 #include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp"
21 #endif
22 
23 namespace ck_tile {
24 
26 template <typename GroupedConvTraitsType_>
28 {
29 
31  TransformConvBwdWeightToGemm<GroupedConvTraitsType_::NDimSpatial,
32  GroupedConvTraitsType_::ConvSpecialization,
33  GroupedConvTraitsType_::VectorSizeA,
34  GroupedConvTraitsType_::VectorSizeB,
35  GroupedConvTraitsType_::VectorSizeC,
36  GroupedConvTraitsType_::NumGroupsToMerge>;
37  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
38 
39  template <
40  typename InLay = typename GroupedConvTraitsType_::InLayout,
41  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
42  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
43  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NWGC> &&
44  std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
45  std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
46  bool>::type = false>
48  {
49  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
50  static_cast<index_t>(args.N_),
51  static_cast<index_t>(args.C_),
52  static_cast<index_t>(args.input_spatial_lengths_[0])};
53  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
54  static_cast<index_t>(args.K_),
55  static_cast<index_t>(args.C_),
56  static_cast<index_t>(args.filter_spatial_lengths_[0])};
57  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
58  static_cast<index_t>(args.N_),
59  static_cast<index_t>(args.K_),
60  static_cast<index_t>(args.output_spatial_lengths_[0])};
61 
62  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0])};
63  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0])};
64  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0])};
65  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0])};
66 
67  in_ptr = args.in_ptr;
68  wei_ptr = args.wei_ptr;
69  for(index_t d = 0; d < NumDTensor; d++)
70  {
71  ds_ptr[d] = args.ds_ptr[d];
72  }
73  out_ptr = args.out_ptr;
74 
75  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
82 
83  // tuple
84  auto grid_descs =
85  conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
86  GroupedConvTraitsType_::NDimSpatial>();
87 
88  a_grid_desc_k_m = grid_descs.at(number<0>{});
89  b_grid_desc_k_n = grid_descs.at(number<1>{});
90  c_grid_desc_m_n = grid_descs.at(number<2>{});
91 
92  NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
93  group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NWGK
94  group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NWGC
95  group_stride_c = args.K_ * args.C_ // C: Wei GKXC
97  std::accumulate(args.filter_spatial_lengths_.begin(),
98  args.filter_spatial_lengths_.end(),
99  1,
100  std::multiplies<index_t>());
101 
102  GemmM = a_grid_desc_k_m.get_length(number<1>{});
103  GemmN = b_grid_desc_k_n.get_length(number<1>{});
104  GemmK = a_grid_desc_k_m.get_length(number<0>{});
106 
107  k_batch = args.k_batch;
108 
109  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
110  {
111  std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
112  << ", GemmBatch: " << GemmBatch
113  << ", NumGroupsPerBatch: " << NumGroupsPerBatch << ", k_batch: " << k_batch
114  << std::endl;
115  }
116  }
117 
118  template <
119  typename InLay = typename GroupedConvTraitsType_::InLayout,
120  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
121  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
122  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NHWGC> &&
123  std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
124  std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
125  bool>::type = false>
127  {
128  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
129  static_cast<index_t>(args.N_),
130  static_cast<index_t>(args.C_),
131  static_cast<index_t>(args.input_spatial_lengths_[0]),
132  static_cast<index_t>(args.input_spatial_lengths_[1])};
133  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
134  static_cast<index_t>(args.K_),
135  static_cast<index_t>(args.C_),
136  static_cast<index_t>(args.filter_spatial_lengths_[0]),
137  static_cast<index_t>(args.filter_spatial_lengths_[1])};
138  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
139  static_cast<index_t>(args.N_),
140  static_cast<index_t>(args.K_),
141  static_cast<index_t>(args.output_spatial_lengths_[0]),
142  static_cast<index_t>(args.output_spatial_lengths_[1])};
143 
144  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
145  static_cast<index_t>(args.conv_filter_strides_[1])};
146  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
147  static_cast<index_t>(args.conv_filter_dilations_[1])};
148  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
149  static_cast<index_t>(args.input_left_pads_[1])};
150  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
151  static_cast<index_t>(args.input_right_pads_[1])};
152 
153  in_ptr = args.in_ptr;
154  wei_ptr = args.wei_ptr;
155  for(index_t d = 0; d < NumDTensor; d++)
156  {
157  ds_ptr[d] = args.ds_ptr[d];
158  }
159  out_ptr = args.out_ptr;
160 
161  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
168 
169  // tuple
170  auto grid_descs =
171  conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
172  GroupedConvTraitsType_::NDimSpatial>();
173 
174  a_grid_desc_k_m = grid_descs.at(number<0>{});
175  b_grid_desc_k_n = grid_descs.at(number<1>{});
176  c_grid_desc_m_n = grid_descs.at(number<2>{});
177 
178  NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
179  group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NHWGK
180  group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NHWGC
181  group_stride_c = args.K_ * args.C_ // C: Wei GKYXC
183  std::accumulate(args.filter_spatial_lengths_.begin(),
184  args.filter_spatial_lengths_.end(),
185  1,
186  std::multiplies<index_t>());
187 
188  GemmM = a_grid_desc_k_m.get_length(number<1>{});
189  GemmN = b_grid_desc_k_n.get_length(number<1>{});
190  GemmK = a_grid_desc_k_m.get_length(number<0>{});
192 
193  k_batch = args.k_batch;
194 
195  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
196  {
197  std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
198  << ", GemmBatch: " << GemmBatch
199  << ", NumGroupsPerBatch: " << NumGroupsPerBatch << ", k_batch: " << k_batch
200  << std::endl;
201  }
202  }
203 
204  template <
205  typename InLay = typename GroupedConvTraitsType_::InLayout,
206  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
207  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
208  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NDHWGC> &&
209  std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
210  std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
211  bool>::type = false>
213  {
214  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
215  static_cast<index_t>(args.N_),
216  static_cast<index_t>(args.C_),
217  static_cast<index_t>(args.input_spatial_lengths_[0]),
218  static_cast<index_t>(args.input_spatial_lengths_[1]),
219  static_cast<index_t>(args.input_spatial_lengths_[2])};
220  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
221  static_cast<index_t>(args.K_),
222  static_cast<index_t>(args.C_),
223  static_cast<index_t>(args.filter_spatial_lengths_[0]),
224  static_cast<index_t>(args.filter_spatial_lengths_[1]),
225  static_cast<index_t>(args.filter_spatial_lengths_[2])};
226  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
227  static_cast<index_t>(args.N_),
228  static_cast<index_t>(args.K_),
229  static_cast<index_t>(args.output_spatial_lengths_[0]),
230  static_cast<index_t>(args.output_spatial_lengths_[1]),
231  static_cast<index_t>(args.output_spatial_lengths_[2])};
232 
233  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
234  static_cast<index_t>(args.conv_filter_strides_[1]),
235  static_cast<index_t>(args.conv_filter_strides_[2])};
236  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
237  static_cast<index_t>(args.conv_filter_dilations_[1]),
238  static_cast<index_t>(args.conv_filter_dilations_[2])};
239  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
240  static_cast<index_t>(args.input_left_pads_[1]),
241  static_cast<index_t>(args.input_left_pads_[2])};
242  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
243  static_cast<index_t>(args.input_right_pads_[1]),
244  static_cast<index_t>(args.input_right_pads_[2])};
245 
246  in_ptr = args.in_ptr;
247  wei_ptr = args.wei_ptr;
248  for(index_t d = 0; d < NumDTensor; d++)
249  {
250  ds_ptr[d] = args.ds_ptr[d];
251  }
252  out_ptr = args.out_ptr;
253 
254  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
261 
262  // tuple
263  auto grid_descs =
264  conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
265  GroupedConvTraitsType_::NDimSpatial>();
266 
267  a_grid_desc_k_m = grid_descs.at(number<0>{});
268  b_grid_desc_k_n = grid_descs.at(number<1>{});
269  c_grid_desc_m_n = grid_descs.at(number<2>{});
270 
271  NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
272  group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NDHWGK
273  group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NDHWGC
274  group_stride_c = args.K_ * args.C_ // C: Wei GKZYXC
276  std::accumulate(args.filter_spatial_lengths_.begin(),
277  args.filter_spatial_lengths_.end(),
278  1,
279  std::multiplies<index_t>());
280 
281  GemmM = a_grid_desc_k_m.get_length(number<1>{});
282  GemmN = b_grid_desc_k_n.get_length(number<1>{});
283  GemmK = a_grid_desc_k_m.get_length(number<0>{});
285 
286  k_batch = args.k_batch;
287 
288  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
289  {
290  std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
291  << ", GemmBatch: " << GemmBatch
292  << ", NumGroupsPerBatch: " << NumGroupsPerBatch << ", k_batch: " << k_batch
293  << std::endl;
294  }
295  }
296 
299 
303 
304  static constexpr index_t NonSpatialDims = 3;
308 
313 
320 
321  const void* out_ptr;
322  const void* in_ptr;
323  std::array<const void*, NumDTensor> ds_ptr;
324  void* wei_ptr;
325 
329 
333 };
334 
372 template <typename GroupedConvTraitsType_,
373  typename TilePartitioner_,
374  typename GemmPipeline_,
375  typename EpiloguePipeline_>
377 {
378  static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial;
380  GroupedConvTraitsType_::ConvSpecialization;
387 
392 
394  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
395 
396  static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
397 
402 
405 
406  static constexpr bool IsSplitKSupported = true;
407 
408  static constexpr auto I0 = number<0>();
409  static constexpr auto I1 = number<1>();
410  static constexpr auto I2 = number<2>();
411  static constexpr auto I3 = number<3>();
412 
413  static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
414  "Not supported!");
415  static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
416  static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
417  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
418  static_assert(GroupedConvTraitsType_::ExplicitGemm == false ||
419  GroupedConvTraitsType_::NumGroupsToMerge == 1,
420  "Not supported!");
421 
422  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
423  {
424  static constexpr bool EnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
425  constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
426  // clang-format off
427  return concat('_', "grouped_convolution_backward_weight",
428  gemm_prec_str<InDataType, WeiDataType>(),
429  InLayout::name,
430  WeiLayout::name,
431  OutLayout::name,
432  "gemm",
433  GemmPipeline::GetName(),
434  "epilogue",
435  EpiloguePipeline::GetName(),
437  "MergedGroups",
438  NumGroupsToMerge,
439  "SplitImage",
440  EnableSplitImage,
441  "ExplicitGemm",
442  GroupedConvTraitsType_::ExplicitGemm
443  );
444  // clang-format on
445  }
446 
447  [[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); }
448 
449 #ifdef CK_EXPERIMENTAL_BUILDER
450  CK_TILE_HOST std::string GetInstanceString() const
451  {
452  static_assert(ck_tile::reflect::HasInstanceTraits<GroupedConvolutionBackwardWeightKernel>,
453  "Specialization of instance_traits not found. Please check that a "
454  "specialization exists in file "
455  "ck_tile/builder/reflect/"
456  "instance_traits_tile_grouped_convolution_backward_weight.hpp "
457  "for the given template parameters.");
458  return ck_tile::reflect::instance_string<GroupedConvolutionBackwardWeightKernel>();
459  }
460 #endif
461 
462  CK_TILE_HOST static constexpr auto
464  {
465  return dim3(
466  TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch);
467  }
468 
469  CK_TILE_HOST static constexpr auto BlockSize()
470  {
471  return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
472  }
473 
476  {
477  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
478  {
479  std::cout << "MPerBlock: " << number<TilePartitioner::MPerBlock>{} << std::endl;
480  std::cout << "NPerBlock: " << number<TilePartitioner::NPerBlock>{} << std::endl;
481  std::cout << "KPerBlock: " << number<TilePartitioner::KPerBlock>{} << std::endl;
482  }
483 
484  auto kernel_args = GroupedConvBwdWeightKernelArgsSpecialized(hostArgs);
485 
486  using KernelImpl = GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType_,
487  TilePartitioner_,
488  GemmPipeline_,
489  EpiloguePipeline_>;
490 
491  // Negative k_batch value: split-K autodeduction.
492  if(kernel_args.k_batch < 0)
493  {
494  const auto optimal_split_k =
495  calculate_optimal_k_batch<GemmPipeline_::BlockSize, KernelImpl, TilePartitioner_>(
496  kernel_args);
497  kernel_args.k_batch = optimal_split_k;
498  }
499 
500  return kernel_args;
501  }
502 
504  {
505  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
506  }
507 
508  CK_TILE_HOST static bool
510  {
511  if(kargs.k_batch < 1)
512  {
513  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
514  {
516  "k_batch must be at least one. Ensure argument is created via MakeKernelArgs.");
517  }
518  return false;
519  }
520 
521  if constexpr(!std::is_same_v<typename EpiloguePipeline::ODataType, float> &&
522  !std::is_same_v<typename EpiloguePipeline::ODataType, double>)
523  {
524  // The epilogue performs atomic add related to split-K using the ODataType.
525  // If the type is less accurate than float, large split-K values may lead to
526  // accuracy issues. Hence, we limit the maximum split-K value to 128 in such cases.
527  if(kargs.k_batch > 128)
528  {
529  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
530  {
532  "For epilogue output data type that is not float/double, we must have "
533  "k_batch <= 128.");
534  }
535  return false;
536  }
537  }
538 
539  if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
541  {
542  if(kargs.k_batch != 1)
543  {
544  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
545  {
546  CK_TILE_ERROR("Conditions not met for K_batch > 1!");
547  }
548  return false;
549  }
550  }
551 
552  if(kargs.GemmK < TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}) * kargs.k_batch)
553  {
554  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
555  {
556  CK_TILE_ERROR("KBatch is too large, part of GPU wouldn't be utilized!");
557  }
558  return false;
559  }
560 
561  const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
562  const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
563 
564  // check ConvSpecialization
566  {
567  // check if it's 1x1, stride=1 conv
568  for(index_t i = 0; i < NDimSpatial; ++i)
569  {
570  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
571  const index_t ConvStride = kargs.conv_filter_strides[i];
572  const index_t LeftPad = kargs.input_left_pads[i];
573  const index_t RightPad = kargs.input_right_pads[i];
574 
575  if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
576  {
577  return false;
578  }
579  }
580  }
582  {
583  // check if it's 1x1 conv
584  for(index_t i = 0; i < NDimSpatial; ++i)
585  {
586  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
587  const index_t LeftPad = kargs.input_left_pads[i];
588  const index_t RightPad = kargs.input_right_pads[i];
589 
590  if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
591  {
592  return false;
593  }
594  }
595  }
597  {
598  if(ConvC != 1)
599  {
600  return false;
601  }
602  for(index_t i = 0; i < NDimSpatial; ++i)
603  {
604  const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3];
605 
606  if(filter_spatial_dim != I3)
607  {
608  return false;
609  }
610  }
611  }
612 
613  if constexpr(GroupedConvTraitsType_::ExplicitGemm &&
615  {
617  "Explicit Gemm is supported only for Filter1x1Stride1Pad0 specialization!");
618  return false;
619  }
620 
621  namespace ctc = tensor_layout::convolution;
622 
623  if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
624  std::is_same_v<InLayout, ctc::NDHWGC>)
625  {
626  // Check access per C
627  if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
628  {
629  CK_TILE_ERROR("Conv C is not a multiple of vector load size for "
630  "input image!");
631  return false;
632  }
633  }
634  else
635  {
636  CK_TILE_ERROR("Not supported input layout!");
637  return false;
638  }
639 
640  if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
641  std::is_same_v<WeiLayout, ctc::GKYXC> ||
642  std::is_same_v<WeiLayout, ctc::GKZYXC>)
643  {
644  if(ConvC % GroupedConvTraitsType_::VectorSizeC != 0)
645  {
646  CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
647  return false;
648  }
649  }
650  else
651  {
652  CK_TILE_ERROR("Not supported weight layout!");
653  return false;
654  }
655 
656  if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
657  std::is_same_v<OutLayout, ctc::NHWGK> ||
658  std::is_same_v<OutLayout, ctc::NDHWGK>)
659  {
660  if(ConvK % GroupedConvTraitsType_::VectorSizeA != 0)
661  {
662  CK_TILE_ERROR("Conv K is not a multiple of vector store size "
663  "for output image!");
664  return false;
665  }
666  }
667  else
668  {
669  CK_TILE_ERROR("Not supported output layout!");
670  return false;
671  }
672 
673  if constexpr(GroupedConvTraitsType_::NumGroupsToMerge > 1)
674  {
675  const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}];
676  if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0)
677  {
678  CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge!");
679  return false;
680  }
681  }
682 
683  return true;
684  }
685 
686  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
687  CK_TILE_DEVICE static auto
690  const index_t block_idx_m,
691  const index_t block_idx_n)
692  {
693  const auto& c_tensor_view =
694  make_tensor_view<address_space_enum::global, DstInMemOp>(c_ptr, kargs.c_grid_desc_m_n);
695 
696  const auto& c_pad_view = pad_tensor_view(
697  c_tensor_view,
700 
701  return make_tile_window(
702  c_pad_view,
704  {block_idx_m, block_idx_n});
705  }
706 
707  CK_TILE_DEVICE static auto
708  MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
710  const index_t block_idx_m,
711  const index_t block_idx_n)
712  {
713  const auto& ds_tensor_view = generate_tuple(
714  [&](auto i) {
715  static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
716  "Not supported!");
717  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
718  "Not supported!");
719  static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, WeiDataType>,
720  "Not supported!");
721 
722  return make_tensor_view<address_space_enum::global>(
723  static_cast<WeiDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
724  },
726 
727  const auto& ds_pad_view = generate_tuple(
728  [&](auto i) {
729  return pad_tensor_view(ds_tensor_view[i],
733  },
735 
736  return generate_tuple(
737  [&](auto i) {
738  return make_tile_window(ds_pad_view[i],
741  {block_idx_m, block_idx_n});
742  },
744  }
745 
746  CK_TILE_DEVICE static auto
749  const index_t block_idx_n,
750  const index_t block_idx_k)
751  {
752  static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!");
753  const auto& b_tensor_view =
754  make_tensor_view<address_space_enum::global>(b_ptr, kargs.b_grid_desc_k_n);
755 
756  const auto& b_pad_view =
757  pad_tensor_view(b_tensor_view,
761 
762  return make_tile_window(
763  b_pad_view,
765  {block_idx_k, block_idx_n});
766  }
767 
768  CK_TILE_DEVICE static auto
771  const index_t block_idx_m,
772  const index_t block_idx_k)
773  {
774  static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
775  const auto& a_tensor_view =
776  make_tensor_view<address_space_enum::global>(a_ptr, kargs.a_grid_desc_k_m);
777 
778  const auto& a_pad_view =
779  pad_tensor_view(a_tensor_view,
783 
784  return make_tile_window(
785  a_pad_view,
787  {block_idx_k, block_idx_m});
788  }
789 
802  CK_TILE_DEVICE static void RunGemm(const OutDataType* a_ptr,
803  const InDataType* b_ptr,
804  const std::array<const void*, NumDTensor>& ds_ptr,
805  WeiDataType* c_ptr,
806  void* smem_ptr_0,
808  const index_t num_loop,
809  const index_t block_idx_m,
810  const index_t block_idx_n,
811  const index_t block_idx_k)
812  {
813  // Create block windows using helper methods
814  const auto& a_block_window = MakeABlockWindow(a_ptr, kargs, block_idx_m, block_idx_k);
815  const auto& b_block_window = MakeBBlockWindow(b_ptr, kargs, block_idx_n, block_idx_k);
816  const auto& d_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
817 
818  // Run GEMM cooperatively by whole workgroup.
819  const auto& c_block_tile = GemmPipeline{}.template operator()(
820  a_block_window, b_block_window, num_loop, smem_ptr_0);
821 
822  // Run Epilogue Pipeline with k_batch dispatching
823  if(kargs.k_batch == 1)
824  {
825  auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
826  c_ptr, kargs, block_idx_m, block_idx_n);
827 
828  EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
829  }
830  else
831  {
832  if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
834  {
835  auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
836  c_ptr, kargs, block_idx_m, block_idx_n);
837 
838  EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
839  }
840  }
841  }
842 
844  {
845  static_assert(NumDTensor == 0, "Not supported!");
846  using ExplicitBatchedGemmKernel =
848  const auto batched_gemm_kargs = typename ExplicitBatchedGemmKernel::BatchedGemmKernelArgs{
849  {{kargs.out_ptr},
850  {kargs.in_ptr},
851  {},
852  kargs.wei_ptr,
853  kargs.GemmM,
854  kargs.GemmN,
855  kargs.GemmK,
856  {kargs.GemmM * kargs.GemmBatch},
857  {kargs.GemmN * kargs.GemmBatch},
858  {},
859  kargs.GemmN,
860  kargs.k_batch},
861  kargs.GemmM,
862  kargs.GemmN,
863  kargs.GemmM * kargs.GemmN,
864  kargs.GemmBatch};
865  ExplicitBatchedGemmKernel{}(batched_gemm_kargs);
866  }
867 
869  {
870  if constexpr(GroupedConvTraitsType_::ExplicitGemm)
871  {
872  CallExplicitGemm(kargs);
873  }
874  else
875  {
876  const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
877  const auto [iM, iN] =
878  TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX);
879  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
880  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
881 
882  const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z);
884  kargs.GemmK, kargs.k_batch * TilePartitioner::KPerBlock));
885  const index_t i_k =
886  amd_wave_read_first_lane(blockIdZ * num_loop * TilePartitioner::KPerBlock);
887 
888  const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
889  const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
890  const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
891  const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
892 
893  // options
894  // conv_bwd_weight = Out * In = Weight
895  const OutDataType* a_ptr =
896  static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a;
897  const InDataType* b_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_b;
898  WeiDataType* c_ptr = static_cast<WeiDataType*>(kargs.wei_ptr) + group_offset_c;
899 
900  __shared__ char smem_ptr[GetSmemSize()];
901 
902  RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr, kargs, num_loop, i_m, i_n, i_k);
903  }
904  }
905 };
906 
907 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
Definition: cluster_descriptor.hpp:13
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:145
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:36
ConvolutionSpecialization
Definition: convolution_specialization.hpp:11
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
CK_TILE_HOST std::string getConvSpecializationString(const ConvolutionSpecialization &s)
Definition: convolution_specialization.hpp:18
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:157
constexpr bool is_same_v
Definition: type.hpp:283
Definition: batched_gemm_kernel.hpp:62
The Grouped Convolution kernel device arguments.
Definition: grouped_convolution_backward_weight_kernel.hpp:28
long_index_t group_stride_a
Definition: grouped_convolution_backward_weight_kernel.hpp:330
remove_cvref_t< decltype(ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N())> ABCGridDescs
Definition: grouped_convolution_backward_weight_kernel.hpp:298
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_strides
Definition: grouped_convolution_backward_weight_kernel.hpp:309
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > wei_g_k_c_xs_lengths
Definition: grouped_convolution_backward_weight_kernel.hpp:306
void * wei_ptr
Definition: grouped_convolution_backward_weight_kernel.hpp:324
long_index_t group_stride_b
Definition: grouped_convolution_backward_weight_kernel.hpp:331
CGridDescMN c_grid_desc_m_n
Definition: grouped_convolution_backward_weight_kernel.hpp:328
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > in_g_n_c_wis_lengths
Definition: grouped_convolution_backward_weight_kernel.hpp:305
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_dilations
Definition: grouped_convolution_backward_weight_kernel.hpp:310
AGridDescKM a_grid_desc_k_m
Definition: grouped_convolution_backward_weight_kernel.hpp:326
BGridDescKN b_grid_desc_k_n
Definition: grouped_convolution_backward_weight_kernel.hpp:327
index_t GemmN
Definition: grouped_convolution_backward_weight_kernel.hpp:316
index_t GemmBatch
Definition: grouped_convolution_backward_weight_kernel.hpp:318
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > out_g_n_k_wos_lengths
Definition: grouped_convolution_backward_weight_kernel.hpp:307
CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs &args)
Definition: grouped_convolution_backward_weight_kernel.hpp:47
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_left_pads
Definition: grouped_convolution_backward_weight_kernel.hpp:311
remove_cvref_t< decltype(ABCGridDescs{}[number< 1 >{}])> BGridDescKN
Definition: grouped_convolution_backward_weight_kernel.hpp:301
std::array< const void *, NumDTensor > ds_ptr
Definition: grouped_convolution_backward_weight_kernel.hpp:323
index_t GemmM
Definition: grouped_convolution_backward_weight_kernel.hpp:315
index_t NumGroupsPerBatch
Definition: grouped_convolution_backward_weight_kernel.hpp:319
remove_cvref_t< decltype(ABCGridDescs{}[number< 2 >{}])> CGridDescMN
Definition: grouped_convolution_backward_weight_kernel.hpp:302
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_right_pads
Definition: grouped_convolution_backward_weight_kernel.hpp:312
index_t GemmK
Definition: grouped_convolution_backward_weight_kernel.hpp:317
const void * in_ptr
Definition: grouped_convolution_backward_weight_kernel.hpp:322
index_t k_batch
Definition: grouped_convolution_backward_weight_kernel.hpp:314
static constexpr index_t NonSpatialDims
Definition: grouped_convolution_backward_weight_kernel.hpp:304
const void * out_ptr
Definition: grouped_convolution_backward_weight_kernel.hpp:321
remove_cvref_t< decltype(ABCGridDescs{}[number< 0 >{}])> AGridDescKM
Definition: grouped_convolution_backward_weight_kernel.hpp:300
static constexpr index_t NumDTensor
Definition: grouped_convolution_backward_weight_kernel.hpp:37
long_index_t group_stride_c
Definition: grouped_convolution_backward_weight_kernel.hpp:332
The Grouped Conv kernel host arguments.
Definition: grouped_convolution_utils.hpp:27
InPtr in_ptr
Definition: grouped_convolution_utils.hpp:46
OutPtr out_ptr
Definition: grouped_convolution_utils.hpp:49
WeiPtr wei_ptr
Definition: grouped_convolution_utils.hpp:47
index_t k_batch
Definition: grouped_convolution_utils.hpp:50
const std::vector< const void * > ds_ptr
Definition: grouped_convolution_utils.hpp:48
The Grouped Convolution Backward Weight kernel template.
Definition: grouped_convolution_backward_weight_kernel.hpp:377
remove_cvref_t< typename EpiloguePipeline::DsLayout > GemmDsLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:393
static constexpr index_t kBlockSize
Definition: grouped_convolution_backward_weight_kernel.hpp:396
static constexpr CK_TILE_HOST auto BlockSize()
Definition: grouped_convolution_backward_weight_kernel.hpp:469
remove_cvref_t< typename GroupedConvTraitsType_::OutLayout > OutLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:390
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_convolution_backward_weight_kernel.hpp:381
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: grouped_convolution_backward_weight_kernel.hpp:382
GroupedConvBwdWeightKernelArgs< GroupedConvTraitsType_ > GroupedConvBwdWeightKernelArgsSpecialized
Definition: grouped_convolution_backward_weight_kernel.hpp:404
static CK_TILE_HOST const std::string GetName()
Definition: grouped_convolution_backward_weight_kernel.hpp:422
remove_cvref_t< typename GemmPipeline::CLayout > GemmCLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:386
static constexpr auto I2
Definition: grouped_convolution_backward_weight_kernel.hpp:410
static constexpr CK_TILE_HOST GroupedConvBwdWeightKernelArgsSpecialized MakeKernelArgs(const GroupedConvBwdWeightHostArgs &hostArgs)
Definition: grouped_convolution_backward_weight_kernel.hpp:475
CK_TILE_DEVICE void operator()(GroupedConvBwdWeightKernelArgsSpecialized &kargs) const
Definition: grouped_convolution_backward_weight_kernel.hpp:868
remove_cvref_t< typename GemmPipeline::ALayout > GemmALayout
Definition: grouped_convolution_backward_weight_kernel.hpp:384
static CK_TILE_HOST bool IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_weight_kernel.hpp:509
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: grouped_convolution_backward_weight_kernel.hpp:383
CK_TILE_DEVICE void CallExplicitGemm(GroupedConvBwdWeightKernelArgsSpecialized &kargs) const
Definition: grouped_convolution_backward_weight_kernel.hpp:843
static CK_TILE_DEVICE auto MakeDBlockWindows(const std::array< const void *, NumDTensor > &ds_ptr, const GroupedConvBwdWeightKernelArgsSpecialized &kargs, const index_t block_idx_m, const index_t block_idx_n)
Definition: grouped_convolution_backward_weight_kernel.hpp:708
static CK_TILE_DEVICE auto MakeBBlockWindow(const InDataType *b_ptr, const GroupedConvBwdWeightKernelArgsSpecialized &kargs, const index_t block_idx_n, const index_t block_idx_k)
Definition: grouped_convolution_backward_weight_kernel.hpp:747
static constexpr ConvolutionSpecialization ConvSpecialization
Definition: grouped_convolution_backward_weight_kernel.hpp:379
remove_cvref_t< typename GroupedConvTraitsType_::WeiLayout > WeiLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:389
static constexpr bool IsSplitKSupported
Definition: grouped_convolution_backward_weight_kernel.hpp:406
static CK_TILE_DEVICE auto MakeABlockWindow(const OutDataType *a_ptr, const GroupedConvBwdWeightKernelArgsSpecialized &kargs, const index_t block_idx_m, const index_t block_idx_k)
Definition: grouped_convolution_backward_weight_kernel.hpp:769
static CK_TILE_DEVICE void RunGemm(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, void *smem_ptr_0, const GroupedConvBwdWeightKernelArgsSpecialized &kargs, const index_t num_loop, const index_t block_idx_m, const index_t block_idx_n, const index_t block_idx_k)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_backward_weight_kernel.hpp:802
static constexpr index_t NDimSpatial
Definition: grouped_convolution_backward_weight_kernel.hpp:378
remove_cvref_t< typename GroupedConvTraitsType_::DsLayout > DsLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:391
static CK_TILE_DEVICE auto MakeCBlockWindow(WeiDataType *c_ptr, const GroupedConvBwdWeightKernelArgsSpecialized &kargs, const index_t block_idx_m, const index_t block_idx_n)
Definition: grouped_convolution_backward_weight_kernel.hpp:688
remove_cvref_t< typename GroupedConvTraitsType_::InLayout > InLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:388
remove_cvref_t< typename EpiloguePipeline::ODataType > WeiDataType
Definition: grouped_convolution_backward_weight_kernel.hpp:401
static constexpr auto I3
Definition: grouped_convolution_backward_weight_kernel.hpp:411
static constexpr auto I0
Definition: grouped_convolution_backward_weight_kernel.hpp:408
static constexpr auto I1
Definition: grouped_convolution_backward_weight_kernel.hpp:409
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: grouped_convolution_backward_weight_kernel.hpp:400
static constexpr index_t NumDTensor
Definition: grouped_convolution_backward_weight_kernel.hpp:394
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: grouped_convolution_backward_weight_kernel.hpp:503
remove_cvref_t< typename GemmPipeline::ADataType > OutDataType
Definition: grouped_convolution_backward_weight_kernel.hpp:398
remove_cvref_t< typename GemmPipeline::BDataType > InDataType
Definition: grouped_convolution_backward_weight_kernel.hpp:399
remove_cvref_t< typename GemmPipeline::BLayout > GemmBLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:385
static constexpr CK_TILE_HOST auto GridSize(const GroupedConvBwdWeightKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_weight_kernel.hpp:463
static CK_TILE_HOST const std::string GetTypeString()
Definition: grouped_convolution_backward_weight_kernel.hpp:447
Definition: transform_conv_bwd_weight_to_gemm.hpp:21
CK_TILE_HOST auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N() const
Definition: transform_conv_bwd_weight_to_gemm.hpp:818
Definition: integral_constant.hpp:13
std::vector< ck_tile::long_index_t > input_spatial_lengths_
Definition: convolution_parameter.hpp:130
ck_tile::long_index_t K_
Definition: convolution_parameter.hpp:126
std::vector< ck_tile::long_index_t > output_spatial_lengths_
Definition: convolution_parameter.hpp:131
std::vector< ck_tile::long_index_t > input_right_pads_
Definition: convolution_parameter.hpp:137
ck_tile::long_index_t G_
Definition: convolution_parameter.hpp:124
std::vector< ck_tile::long_index_t > conv_filter_strides_
Definition: convolution_parameter.hpp:133
std::vector< ck_tile::long_index_t > filter_spatial_lengths_
Definition: convolution_parameter.hpp:129
ck_tile::long_index_t C_
Definition: convolution_parameter.hpp:127
ck_tile::long_index_t N_
Definition: convolution_parameter.hpp:125
std::vector< ck_tile::long_index_t > input_left_pads_
Definition: convolution_parameter.hpp:136
std::vector< ck_tile::long_index_t > conv_filter_dilations_
Definition: convolution_parameter.hpp:134
Definition: type_traits.hpp:115
Definition: sequence.hpp:49
#define CK_TILE_ENV(name)
Definition: env.hpp:145