/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp Source File
grouped_convolution_backward_weight_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 #include "ck_tile/host/concat.hpp"
16 
18 
19 #ifdef CK_EXPERIMENTAL_BUILDER
20 #include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp"
21 #endif
22 
23 namespace ck_tile {
24 
26 template <typename GroupedConvTraitsType_>
28 {
29 
31  TransformConvBwdWeightToGemm<GroupedConvTraitsType_::NDimSpatial,
32  GroupedConvTraitsType_::ConvSpecialization,
33  GroupedConvTraitsType_::VectorSizeA,
34  GroupedConvTraitsType_::VectorSizeB,
35  GroupedConvTraitsType_::VectorSizeC,
36  GroupedConvTraitsType_::NumGroupsToMerge>;
37  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
38 
39  template <
40  typename InLay = typename GroupedConvTraitsType_::InLayout,
41  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
42  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
43  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NWGC> &&
44  std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
45  std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
46  bool>::type = false>
48  {
49  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
50  static_cast<index_t>(args.N_),
51  static_cast<index_t>(args.C_),
52  static_cast<index_t>(args.input_spatial_lengths_[0])};
53  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
54  static_cast<index_t>(args.K_),
55  static_cast<index_t>(args.C_),
56  static_cast<index_t>(args.filter_spatial_lengths_[0])};
57  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
58  static_cast<index_t>(args.N_),
59  static_cast<index_t>(args.K_),
60  static_cast<index_t>(args.output_spatial_lengths_[0])};
61 
62  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0])};
63  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0])};
64  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0])};
65  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0])};
66 
67  in_ptr = args.in_ptr;
68  wei_ptr = args.wei_ptr;
69  for(index_t d = 0; d < NumDTensor; d++)
70  {
71  ds_ptr[d] = args.ds_ptr[d];
72  }
73  out_ptr = args.out_ptr;
74 
75  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
82 
83  // tuple
84  auto grid_descs =
85  conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
86  GroupedConvTraitsType_::NDimSpatial>();
87 
88  a_grid_desc_k_m = grid_descs.at(number<0>{});
89  b_grid_desc_k_n = grid_descs.at(number<1>{});
90  c_grid_desc_m_n = grid_descs.at(number<2>{});
91 
92  NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
93  group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NWGK
94  group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NWGC
95  group_stride_c = args.K_ * args.C_ // C: Wei GKXC
97  std::accumulate(args.filter_spatial_lengths_.begin(),
98  args.filter_spatial_lengths_.end(),
99  1,
100  std::multiplies<index_t>());
101 
102  GemmM = a_grid_desc_k_m.get_length(number<1>{});
103  GemmN = b_grid_desc_k_n.get_length(number<1>{});
104  GemmK = a_grid_desc_k_m.get_length(number<0>{});
106 
107  k_batch = args.k_batch;
108 
109  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
110  {
111  std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
112  << ", GemmBatch: " << GemmBatch
113  << ", NumGroupsPerBatch: " << NumGroupsPerBatch << ", k_batch: " << k_batch
114  << std::endl;
115  }
116  }
117 
118  template <
119  typename InLay = typename GroupedConvTraitsType_::InLayout,
120  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
121  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
122  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NHWGC> &&
123  std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
124  std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
125  bool>::type = false>
127  {
128  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
129  static_cast<index_t>(args.N_),
130  static_cast<index_t>(args.C_),
131  static_cast<index_t>(args.input_spatial_lengths_[0]),
132  static_cast<index_t>(args.input_spatial_lengths_[1])};
133  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
134  static_cast<index_t>(args.K_),
135  static_cast<index_t>(args.C_),
136  static_cast<index_t>(args.filter_spatial_lengths_[0]),
137  static_cast<index_t>(args.filter_spatial_lengths_[1])};
138  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
139  static_cast<index_t>(args.N_),
140  static_cast<index_t>(args.K_),
141  static_cast<index_t>(args.output_spatial_lengths_[0]),
142  static_cast<index_t>(args.output_spatial_lengths_[1])};
143 
144  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
145  static_cast<index_t>(args.conv_filter_strides_[1])};
146  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
147  static_cast<index_t>(args.conv_filter_dilations_[1])};
148  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
149  static_cast<index_t>(args.input_left_pads_[1])};
150  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
151  static_cast<index_t>(args.input_right_pads_[1])};
152 
153  in_ptr = args.in_ptr;
154  wei_ptr = args.wei_ptr;
155  for(index_t d = 0; d < NumDTensor; d++)
156  {
157  ds_ptr[d] = args.ds_ptr[d];
158  }
159  out_ptr = args.out_ptr;
160 
161  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
168 
169  // tuple
170  auto grid_descs =
171  conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
172  GroupedConvTraitsType_::NDimSpatial>();
173 
174  a_grid_desc_k_m = grid_descs.at(number<0>{});
175  b_grid_desc_k_n = grid_descs.at(number<1>{});
176  c_grid_desc_m_n = grid_descs.at(number<2>{});
177 
178  NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
179  group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NHWGK
180  group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NHWGC
181  group_stride_c = args.K_ * args.C_ // C: Wei GKYXC
183  std::accumulate(args.filter_spatial_lengths_.begin(),
184  args.filter_spatial_lengths_.end(),
185  1,
186  std::multiplies<index_t>());
187 
188  GemmM = a_grid_desc_k_m.get_length(number<1>{});
189  GemmN = b_grid_desc_k_n.get_length(number<1>{});
190  GemmK = a_grid_desc_k_m.get_length(number<0>{});
192 
193  k_batch = args.k_batch;
194 
195  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
196  {
197  std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
198  << ", GemmBatch: " << GemmBatch
199  << ", NumGroupsPerBatch: " << NumGroupsPerBatch << ", k_batch: " << k_batch
200  << std::endl;
201  }
202  }
203 
204  template <
205  typename InLay = typename GroupedConvTraitsType_::InLayout,
206  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
207  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
208  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NDHWGC> &&
209  std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
210  std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
211  bool>::type = false>
213  {
214  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
215  static_cast<index_t>(args.N_),
216  static_cast<index_t>(args.C_),
217  static_cast<index_t>(args.input_spatial_lengths_[0]),
218  static_cast<index_t>(args.input_spatial_lengths_[1]),
219  static_cast<index_t>(args.input_spatial_lengths_[2])};
220  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
221  static_cast<index_t>(args.K_),
222  static_cast<index_t>(args.C_),
223  static_cast<index_t>(args.filter_spatial_lengths_[0]),
224  static_cast<index_t>(args.filter_spatial_lengths_[1]),
225  static_cast<index_t>(args.filter_spatial_lengths_[2])};
226  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
227  static_cast<index_t>(args.N_),
228  static_cast<index_t>(args.K_),
229  static_cast<index_t>(args.output_spatial_lengths_[0]),
230  static_cast<index_t>(args.output_spatial_lengths_[1]),
231  static_cast<index_t>(args.output_spatial_lengths_[2])};
232 
233  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
234  static_cast<index_t>(args.conv_filter_strides_[1]),
235  static_cast<index_t>(args.conv_filter_strides_[2])};
236  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
237  static_cast<index_t>(args.conv_filter_dilations_[1]),
238  static_cast<index_t>(args.conv_filter_dilations_[2])};
239  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
240  static_cast<index_t>(args.input_left_pads_[1]),
241  static_cast<index_t>(args.input_left_pads_[2])};
242  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
243  static_cast<index_t>(args.input_right_pads_[1]),
244  static_cast<index_t>(args.input_right_pads_[2])};
245 
246  in_ptr = args.in_ptr;
247  wei_ptr = args.wei_ptr;
248  for(index_t d = 0; d < NumDTensor; d++)
249  {
250  ds_ptr[d] = args.ds_ptr[d];
251  }
252  out_ptr = args.out_ptr;
253 
254  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
261 
262  // tuple
263  auto grid_descs =
264  conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
265  GroupedConvTraitsType_::NDimSpatial>();
266 
267  a_grid_desc_k_m = grid_descs.at(number<0>{});
268  b_grid_desc_k_n = grid_descs.at(number<1>{});
269  c_grid_desc_m_n = grid_descs.at(number<2>{});
270 
271  NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
272  group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NDHWGK
273  group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NDHWGC
274  group_stride_c = args.K_ * args.C_ // C: Wei GKZYXC
276  std::accumulate(args.filter_spatial_lengths_.begin(),
277  args.filter_spatial_lengths_.end(),
278  1,
279  std::multiplies<index_t>());
280 
281  GemmM = a_grid_desc_k_m.get_length(number<1>{});
282  GemmN = b_grid_desc_k_n.get_length(number<1>{});
283  GemmK = a_grid_desc_k_m.get_length(number<0>{});
285 
286  k_batch = args.k_batch;
287 
288  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
289  {
290  std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
291  << ", GemmBatch: " << GemmBatch
292  << ", NumGroupsPerBatch: " << NumGroupsPerBatch << ", k_batch: " << k_batch
293  << std::endl;
294  }
295  }
296 
299 
303 
304  static constexpr index_t NonSpatialDims = 3;
308 
313 
320 
321  const void* out_ptr;
322  const void* in_ptr;
323  std::array<const void*, NumDTensor> ds_ptr;
324  void* wei_ptr;
325 
329 
333 };
334 
372 template <typename GroupedConvTraitsType_,
373  typename TilePartitioner_,
374  typename GemmPipeline_,
375  typename EpiloguePipeline_>
377 {
378  static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial;
380  GroupedConvTraitsType_::ConvSpecialization;
387 
392 
394  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
395 
396  static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
397 
402 
405 
406  static constexpr bool IsSplitKSupported = true;
407 
408  static constexpr auto I0 = number<0>();
409  static constexpr auto I1 = number<1>();
410  static constexpr auto I2 = number<2>();
411  static constexpr auto I3 = number<3>();
412 
413  static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
414  "Not supported!");
415  static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
416  static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
417  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
418  static_assert(GroupedConvTraitsType_::ExplicitGemm == false ||
419  GroupedConvTraitsType_::NumGroupsToMerge == 1,
420  "Not supported!");
421 
422  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
423  {
424  static constexpr bool EnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
425  constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
426  // clang-format off
427  return concat('_', "grouped_convolution_backward_weight",
428  gemm_prec_str<InDataType, WeiDataType>(),
429  InLayout::name,
430  WeiLayout::name,
431  OutLayout::name,
432  "gemm",
433  GemmPipeline::GetName(),
434  "epilogue",
435  EpiloguePipeline::GetName(),
437  "MergedGroups",
438  NumGroupsToMerge,
439  "SplitImage",
440  EnableSplitImage,
441  "ExplicitGemm",
442  GroupedConvTraitsType_::ExplicitGemm
443  );
444  // clang-format on
445  }
446 
447  [[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); }
448 
449 #ifdef CK_EXPERIMENTAL_BUILDER
450  CK_TILE_HOST std::string GetInstanceString() const
451  {
452  static_assert(ck_tile::reflect::HasInstanceTraits<GroupedConvolutionBackwardWeightKernel>,
453  "Specialization of instance_traits not found. Please check that a "
454  "specialization exists in file "
455  "ck_tile/builder/reflect/"
456  "instance_traits_tile_grouped_convolution_backward_weight.hpp "
457  "for the given template parameters.");
458  return ck_tile::reflect::instance_string<GroupedConvolutionBackwardWeightKernel>();
459  }
460 #endif
461 
462  CK_TILE_HOST static constexpr auto
464  {
465  return dim3(
466  TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch);
467  }
468 
469  CK_TILE_HOST static constexpr auto BlockSize()
470  {
471  return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
472  }
473 
476  {
477  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
478  {
479  std::cout << "MPerBlock: " << number<TilePartitioner::MPerBlock>{} << std::endl;
480  std::cout << "NPerBlock: " << number<TilePartitioner::NPerBlock>{} << std::endl;
481  std::cout << "KPerBlock: " << number<TilePartitioner::KPerBlock>{} << std::endl;
482  }
483 
484  auto kernel_args = GroupedConvBwdWeightKernelArgsSpecialized(hostArgs);
485 
486  using KernelImpl = GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType_,
487  TilePartitioner_,
488  GemmPipeline_,
489  EpiloguePipeline_>;
490 
491  // Negative k_batch value: split-K autodeduction.
492  if(kernel_args.k_batch < 0)
493  {
494  const auto optimal_split_k =
495  calculate_optimal_k_batch<GemmPipeline_::BlockSize, KernelImpl, TilePartitioner_>(
496  kernel_args);
497  kernel_args.k_batch = optimal_split_k;
498  }
499 
500  return kernel_args;
501  }
502 
504  {
505  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
506  }
507 
509  {
511  const std::size_t k_id = blockIdx.z)
512  {
513  constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{});
514  const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
515  const index_t KRead = amd_wave_read_first_lane((kargs.GemmK + K_t - 1) / K_t * K1);
516 
519 
520  if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
521  {
523  }
524  else
525  {
526  splitted_k = amd_wave_read_first_lane(kargs.GemmK - KRead * (kargs.k_batch - 1));
527  }
528  }
529 
533  };
534 
535  CK_TILE_HOST static bool
537  {
538  if(kargs.k_batch < 1)
539  {
540  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
541  {
543  "k_batch must be at least one. Ensure argument is created via MakeKernelArgs.");
544  }
545  return false;
546  }
547 
548  if constexpr(EpiloguePipeline_::MemoryOperation == memory_operation_enum::atomic_add)
549  {
550  if(kargs.k_batch == 1)
551  {
552  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
553  {
554  CK_TILE_ERROR("Atomic add epilogue only supports k_batch > 1.");
555  }
556  return false;
557  }
558  }
559 
560  if constexpr(!std::is_same_v<typename EpiloguePipeline::ODataType, float> &&
561  !std::is_same_v<typename EpiloguePipeline::ODataType, double>)
562  {
563  // The epilogue performs atomic add related to split-K using the ODataType.
564  // If the type is less accurate than float, large split-K values may lead to
565  // accuracy issues. Hence, we limit the maximum split-K value to 128 in such cases.
566  if(kargs.k_batch > 128)
567  {
568  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
569  {
571  "For epilogue output data type that is not float/double, we must have "
572  "k_batch <= 128.");
573  }
574  return false;
575  }
576  }
577 
578  if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
580  {
581  if(kargs.k_batch != 1)
582  {
583  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
584  {
585  CK_TILE_ERROR("Conditions not met for K_batch > 1!");
586  }
587  return false;
588  }
589  }
590 
591  const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
592  const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
593 
594  // check ConvSpecialization
596  {
597  // check if it's 1x1, stride=1 conv
598  for(index_t i = 0; i < NDimSpatial; ++i)
599  {
600  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
601  const index_t ConvStride = kargs.conv_filter_strides[i];
602  const index_t LeftPad = kargs.input_left_pads[i];
603  const index_t RightPad = kargs.input_right_pads[i];
604 
605  if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
606  {
607  return false;
608  }
609  }
610  }
612  {
613  // check if it's 1x1 conv
614  for(index_t i = 0; i < NDimSpatial; ++i)
615  {
616  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
617  const index_t LeftPad = kargs.input_left_pads[i];
618  const index_t RightPad = kargs.input_right_pads[i];
619 
620  if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
621  {
622  return false;
623  }
624  }
625  }
627  {
628  if(ConvC != 1)
629  {
630  return false;
631  }
632  for(index_t i = 0; i < NDimSpatial; ++i)
633  {
634  const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3];
635 
636  if(filter_spatial_dim != I3)
637  {
638  return false;
639  }
640  }
641  }
642 
643  if constexpr(GroupedConvTraitsType_::ExplicitGemm &&
645  {
647  "Explicit Gemm is supported only for Filter1x1Stride1Pad0 specialization!");
648  return false;
649  }
650 
651  namespace ctc = tensor_layout::convolution;
652 
653  if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
654  std::is_same_v<InLayout, ctc::NDHWGC>)
655  {
656  // Check access per C
657  if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
658  {
659  CK_TILE_ERROR("Conv C is not a multiple of vector load size for "
660  "input image!");
661  return false;
662  }
663  }
664  else
665  {
666  CK_TILE_ERROR("Not supported input layout!");
667  return false;
668  }
669 
670  if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
671  std::is_same_v<WeiLayout, ctc::GKYXC> ||
672  std::is_same_v<WeiLayout, ctc::GKZYXC>)
673  {
674  if(ConvC % GroupedConvTraitsType_::VectorSizeC != 0)
675  {
676  CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
677  return false;
678  }
679  }
680  else
681  {
682  CK_TILE_ERROR("Not supported weight layout!");
683  return false;
684  }
685 
686  if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
687  std::is_same_v<OutLayout, ctc::NHWGK> ||
688  std::is_same_v<OutLayout, ctc::NDHWGK>)
689  {
690  if(ConvK % GroupedConvTraitsType_::VectorSizeA != 0)
691  {
692  CK_TILE_ERROR("Conv K is not a multiple of vector store size "
693  "for output image!");
694  return false;
695  }
696  }
697  else
698  {
699  CK_TILE_ERROR("Not supported output layout!");
700  return false;
701  }
702 
703  if constexpr(GroupedConvTraitsType_::NumGroupsToMerge > 1)
704  {
705  const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}];
706  if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0)
707  {
708  CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge!");
709  return false;
710  }
711  }
712 
713  return true;
714  }
715 
716  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
717  CK_TILE_DEVICE static auto
719  const InDataType* b_ptr,
720  const std::array<const void*, NumDTensor>& ds_ptr,
721  WeiDataType* c_ptr,
723  {
724  static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
725  static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!");
726  const auto& a_tensor_view = [&]() {
727  return make_tensor_view<address_space_enum::global>(a_ptr,
728  kargs.a_grid_desc_k_m); // A: out
729  }();
730 
731  const auto& b_tensor_view = [&]() {
732  return make_tensor_view<address_space_enum::global>(b_ptr,
733  kargs.b_grid_desc_k_n); // B: in
734  }();
735 
736  const auto& c_tensor_view = [&]() {
737  return make_tensor_view<address_space_enum::global, DstInMemOp>(c_ptr,
738  kargs.c_grid_desc_m_n);
739  }();
740 
741  const auto& ds_tensor_view = generate_tuple(
742  [&](auto i) {
743  static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
744  "Not supported!");
745  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
746  "Not supported!");
747  static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, WeiDataType>,
748  "Not supported!");
749 
750  return make_tensor_view<address_space_enum::global>(
751  static_cast<WeiDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
752  },
754 
755  return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
756  }
757 
758  template <typename TensorView>
759  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views, const index_t k_batch)
760  {
761  const auto& a_pad_view = [&]() {
762  const auto& a_tensor_view = views.at(I0);
763  return pad_tensor_view(a_tensor_view,
767  }();
768 
769  const auto& b_pad_view = [&]() {
770  const auto& b_tensor_view = views.at(I1);
771  return pad_tensor_view(b_tensor_view,
775  }();
776 
777  const auto& ds_tensor_view = views.at(I2);
778  const auto& ds_pad_view = generate_tuple(
779  [&](auto i) {
780  return pad_tensor_view(ds_tensor_view[i],
784  },
786 
787  const auto& c_pad_view = [&]() {
788  const auto& c_tensor_view = views.at(I3);
789  return pad_tensor_view(c_tensor_view,
793  }();
794 
795  return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
796  }
797 
808  template <typename PadView>
809  CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
810  const index_t i_m,
811  const index_t i_n,
812  const index_t i_k)
813  {
814  const auto& a_pad_view = views.at(I0);
815  const auto& b_pad_view = views.at(I1);
816  const auto& ds_pad_view = views.at(I2);
817  const auto& c_pad_view = views.at(I3);
818 
819  const auto& a_block_window = [&]() {
820  return make_tile_window(a_pad_view,
823  {i_k, i_m});
824  }();
825 
826  const auto& b_block_window = [&]() {
827  return make_tile_window(b_pad_view,
830  {i_k, i_n});
831  }();
832 
833  const auto ds_block_window = generate_tuple(
834  [&](auto i) {
835  return make_tile_window(ds_pad_view[i],
838  {i_m, i_n});
839  },
841 
842  auto c_block_window = make_tile_window(
843  c_pad_view,
845  {i_m, i_n});
846 
847  return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
848  }
849 
862  CK_TILE_DEVICE static void RunGemm(const OutDataType* a_ptr,
863  const InDataType* b_ptr,
864  const std::array<const void*, NumDTensor>& ds_ptr,
865  WeiDataType* c_ptr,
866  void* smem_ptr_0,
868  const index_t num_loop,
869  const index_t block_idx_m,
870  const index_t block_idx_n,
871  const index_t block_idx_k)
872  {
873  // Create Gemm tensor views, pad views and tile windows
874  const auto& gemm_tensor_views_tuple =
875  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
876  a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
877 
878  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple, kargs.k_batch);
879  auto gemm_tile_windows =
880  MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
881 
882  // Run GEMM cooperatively by whole workgroup.
883  const auto& a_block_window = gemm_tile_windows.at(I0);
884  const auto& b_block_window = gemm_tile_windows.at(I1);
885  const auto& d_block_window = gemm_tile_windows.at(I2);
886 
887  const auto& c_block_tile = GemmPipeline{}.template operator()(
888  a_block_window, b_block_window, num_loop, smem_ptr_0);
889 
890  // Run Epilogue Pipeline
891  auto& c_block_window = gemm_tile_windows.at(I3);
892 
893  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
894  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
895  }
896 
912  CK_TILE_DEVICE static void RunGemm2LDS(const OutDataType* a_ptr,
913  const InDataType* b_ptr,
914  const std::array<const void*, NumDTensor>& ds_ptr,
915  WeiDataType* c_ptr,
916  void* __restrict__ smem_ptr_0,
917  void* __restrict__ smem_ptr_1,
919  const index_t num_loop,
920  const index_t block_idx_m,
921  const index_t block_idx_n,
922  const index_t block_idx_k)
923  {
924  // Create Gemm tensor views, pad views and tile windows
925  const auto& gemm_tensor_views_tuple =
926  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
927  a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
928  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple, kargs.k_batch);
929  auto gemm_tile_windows =
930  MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
931 
932  // Run GEMM cooperatively by whole workgroup.
933  const auto& a_block_window = gemm_tile_windows.at(I0);
934  const auto& b_block_window = gemm_tile_windows.at(I1);
935  const auto& d_block_window = gemm_tile_windows.at(I2);
936 
937  const auto& c_block_tile = GemmPipeline{}.template operator()(
938  a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
939 
940  // Run Epilogue Pipeline
941  auto& c_block_window = gemm_tile_windows.at(I3);
942 
943  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
944  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
945  }
946 
948  {
949  static_assert(NumDTensor == 0, "Not supported!");
950  using ExplicitBatchedGemmKernel =
952  const auto batched_gemm_kargs = typename ExplicitBatchedGemmKernel::BatchedGemmKernelArgs{
953  {{kargs.out_ptr},
954  {kargs.in_ptr},
955  {},
956  kargs.wei_ptr,
957  kargs.GemmM,
958  kargs.GemmN,
959  kargs.GemmK,
960  {kargs.GemmM * kargs.GemmBatch},
961  {kargs.GemmN * kargs.GemmBatch},
962  {},
963  kargs.GemmN,
964  kargs.k_batch},
965  kargs.GemmM,
966  kargs.GemmN,
967  kargs.GemmM * kargs.GemmN,
968  kargs.GemmBatch};
969  ExplicitBatchedGemmKernel{}(batched_gemm_kargs);
970  }
971 
973  {
974  if constexpr(GroupedConvTraitsType_::ExplicitGemm)
975  {
976  CallExplicitGemm(kargs);
977  }
978  else
979  {
980  const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
981  const auto [iM, iN] =
982  TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX);
983  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
984  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
985 
986  const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z);
988  kargs.GemmK, kargs.k_batch * TilePartitioner::KPerBlock));
989  const index_t i_k =
990  amd_wave_read_first_lane(blockIdZ * num_loop * TilePartitioner::KPerBlock);
991 
992  const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
993  const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
994  const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
995  const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
996 
997  // options
998  // conv_bwd_weight = Out * In = Weight
999  const OutDataType* a_ptr =
1000  static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a;
1001  const InDataType* b_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_b;
1002  WeiDataType* c_ptr = static_cast<WeiDataType*>(kargs.wei_ptr) + group_offset_c;
1003 
1004  __shared__ char smem_ptr_0[GetSmemSize()];
1005 
1006  if constexpr(GemmPipeline::DoubleSmemBuffer == true)
1007  {
1008  __shared__ char smem_ptr_1[GetSmemSize()];
1009  if constexpr(!(EpiloguePipeline::MemoryOperation ==
1011  GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
1013  {
1014  RunGemm2LDS(a_ptr,
1015  b_ptr,
1016  kargs.ds_ptr,
1017  c_ptr,
1018  smem_ptr_0,
1019  smem_ptr_1,
1020  kargs,
1021  num_loop,
1022  i_m,
1023  i_n,
1024  i_k);
1025  }
1026  }
1027  else
1028  {
1029  if constexpr(!(EpiloguePipeline::MemoryOperation ==
1031  GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
1033  {
1034  RunGemm(a_ptr,
1035  b_ptr,
1036  kargs.ds_ptr,
1037  c_ptr,
1038  smem_ptr_0,
1039  kargs,
1040  num_loop,
1041  i_m,
1042  i_n,
1043  i_k);
1044  }
1045  }
1046  }
1047  }
1048 };
1049 
1050 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:49
#define CK_TILE_HOST
Definition: config.hpp:48
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:50
Definition: cluster_descriptor.hpp:13
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:145
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
ConvolutionSpecialization
Definition: convolution_specialization.hpp:11
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
CK_TILE_HOST std::string getConvSpecializationString(const ConvolutionSpecialization &s)
Definition: convolution_specialization.hpp:18
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:157
constexpr bool is_same_v
Definition: type.hpp:283
__device__ X atomic_add(X *p_dst, const X &x)
unsigned int uint32_t
Definition: stdint.h:126
Definition: batched_gemm_kernel.hpp:62
The Grouped Convolution kernel device arguments.
Definition: grouped_convolution_backward_weight_kernel.hpp:28
long_index_t group_stride_a
Definition: grouped_convolution_backward_weight_kernel.hpp:330
remove_cvref_t< decltype(ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N())> ABCGridDescs
Definition: grouped_convolution_backward_weight_kernel.hpp:298
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_strides
Definition: grouped_convolution_backward_weight_kernel.hpp:309
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > wei_g_k_c_xs_lengths
Definition: grouped_convolution_backward_weight_kernel.hpp:306
void * wei_ptr
Definition: grouped_convolution_backward_weight_kernel.hpp:324
long_index_t group_stride_b
Definition: grouped_convolution_backward_weight_kernel.hpp:331
CGridDescMN c_grid_desc_m_n
Definition: grouped_convolution_backward_weight_kernel.hpp:328
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > in_g_n_c_wis_lengths
Definition: grouped_convolution_backward_weight_kernel.hpp:305
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_dilations
Definition: grouped_convolution_backward_weight_kernel.hpp:310
AGridDescKM a_grid_desc_k_m
Definition: grouped_convolution_backward_weight_kernel.hpp:326
BGridDescKN b_grid_desc_k_n
Definition: grouped_convolution_backward_weight_kernel.hpp:327
index_t GemmN
Definition: grouped_convolution_backward_weight_kernel.hpp:316
index_t GemmBatch
Definition: grouped_convolution_backward_weight_kernel.hpp:318
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > out_g_n_k_wos_lengths
Definition: grouped_convolution_backward_weight_kernel.hpp:307
CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs &args)
Definition: grouped_convolution_backward_weight_kernel.hpp:47
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_left_pads
Definition: grouped_convolution_backward_weight_kernel.hpp:311
remove_cvref_t< decltype(ABCGridDescs{}[number< 1 >{}])> BGridDescKN
Definition: grouped_convolution_backward_weight_kernel.hpp:301
std::array< const void *, NumDTensor > ds_ptr
Definition: grouped_convolution_backward_weight_kernel.hpp:323
index_t GemmM
Definition: grouped_convolution_backward_weight_kernel.hpp:315
index_t NumGroupsPerBatch
Definition: grouped_convolution_backward_weight_kernel.hpp:319
remove_cvref_t< decltype(ABCGridDescs{}[number< 2 >{}])> CGridDescMN
Definition: grouped_convolution_backward_weight_kernel.hpp:302
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_right_pads
Definition: grouped_convolution_backward_weight_kernel.hpp:312
index_t GemmK
Definition: grouped_convolution_backward_weight_kernel.hpp:317
const void * in_ptr
Definition: grouped_convolution_backward_weight_kernel.hpp:322
index_t k_batch
Definition: grouped_convolution_backward_weight_kernel.hpp:314
static constexpr index_t NonSpatialDims
Definition: grouped_convolution_backward_weight_kernel.hpp:304
const void * out_ptr
Definition: grouped_convolution_backward_weight_kernel.hpp:321
remove_cvref_t< decltype(ABCGridDescs{}[number< 0 >{}])> AGridDescKM
Definition: grouped_convolution_backward_weight_kernel.hpp:300
static constexpr index_t NumDTensor
Definition: grouped_convolution_backward_weight_kernel.hpp:37
long_index_t group_stride_c
Definition: grouped_convolution_backward_weight_kernel.hpp:332
The Grouped Conv kernel host arguments.
Definition: grouped_convolution_utils.hpp:27
InPtr in_ptr
Definition: grouped_convolution_utils.hpp:46
OutPtr out_ptr
Definition: grouped_convolution_utils.hpp:49
WeiPtr wei_ptr
Definition: grouped_convolution_utils.hpp:47
index_t k_batch
Definition: grouped_convolution_utils.hpp:50
const std::vector< const void * > ds_ptr
Definition: grouped_convolution_utils.hpp:48
Definition: grouped_convolution_backward_weight_kernel.hpp:509
index_t b_k_split_offset
Definition: grouped_convolution_backward_weight_kernel.hpp:531
index_t splitted_k
Definition: grouped_convolution_backward_weight_kernel.hpp:532
__device__ SplitKBatchOffset(const GroupedConvBwdWeightKernelArgsSpecialized &kargs, const std::size_t k_id=blockIdx.z)
Definition: grouped_convolution_backward_weight_kernel.hpp:510
index_t a_k_split_offset
Definition: grouped_convolution_backward_weight_kernel.hpp:530
The Grouped Convolution Backward Weight kernel template.
Definition: grouped_convolution_backward_weight_kernel.hpp:377
remove_cvref_t< typename EpiloguePipeline::DsLayout > GemmDsLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:393
static constexpr index_t kBlockSize
Definition: grouped_convolution_backward_weight_kernel.hpp:396
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views, const index_t k_batch)
Definition: grouped_convolution_backward_weight_kernel.hpp:759
static constexpr CK_TILE_HOST auto BlockSize()
Definition: grouped_convolution_backward_weight_kernel.hpp:469
remove_cvref_t< typename GroupedConvTraitsType_::OutLayout > OutLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:390
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_convolution_backward_weight_kernel.hpp:381
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: grouped_convolution_backward_weight_kernel.hpp:382
GroupedConvBwdWeightKernelArgs< GroupedConvTraitsType_ > GroupedConvBwdWeightKernelArgsSpecialized
Definition: grouped_convolution_backward_weight_kernel.hpp:404
static CK_TILE_HOST const std::string GetName()
Definition: grouped_convolution_backward_weight_kernel.hpp:422
static CK_TILE_DEVICE void RunGemm2LDS(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const GroupedConvBwdWeightKernelArgsSpecialized &kargs, const index_t num_loop, const index_t block_idx_m, const index_t block_idx_n, const index_t block_idx_k)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_backward_weight_kernel.hpp:912
remove_cvref_t< typename GemmPipeline::CLayout > GemmCLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:386
static constexpr auto I2
Definition: grouped_convolution_backward_weight_kernel.hpp:410
static constexpr CK_TILE_HOST GroupedConvBwdWeightKernelArgsSpecialized MakeKernelArgs(const GroupedConvBwdWeightHostArgs &hostArgs)
Definition: grouped_convolution_backward_weight_kernel.hpp:475
CK_TILE_DEVICE void operator()(GroupedConvBwdWeightKernelArgsSpecialized &kargs) const
Definition: grouped_convolution_backward_weight_kernel.hpp:972
remove_cvref_t< typename GemmPipeline::ALayout > GemmALayout
Definition: grouped_convolution_backward_weight_kernel.hpp:384
static CK_TILE_HOST bool IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_weight_kernel.hpp:536
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: grouped_convolution_backward_weight_kernel.hpp:383
CK_TILE_DEVICE void CallExplicitGemm(GroupedConvBwdWeightKernelArgsSpecialized &kargs) const
Definition: grouped_convolution_backward_weight_kernel.hpp:947
static constexpr ConvolutionSpecialization ConvSpecialization
Definition: grouped_convolution_backward_weight_kernel.hpp:379
remove_cvref_t< typename GroupedConvTraitsType_::WeiLayout > WeiLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:389
static constexpr bool IsSplitKSupported
Definition: grouped_convolution_backward_weight_kernel.hpp:406
static CK_TILE_DEVICE void RunGemm(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, void *smem_ptr_0, const GroupedConvBwdWeightKernelArgsSpecialized &kargs, const index_t num_loop, const index_t block_idx_m, const index_t block_idx_n, const index_t block_idx_k)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_backward_weight_kernel.hpp:862
static constexpr index_t NDimSpatial
Definition: grouped_convolution_backward_weight_kernel.hpp:378
remove_cvref_t< typename GroupedConvTraitsType_::DsLayout > DsLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:391
remove_cvref_t< typename GroupedConvTraitsType_::InLayout > InLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:388
remove_cvref_t< typename EpiloguePipeline::ODataType > WeiDataType
Definition: grouped_convolution_backward_weight_kernel.hpp:401
static constexpr auto I3
Definition: grouped_convolution_backward_weight_kernel.hpp:411
static constexpr auto I0
Definition: grouped_convolution_backward_weight_kernel.hpp:408
static constexpr auto I1
Definition: grouped_convolution_backward_weight_kernel.hpp:409
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: grouped_convolution_backward_weight_kernel.hpp:400
static constexpr index_t NumDTensor
Definition: grouped_convolution_backward_weight_kernel.hpp:394
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: grouped_convolution_backward_weight_kernel.hpp:503
remove_cvref_t< typename GemmPipeline::ADataType > OutDataType
Definition: grouped_convolution_backward_weight_kernel.hpp:398
static CK_TILE_DEVICE auto MakeGemmTensorViews(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, const GroupedConvBwdWeightKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_weight_kernel.hpp:718
remove_cvref_t< typename GemmPipeline::BDataType > InDataType
Definition: grouped_convolution_backward_weight_kernel.hpp:399
remove_cvref_t< typename GemmPipeline::BLayout > GemmBLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:385
static constexpr CK_TILE_HOST auto GridSize(const GroupedConvBwdWeightKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_weight_kernel.hpp:463
static CK_TILE_HOST const std::string GetTypeString()
Definition: grouped_convolution_backward_weight_kernel.hpp:447
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n, const index_t i_k)
Create views to the data that each workgroup will process.
Definition: grouped_convolution_backward_weight_kernel.hpp:809
Definition: transform_conv_bwd_weight_to_gemm.hpp:21
CK_TILE_HOST auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N() const
Definition: transform_conv_bwd_weight_to_gemm.hpp:818
Definition: integral_constant.hpp:13
std::vector< ck_tile::long_index_t > input_spatial_lengths_
Definition: convolution_parameter.hpp:130
ck_tile::long_index_t K_
Definition: convolution_parameter.hpp:126
std::vector< ck_tile::long_index_t > output_spatial_lengths_
Definition: convolution_parameter.hpp:131
std::vector< ck_tile::long_index_t > input_right_pads_
Definition: convolution_parameter.hpp:137
ck_tile::long_index_t G_
Definition: convolution_parameter.hpp:124
std::vector< ck_tile::long_index_t > conv_filter_strides_
Definition: convolution_parameter.hpp:133
std::vector< ck_tile::long_index_t > filter_spatial_lengths_
Definition: convolution_parameter.hpp:129
ck_tile::long_index_t C_
Definition: convolution_parameter.hpp:127
ck_tile::long_index_t N_
Definition: convolution_parameter.hpp:125
std::vector< ck_tile::long_index_t > input_left_pads_
Definition: convolution_parameter.hpp:136
std::vector< ck_tile::long_index_t > conv_filter_dilations_
Definition: convolution_parameter.hpp:134
Definition: type_traits.hpp:115
Definition: sequence.hpp:49
#define CK_TILE_ENV(name)
Definition: env.hpp:145