/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp Source File
grouped_convolution_backward_data_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 #include "ck_tile/host/concat.hpp"
16 
17 #ifdef CK_EXPERIMENTAL_BUILDER
18 #include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_data.hpp"
19 #endif
20 
21 namespace ck_tile {
22 
24 template <typename GroupedConvTraitsType_, typename TilePartitioner_>
26 {
28 
30  TransformConvBwdDataToGemm<GroupedConvTraitsType_::NDimSpatial,
31  GroupedConvTraitsType_::ConvSpecialization,
32  GroupedConvTraitsType_::VectorSizeA,
33  GroupedConvTraitsType_::VectorSizeB,
34  GroupedConvTraitsType_::VectorSizeC,
35  true>; // Split N enabled
36  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
37 
38  static constexpr auto I0 = number<0>();
39  static constexpr auto I1 = number<1>();
40 
41  template <
42  typename InLay = typename GroupedConvTraitsType_::InLayout,
43  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
44  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
45  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NWGC> &&
46  std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
47  std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
48  bool>::type = false>
50  {
51  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
52  static_cast<index_t>(args.N_),
53  static_cast<index_t>(args.C_),
54  static_cast<index_t>(args.input_spatial_lengths_[0])};
55  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
56  static_cast<index_t>(args.K_),
57  static_cast<index_t>(args.C_),
58  static_cast<index_t>(args.filter_spatial_lengths_[0])};
59  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
60  static_cast<index_t>(args.N_),
61  static_cast<index_t>(args.K_),
62  static_cast<index_t>(args.output_spatial_lengths_[0])};
63 
64  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0])};
65  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0])};
66  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0])};
67  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0])};
68 
69  k_batch = args.k_batch;
70 
71  in_ptr = args.in_ptr;
72  wei_ptr = args.wei_ptr;
73  for(index_t d = 0; d < NumDTensor; d++)
74  {
75  ds_ptr[d] = args.ds_ptr[d];
76  }
77  out_ptr = args.out_ptr;
78 
79  const index_t X = wei_g_k_c_xs_lengths[3];
80  const index_t ConvStrideW = conv_filter_strides[0];
81  const index_t ConvDilationW = conv_filter_dilations[0];
82  const auto GcdStrideDilationW = gcd(ConvStrideW, ConvDilationW);
83  const auto XTilde = ConvStrideW / GcdStrideDilationW;
84 
85  for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
86  {
87  const auto XDotSlice = integer_divide_ceil(X - i_xtilde, XTilde);
88 
89  if(XDotSlice <= 0)
90  {
91  continue;
92  }
93 
95  {
96  gemm_count++;
97  // Avoid array segfault
98  continue;
99  }
100 
101  tildes = {i_xtilde};
102 
103  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
110  tildes};
111 
112  auto grid_descs =
113  conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
114  GroupedConvTraitsType_::NDimSpatial>(1);
115 
116  a_grid_descs_m_k[gemm_count] = grid_descs.at(number<0>{});
117  b_grid_descs_n_k[gemm_count] = grid_descs.at(number<1>{});
118  c_grid_descs_m_n[gemm_count] = grid_descs.at(number<2>{});
119 
120  const index_t grid_size_grp =
121  TilePartitioner::GridSize(c_grid_descs_m_n[gemm_count].get_length(I0),
122  c_grid_descs_m_n[gemm_count].get_length(I1));
123 
125  block_ends[gemm_count] = grid_size_ + grid_size_grp;
126 
127  grid_size_ += grid_size_grp;
128 
129  // Get the actual split N from transformer
130  n_per_split = conv_to_gemm_transformer.GetN();
131  original_n = conv_to_gemm_transformer.GetOriginalN();
133 
134  ++gemm_count;
135  }
136  group_stride_a = args.K_; // A: Out NWGK
137  group_stride_b = args.K_ * args.C_ *
138  std::accumulate(args.filter_spatial_lengths_.begin(),
139  args.filter_spatial_lengths_.end(),
140  1,
141  std::multiplies<index_t>()); // B: Wei GKXC
142  group_stride_c = args.C_; // C: In NWGC
143 
144  input_batch_stride = args.C_ * args.G_ * args.input_spatial_lengths_[0];
145  output_batch_stride = args.K_ * args.G_ * args.output_spatial_lengths_[0];
146 
147  GemmBatch = args.G_;
148  }
149 
150  template <
151  typename InLay = typename GroupedConvTraitsType_::InLayout,
152  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
153  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
154  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NHWGC> &&
155  std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
156  std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
157  bool>::type = false>
159  {
160  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
161  static_cast<index_t>(args.N_),
162  static_cast<index_t>(args.C_),
163  static_cast<index_t>(args.input_spatial_lengths_[0]),
164  static_cast<index_t>(args.input_spatial_lengths_[1])};
165  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
166  static_cast<index_t>(args.K_),
167  static_cast<index_t>(args.C_),
168  static_cast<index_t>(args.filter_spatial_lengths_[0]),
169  static_cast<index_t>(args.filter_spatial_lengths_[1])};
170  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
171  static_cast<index_t>(args.N_),
172  static_cast<index_t>(args.K_),
173  static_cast<index_t>(args.output_spatial_lengths_[0]),
174  static_cast<index_t>(args.output_spatial_lengths_[1])};
175 
176  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
177  static_cast<index_t>(args.conv_filter_strides_[1])};
178  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
179  static_cast<index_t>(args.conv_filter_dilations_[1])};
180  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
181  static_cast<index_t>(args.input_left_pads_[1])};
182  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
183  static_cast<index_t>(args.input_right_pads_[1])};
184 
185  k_batch = args.k_batch;
186 
187  in_ptr = args.in_ptr;
188  wei_ptr = args.wei_ptr;
189  for(index_t d = 0; d < NumDTensor; d++)
190  {
191  ds_ptr[d] = args.ds_ptr[d];
192  }
193  out_ptr = args.out_ptr;
194 
195  const index_t Y = wei_g_k_c_xs_lengths[3];
196  const index_t X = wei_g_k_c_xs_lengths[4];
197  const index_t ConvStrideH = conv_filter_strides[0];
198  const index_t ConvStrideW = conv_filter_strides[1];
199  const index_t ConvDilationH = conv_filter_dilations[0];
200  const index_t ConvDilationW = conv_filter_dilations[1];
201  const auto GcdStrideDilationH = gcd(ConvStrideH, ConvDilationH);
202  const auto GcdStrideDilationW = gcd(ConvStrideW, ConvDilationW);
203  const auto YTilde = ConvStrideH / GcdStrideDilationH;
204  const auto XTilde = ConvStrideW / GcdStrideDilationW;
205 
206  for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
207  {
208  for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
209  {
210  const auto YDotSlice = integer_divide_ceil(Y - i_ytilde, YTilde);
211  const auto XDotSlice = integer_divide_ceil(X - i_xtilde, XTilde);
212 
213  if(XDotSlice * YDotSlice <= 0)
214  {
215  continue;
216  }
217 
219  {
220  gemm_count++;
221  // Avoid array segfault
222  continue;
223  }
224 
225  tildes = {i_ytilde, i_xtilde};
226 
227  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
234  tildes};
235 
236  auto grid_descs = conv_to_gemm_transformer
237  .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
238  GroupedConvTraitsType_::NDimSpatial>(1);
239 
240  a_grid_descs_m_k[gemm_count] = grid_descs.at(number<0>{});
241  b_grid_descs_n_k[gemm_count] = grid_descs.at(number<1>{});
242  c_grid_descs_m_n[gemm_count] = grid_descs.at(number<2>{});
243 
244  const index_t grid_size_grp =
245  TilePartitioner::GridSize(c_grid_descs_m_n[gemm_count].get_length(I0),
246  c_grid_descs_m_n[gemm_count].get_length(I1));
247 
249  block_ends[gemm_count] = grid_size_ + grid_size_grp;
250 
251  grid_size_ += grid_size_grp;
252 
253  // Get the actual split N from transformer
254  n_per_split = conv_to_gemm_transformer.GetN();
255  original_n = conv_to_gemm_transformer.GetOriginalN();
257 
258  ++gemm_count;
259  }
260  }
261  group_stride_a = args.K_; // A: Out NWGK
262  group_stride_b = args.K_ * args.C_ *
263  std::accumulate(args.filter_spatial_lengths_.begin(),
264  args.filter_spatial_lengths_.end(),
265  1,
266  std::multiplies<index_t>()); // B: Wei GKXC
267  group_stride_c = args.C_; // C: In NWGC
268 
270  args.C_ * args.G_ * args.input_spatial_lengths_[0] * args.input_spatial_lengths_[1];
272  args.K_ * args.G_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
273 
274  GemmBatch = args.G_;
275  }
276 
277  template <
278  typename InLay = typename GroupedConvTraitsType_::InLayout,
279  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
280  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
281  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NDHWGC> &&
282  std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
283  std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
284  bool>::type = false>
286  {
287  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
288  static_cast<index_t>(args.N_),
289  static_cast<index_t>(args.C_),
290  static_cast<index_t>(args.input_spatial_lengths_[0]),
291  static_cast<index_t>(args.input_spatial_lengths_[1]),
292  static_cast<index_t>(args.input_spatial_lengths_[2])};
293  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
294  static_cast<index_t>(args.K_),
295  static_cast<index_t>(args.C_),
296  static_cast<index_t>(args.filter_spatial_lengths_[0]),
297  static_cast<index_t>(args.filter_spatial_lengths_[1]),
298  static_cast<index_t>(args.filter_spatial_lengths_[2])};
299  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
300  static_cast<index_t>(args.N_),
301  static_cast<index_t>(args.K_),
302  static_cast<index_t>(args.output_spatial_lengths_[0]),
303  static_cast<index_t>(args.output_spatial_lengths_[1]),
304  static_cast<index_t>(args.output_spatial_lengths_[2])};
305 
306  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
307  static_cast<index_t>(args.conv_filter_strides_[1]),
308  static_cast<index_t>(args.conv_filter_strides_[2])};
309  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
310  static_cast<index_t>(args.conv_filter_dilations_[1]),
311  static_cast<index_t>(args.conv_filter_dilations_[2])};
312  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
313  static_cast<index_t>(args.input_left_pads_[1]),
314  static_cast<index_t>(args.input_left_pads_[2])};
315  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
316  static_cast<index_t>(args.input_right_pads_[1]),
317  static_cast<index_t>(args.input_right_pads_[2])};
318 
319  k_batch = args.k_batch;
320 
321  in_ptr = args.in_ptr;
322  wei_ptr = args.wei_ptr;
323  for(index_t d = 0; d < NumDTensor; d++)
324  {
325  ds_ptr[d] = args.ds_ptr[d];
326  }
327  out_ptr = args.out_ptr;
328 
329  const index_t Z = wei_g_k_c_xs_lengths[3];
330  const index_t Y = wei_g_k_c_xs_lengths[4];
331  const index_t X = wei_g_k_c_xs_lengths[5];
332  const index_t ConvStrideD = conv_filter_strides[0];
333  const index_t ConvStrideH = conv_filter_strides[1];
334  const index_t ConvStrideW = conv_filter_strides[2];
335  const index_t ConvDilationD = conv_filter_dilations[0];
336  const index_t ConvDilationH = conv_filter_dilations[1];
337  const index_t ConvDilationW = conv_filter_dilations[2];
338  const auto GcdStrideDilationD = gcd(ConvStrideD, ConvDilationD);
339  const auto GcdStrideDilationH = gcd(ConvStrideH, ConvDilationH);
340  const auto GcdStrideDilationW = gcd(ConvStrideW, ConvDilationW);
341  const auto ZTilde = ConvStrideD / GcdStrideDilationD;
342  const auto YTilde = ConvStrideH / GcdStrideDilationH;
343  const auto XTilde = ConvStrideW / GcdStrideDilationW;
344 
345  for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde)
346  {
347  for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
348  {
349  for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
350  {
351  const auto ZDotSlice = integer_divide_ceil(Z - i_ztilde, ZTilde);
352  const auto YDotSlice = integer_divide_ceil(Y - i_ytilde, YTilde);
353  const auto XDotSlice = integer_divide_ceil(X - i_xtilde, XTilde);
354 
355  if(ZDotSlice * XDotSlice * YDotSlice <= 0)
356  {
357  continue;
358  }
359 
361  {
362  gemm_count++;
363  // Avoid array segfault
364  continue;
365  }
366 
367  tildes = {i_ztilde, i_ytilde, i_xtilde};
368 
369  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
376  tildes};
377 
378  auto grid_descs = conv_to_gemm_transformer
379  .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
380  GroupedConvTraitsType_::NDimSpatial>(1);
381 
382  a_grid_descs_m_k[gemm_count] = grid_descs.at(number<0>{});
383  b_grid_descs_n_k[gemm_count] = grid_descs.at(number<1>{});
384  c_grid_descs_m_n[gemm_count] = grid_descs.at(number<2>{});
385 
386  const index_t grid_size_grp =
387  TilePartitioner::GridSize(c_grid_descs_m_n[gemm_count].get_length(I0),
388  c_grid_descs_m_n[gemm_count].get_length(I1));
389 
391  block_ends[gemm_count] = grid_size_ + grid_size_grp;
392 
393  grid_size_ += grid_size_grp;
394 
395  // Get the actual split N from transformer
396  n_per_split = conv_to_gemm_transformer.GetN();
397  original_n = conv_to_gemm_transformer.GetOriginalN();
399 
400  ++gemm_count;
401  }
402  }
403  }
404 
405  group_stride_a = args.K_; // A: Out NWGK
406  group_stride_b = args.K_ * args.C_ *
407  std::accumulate(args.filter_spatial_lengths_.begin(),
408  args.filter_spatial_lengths_.end(),
409  1,
410  std::multiplies<index_t>()); // B: Wei GKXC
411  group_stride_c = args.C_; // C: In NWGC
412 
413  input_batch_stride = args.C_ * args.G_ * args.input_spatial_lengths_[0] *
415  output_batch_stride = args.K_ * args.G_ * args.output_spatial_lengths_[0] *
417 
418  GemmBatch = args.G_; // C: In NWGC
419  }
420 
421  static constexpr index_t MaxGroupedGemmGroupsNum = 128;
422 
425 
429 
430  static constexpr index_t NonSpatialDims = 3;
434 
440 
445 
446  const void* out_ptr;
447  void* in_ptr;
448  std::array<const void*, NumDTensor> ds_ptr;
449  const void* wei_ptr;
450 
454 
457 
461 
462  // Split-N support fields - initialize to safe defaults
463  index_t n_splits = 1; // Number of batch splits (e.g., 2 for 128→64×2)
464  index_t n_per_split = 1; // Batches per split (N_ from transformer)
465  index_t original_n = 1; // Original batch size before splitting
466  index_t input_batch_stride = 0; // Stride to next batch in input tensor
467  index_t output_batch_stride = 0; // Stride to next batch in output tensor
468 };
469 
508 template <typename GroupedConvTraitsType_,
509  typename TilePartitioner_,
510  typename GemmPipeline_,
511  typename EpiloguePipeline_>
513 {
514  static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial;
516  GroupedConvTraitsType_::ConvSpecialization;
523 
528 
530  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
531 
532  static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
533 
537 
539 
542  static constexpr index_t MaxGroupedGemmGroupsNum =
544 
545  // TODO: Enable this
546  static constexpr bool IsSplitKSupported = false;
547 
548  static constexpr auto I0 = number<0>();
549  static constexpr auto I1 = number<1>();
550  static constexpr auto I2 = number<2>();
551  static constexpr auto I3 = number<3>();
552 
553  static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
554  "Not supported!");
555  static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>, "Not supported!");
556  static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
557  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
558  "Not supported C GEMM layout!");
559  static_assert(GroupedConvTraitsType_::ExplicitGemm == false, "Not supported yet!");
560 
561  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
562  {
563  static constexpr bool EnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
564  constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
565  // clang-format off
566  return concat('_', "grouped_convolution_backward_data",
567  gemm_prec_str<InDataType, WeiDataType>(),
568  InLayout::name,
569  WeiLayout::name,
570  OutLayout::name,
571  "gemm",
572  GemmPipeline::GetName(),
573  "epilogue",
574  EpiloguePipeline::GetName(),
576  "MergedGroups",
577  NumGroupsToMerge,
578  "SplitImage",
579  EnableSplitImage,
580  "ExplicitGemm",
581  GroupedConvTraitsType_::ExplicitGemm
582  );
583  // clang-format on
584  }
585 
586  [[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); }
587 
588 #ifdef CK_EXPERIMENTAL_BUILDER
589  CK_TILE_HOST std::string GetInstanceString() const
590  {
591  static_assert(ck_tile::reflect::HasInstanceTraits<GroupedConvolutionBackwardDataKernel>,
592  "Specialization of instance_traits not found. Please check that a "
593  "specialization exists in file "
594  "ck_tile/builder/reflect/"
595  "instance_traits_tile_grouped_convolution_backward_data.hpp "
596  "for the given template parameters.");
597  return ck_tile::reflect::instance_string<GroupedConvolutionBackwardDataKernel>();
598  }
599 #endif
600 
602  {
603  // enable batched grouped gemm
604  return dim3(kargs.grid_size_, kargs.GemmBatch, kargs.n_splits * kargs.k_batch);
605  }
606 
607  CK_TILE_HOST static constexpr auto BlockSize()
608  {
609  return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
610  }
611 
614  {
616  }
617 
619  {
620  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
621  }
622 
623  CK_TILE_HOST static bool
625  {
626  if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
629  {
630  if(kargs.k_batch != 1)
631  {
632  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
633  {
634  CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
635  }
636  return false;
637  }
638  }
639 
641  {
642  return false;
643  }
644 
645  const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
646  const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
647 
648  // check ConvSpecialization
650  {
651  // check if it's 1x1, stride=1 conv
652  for(index_t i = 0; i < NDimSpatial; ++i)
653  {
654  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
655  const index_t ConvStride = kargs.conv_filter_strides[i];
656  const index_t LeftPad = kargs.input_left_pads[i];
657  const index_t RightPad = kargs.input_right_pads[i];
658 
659  if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
660  {
661  return false;
662  }
663  }
664  }
666  {
667  // check if it's 1x1 conv
668  for(index_t i = 0; i < NDimSpatial; ++i)
669  {
670  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
671  const index_t LeftPad = kargs.input_left_pads[i];
672  const index_t RightPad = kargs.input_right_pads[i];
673 
674  if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
675  {
676  return false;
677  }
678  }
679  }
681  {
682  if(ConvC != 1)
683  {
684  return false;
685  }
686  for(index_t i = 0; i < NDimSpatial; ++i)
687  {
688  const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3];
689 
690  if(filter_spatial_dim != I3)
691  {
692  return false;
693  }
694  }
695  }
696 
697  namespace ctc = tensor_layout::convolution;
698 
699  if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
700  std::is_same_v<InLayout, ctc::NDHWGC>)
701  {
702  // Check access per C
703  if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
704  {
705  CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
706  return false;
707  }
708  }
709  else
710  {
711  CK_TILE_ERROR("Not supported input layout!");
712  return false;
713  }
714 
715  // FIXME: layout
716  if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
717  std::is_same_v<WeiLayout, ctc::GKYXC> ||
718  std::is_same_v<WeiLayout, ctc::GKZYXC>)
719  {
720  if(ConvC % GroupedConvTraitsType_::VectorSizeC != 0)
721  {
722  CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
723  return false;
724  }
725  }
726  else
727  {
728  CK_TILE_ERROR("Not supported weight layout!");
729  return false;
730  }
731 
732  if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
733  std::is_same_v<OutLayout, ctc::NHWGK> ||
734  std::is_same_v<OutLayout, ctc::NDHWGK>)
735  {
736  if(ConvK % GroupedConvTraitsType_::VectorSizeA != 0)
737  {
738  CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!");
739  return false;
740  }
741  }
742  else
743  {
744  CK_TILE_ERROR("Not supported output layout!");
745  return false;
746  }
747 
748  return true;
749  }
750 
751  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
752  CK_TILE_DEVICE static auto
754  const InDataType* b_ptr,
755  const std::array<const void*, NumDTensor>& ds_ptr,
756  WeiDataType* c_ptr,
758  const index_t group_id)
759  {
760  static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
761  static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!");
762  const auto& a_tensor_view = [&]() {
763  return make_tensor_view<address_space_enum::global>(
764  a_ptr,
765  kargs.a_grid_descs_m_k[group_id]); // A: out
766  }();
767 
768  const auto& b_tensor_view = [&]() {
769  return make_tensor_view<address_space_enum::global>(
770  b_ptr,
771  kargs.b_grid_descs_n_k[group_id]); // B: weight
772  }();
773 
774  const auto& c_tensor_view = [&]() {
775  return make_tensor_view<address_space_enum::global>(c_ptr,
776  kargs.c_grid_descs_m_n[group_id]);
777  }();
778 
779  const auto& ds_tensor_view = generate_tuple(
780  [&](auto i) {
781  static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
782  "Not supported!");
783  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
784  "Not supported!");
785  static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, OutDataType>,
786  "Not supported!");
787 
788  return make_tensor_view<address_space_enum::global>(
789  static_cast<OutDataType*>(ds_ptr[i]), kargs.c_grid_descs_m_n[group_id]);
790  },
792 
793  return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
794  }
795 
796  template <typename TensorView>
797  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
798  {
799  const auto& a_pad_view = [&]() {
800  const auto& a_tensor_view = views.at(I0);
801  return pad_tensor_view(a_tensor_view,
805  }();
806 
807  const auto& b_pad_view = [&]() {
808  const auto& b_tensor_view = views.at(I1);
809  return pad_tensor_view(b_tensor_view,
813  }();
814 
815  const auto& ds_tensor_view = views.at(I2);
816  const auto& ds_pad_view = generate_tuple(
817  [&](auto i) {
818  return pad_tensor_view(ds_tensor_view[i],
822  },
824 
825  const auto& c_pad_view = [&]() {
826  const auto& c_tensor_view = views.at(I3);
827  return pad_tensor_view(c_tensor_view,
831  }();
832 
833  return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
834  }
835 
836  template <typename PadView>
837  CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
838  const index_t i_m,
839  const index_t i_n,
840  const index_t i_k = 0)
841  {
842  const auto& a_pad_view = views.at(I0);
843  const auto& b_pad_view = views.at(I1);
844  const auto& ds_pad_view = views.at(I2);
845  const auto& c_pad_view = views.at(I3);
846 
847  const auto& a_block_window = [&]() {
848  return make_tile_window(a_pad_view,
851  {i_m, i_k});
852  }();
853 
854  const auto& b_block_window = [&]() {
855  return make_tile_window(b_pad_view,
858  {i_k, i_n});
859  }();
860 
861  const auto ds_block_window = generate_tuple(
862  [&](auto i) {
863  return make_tile_window(ds_pad_view[i],
866  {i_m, i_n});
867  },
869 
870  auto c_block_window = make_tile_window(
871  c_pad_view,
873  {i_m, i_n});
874 
875  return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
876  }
877 
890  CK_TILE_DEVICE static void RunGemm(const OutDataType* a_ptr,
891  const InDataType* b_ptr,
892  const std::array<const void*, NumDTensor>& ds_ptr,
893  WeiDataType* c_ptr,
894  void* smem_ptr_0,
896  const index_t block_idx_m,
897  const index_t block_idx_n,
898  const index_t group_id)
899  {
900  // Create Gemm tensor views, pad views and tile windows
901  const auto& gemm_tensor_views_tuple =
902  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
903  a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id);
904 
905  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
906  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
907 
908  const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(
909  gemm_pad_views.at(I0).get_tensor_descriptor().get_length(I1)));
910 
911  // Run GEMM cooperatively by whole workgroup.
912  const auto& a_block_window = gemm_tile_windows.at(I0);
913  const auto& b_block_window = gemm_tile_windows.at(I1);
914  const auto& d_block_window = gemm_tile_windows.at(I2);
915 
916  const auto& c_block_tile = GemmPipeline{}.template operator()(
917  a_block_window, b_block_window, num_loop, smem_ptr_0);
918 
919  // Run Epilogue Pipeline
920  auto& c_block_window = gemm_tile_windows.at(I3);
921 
922  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
923  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
924  }
925 
941  CK_TILE_DEVICE static void RunGemm2LDS(const OutDataType* a_ptr,
942  const InDataType* b_ptr,
943  const std::array<const void*, NumDTensor>& ds_ptr,
944  WeiDataType* c_ptr,
945  void* __restrict__ smem_ptr_0,
946  void* __restrict__ smem_ptr_1,
948  const index_t block_idx_m,
949  const index_t block_idx_n,
950  const index_t group_id)
951  {
952  // Create Gemm tensor views, pad views and tile windows
953  const auto& gemm_tensor_views_tuple =
954  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
955  a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id);
956  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
957  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
958 
959  const index_t num_loop = amd_wave_read_first_lane(
960  TilePartitioner::GetLoopNum(gemm_tile_windows.at(I0).get_length(I1)));
961 
962  // Run GEMM cooperatively by whole workgroup.
963  const auto& a_block_window = gemm_tile_windows.at(I0);
964  const auto& b_block_window = gemm_tile_windows.at(I1);
965  const auto& d_block_window = gemm_tile_windows.at(I2);
966 
967  const auto& c_block_tile = GemmPipeline{}.template operator()(
968  a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
969 
970  // Run Epilogue Pipeline
971  auto& c_block_window = gemm_tile_windows.at(I3);
972 
973  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
974  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
975  }
976 
978  index_t block_id) const
979  {
980  index_t left = 0;
981  index_t right = kargs.gemm_count;
982  index_t group_id = index_t((left + right) >> 1);
983 
984  while((!(block_id >= kargs.block_starts[group_id] &&
985  block_id < kargs.block_ends[group_id])) &&
986  left <= right)
987  {
988  if(block_id < kargs.block_starts[group_id])
989  {
990  right = group_id;
991  }
992  else
993  {
994  left = group_id;
995  }
996  group_id = index_t((left + right) >> 1);
997  }
998 
999  return group_id;
1000  }
1001 
1003  {
1004  const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
1005  const index_t group_id = FindGroupId(kargs, blockIdX);
1006 
1008  kargs.block_starts[group_id],
1009  kargs.c_grid_descs_m_n[group_id].get_length(I0),
1010  kargs.c_grid_descs_m_n[group_id].get_length(I1));
1011 
1012  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
1013  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
1014 
1015  const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
1016  const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
1017  const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
1018  const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
1019 
1020  const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z);
1021 
1022  // SplitN
1023  const index_t split_n_idx = __builtin_amdgcn_readfirstlane(blockIdZ / kargs.k_batch);
1024  const index_t split_n_offset =
1025  __builtin_amdgcn_readfirstlane(split_n_idx * kargs.n_per_split);
1026 
1027  const long_index_t output_batch_offset =
1028  static_cast<long_index_t>(split_n_offset) *
1029  static_cast<long_index_t>(kargs.output_batch_stride);
1030  const long_index_t input_batch_offset = static_cast<long_index_t>(split_n_offset) *
1031  static_cast<long_index_t>(kargs.input_batch_stride);
1032 
1033  // SplitK
1034  // TODO: Implement SplitK support
1035  // const index_t split_k_idx =
1036  // __builtin_amdgcn_readfirstlane(blockIdZ - split_n_idx * kargs.k_batch);
1037 
1038  // options
1039  // conv_bwd_data = Out * Weight = In
1040  const OutDataType* a_ptr =
1041  static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a + output_batch_offset;
1042  const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) + group_offset_b;
1043  InDataType* c_ptr =
1044  static_cast<InDataType*>(kargs.in_ptr) + group_offset_c + input_batch_offset;
1045 
1046  // allocate LDS
1047  __shared__ char smem_ptr_0[GetSmemSize()];
1048 
1049  if constexpr(GemmPipeline::DoubleSmemBuffer == true)
1050  {
1051  __shared__ char smem_ptr_1[GetSmemSize()];
1052  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
1053  GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
1055  {
1056  RunGemm2LDS(a_ptr,
1057  b_ptr,
1058  kargs.ds_ptr,
1059  c_ptr,
1060  smem_ptr_0,
1061  smem_ptr_1,
1062  kargs,
1063  i_m,
1064  i_n,
1065  group_id);
1066  }
1067  }
1068  else
1069  {
1070  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
1071  GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
1073  {
1074  RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n, group_id);
1075  }
1076  }
1077  }
1078 };
1079 
1080 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:49
#define CK_TILE_HOST
Definition: config.hpp:48
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:50
Definition: cluster_descriptor.hpp:13
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:145
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
ConvolutionSpecialization
Definition: convolution_specialization.hpp:11
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE index_t gcd(index_t x, index_t y)
Definition: math.hpp:264
int64_t long_index_t
Definition: integer.hpp:11
CK_TILE_HOST std::string getConvSpecializationString(const ConvolutionSpecialization &s)
Definition: convolution_specialization.hpp:18
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:157
constexpr bool is_same_v
Definition: type.hpp:283
__device__ X atomic_add(X *p_dst, const X &x)
The Grouped Convolution kernel device arguments.
Definition: grouped_convolution_backward_data_kernel.hpp:26
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > wei_g_k_c_xs_lengths
Definition: grouped_convolution_backward_data_kernel.hpp:432
static constexpr auto I1
Definition: grouped_convolution_backward_data_kernel.hpp:39
CK_TILE_HOST GroupedConvBwdDataKernelArgs(const GroupedConvBwdDataHostArgs &args)
Definition: grouped_convolution_backward_data_kernel.hpp:49
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_dilations
Definition: grouped_convolution_backward_data_kernel.hpp:436
std::array< const void *, NumDTensor > ds_ptr
Definition: grouped_convolution_backward_data_kernel.hpp:448
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_strides
Definition: grouped_convolution_backward_data_kernel.hpp:435
array< index_t, MaxGroupedGemmGroupsNum > block_starts
Definition: grouped_convolution_backward_data_kernel.hpp:455
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_left_pads
Definition: grouped_convolution_backward_data_kernel.hpp:437
long_index_t group_stride_b
Definition: grouped_convolution_backward_data_kernel.hpp:459
long_index_t group_stride_c
Definition: grouped_convolution_backward_data_kernel.hpp:460
array< index_t, MaxGroupedGemmGroupsNum > block_ends
Definition: grouped_convolution_backward_data_kernel.hpp:456
const void * out_ptr
Definition: grouped_convolution_backward_data_kernel.hpp:446
remove_cvref_t< decltype(ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(1))> ABCGridDescs
Definition: grouped_convolution_backward_data_kernel.hpp:424
remove_cvref_t< decltype(ABCGridDescs{}[number< 1 >{}])> BGridDescNK
Definition: grouped_convolution_backward_data_kernel.hpp:427
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_convolution_backward_data_kernel.hpp:27
array< index_t, GroupedConvTraitsType_::NDimSpatial > tildes
Definition: grouped_convolution_backward_data_kernel.hpp:439
remove_cvref_t< decltype(ABCGridDescs{}[number< 0 >{}])> AGridDescMK
Definition: grouped_convolution_backward_data_kernel.hpp:426
const void * wei_ptr
Definition: grouped_convolution_backward_data_kernel.hpp:449
index_t n_per_split
Definition: grouped_convolution_backward_data_kernel.hpp:464
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > out_g_n_k_wos_lengths
Definition: grouped_convolution_backward_data_kernel.hpp:433
long_index_t group_stride_a
Definition: grouped_convolution_backward_data_kernel.hpp:458
index_t GemmBatch
Definition: grouped_convolution_backward_data_kernel.hpp:442
void * in_ptr
Definition: grouped_convolution_backward_data_kernel.hpp:447
index_t n_splits
Definition: grouped_convolution_backward_data_kernel.hpp:463
index_t gemm_count
Definition: grouped_convolution_backward_data_kernel.hpp:444
array< CGridDescMN, MaxGroupedGemmGroupsNum > c_grid_descs_m_n
Definition: grouped_convolution_backward_data_kernel.hpp:453
index_t original_n
Definition: grouped_convolution_backward_data_kernel.hpp:465
index_t grid_size_
Definition: grouped_convolution_backward_data_kernel.hpp:443
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_right_pads
Definition: grouped_convolution_backward_data_kernel.hpp:438
array< BGridDescNK, MaxGroupedGemmGroupsNum > b_grid_descs_n_k
Definition: grouped_convolution_backward_data_kernel.hpp:452
index_t k_batch
Definition: grouped_convolution_backward_data_kernel.hpp:441
static constexpr auto I0
Definition: grouped_convolution_backward_data_kernel.hpp:38
static constexpr index_t MaxGroupedGemmGroupsNum
Definition: grouped_convolution_backward_data_kernel.hpp:421
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > in_g_n_c_wis_lengths
Definition: grouped_convolution_backward_data_kernel.hpp:431
static constexpr index_t NumDTensor
Definition: grouped_convolution_backward_data_kernel.hpp:36
index_t output_batch_stride
Definition: grouped_convolution_backward_data_kernel.hpp:467
index_t input_batch_stride
Definition: grouped_convolution_backward_data_kernel.hpp:466
array< AGridDescMK, MaxGroupedGemmGroupsNum > a_grid_descs_m_k
Definition: grouped_convolution_backward_data_kernel.hpp:451
remove_cvref_t< decltype(ABCGridDescs{}[number< 2 >{}])> CGridDescMN
Definition: grouped_convolution_backward_data_kernel.hpp:428
static constexpr index_t NonSpatialDims
Definition: grouped_convolution_backward_data_kernel.hpp:430
The Grouped Conv kernel host arguments.
Definition: grouped_convolution_utils.hpp:27
InPtr in_ptr
Definition: grouped_convolution_utils.hpp:46
OutPtr out_ptr
Definition: grouped_convolution_utils.hpp:49
WeiPtr wei_ptr
Definition: grouped_convolution_utils.hpp:47
index_t k_batch
Definition: grouped_convolution_utils.hpp:50
const std::vector< const void * > ds_ptr
Definition: grouped_convolution_utils.hpp:48
The Grouped Convolution Backward Data kernel template.
Definition: grouped_convolution_backward_data_kernel.hpp:513
static constexpr index_t NDimSpatial
Definition: grouped_convolution_backward_data_kernel.hpp:514
static constexpr CK_TILE_HOST auto BlockSize()
Definition: grouped_convolution_backward_data_kernel.hpp:607
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: grouped_convolution_backward_data_kernel.hpp:518
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n, const index_t i_k=0)
Definition: grouped_convolution_backward_data_kernel.hpp:837
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: grouped_convolution_backward_data_kernel.hpp:797
GroupedConvBwdDataKernelArgs< GroupedConvTraitsType_, TilePartitioner > GroupedConvBwdDataKernelArgsSpecialized
Definition: grouped_convolution_backward_data_kernel.hpp:541
remove_cvref_t< typename GemmPipeline::ADataType > InDataType
Definition: grouped_convolution_backward_data_kernel.hpp:534
static constexpr index_t MaxGroupedGemmGroupsNum
Definition: grouped_convolution_backward_data_kernel.hpp:542
static constexpr auto I1
Definition: grouped_convolution_backward_data_kernel.hpp:549
CK_TILE_DEVICE void operator()(GroupedConvBwdDataKernelArgsSpecialized &kargs) const
Definition: grouped_convolution_backward_data_kernel.hpp:1002
static constexpr auto I3
Definition: grouped_convolution_backward_data_kernel.hpp:551
remove_cvref_t< typename GroupedConvTraitsType_::OutLayout > OutLayout
Definition: grouped_convolution_backward_data_kernel.hpp:526
static constexpr ConvolutionSpecialization ConvSpecialization
Definition: grouped_convolution_backward_data_kernel.hpp:515
static constexpr CK_TILE_HOST GroupedConvBwdDataKernelArgsSpecialized MakeKernelArgs(const GroupedConvBwdDataHostArgs &hostArgs)
Definition: grouped_convolution_backward_data_kernel.hpp:613
static constexpr index_t NumDTensor
Definition: grouped_convolution_backward_data_kernel.hpp:530
remove_cvref_t< typename GemmPipeline::BDataType > WeiDataType
Definition: grouped_convolution_backward_data_kernel.hpp:535
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: grouped_convolution_backward_data_kernel.hpp:519
remove_cvref_t< typename EpiloguePipeline::ODataType > OutDataType
Definition: grouped_convolution_backward_data_kernel.hpp:538
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_convolution_backward_data_kernel.hpp:517
remove_cvref_t< typename GroupedConvTraitsType_::WeiLayout > WeiLayout
Definition: grouped_convolution_backward_data_kernel.hpp:525
static constexpr index_t kBlockSize
Definition: grouped_convolution_backward_data_kernel.hpp:532
static CK_TILE_HOST const std::string GetTypeString()
Definition: grouped_convolution_backward_data_kernel.hpp:586
static CK_TILE_HOST bool IsSupportedArgument(const GroupedConvBwdDataKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_data_kernel.hpp:624
remove_cvref_t< typename GemmPipeline::BLayout > GemmBLayout
Definition: grouped_convolution_backward_data_kernel.hpp:521
remove_cvref_t< typename GroupedConvTraitsType_::DsLayout > DsLayout
Definition: grouped_convolution_backward_data_kernel.hpp:527
static constexpr auto I2
Definition: grouped_convolution_backward_data_kernel.hpp:550
static CK_TILE_DEVICE auto MakeGemmTensorViews(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, const GroupedConvBwdDataKernelArgsSpecialized &kargs, const index_t group_id)
Definition: grouped_convolution_backward_data_kernel.hpp:753
static CK_TILE_HOST auto GridSize(const GroupedConvBwdDataKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_data_kernel.hpp:601
remove_cvref_t< typename GemmPipeline::ALayout > GemmALayout
Definition: grouped_convolution_backward_data_kernel.hpp:520
remove_cvref_t< typename EpiloguePipeline::DsLayout > GemmDsLayout
Definition: grouped_convolution_backward_data_kernel.hpp:529
CK_TILE_DEVICE index_t FindGroupId(const GroupedConvBwdDataKernelArgsSpecialized &kargs, index_t block_id) const
Definition: grouped_convolution_backward_data_kernel.hpp:977
static CK_TILE_DEVICE void RunGemm(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, void *smem_ptr_0, const GroupedConvBwdDataKernelArgsSpecialized &kargs, const index_t block_idx_m, const index_t block_idx_n, const index_t group_id)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_backward_data_kernel.hpp:890
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: grouped_convolution_backward_data_kernel.hpp:618
static CK_TILE_DEVICE void RunGemm2LDS(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const GroupedConvBwdDataKernelArgsSpecialized &kargs, const index_t block_idx_m, const index_t block_idx_n, const index_t group_id)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_backward_data_kernel.hpp:941
static constexpr bool IsSplitKSupported
Definition: grouped_convolution_backward_data_kernel.hpp:546
remove_cvref_t< typename GroupedConvTraitsType_::InLayout > InLayout
Definition: grouped_convolution_backward_data_kernel.hpp:524
remove_cvref_t< typename GemmPipeline::CLayout > GemmCLayout
Definition: grouped_convolution_backward_data_kernel.hpp:522
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: grouped_convolution_backward_data_kernel.hpp:536
static CK_TILE_HOST const std::string GetName()
Definition: grouped_convolution_backward_data_kernel.hpp:561
static constexpr auto I0
Definition: grouped_convolution_backward_data_kernel.hpp:548
static CK_TILE_DEVICE auto GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept -> const tuple< index_t, index_t >
The function subtracts the block's start (offset) from 1D raw-indexes.
Definition: gemm_tile_partitioner.hpp:192
Definition: transform_conv_bwd_data_to_gemm.hpp:21
CK_TILE_HOST auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N([[maybe_unused]] const index_t GemmKBatch) const
Definition: transform_conv_bwd_data_to_gemm.hpp:659
constexpr CK_TILE_HOST_DEVICE auto & at(index_t i)
Definition: array.hpp:110
Definition: integral_constant.hpp:13
std::vector< ck_tile::long_index_t > input_spatial_lengths_
Definition: convolution_parameter.hpp:130
ck_tile::long_index_t K_
Definition: convolution_parameter.hpp:126
std::vector< ck_tile::long_index_t > output_spatial_lengths_
Definition: convolution_parameter.hpp:131
std::vector< ck_tile::long_index_t > input_right_pads_
Definition: convolution_parameter.hpp:137
ck_tile::long_index_t G_
Definition: convolution_parameter.hpp:124
std::vector< ck_tile::long_index_t > conv_filter_strides_
Definition: convolution_parameter.hpp:133
std::vector< ck_tile::long_index_t > filter_spatial_lengths_
Definition: convolution_parameter.hpp:129
ck_tile::long_index_t C_
Definition: convolution_parameter.hpp:127
ck_tile::long_index_t N_
Definition: convolution_parameter.hpp:125
std::vector< ck_tile::long_index_t > input_left_pads_
Definition: convolution_parameter.hpp:136
std::vector< ck_tile::long_index_t > conv_filter_dilations_
Definition: convolution_parameter.hpp:134
Definition: type_traits.hpp:115
Definition: sequence.hpp:49
#define CK_TILE_ENV(name)
Definition: env.hpp:145