/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp Source File
grouped_convolution_backward_data_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 #include "ck_tile/host/concat.hpp"
16 
17 namespace ck_tile {
18 
20 template <typename GroupedConvTraitsType_, typename TilePartitioner_>
22 {
24 
26  TransformConvBwdDataToGemm<GroupedConvTraitsType_::NDimSpatial,
27  GroupedConvTraitsType_::ConvSpecialization,
28  GroupedConvTraitsType_::VectorSizeA,
29  GroupedConvTraitsType_::VectorSizeB,
30  GroupedConvTraitsType_::VectorSizeC,
31  true>; // Split N enabled
32  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
33 
34  static constexpr auto I0 = number<0>();
35  static constexpr auto I1 = number<1>();
36 
37  template <
38  typename InLay = typename GroupedConvTraitsType_::InLayout,
39  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
40  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
41  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NWGC> &&
42  std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
43  std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
44  bool>::type = false>
46  {
47  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
48  static_cast<index_t>(args.N_),
49  static_cast<index_t>(args.C_),
50  static_cast<index_t>(args.input_spatial_lengths_[0])};
51  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
52  static_cast<index_t>(args.K_),
53  static_cast<index_t>(args.C_),
54  static_cast<index_t>(args.filter_spatial_lengths_[0])};
55  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
56  static_cast<index_t>(args.N_),
57  static_cast<index_t>(args.K_),
58  static_cast<index_t>(args.output_spatial_lengths_[0])};
59 
60  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0])};
61  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0])};
62  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0])};
63  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0])};
64 
65  k_batch = args.k_batch;
66 
67  in_ptr = args.in_ptr;
68  wei_ptr = args.wei_ptr;
69  for(index_t d = 0; d < NumDTensor; d++)
70  {
71  ds_ptr[d] = args.ds_ptr[d];
72  }
73  out_ptr = args.out_ptr;
74 
75  const index_t X = wei_g_k_c_xs_lengths[3];
76  const index_t ConvStrideW = conv_filter_strides[0];
77  const index_t ConvDilationW = conv_filter_dilations[0];
78  const auto GcdStrideDilationW = gcd(ConvStrideW, ConvDilationW);
79  const auto XTilde = ConvStrideW / GcdStrideDilationW;
80 
81  for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
82  {
83  const auto XDotSlice = integer_divide_ceil(X - i_xtilde, XTilde);
84 
85  if(XDotSlice <= 0)
86  {
87  continue;
88  }
89 
91  {
92  gemm_count++;
93  // Avoid array segfault
94  continue;
95  }
96 
97  tildes = {i_xtilde};
98 
99  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
106  tildes};
107 
108  auto grid_descs =
109  conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
110  GroupedConvTraitsType_::NDimSpatial>(1);
111 
112  a_grid_descs_m_k[gemm_count] = grid_descs.at(number<0>{});
113  b_grid_descs_n_k[gemm_count] = grid_descs.at(number<1>{});
114  c_grid_descs_m_n[gemm_count] = grid_descs.at(number<2>{});
115 
116  const index_t grid_size_grp =
117  TilePartitioner::GridSize(c_grid_descs_m_n[gemm_count].get_length(I0),
118  c_grid_descs_m_n[gemm_count].get_length(I1));
119 
121  block_ends[gemm_count] = grid_size_ + grid_size_grp;
122 
123  grid_size_ += grid_size_grp;
124 
125  // Get the actual split N from transformer
126  n_per_split = conv_to_gemm_transformer.GetN();
127  original_n = conv_to_gemm_transformer.GetOriginalN();
129 
130  ++gemm_count;
131  }
132  group_stride_a = args.K_; // A: Out NWGK
133  group_stride_b = args.K_ * args.C_ *
134  std::accumulate(args.filter_spatial_lengths_.begin(),
135  args.filter_spatial_lengths_.end(),
136  1,
137  std::multiplies<index_t>()); // B: Wei GKXC
138  group_stride_c = args.C_; // C: In NWGC
139 
140  input_batch_stride = args.C_ * args.G_ * args.input_spatial_lengths_[0];
141  output_batch_stride = args.K_ * args.G_ * args.output_spatial_lengths_[0];
142 
143  GemmBatch = args.G_;
144  }
145 
146  template <
147  typename InLay = typename GroupedConvTraitsType_::InLayout,
148  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
149  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
150  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NHWGC> &&
151  std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
152  std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
153  bool>::type = false>
155  {
156  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
157  static_cast<index_t>(args.N_),
158  static_cast<index_t>(args.C_),
159  static_cast<index_t>(args.input_spatial_lengths_[0]),
160  static_cast<index_t>(args.input_spatial_lengths_[1])};
161  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
162  static_cast<index_t>(args.K_),
163  static_cast<index_t>(args.C_),
164  static_cast<index_t>(args.filter_spatial_lengths_[0]),
165  static_cast<index_t>(args.filter_spatial_lengths_[1])};
166  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
167  static_cast<index_t>(args.N_),
168  static_cast<index_t>(args.K_),
169  static_cast<index_t>(args.output_spatial_lengths_[0]),
170  static_cast<index_t>(args.output_spatial_lengths_[1])};
171 
172  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
173  static_cast<index_t>(args.conv_filter_strides_[1])};
174  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
175  static_cast<index_t>(args.conv_filter_dilations_[1])};
176  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
177  static_cast<index_t>(args.input_left_pads_[1])};
178  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
179  static_cast<index_t>(args.input_right_pads_[1])};
180 
181  k_batch = args.k_batch;
182 
183  in_ptr = args.in_ptr;
184  wei_ptr = args.wei_ptr;
185  for(index_t d = 0; d < NumDTensor; d++)
186  {
187  ds_ptr[d] = args.ds_ptr[d];
188  }
189  out_ptr = args.out_ptr;
190 
191  const index_t Y = wei_g_k_c_xs_lengths[3];
192  const index_t X = wei_g_k_c_xs_lengths[4];
193  const index_t ConvStrideH = conv_filter_strides[0];
194  const index_t ConvStrideW = conv_filter_strides[1];
195  const index_t ConvDilationH = conv_filter_dilations[0];
196  const index_t ConvDilationW = conv_filter_dilations[1];
197  const auto GcdStrideDilationH = gcd(ConvStrideH, ConvDilationH);
198  const auto GcdStrideDilationW = gcd(ConvStrideW, ConvDilationW);
199  const auto YTilde = ConvStrideH / GcdStrideDilationH;
200  const auto XTilde = ConvStrideW / GcdStrideDilationW;
201 
202  for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
203  {
204  for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
205  {
206  const auto YDotSlice = integer_divide_ceil(Y - i_ytilde, YTilde);
207  const auto XDotSlice = integer_divide_ceil(X - i_xtilde, XTilde);
208 
209  if(XDotSlice * YDotSlice <= 0)
210  {
211  continue;
212  }
213 
215  {
216  gemm_count++;
217  // Avoid array segfault
218  continue;
219  }
220 
221  tildes = {i_ytilde, i_xtilde};
222 
223  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
230  tildes};
231 
232  auto grid_descs = conv_to_gemm_transformer
233  .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
234  GroupedConvTraitsType_::NDimSpatial>(1);
235 
236  a_grid_descs_m_k[gemm_count] = grid_descs.at(number<0>{});
237  b_grid_descs_n_k[gemm_count] = grid_descs.at(number<1>{});
238  c_grid_descs_m_n[gemm_count] = grid_descs.at(number<2>{});
239 
240  const index_t grid_size_grp =
241  TilePartitioner::GridSize(c_grid_descs_m_n[gemm_count].get_length(I0),
242  c_grid_descs_m_n[gemm_count].get_length(I1));
243 
245  block_ends[gemm_count] = grid_size_ + grid_size_grp;
246 
247  grid_size_ += grid_size_grp;
248 
249  // Get the actual split N from transformer
250  n_per_split = conv_to_gemm_transformer.GetN();
251  original_n = conv_to_gemm_transformer.GetOriginalN();
253 
254  ++gemm_count;
255  }
256  }
257  group_stride_a = args.K_; // A: Out NWGK
258  group_stride_b = args.K_ * args.C_ *
259  std::accumulate(args.filter_spatial_lengths_.begin(),
260  args.filter_spatial_lengths_.end(),
261  1,
262  std::multiplies<index_t>()); // B: Wei GKXC
263  group_stride_c = args.C_; // C: In NWGC
264 
266  args.C_ * args.G_ * args.input_spatial_lengths_[0] * args.input_spatial_lengths_[1];
268  args.K_ * args.G_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
269 
270  GemmBatch = args.G_;
271  }
272 
273  template <
274  typename InLay = typename GroupedConvTraitsType_::InLayout,
275  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
276  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
277  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NDHWGC> &&
278  std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
279  std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
280  bool>::type = false>
282  {
283  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
284  static_cast<index_t>(args.N_),
285  static_cast<index_t>(args.C_),
286  static_cast<index_t>(args.input_spatial_lengths_[0]),
287  static_cast<index_t>(args.input_spatial_lengths_[1]),
288  static_cast<index_t>(args.input_spatial_lengths_[2])};
289  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
290  static_cast<index_t>(args.K_),
291  static_cast<index_t>(args.C_),
292  static_cast<index_t>(args.filter_spatial_lengths_[0]),
293  static_cast<index_t>(args.filter_spatial_lengths_[1]),
294  static_cast<index_t>(args.filter_spatial_lengths_[2])};
295  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
296  static_cast<index_t>(args.N_),
297  static_cast<index_t>(args.K_),
298  static_cast<index_t>(args.output_spatial_lengths_[0]),
299  static_cast<index_t>(args.output_spatial_lengths_[1]),
300  static_cast<index_t>(args.output_spatial_lengths_[2])};
301 
302  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
303  static_cast<index_t>(args.conv_filter_strides_[1]),
304  static_cast<index_t>(args.conv_filter_strides_[2])};
305  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
306  static_cast<index_t>(args.conv_filter_dilations_[1]),
307  static_cast<index_t>(args.conv_filter_dilations_[2])};
308  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
309  static_cast<index_t>(args.input_left_pads_[1]),
310  static_cast<index_t>(args.input_left_pads_[2])};
311  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
312  static_cast<index_t>(args.input_right_pads_[1]),
313  static_cast<index_t>(args.input_right_pads_[2])};
314 
315  k_batch = args.k_batch;
316 
317  in_ptr = args.in_ptr;
318  wei_ptr = args.wei_ptr;
319  for(index_t d = 0; d < NumDTensor; d++)
320  {
321  ds_ptr[d] = args.ds_ptr[d];
322  }
323  out_ptr = args.out_ptr;
324 
325  const index_t Z = wei_g_k_c_xs_lengths[3];
326  const index_t Y = wei_g_k_c_xs_lengths[4];
327  const index_t X = wei_g_k_c_xs_lengths[5];
328  const index_t ConvStrideD = conv_filter_strides[0];
329  const index_t ConvStrideH = conv_filter_strides[1];
330  const index_t ConvStrideW = conv_filter_strides[2];
331  const index_t ConvDilationD = conv_filter_dilations[0];
332  const index_t ConvDilationH = conv_filter_dilations[1];
333  const index_t ConvDilationW = conv_filter_dilations[2];
334  const auto GcdStrideDilationD = gcd(ConvStrideD, ConvDilationD);
335  const auto GcdStrideDilationH = gcd(ConvStrideH, ConvDilationH);
336  const auto GcdStrideDilationW = gcd(ConvStrideW, ConvDilationW);
337  const auto ZTilde = ConvStrideD / GcdStrideDilationD;
338  const auto YTilde = ConvStrideH / GcdStrideDilationH;
339  const auto XTilde = ConvStrideW / GcdStrideDilationW;
340 
341  for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde)
342  {
343  for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
344  {
345  for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
346  {
347  const auto ZDotSlice = integer_divide_ceil(Z - i_ztilde, ZTilde);
348  const auto YDotSlice = integer_divide_ceil(Y - i_ytilde, YTilde);
349  const auto XDotSlice = integer_divide_ceil(X - i_xtilde, XTilde);
350 
351  if(ZDotSlice * XDotSlice * YDotSlice <= 0)
352  {
353  continue;
354  }
355 
357  {
358  gemm_count++;
359  // Avoid array segfault
360  continue;
361  }
362 
363  tildes = {i_ztilde, i_ytilde, i_xtilde};
364 
365  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
372  tildes};
373 
374  auto grid_descs = conv_to_gemm_transformer
375  .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
376  GroupedConvTraitsType_::NDimSpatial>(1);
377 
378  a_grid_descs_m_k[gemm_count] = grid_descs.at(number<0>{});
379  b_grid_descs_n_k[gemm_count] = grid_descs.at(number<1>{});
380  c_grid_descs_m_n[gemm_count] = grid_descs.at(number<2>{});
381 
382  const index_t grid_size_grp =
383  TilePartitioner::GridSize(c_grid_descs_m_n[gemm_count].get_length(I0),
384  c_grid_descs_m_n[gemm_count].get_length(I1));
385 
387  block_ends[gemm_count] = grid_size_ + grid_size_grp;
388 
389  grid_size_ += grid_size_grp;
390 
391  // Get the actual split N from transformer
392  n_per_split = conv_to_gemm_transformer.GetN();
393  original_n = conv_to_gemm_transformer.GetOriginalN();
395 
396  ++gemm_count;
397  }
398  }
399  }
400 
401  group_stride_a = args.K_; // A: Out NWGK
402  group_stride_b = args.K_ * args.C_ *
403  std::accumulate(args.filter_spatial_lengths_.begin(),
404  args.filter_spatial_lengths_.end(),
405  1,
406  std::multiplies<index_t>()); // B: Wei GKXC
407  group_stride_c = args.C_; // C: In NWGC
408 
409  input_batch_stride = args.C_ * args.G_ * args.input_spatial_lengths_[0] *
411  output_batch_stride = args.K_ * args.G_ * args.output_spatial_lengths_[0] *
413 
414  GemmBatch = args.G_; // C: In NWGC
415  }
416 
417  static constexpr index_t MaxGroupedGemmGroupsNum = 128;
418 
421 
425 
426  static constexpr index_t NonSpatialDims = 3;
430 
436 
441 
442  const void* out_ptr;
443  void* in_ptr;
444  std::array<const void*, NumDTensor> ds_ptr;
445  const void* wei_ptr;
446 
450 
453 
457 
458  // Split-N support fields - initialize to safe defaults
459  index_t n_splits = 1; // Number of batch splits (e.g., 2 for 128→64×2)
460  index_t n_per_split = 1; // Batches per split (N_ from transformer)
461  index_t original_n = 1; // Original batch size before splitting
462  index_t input_batch_stride = 0; // Stride to next batch in input tensor
463  index_t output_batch_stride = 0; // Stride to next batch in output tensor
464 };
465 
504 template <typename GroupedConvTraitsType_,
505  typename TilePartitioner_,
506  typename GemmPipeline_,
507  typename EpiloguePipeline_>
509 {
510  static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial_;
512  GroupedConvTraitsType_::ConvSpecialization;
519 
524 
526  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
527 
528  static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
529 
533 
535 
538  static constexpr index_t MaxGroupedGemmGroupsNum =
540 
541  // TODO: Enable this
542  static constexpr bool IsSplitKSupported = false;
543 
544  static constexpr auto I0 = number<0>();
545  static constexpr auto I1 = number<1>();
546  static constexpr auto I2 = number<2>();
547  static constexpr auto I3 = number<3>();
548 
549  static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
550  "Not supported!");
551  static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>, "Not supported!");
552  static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
553  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
554  "Not supported C GEMM layout!");
555 
556  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
557  {
558  // clang-format off
559  return concat('_', "grouped_convolution_backward_data",
560  gemm_prec_str<InDataType, WeiDataType>(),
561  "gemm",
562  GemmPipeline::GetName(),
563  "epilogue",
564  EpiloguePipeline::GetName());
565  // clang-format on
566  }
567 
569  {
570  // enable batched grouped gemm
571  return dim3(kargs.grid_size_, kargs.GemmBatch, kargs.n_splits * kargs.k_batch);
572  }
573 
574  CK_TILE_HOST static constexpr auto BlockSize()
575  {
576  return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
577  }
578 
581  {
583  }
584 
586  {
587  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
588  }
589 
590  CK_TILE_HOST static bool
592  {
593  if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
596  {
597  if(kargs.k_batch != 1)
598  {
599  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
600  {
601  CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
602  }
603  return false;
604  }
605  }
606 
608  {
609  return false;
610  }
611 
612  const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
613  const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
614 
615  // check ConvSpecialization
617  {
618  // check if it's 1x1, stride=1 conv
619  for(index_t i = 0; i < NDimSpatial; ++i)
620  {
621  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
622  const index_t ConvStride = kargs.conv_filter_strides[i];
623  const index_t LeftPad = kargs.input_left_pads[i];
624  const index_t RightPad = kargs.input_right_pads[i];
625 
626  if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
627  {
628  return false;
629  }
630  }
631  }
633  {
634  // check if it's 1x1 conv
635  for(index_t i = 0; i < NDimSpatial; ++i)
636  {
637  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
638  const index_t LeftPad = kargs.input_left_pads[i];
639  const index_t RightPad = kargs.input_right_pads[i];
640 
641  if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
642  {
643  return false;
644  }
645  }
646  }
648  {
649  if(ConvC != 1)
650  {
651  return false;
652  }
653  for(index_t i = 0; i < NDimSpatial; ++i)
654  {
655  const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3];
656 
657  if(filter_spatial_dim != I3)
658  {
659  return false;
660  }
661  }
662  }
663 
664  namespace ctc = tensor_layout::convolution;
665 
666  if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
667  std::is_same_v<InLayout, ctc::NDHWGC>)
668  {
669  // Check access per C
670  if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
671  {
672  CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
673  return false;
674  }
675  }
676  else
677  {
678  CK_TILE_ERROR("Not supported input layout!");
679  return false;
680  }
681 
682  // FIXME: layout
683  if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
684  std::is_same_v<WeiLayout, ctc::GKYXC> ||
685  std::is_same_v<WeiLayout, ctc::GKZYXC>)
686  {
687  if(ConvC % GroupedConvTraitsType_::VectorSizeC != 0)
688  {
689  CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
690  return false;
691  }
692  }
693  else
694  {
695  CK_TILE_ERROR("Not supported weight layout!");
696  return false;
697  }
698 
699  if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
700  std::is_same_v<OutLayout, ctc::NHWGK> ||
701  std::is_same_v<OutLayout, ctc::NDHWGK>)
702  {
703  if(ConvK % GroupedConvTraitsType_::VectorSizeA != 0)
704  {
705  CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!");
706  return false;
707  }
708  }
709  else
710  {
711  CK_TILE_ERROR("Not supported output layout!");
712  return false;
713  }
714 
715  return true;
716  }
717 
718  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
719  CK_TILE_DEVICE static auto
721  const InDataType* b_ptr,
722  const std::array<const void*, NumDTensor>& ds_ptr,
723  WeiDataType* c_ptr,
725  const index_t group_id)
726  {
727  static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
728  static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
729  const auto& a_tensor_view = [&]() {
730  return make_tensor_view<address_space_enum::global>(
731  a_ptr,
732  kargs.a_grid_descs_m_k[group_id]); // A: out
733  }();
734 
735  const auto& b_tensor_view = [&]() {
736  return make_tensor_view<address_space_enum::global>(
737  b_ptr,
738  kargs.b_grid_descs_n_k[group_id]); // B: weight
739  }();
740 
741  const auto& c_tensor_view = [&]() {
742  return make_tensor_view<address_space_enum::global>(c_ptr,
743  kargs.c_grid_descs_m_n[group_id]);
744  }();
745 
746  const auto& ds_tensor_view = generate_tuple(
747  [&](auto i) {
748  static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
749  "Not supported!");
750  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
751  "Not supported!");
752  static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, OutDataType>,
753  "Not supported!");
754 
755  return make_tensor_view<address_space_enum::global>(
756  static_cast<OutDataType*>(ds_ptr[i]), kargs.c_grid_descs_m_n[group_id]);
757  },
759 
760  return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
761  }
762 
763  template <typename TensorView>
764  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
765  {
766  const auto& a_pad_view = [&]() {
767  const auto& a_tensor_view = views.at(I0);
768  return pad_tensor_view(a_tensor_view,
772  }();
773 
774  const auto& b_pad_view = [&]() {
775  const auto& b_tensor_view = views.at(I1);
776  return pad_tensor_view(b_tensor_view,
780  }();
781 
782  const auto& ds_tensor_view = views.at(I2);
783  const auto& ds_pad_view = generate_tuple(
784  [&](auto i) {
785  return pad_tensor_view(ds_tensor_view[i],
789  },
791 
792  const auto& c_pad_view = [&]() {
793  const auto& c_tensor_view = views.at(I3);
794  return pad_tensor_view(c_tensor_view,
798  }();
799 
800  return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
801  }
802 
803  template <typename PadView>
804  CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
805  const index_t i_m,
806  const index_t i_n,
807  const index_t i_k = 0)
808  {
809  const auto& a_pad_view = views.at(I0);
810  const auto& b_pad_view = views.at(I1);
811  const auto& ds_pad_view = views.at(I2);
812  const auto& c_pad_view = views.at(I3);
813 
814  const auto& a_block_window = [&]() {
815  return make_tile_window(a_pad_view,
818  {i_m, i_k});
819  }();
820 
821  const auto& b_block_window = [&]() {
822  return make_tile_window(b_pad_view,
825  {i_k, i_n});
826  }();
827 
828  const auto ds_block_window = generate_tuple(
829  [&](auto i) {
830  return make_tile_window(ds_pad_view[i],
833  {i_m, i_n});
834  },
836 
837  auto c_block_window = make_tile_window(
838  c_pad_view,
840  {i_m, i_n});
841 
842  return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
843  }
844 
857  CK_TILE_DEVICE static void RunGemm(const OutDataType* a_ptr,
858  const InDataType* b_ptr,
859  const std::array<const void*, NumDTensor>& ds_ptr,
860  WeiDataType* c_ptr,
861  void* smem_ptr_0,
863  const index_t block_idx_m,
864  const index_t block_idx_n,
865  const index_t group_id)
866  {
867  // Create Gemm tensor views, pad views and tile windows
868  const auto& gemm_tensor_views_tuple =
869  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
870  a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id);
871 
872  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
873  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
874 
875  const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(
876  gemm_pad_views.at(I0).get_tensor_descriptor().get_length(I1)));
877 
878  // Run GEMM cooperatively by whole workgroup.
879  const auto& a_block_window = gemm_tile_windows.at(I0);
880  const auto& b_block_window = gemm_tile_windows.at(I1);
881  const auto& d_block_window = gemm_tile_windows.at(I2);
882 
883  const auto& c_block_tile = GemmPipeline{}.template operator()(
884  a_block_window, b_block_window, num_loop, smem_ptr_0);
885 
886  // Run Epilogue Pipeline
887  auto& c_block_window = gemm_tile_windows.at(I3);
888 
889  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
890  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
891  }
892 
908  CK_TILE_DEVICE static void RunGemm2LDS(const OutDataType* a_ptr,
909  const InDataType* b_ptr,
910  const std::array<const void*, NumDTensor>& ds_ptr,
911  WeiDataType* c_ptr,
912  void* __restrict__ smem_ptr_0,
913  void* __restrict__ smem_ptr_1,
915  const index_t block_idx_m,
916  const index_t block_idx_n,
917  const index_t group_id)
918  {
919  // Create Gemm tensor views, pad views and tile windows
920  const auto& gemm_tensor_views_tuple =
921  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
922  a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id);
923  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
924  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
925 
926  const index_t num_loop = amd_wave_read_first_lane(
927  TilePartitioner::GetLoopNum(gemm_tile_windows.at(I0).get_length(I1)));
928 
929  // Run GEMM cooperatively by whole workgroup.
930  const auto& a_block_window = gemm_tile_windows.at(I0);
931  const auto& b_block_window = gemm_tile_windows.at(I1);
932  const auto& d_block_window = gemm_tile_windows.at(I2);
933 
934  const auto& c_block_tile = GemmPipeline{}.template operator()(
935  a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
936 
937  // Run Epilogue Pipeline
938  auto& c_block_window = gemm_tile_windows.at(I3);
939 
940  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
941  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
942  }
943 
945  index_t block_id) const
946  {
947  index_t left = 0;
948  index_t right = kargs.gemm_count;
949  index_t group_id = index_t((left + right) >> 1);
950 
951  while((!(block_id >= kargs.block_starts[group_id] &&
952  block_id < kargs.block_ends[group_id])) &&
953  left <= right)
954  {
955  if(block_id < kargs.block_starts[group_id])
956  {
957  right = group_id;
958  }
959  else
960  {
961  left = group_id;
962  }
963  group_id = index_t((left + right) >> 1);
964  }
965 
966  return group_id;
967  }
968 
970  {
971  const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
972  const index_t group_id = FindGroupId(kargs, blockIdX);
973 
975  kargs.block_starts[group_id],
976  kargs.c_grid_descs_m_n[group_id].get_length(I0),
977  kargs.c_grid_descs_m_n[group_id].get_length(I1));
978 
979  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
980  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
981 
982  const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
983  const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
984  const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
985  const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
986 
987  const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z);
988 
989  // SplitN
990  const index_t split_n_idx = __builtin_amdgcn_readfirstlane(blockIdZ / kargs.k_batch);
991  const index_t split_n_offset =
992  __builtin_amdgcn_readfirstlane(split_n_idx * kargs.n_per_split);
993 
994  const long_index_t output_batch_offset =
995  static_cast<long_index_t>(split_n_offset) *
996  static_cast<long_index_t>(kargs.output_batch_stride);
997  const long_index_t input_batch_offset = static_cast<long_index_t>(split_n_offset) *
998  static_cast<long_index_t>(kargs.input_batch_stride);
999 
1000  // SplitK
1001  // TODO: Implement SplitK support
1002  // const index_t split_k_idx =
1003  // __builtin_amdgcn_readfirstlane(blockIdZ - split_n_idx * kargs.k_batch);
1004 
1005  // options
1006  // conv_bwd_data = Out * Weight = In
1007  const OutDataType* a_ptr =
1008  static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a + output_batch_offset;
1009  const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) + group_offset_b;
1010  InDataType* c_ptr =
1011  static_cast<InDataType*>(kargs.in_ptr) + group_offset_c + input_batch_offset;
1012 
1013  // allocate LDS
1014  __shared__ char smem_ptr_0[GetSmemSize()];
1015 
1016  if constexpr(GemmPipeline::DoubleSmemBuffer == true)
1017  {
1018  __shared__ char smem_ptr_1[GetSmemSize()];
1019  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
1020  GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
1022  {
1023  RunGemm2LDS(a_ptr,
1024  b_ptr,
1025  kargs.ds_ptr,
1026  c_ptr,
1027  smem_ptr_0,
1028  smem_ptr_1,
1029  kargs,
1030  i_m,
1031  i_n,
1032  group_id);
1033  }
1034  }
1035  else
1036  {
1037  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
1038  GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
1040  {
1041  RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n, group_id);
1042  }
1043  }
1044  }
1045 };
1046 
1047 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
ConvolutionSpecialization
Definition: convolution_specialization.hpp:11
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE index_t gcd(index_t x, index_t y)
Definition: math.hpp:268
int64_t long_index_t
Definition: integer.hpp:11
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
constexpr bool is_same_v
Definition: type.hpp:283
__device__ X atomic_add(X *p_dst, const X &x)
The Grouped Convolution kernel device arguments.
Definition: grouped_convolution_backward_data_kernel.hpp:22
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > wei_g_k_c_xs_lengths
Definition: grouped_convolution_backward_data_kernel.hpp:428
static constexpr auto I1
Definition: grouped_convolution_backward_data_kernel.hpp:35
CK_TILE_HOST GroupedConvBwdDataKernelArgs(const GroupedConvBwdDataHostArgs &args)
Definition: grouped_convolution_backward_data_kernel.hpp:45
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_dilations
Definition: grouped_convolution_backward_data_kernel.hpp:432
std::array< const void *, NumDTensor > ds_ptr
Definition: grouped_convolution_backward_data_kernel.hpp:444
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_strides
Definition: grouped_convolution_backward_data_kernel.hpp:431
array< index_t, MaxGroupedGemmGroupsNum > block_starts
Definition: grouped_convolution_backward_data_kernel.hpp:451
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_left_pads
Definition: grouped_convolution_backward_data_kernel.hpp:433
long_index_t group_stride_b
Definition: grouped_convolution_backward_data_kernel.hpp:455
long_index_t group_stride_c
Definition: grouped_convolution_backward_data_kernel.hpp:456
array< index_t, MaxGroupedGemmGroupsNum > block_ends
Definition: grouped_convolution_backward_data_kernel.hpp:452
const void * out_ptr
Definition: grouped_convolution_backward_data_kernel.hpp:442
remove_cvref_t< decltype(ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(1))> ABCGridDescs
Definition: grouped_convolution_backward_data_kernel.hpp:420
remove_cvref_t< decltype(ABCGridDescs{}[number< 1 >{}])> BGridDescNK
Definition: grouped_convolution_backward_data_kernel.hpp:423
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_convolution_backward_data_kernel.hpp:23
array< index_t, GroupedConvTraitsType_::NDimSpatial > tildes
Definition: grouped_convolution_backward_data_kernel.hpp:435
remove_cvref_t< decltype(ABCGridDescs{}[number< 0 >{}])> AGridDescMK
Definition: grouped_convolution_backward_data_kernel.hpp:422
const void * wei_ptr
Definition: grouped_convolution_backward_data_kernel.hpp:445
index_t n_per_split
Definition: grouped_convolution_backward_data_kernel.hpp:460
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > out_g_n_k_wos_lengths
Definition: grouped_convolution_backward_data_kernel.hpp:429
long_index_t group_stride_a
Definition: grouped_convolution_backward_data_kernel.hpp:454
index_t GemmBatch
Definition: grouped_convolution_backward_data_kernel.hpp:438
void * in_ptr
Definition: grouped_convolution_backward_data_kernel.hpp:443
index_t n_splits
Definition: grouped_convolution_backward_data_kernel.hpp:459
index_t gemm_count
Definition: grouped_convolution_backward_data_kernel.hpp:440
array< CGridDescMN, MaxGroupedGemmGroupsNum > c_grid_descs_m_n
Definition: grouped_convolution_backward_data_kernel.hpp:449
index_t original_n
Definition: grouped_convolution_backward_data_kernel.hpp:461
index_t grid_size_
Definition: grouped_convolution_backward_data_kernel.hpp:439
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_right_pads
Definition: grouped_convolution_backward_data_kernel.hpp:434
array< BGridDescNK, MaxGroupedGemmGroupsNum > b_grid_descs_n_k
Definition: grouped_convolution_backward_data_kernel.hpp:448
index_t k_batch
Definition: grouped_convolution_backward_data_kernel.hpp:437
static constexpr auto I0
Definition: grouped_convolution_backward_data_kernel.hpp:34
static constexpr index_t MaxGroupedGemmGroupsNum
Definition: grouped_convolution_backward_data_kernel.hpp:417
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > in_g_n_c_wis_lengths
Definition: grouped_convolution_backward_data_kernel.hpp:427
static constexpr index_t NumDTensor
Definition: grouped_convolution_backward_data_kernel.hpp:32
index_t output_batch_stride
Definition: grouped_convolution_backward_data_kernel.hpp:463
index_t input_batch_stride
Definition: grouped_convolution_backward_data_kernel.hpp:462
array< AGridDescMK, MaxGroupedGemmGroupsNum > a_grid_descs_m_k
Definition: grouped_convolution_backward_data_kernel.hpp:447
remove_cvref_t< decltype(ABCGridDescs{}[number< 2 >{}])> CGridDescMN
Definition: grouped_convolution_backward_data_kernel.hpp:424
static constexpr index_t NonSpatialDims
Definition: grouped_convolution_backward_data_kernel.hpp:426
The Grouped Conv kernel host arguments.
Definition: grouped_convolution_utils.hpp:20
InPtr in_ptr
Definition: grouped_convolution_utils.hpp:39
OutPtr out_ptr
Definition: grouped_convolution_utils.hpp:42
WeiPtr wei_ptr
Definition: grouped_convolution_utils.hpp:40
index_t k_batch
Definition: grouped_convolution_utils.hpp:43
const std::vector< const void * > ds_ptr
Definition: grouped_convolution_utils.hpp:41
The Grouped Convolution Backward Data kernel template.
Definition: grouped_convolution_backward_data_kernel.hpp:509
static constexpr index_t NDimSpatial
Definition: grouped_convolution_backward_data_kernel.hpp:510
static constexpr CK_TILE_HOST auto BlockSize()
Definition: grouped_convolution_backward_data_kernel.hpp:574
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: grouped_convolution_backward_data_kernel.hpp:514
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n, const index_t i_k=0)
Definition: grouped_convolution_backward_data_kernel.hpp:804
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: grouped_convolution_backward_data_kernel.hpp:764
GroupedConvBwdDataKernelArgs< GroupedConvTraitsType_, TilePartitioner > GroupedConvBwdDataKernelArgsSpecialized
Definition: grouped_convolution_backward_data_kernel.hpp:537
remove_cvref_t< typename GemmPipeline::ADataType > InDataType
Definition: grouped_convolution_backward_data_kernel.hpp:530
static constexpr index_t MaxGroupedGemmGroupsNum
Definition: grouped_convolution_backward_data_kernel.hpp:538
static constexpr auto I1
Definition: grouped_convolution_backward_data_kernel.hpp:545
static constexpr auto I3
Definition: grouped_convolution_backward_data_kernel.hpp:547
remove_cvref_t< typename GroupedConvTraitsType_::OutLayout > OutLayout
Definition: grouped_convolution_backward_data_kernel.hpp:522
static constexpr ConvolutionSpecialization ConvSpecialization
Definition: grouped_convolution_backward_data_kernel.hpp:511
static constexpr CK_TILE_HOST GroupedConvBwdDataKernelArgsSpecialized MakeKernelArgs(const GroupedConvBwdDataHostArgs &hostArgs)
Definition: grouped_convolution_backward_data_kernel.hpp:580
static constexpr index_t NumDTensor
Definition: grouped_convolution_backward_data_kernel.hpp:526
remove_cvref_t< typename GemmPipeline::BDataType > WeiDataType
Definition: grouped_convolution_backward_data_kernel.hpp:531
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: grouped_convolution_backward_data_kernel.hpp:515
remove_cvref_t< typename EpiloguePipeline::ODataType > OutDataType
Definition: grouped_convolution_backward_data_kernel.hpp:534
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_convolution_backward_data_kernel.hpp:513
remove_cvref_t< typename GroupedConvTraitsType_::WeiLayout > WeiLayout
Definition: grouped_convolution_backward_data_kernel.hpp:521
static constexpr index_t kBlockSize
Definition: grouped_convolution_backward_data_kernel.hpp:528
static CK_TILE_HOST bool IsSupportedArgument(const GroupedConvBwdDataKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_data_kernel.hpp:591
remove_cvref_t< typename GemmPipeline::BLayout > GemmBLayout
Definition: grouped_convolution_backward_data_kernel.hpp:517
remove_cvref_t< typename GroupedConvTraitsType_::DsLayout > DsLayout
Definition: grouped_convolution_backward_data_kernel.hpp:523
static constexpr auto I2
Definition: grouped_convolution_backward_data_kernel.hpp:546
static CK_TILE_DEVICE auto MakeGemmTensorViews(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, const GroupedConvBwdDataKernelArgsSpecialized &kargs, const index_t group_id)
Definition: grouped_convolution_backward_data_kernel.hpp:720
static CK_TILE_HOST auto GridSize(const GroupedConvBwdDataKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_data_kernel.hpp:568
remove_cvref_t< typename GemmPipeline::ALayout > GemmALayout
Definition: grouped_convolution_backward_data_kernel.hpp:516
remove_cvref_t< typename EpiloguePipeline::DsLayout > GemmDsLayout
Definition: grouped_convolution_backward_data_kernel.hpp:525
CK_TILE_DEVICE index_t FindGroupId(const GroupedConvBwdDataKernelArgsSpecialized &kargs, index_t block_id) const
Definition: grouped_convolution_backward_data_kernel.hpp:944
static CK_TILE_DEVICE void RunGemm(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, void *smem_ptr_0, const GroupedConvBwdDataKernelArgsSpecialized &kargs, const index_t block_idx_m, const index_t block_idx_n, const index_t group_id)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_backward_data_kernel.hpp:857
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: grouped_convolution_backward_data_kernel.hpp:585
static CK_TILE_DEVICE void RunGemm2LDS(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const GroupedConvBwdDataKernelArgsSpecialized &kargs, const index_t block_idx_m, const index_t block_idx_n, const index_t group_id)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_backward_data_kernel.hpp:908
CK_TILE_DEVICE void operator()(GroupedConvBwdDataKernelArgsSpecialized kargs) const
Definition: grouped_convolution_backward_data_kernel.hpp:969
static constexpr bool IsSplitKSupported
Definition: grouped_convolution_backward_data_kernel.hpp:542
remove_cvref_t< typename GroupedConvTraitsType_::InLayout > InLayout
Definition: grouped_convolution_backward_data_kernel.hpp:520
remove_cvref_t< typename GemmPipeline::CLayout > GemmCLayout
Definition: grouped_convolution_backward_data_kernel.hpp:518
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: grouped_convolution_backward_data_kernel.hpp:532
static CK_TILE_HOST const std::string GetName()
Definition: grouped_convolution_backward_data_kernel.hpp:556
static constexpr auto I0
Definition: grouped_convolution_backward_data_kernel.hpp:544
static CK_TILE_DEVICE auto GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept -> const tuple< index_t, index_t >
The function subtracts the block's start (offset) from 1D raw-indexes.
Definition: gemm_tile_partitioner.hpp:192
Definition: transform_conv_bwd_data_to_gemm.hpp:22
CK_TILE_HOST auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N([[maybe_unused]] const index_t GemmKBatch) const
Definition: transform_conv_bwd_data_to_gemm.hpp:598
constexpr CK_TILE_HOST_DEVICE auto & at(index_t i)
Definition: array.hpp:110
Definition: integral_constant.hpp:13
std::vector< ck_tile::long_index_t > input_spatial_lengths_
Definition: convolution_parameter.hpp:130
ck_tile::long_index_t K_
Definition: convolution_parameter.hpp:126
std::vector< ck_tile::long_index_t > output_spatial_lengths_
Definition: convolution_parameter.hpp:131
std::vector< ck_tile::long_index_t > input_right_pads_
Definition: convolution_parameter.hpp:137
ck_tile::long_index_t G_
Definition: convolution_parameter.hpp:124
std::vector< ck_tile::long_index_t > conv_filter_strides_
Definition: convolution_parameter.hpp:133
std::vector< ck_tile::long_index_t > filter_spatial_lengths_
Definition: convolution_parameter.hpp:129
ck_tile::long_index_t C_
Definition: convolution_parameter.hpp:127
ck_tile::long_index_t N_
Definition: convolution_parameter.hpp:125
std::vector< ck_tile::long_index_t > input_left_pads_
Definition: convolution_parameter.hpp:136
std::vector< ck_tile::long_index_t > conv_filter_dilations_
Definition: convolution_parameter.hpp:134
Definition: type_traits.hpp:115
Definition: sequence.hpp:49
#define CK_TILE_ENV(name)
Definition: env.hpp:145