/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp Source File
grouped_convolution_backward_data_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 #include "ck_tile/host/concat.hpp"
16 
17 #ifdef CK_EXPERIMENTAL_BUILDER
18 #include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_data.hpp"
19 #endif
20 
21 namespace ck_tile {
22 
24 template <typename GroupedConvTraitsType_, typename TilePartitioner_>
26 {
28 
30  TransformConvBwdDataToGemm<GroupedConvTraitsType_::NDimSpatial,
31  GroupedConvTraitsType_::ConvSpecialization,
32  GroupedConvTraitsType_::VectorSizeA,
33  GroupedConvTraitsType_::VectorSizeB,
34  GroupedConvTraitsType_::VectorSizeC,
35  true>; // Split N enabled
36  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
37 
38  static constexpr auto I0 = number<0>();
39  static constexpr auto I1 = number<1>();
40 
41  template <
42  typename InLay = typename GroupedConvTraitsType_::InLayout,
43  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
44  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
45  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NWGC> &&
46  std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
47  std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
48  bool>::type = false>
50  {
51  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
52  static_cast<index_t>(args.N_),
53  static_cast<index_t>(args.C_),
54  static_cast<index_t>(args.input_spatial_lengths_[0])};
55  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
56  static_cast<index_t>(args.K_),
57  static_cast<index_t>(args.C_),
58  static_cast<index_t>(args.filter_spatial_lengths_[0])};
59  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
60  static_cast<index_t>(args.N_),
61  static_cast<index_t>(args.K_),
62  static_cast<index_t>(args.output_spatial_lengths_[0])};
63 
64  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0])};
65  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0])};
66  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0])};
67  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0])};
68 
69  k_batch = args.k_batch;
70 
71  in_ptr = args.in_ptr;
72  wei_ptr = args.wei_ptr;
73  for(index_t d = 0; d < NumDTensor; d++)
74  {
75  ds_ptr[d] = args.ds_ptr[d];
76  }
77  out_ptr = args.out_ptr;
78 
79  const index_t X = wei_g_k_c_xs_lengths[3];
80  const index_t ConvStrideW = conv_filter_strides[0];
81  const index_t ConvDilationW = conv_filter_dilations[0];
82  const auto GcdStrideDilationW = gcd(ConvStrideW, ConvDilationW);
83  const auto XTilde = ConvStrideW / GcdStrideDilationW;
84 
85  for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
86  {
87  const auto XDotSlice = integer_divide_ceil(X - i_xtilde, XTilde);
88 
89  if(XDotSlice <= 0)
90  {
91  continue;
92  }
93 
95  {
96  gemm_count++;
97  // Avoid array segfault
98  continue;
99  }
100 
101  tildes = {i_xtilde};
102 
103  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
110  tildes};
111 
112  auto grid_descs =
113  conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
114  GroupedConvTraitsType_::NDimSpatial>(1);
115 
116  a_grid_descs_m_k[gemm_count] = grid_descs.at(number<0>{});
117  b_grid_descs_n_k[gemm_count] = grid_descs.at(number<1>{});
118  c_grid_descs_m_n[gemm_count] = grid_descs.at(number<2>{});
119 
120  const index_t grid_size_grp =
121  TilePartitioner::GridSize(c_grid_descs_m_n[gemm_count].get_length(I0),
122  c_grid_descs_m_n[gemm_count].get_length(I1));
123 
125  block_ends[gemm_count] = grid_size_ + grid_size_grp;
126 
127  grid_size_ += grid_size_grp;
128 
129  // Get the actual split N from transformer
130  n_per_split = conv_to_gemm_transformer.GetN();
131  original_n = conv_to_gemm_transformer.GetOriginalN();
133 
134  ++gemm_count;
135  }
136  group_stride_a = args.K_; // A: Out NWGK
137  group_stride_b = args.K_ * args.C_ *
138  std::accumulate(args.filter_spatial_lengths_.begin(),
139  args.filter_spatial_lengths_.end(),
140  1,
141  std::multiplies<index_t>()); // B: Wei GKXC
142  group_stride_c = args.C_; // C: In NWGC
143 
144  input_batch_stride = args.C_ * args.G_ * args.input_spatial_lengths_[0];
145  output_batch_stride = args.K_ * args.G_ * args.output_spatial_lengths_[0];
146 
147  GemmBatch = args.G_;
148  }
149 
150  template <
151  typename InLay = typename GroupedConvTraitsType_::InLayout,
152  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
153  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
154  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NHWGC> &&
155  std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
156  std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
157  bool>::type = false>
159  {
160  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
161  static_cast<index_t>(args.N_),
162  static_cast<index_t>(args.C_),
163  static_cast<index_t>(args.input_spatial_lengths_[0]),
164  static_cast<index_t>(args.input_spatial_lengths_[1])};
165  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
166  static_cast<index_t>(args.K_),
167  static_cast<index_t>(args.C_),
168  static_cast<index_t>(args.filter_spatial_lengths_[0]),
169  static_cast<index_t>(args.filter_spatial_lengths_[1])};
170  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
171  static_cast<index_t>(args.N_),
172  static_cast<index_t>(args.K_),
173  static_cast<index_t>(args.output_spatial_lengths_[0]),
174  static_cast<index_t>(args.output_spatial_lengths_[1])};
175 
176  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
177  static_cast<index_t>(args.conv_filter_strides_[1])};
178  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
179  static_cast<index_t>(args.conv_filter_dilations_[1])};
180  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
181  static_cast<index_t>(args.input_left_pads_[1])};
182  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
183  static_cast<index_t>(args.input_right_pads_[1])};
184 
185  k_batch = args.k_batch;
186 
187  in_ptr = args.in_ptr;
188  wei_ptr = args.wei_ptr;
189  for(index_t d = 0; d < NumDTensor; d++)
190  {
191  ds_ptr[d] = args.ds_ptr[d];
192  }
193  out_ptr = args.out_ptr;
194 
195  const index_t Y = wei_g_k_c_xs_lengths[3];
196  const index_t X = wei_g_k_c_xs_lengths[4];
197  const index_t ConvStrideH = conv_filter_strides[0];
198  const index_t ConvStrideW = conv_filter_strides[1];
199  const index_t ConvDilationH = conv_filter_dilations[0];
200  const index_t ConvDilationW = conv_filter_dilations[1];
201  const auto GcdStrideDilationH = gcd(ConvStrideH, ConvDilationH);
202  const auto GcdStrideDilationW = gcd(ConvStrideW, ConvDilationW);
203  const auto YTilde = ConvStrideH / GcdStrideDilationH;
204  const auto XTilde = ConvStrideW / GcdStrideDilationW;
205 
206  for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
207  {
208  for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
209  {
210  const auto YDotSlice = integer_divide_ceil(Y - i_ytilde, YTilde);
211  const auto XDotSlice = integer_divide_ceil(X - i_xtilde, XTilde);
212 
213  if(XDotSlice * YDotSlice <= 0)
214  {
215  continue;
216  }
217 
219  {
220  gemm_count++;
221  // Avoid array segfault
222  continue;
223  }
224 
225  tildes = {i_ytilde, i_xtilde};
226 
227  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
234  tildes};
235 
236  auto grid_descs = conv_to_gemm_transformer
237  .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
238  GroupedConvTraitsType_::NDimSpatial>(1);
239 
240  a_grid_descs_m_k[gemm_count] = grid_descs.at(number<0>{});
241  b_grid_descs_n_k[gemm_count] = grid_descs.at(number<1>{});
242  c_grid_descs_m_n[gemm_count] = grid_descs.at(number<2>{});
243 
244  const index_t grid_size_grp =
245  TilePartitioner::GridSize(c_grid_descs_m_n[gemm_count].get_length(I0),
246  c_grid_descs_m_n[gemm_count].get_length(I1));
247 
249  block_ends[gemm_count] = grid_size_ + grid_size_grp;
250 
251  grid_size_ += grid_size_grp;
252 
253  // Get the actual split N from transformer
254  n_per_split = conv_to_gemm_transformer.GetN();
255  original_n = conv_to_gemm_transformer.GetOriginalN();
257 
258  ++gemm_count;
259  }
260  }
261  group_stride_a = args.K_; // A: Out NWGK
262  group_stride_b = args.K_ * args.C_ *
263  std::accumulate(args.filter_spatial_lengths_.begin(),
264  args.filter_spatial_lengths_.end(),
265  1,
266  std::multiplies<index_t>()); // B: Wei GKXC
267  group_stride_c = args.C_; // C: In NWGC
268 
270  args.C_ * args.G_ * args.input_spatial_lengths_[0] * args.input_spatial_lengths_[1];
272  args.K_ * args.G_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
273 
274  GemmBatch = args.G_;
275  }
276 
277  template <
278  typename InLay = typename GroupedConvTraitsType_::InLayout,
279  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
280  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
281  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NDHWGC> &&
282  std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
283  std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
284  bool>::type = false>
286  {
287  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
288  static_cast<index_t>(args.N_),
289  static_cast<index_t>(args.C_),
290  static_cast<index_t>(args.input_spatial_lengths_[0]),
291  static_cast<index_t>(args.input_spatial_lengths_[1]),
292  static_cast<index_t>(args.input_spatial_lengths_[2])};
293  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
294  static_cast<index_t>(args.K_),
295  static_cast<index_t>(args.C_),
296  static_cast<index_t>(args.filter_spatial_lengths_[0]),
297  static_cast<index_t>(args.filter_spatial_lengths_[1]),
298  static_cast<index_t>(args.filter_spatial_lengths_[2])};
299  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
300  static_cast<index_t>(args.N_),
301  static_cast<index_t>(args.K_),
302  static_cast<index_t>(args.output_spatial_lengths_[0]),
303  static_cast<index_t>(args.output_spatial_lengths_[1]),
304  static_cast<index_t>(args.output_spatial_lengths_[2])};
305 
306  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
307  static_cast<index_t>(args.conv_filter_strides_[1]),
308  static_cast<index_t>(args.conv_filter_strides_[2])};
309  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
310  static_cast<index_t>(args.conv_filter_dilations_[1]),
311  static_cast<index_t>(args.conv_filter_dilations_[2])};
312  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
313  static_cast<index_t>(args.input_left_pads_[1]),
314  static_cast<index_t>(args.input_left_pads_[2])};
315  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
316  static_cast<index_t>(args.input_right_pads_[1]),
317  static_cast<index_t>(args.input_right_pads_[2])};
318 
319  k_batch = args.k_batch;
320 
321  in_ptr = args.in_ptr;
322  wei_ptr = args.wei_ptr;
323  for(index_t d = 0; d < NumDTensor; d++)
324  {
325  ds_ptr[d] = args.ds_ptr[d];
326  }
327  out_ptr = args.out_ptr;
328 
329  const index_t Z = wei_g_k_c_xs_lengths[3];
330  const index_t Y = wei_g_k_c_xs_lengths[4];
331  const index_t X = wei_g_k_c_xs_lengths[5];
332  const index_t ConvStrideD = conv_filter_strides[0];
333  const index_t ConvStrideH = conv_filter_strides[1];
334  const index_t ConvStrideW = conv_filter_strides[2];
335  const index_t ConvDilationD = conv_filter_dilations[0];
336  const index_t ConvDilationH = conv_filter_dilations[1];
337  const index_t ConvDilationW = conv_filter_dilations[2];
338  const auto GcdStrideDilationD = gcd(ConvStrideD, ConvDilationD);
339  const auto GcdStrideDilationH = gcd(ConvStrideH, ConvDilationH);
340  const auto GcdStrideDilationW = gcd(ConvStrideW, ConvDilationW);
341  const auto ZTilde = ConvStrideD / GcdStrideDilationD;
342  const auto YTilde = ConvStrideH / GcdStrideDilationH;
343  const auto XTilde = ConvStrideW / GcdStrideDilationW;
344 
345  for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde)
346  {
347  for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
348  {
349  for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
350  {
351  const auto ZDotSlice = integer_divide_ceil(Z - i_ztilde, ZTilde);
352  const auto YDotSlice = integer_divide_ceil(Y - i_ytilde, YTilde);
353  const auto XDotSlice = integer_divide_ceil(X - i_xtilde, XTilde);
354 
355  if(ZDotSlice * XDotSlice * YDotSlice <= 0)
356  {
357  continue;
358  }
359 
361  {
362  gemm_count++;
363  // Avoid array segfault
364  continue;
365  }
366 
367  tildes = {i_ztilde, i_ytilde, i_xtilde};
368 
369  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
376  tildes};
377 
378  auto grid_descs = conv_to_gemm_transformer
379  .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
380  GroupedConvTraitsType_::NDimSpatial>(1);
381 
382  a_grid_descs_m_k[gemm_count] = grid_descs.at(number<0>{});
383  b_grid_descs_n_k[gemm_count] = grid_descs.at(number<1>{});
384  c_grid_descs_m_n[gemm_count] = grid_descs.at(number<2>{});
385 
386  const index_t grid_size_grp =
387  TilePartitioner::GridSize(c_grid_descs_m_n[gemm_count].get_length(I0),
388  c_grid_descs_m_n[gemm_count].get_length(I1));
389 
391  block_ends[gemm_count] = grid_size_ + grid_size_grp;
392 
393  grid_size_ += grid_size_grp;
394 
395  // Get the actual split N from transformer
396  n_per_split = conv_to_gemm_transformer.GetN();
397  original_n = conv_to_gemm_transformer.GetOriginalN();
399 
400  ++gemm_count;
401  }
402  }
403  }
404 
405  group_stride_a = args.K_; // A: Out NWGK
406  group_stride_b = args.K_ * args.C_ *
407  std::accumulate(args.filter_spatial_lengths_.begin(),
408  args.filter_spatial_lengths_.end(),
409  1,
410  std::multiplies<index_t>()); // B: Wei GKXC
411  group_stride_c = args.C_; // C: In NWGC
412 
413  input_batch_stride = args.C_ * args.G_ * args.input_spatial_lengths_[0] *
415  output_batch_stride = args.K_ * args.G_ * args.output_spatial_lengths_[0] *
417 
418  GemmBatch = args.G_; // C: In NWGC
419  }
420 
421  static constexpr index_t MaxGroupedGemmGroupsNum = 128;
422 
425 
429 
430  static constexpr index_t NonSpatialDims = 3;
434 
440 
445 
446  const void* out_ptr;
447  void* in_ptr;
448  std::array<const void*, NumDTensor> ds_ptr;
449  const void* wei_ptr;
450 
454 
457 
461 
462  // Split-N support fields - initialize to safe defaults
463  index_t n_splits = 1; // Number of batch splits (e.g., 2 for 128→64×2)
464  index_t n_per_split = 1; // Batches per split (N_ from transformer)
465  index_t original_n = 1; // Original batch size before splitting
466  index_t input_batch_stride = 0; // Stride to next batch in input tensor
467  index_t output_batch_stride = 0; // Stride to next batch in output tensor
468 };
469 
508 template <typename GroupedConvTraitsType_,
509  typename TilePartitioner_,
510  typename GemmPipeline_,
511  typename EpiloguePipeline_>
513 {
514  static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial;
516  GroupedConvTraitsType_::ConvSpecialization;
523 
528 
530  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
531 
532  static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
533 
537 
539 
542  static constexpr index_t MaxGroupedGemmGroupsNum =
544 
545  static constexpr auto I0 = number<0>();
546  static constexpr auto I1 = number<1>();
547  static constexpr auto I2 = number<2>();
548  static constexpr auto I3 = number<3>();
549 
550  static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
551  "Not supported!");
552  static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>, "Not supported!");
553  static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
554  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
555  "Not supported C GEMM layout!");
556  static_assert(GroupedConvTraitsType_::ExplicitGemm == false, "Not supported yet!");
557 
558  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
559  {
560  static constexpr bool EnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
561  constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
562  // clang-format off
563  return concat('_', "grouped_convolution_backward_data",
564  gemm_prec_str<InDataType, WeiDataType>(),
565  InLayout::name,
566  WeiLayout::name,
567  OutLayout::name,
568  "gemm",
569  GemmPipeline::GetName(),
570  "epilogue",
571  EpiloguePipeline::GetName(),
573  "MergedGroups",
574  NumGroupsToMerge,
575  "SplitImage",
576  EnableSplitImage,
577  "ExplicitGemm",
578  GroupedConvTraitsType_::ExplicitGemm
579  );
580  // clang-format on
581  }
582 
583  [[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); }
584 
585 #ifdef CK_EXPERIMENTAL_BUILDER
586  CK_TILE_HOST std::string GetInstanceString() const
587  {
588  static_assert(ck_tile::reflect::HasInstanceTraits<GroupedConvolutionBackwardDataKernel>,
589  "Specialization of instance_traits not found. Please check that a "
590  "specialization exists in file "
591  "ck_tile/builder/reflect/"
592  "instance_traits_tile_grouped_convolution_backward_data.hpp "
593  "for the given template parameters.");
594  return ck_tile::reflect::instance_string<GroupedConvolutionBackwardDataKernel>();
595  }
596 #endif
597 
599  {
600  // enable batched grouped gemm
601  return dim3(kargs.grid_size_, kargs.GemmBatch, kargs.n_splits * kargs.k_batch);
602  }
603 
604  CK_TILE_HOST static constexpr auto BlockSize()
605  {
606  return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
607  }
608 
611  {
613  }
614 
616  {
617  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
618  }
619 
620  CK_TILE_DEVICE static auto
623  const index_t group_id,
624  const index_t i_m,
625  const index_t i_k)
626  {
627  // Step 1: Create tensor view for A (Out tensor)
628  const auto& a_tensor_view =
629  make_tensor_view<address_space_enum::global>(a_ptr, kargs.a_grid_descs_m_k[group_id]);
630 
631  // Step 2: Create padded view
632  const auto& a_pad_view = pad_tensor_view(
633  a_tensor_view,
636 
637  // Step 3: Create tile window
638  auto a_block_window = make_tile_window(
639  a_pad_view,
641  {i_m, i_k});
642 
643  return a_block_window;
644  }
645 
646  CK_TILE_DEVICE static auto
649  const index_t group_id,
650  const index_t i_n,
651  const index_t i_k)
652  {
653  // Step 1: Create tensor view for B (Weight tensor)
654  const auto& b_tensor_view =
655  make_tensor_view<address_space_enum::global>(b_ptr, kargs.b_grid_descs_n_k[group_id]);
656 
657  // Step 2: Create padded view
658  const auto& b_pad_view = pad_tensor_view(
659  b_tensor_view,
662 
663  // Step 3: Create tile window
664  auto b_block_window = make_tile_window(
665  b_pad_view,
667  {i_k, i_n});
668 
669  return b_block_window;
670  }
671 
672  CK_TILE_DEVICE static auto
673  MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
675  const index_t group_id,
676  const index_t i_m,
677  const index_t i_n)
678  {
679  // Create D tensor block windows
680  const auto ds_block_window = generate_tuple(
681  [&](auto i) {
682  // Step 1: Create tensor view for D
683  const auto& d_tensor_view = make_tensor_view<address_space_enum::global>(
684  static_cast<const OutDataType*>(ds_ptr[i]), kargs.c_grid_descs_m_n[group_id]);
685 
686  // Step 2: Create padded view
687  const auto& d_pad_view =
688  pad_tensor_view(d_tensor_view,
692 
693  // Step 3: Create tile window
694  return make_tile_window(d_pad_view,
697  {i_m, i_n});
698  },
700 
701  return ds_block_window;
702  }
703 
704  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
705  CK_TILE_DEVICE static auto
708  const index_t group_id,
709  const index_t i_m,
710  const index_t i_n)
711  {
712  // Step 1: Create tensor view for C (Input tensor)
713  const auto& c_tensor_view = make_tensor_view<address_space_enum::global, DstInMemOp>(
714  c_ptr, kargs.c_grid_descs_m_n[group_id]);
715 
716  // Step 2: Create padded view
717  const auto& c_pad_view = pad_tensor_view(
718  c_tensor_view,
721 
722  // Step 3: Create tile window
723  auto c_block_window = make_tile_window(
724  c_pad_view,
726  {i_m, i_n});
727 
728  return c_block_window;
729  }
730 
731  CK_TILE_HOST static bool
733  {
734  if constexpr(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
736  {
737  if(kargs.k_batch != 1)
738  {
739  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
740  {
741  CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
742  }
743  return false;
744  }
745  }
746 
748  {
749  return false;
750  }
751 
752  const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
753  const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
754 
755  // check ConvSpecialization
757  {
758  // check if it's 1x1, stride=1 conv
759  for(index_t i = 0; i < NDimSpatial; ++i)
760  {
761  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
762  const index_t ConvStride = kargs.conv_filter_strides[i];
763  const index_t LeftPad = kargs.input_left_pads[i];
764  const index_t RightPad = kargs.input_right_pads[i];
765 
766  if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
767  {
768  return false;
769  }
770  }
771  }
773  {
774  // check if it's 1x1 conv
775  for(index_t i = 0; i < NDimSpatial; ++i)
776  {
777  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
778  const index_t LeftPad = kargs.input_left_pads[i];
779  const index_t RightPad = kargs.input_right_pads[i];
780 
781  if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
782  {
783  return false;
784  }
785  }
786  }
788  {
789  if(ConvC != 1)
790  {
791  return false;
792  }
793  for(index_t i = 0; i < NDimSpatial; ++i)
794  {
795  const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3];
796 
797  if(filter_spatial_dim != I3)
798  {
799  return false;
800  }
801  }
802  }
803 
804  namespace ctc = tensor_layout::convolution;
805 
806  if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
807  std::is_same_v<InLayout, ctc::NDHWGC>)
808  {
809  // Check access per C
810  if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
811  {
812  CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
813  return false;
814  }
815  }
816  else
817  {
818  CK_TILE_ERROR("Not supported input layout!");
819  return false;
820  }
821 
822  // FIXME: layout
823  if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
824  std::is_same_v<WeiLayout, ctc::GKYXC> ||
825  std::is_same_v<WeiLayout, ctc::GKZYXC>)
826  {
827  if(ConvC % GroupedConvTraitsType_::VectorSizeC != 0)
828  {
829  CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
830  return false;
831  }
832  }
833  else
834  {
835  CK_TILE_ERROR("Not supported weight layout!");
836  return false;
837  }
838 
839  if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
840  std::is_same_v<OutLayout, ctc::NHWGK> ||
841  std::is_same_v<OutLayout, ctc::NDHWGK>)
842  {
843  if(ConvK % GroupedConvTraitsType_::VectorSizeA != 0)
844  {
845  CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!");
846  return false;
847  }
848  }
849  else
850  {
851  CK_TILE_ERROR("Not supported output layout!");
852  return false;
853  }
854 
855  return true;
856  }
857 
858  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
859  CK_TILE_DEVICE static auto
861  const InDataType* b_ptr,
862  const std::array<const void*, NumDTensor>& ds_ptr,
863  WeiDataType* c_ptr,
865  const index_t group_id)
866  {
867  static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
868  static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!");
869  const auto& a_tensor_view = [&]() {
870  return make_tensor_view<address_space_enum::global>(
871  a_ptr,
872  kargs.a_grid_descs_m_k[group_id]); // A: out
873  }();
874 
875  const auto& b_tensor_view = [&]() {
876  return make_tensor_view<address_space_enum::global>(
877  b_ptr,
878  kargs.b_grid_descs_n_k[group_id]); // B: weight
879  }();
880 
881  const auto& c_tensor_view = [&]() {
882  return make_tensor_view<address_space_enum::global, DstInMemOp>(
883  c_ptr, kargs.c_grid_descs_m_n[group_id]);
884  }();
885 
886  const auto& ds_tensor_view = generate_tuple(
887  [&](auto i) {
888  static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
889  "Not supported!");
890  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
891  "Not supported!");
892  static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, OutDataType>,
893  "Not supported!");
894 
895  return make_tensor_view<address_space_enum::global>(
896  static_cast<OutDataType*>(ds_ptr[i]), kargs.c_grid_descs_m_n[group_id]);
897  },
899 
900  return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
901  }
902 
903  template <typename TensorView>
904  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
905  {
906  const auto& a_pad_view = [&]() {
907  const auto& a_tensor_view = views.at(I0);
908  return pad_tensor_view(a_tensor_view,
912  }();
913 
914  const auto& b_pad_view = [&]() {
915  const auto& b_tensor_view = views.at(I1);
916  return pad_tensor_view(b_tensor_view,
920  }();
921 
922  const auto& ds_tensor_view = views.at(I2);
923  const auto& ds_pad_view = generate_tuple(
924  [&](auto i) {
925  return pad_tensor_view(ds_tensor_view[i],
929  },
931 
932  const auto& c_pad_view = [&]() {
933  const auto& c_tensor_view = views.at(I3);
934  return pad_tensor_view(c_tensor_view,
938  }();
939 
940  return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
941  }
942 
943  template <typename PadView>
944  CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
945  const index_t i_m,
946  const index_t i_n,
947  const index_t i_k)
948  {
949  const auto& a_pad_view = views.at(I0);
950  const auto& b_pad_view = views.at(I1);
951  const auto& ds_pad_view = views.at(I2);
952  const auto& c_pad_view = views.at(I3);
953 
954  const auto& a_block_window = [&]() {
955  return make_tile_window(a_pad_view,
958  {i_m, i_k});
959  }();
960 
961  const auto& b_block_window = [&]() {
962  return make_tile_window(b_pad_view,
965  {i_k, i_n});
966  }();
967 
968  const auto ds_block_window = generate_tuple(
969  [&](auto i) {
970  return make_tile_window(ds_pad_view[i],
973  {i_m, i_n});
974  },
976 
977  auto c_block_window = make_tile_window(
978  c_pad_view,
980  {i_m, i_n});
981 
982  return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
983  }
984 
997  CK_TILE_DEVICE static void RunGemm(const OutDataType* a_ptr,
998  const InDataType* b_ptr,
999  const std::array<const void*, NumDTensor>& ds_ptr,
1000  WeiDataType* c_ptr,
1001  void* smem_ptr_0,
1003  const index_t splitted_k,
1004  const index_t block_idx_m,
1005  const index_t block_idx_n,
1006  const index_t block_idx_k,
1007  const index_t group_id)
1008  {
1009  // Create block windows using specialized methods
1010  const auto& a_block_window =
1011  MakeABlockWindow(a_ptr, kargs, group_id, block_idx_m, block_idx_k);
1012  const auto& b_block_window =
1013  MakeBBlockWindow(b_ptr, kargs, group_id, block_idx_n, block_idx_k);
1014  const auto& d_block_window =
1015  MakeDBlockWindows(ds_ptr, kargs, group_id, block_idx_m, block_idx_n);
1016 
1017  const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitted_k));
1018  const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
1019  const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
1020 
1021  // Run GEMM cooperatively by whole workgroup.
1022  const auto& c_block_tile = GemmPipeline{}.template operator()(
1023  a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
1024 
1025  const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch);
1026 
1027  // Run Epilogue Pipeline with k_batch dispatch
1028  if(k_batch == 1)
1029  {
1030  auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
1031  c_ptr, kargs, group_id, block_idx_m, block_idx_n);
1032 
1034  .template operator()<decltype(c_block_window), decltype(c_block_tile)>(
1035  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
1036  }
1037  else
1038  {
1039  if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
1041  {
1042  auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
1043  c_ptr, kargs, group_id, block_idx_m, block_idx_n);
1044 
1046  .template operator()<decltype(c_block_window), decltype(c_block_tile)>(
1047  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
1048  }
1049  }
1050  }
1051 
1053  index_t block_id) const
1054  {
1055  index_t left = 0;
1056  index_t right = kargs.gemm_count;
1057  index_t group_id = index_t((left + right) >> 1);
1058 
1059  while((!(block_id >= kargs.block_starts[group_id] &&
1060  block_id < kargs.block_ends[group_id])) &&
1061  left <= right)
1062  {
1063  if(block_id < kargs.block_starts[group_id])
1064  {
1065  right = group_id;
1066  }
1067  else
1068  {
1069  left = group_id;
1070  }
1071  group_id = index_t((left + right) >> 1);
1072  }
1073 
1074  return group_id;
1075  }
1076 
1078  {
1079  const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
1080  const index_t group_id = FindGroupId(kargs, blockIdX);
1081 
1083  kargs.block_starts[group_id],
1084  kargs.c_grid_descs_m_n[group_id].get_length(I0),
1085  kargs.c_grid_descs_m_n[group_id].get_length(I1));
1086 
1087  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
1088  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
1089 
1090  const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
1091  const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
1092  const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
1093  const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
1094 
1095  const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z);
1096 
1097  // SplitN
1098  const index_t split_n_idx = __builtin_amdgcn_readfirstlane(blockIdZ / kargs.k_batch);
1099  const index_t split_n_offset =
1100  __builtin_amdgcn_readfirstlane(split_n_idx * kargs.n_per_split);
1101 
1102  const long_index_t output_batch_offset =
1103  static_cast<long_index_t>(split_n_offset) *
1104  static_cast<long_index_t>(kargs.output_batch_stride);
1105  const long_index_t input_batch_offset = static_cast<long_index_t>(split_n_offset) *
1106  static_cast<long_index_t>(kargs.input_batch_stride);
1107 
1108  // SplitK
1109  const index_t split_k_idx =
1110  __builtin_amdgcn_readfirstlane(blockIdZ - split_n_idx * kargs.k_batch);
1111 
1112  const index_t gemm_k = kargs.a_grid_descs_m_k[group_id].get_length(I1);
1113 
1114  constexpr auto K1 = TilePartitioner::KPerBlock;
1115  const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
1116  const index_t KRead = amd_wave_read_first_lane((gemm_k + K_t - 1) / K_t * K1);
1117 
1118  const index_t i_k = amd_wave_read_first_lane(split_k_idx * KRead);
1119  const index_t splitted_k = amd_wave_read_first_lane(KRead);
1120 
1121  // options
1122  // conv_bwd_data = Out * Weight = In
1123  const OutDataType* a_ptr =
1124  static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a + output_batch_offset;
1125  const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) + group_offset_b;
1126  InDataType* c_ptr =
1127  static_cast<InDataType*>(kargs.in_ptr) + group_offset_c + input_batch_offset;
1128 
1129  // allocate LDS
1130  __shared__ char smem_ptr[GetSmemSize()];
1131  RunGemm(a_ptr,
1132  b_ptr,
1133  kargs.ds_ptr,
1134  c_ptr,
1135  smem_ptr,
1136  kargs,
1137  splitted_k,
1138  i_m,
1139  i_n,
1140  i_k,
1141  group_id);
1142  }
1143 };
1144 
1145 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
Definition: cluster_descriptor.hpp:13
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:145
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:36
TailNumber
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:21
ConvolutionSpecialization
Definition: convolution_specialization.hpp:11
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE index_t gcd(index_t x, index_t y)
Definition: math.hpp:264
int64_t long_index_t
Definition: integer.hpp:11
CK_TILE_HOST std::string getConvSpecializationString(const ConvolutionSpecialization &s)
Definition: convolution_specialization.hpp:18
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:157
constexpr bool is_same_v
Definition: type.hpp:283
The Grouped Convolution kernel device arguments.
Definition: grouped_convolution_backward_data_kernel.hpp:26
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > wei_g_k_c_xs_lengths
Definition: grouped_convolution_backward_data_kernel.hpp:432
static constexpr auto I1
Definition: grouped_convolution_backward_data_kernel.hpp:39
CK_TILE_HOST GroupedConvBwdDataKernelArgs(const GroupedConvBwdDataHostArgs &args)
Definition: grouped_convolution_backward_data_kernel.hpp:49
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_dilations
Definition: grouped_convolution_backward_data_kernel.hpp:436
std::array< const void *, NumDTensor > ds_ptr
Definition: grouped_convolution_backward_data_kernel.hpp:448
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_strides
Definition: grouped_convolution_backward_data_kernel.hpp:435
array< index_t, MaxGroupedGemmGroupsNum > block_starts
Definition: grouped_convolution_backward_data_kernel.hpp:455
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_left_pads
Definition: grouped_convolution_backward_data_kernel.hpp:437
long_index_t group_stride_b
Definition: grouped_convolution_backward_data_kernel.hpp:459
long_index_t group_stride_c
Definition: grouped_convolution_backward_data_kernel.hpp:460
array< index_t, MaxGroupedGemmGroupsNum > block_ends
Definition: grouped_convolution_backward_data_kernel.hpp:456
const void * out_ptr
Definition: grouped_convolution_backward_data_kernel.hpp:446
remove_cvref_t< decltype(ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(1))> ABCGridDescs
Definition: grouped_convolution_backward_data_kernel.hpp:424
remove_cvref_t< decltype(ABCGridDescs{}[number< 1 >{}])> BGridDescNK
Definition: grouped_convolution_backward_data_kernel.hpp:427
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_convolution_backward_data_kernel.hpp:27
array< index_t, GroupedConvTraitsType_::NDimSpatial > tildes
Definition: grouped_convolution_backward_data_kernel.hpp:439
remove_cvref_t< decltype(ABCGridDescs{}[number< 0 >{}])> AGridDescMK
Definition: grouped_convolution_backward_data_kernel.hpp:426
const void * wei_ptr
Definition: grouped_convolution_backward_data_kernel.hpp:449
index_t n_per_split
Definition: grouped_convolution_backward_data_kernel.hpp:464
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > out_g_n_k_wos_lengths
Definition: grouped_convolution_backward_data_kernel.hpp:433
long_index_t group_stride_a
Definition: grouped_convolution_backward_data_kernel.hpp:458
index_t GemmBatch
Definition: grouped_convolution_backward_data_kernel.hpp:442
void * in_ptr
Definition: grouped_convolution_backward_data_kernel.hpp:447
index_t n_splits
Definition: grouped_convolution_backward_data_kernel.hpp:463
index_t gemm_count
Definition: grouped_convolution_backward_data_kernel.hpp:444
array< CGridDescMN, MaxGroupedGemmGroupsNum > c_grid_descs_m_n
Definition: grouped_convolution_backward_data_kernel.hpp:453
index_t original_n
Definition: grouped_convolution_backward_data_kernel.hpp:465
index_t grid_size_
Definition: grouped_convolution_backward_data_kernel.hpp:443
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_right_pads
Definition: grouped_convolution_backward_data_kernel.hpp:438
array< BGridDescNK, MaxGroupedGemmGroupsNum > b_grid_descs_n_k
Definition: grouped_convolution_backward_data_kernel.hpp:452
index_t k_batch
Definition: grouped_convolution_backward_data_kernel.hpp:441
static constexpr auto I0
Definition: grouped_convolution_backward_data_kernel.hpp:38
static constexpr index_t MaxGroupedGemmGroupsNum
Definition: grouped_convolution_backward_data_kernel.hpp:421
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > in_g_n_c_wis_lengths
Definition: grouped_convolution_backward_data_kernel.hpp:431
static constexpr index_t NumDTensor
Definition: grouped_convolution_backward_data_kernel.hpp:36
index_t output_batch_stride
Definition: grouped_convolution_backward_data_kernel.hpp:467
index_t input_batch_stride
Definition: grouped_convolution_backward_data_kernel.hpp:466
array< AGridDescMK, MaxGroupedGemmGroupsNum > a_grid_descs_m_k
Definition: grouped_convolution_backward_data_kernel.hpp:451
remove_cvref_t< decltype(ABCGridDescs{}[number< 2 >{}])> CGridDescMN
Definition: grouped_convolution_backward_data_kernel.hpp:428
static constexpr index_t NonSpatialDims
Definition: grouped_convolution_backward_data_kernel.hpp:430
The Grouped Conv kernel host arguments.
Definition: grouped_convolution_utils.hpp:27
InPtr in_ptr
Definition: grouped_convolution_utils.hpp:46
OutPtr out_ptr
Definition: grouped_convolution_utils.hpp:49
WeiPtr wei_ptr
Definition: grouped_convolution_utils.hpp:47
index_t k_batch
Definition: grouped_convolution_utils.hpp:50
const std::vector< const void * > ds_ptr
Definition: grouped_convolution_utils.hpp:48
The Grouped Convolution Backward Data kernel template.
Definition: grouped_convolution_backward_data_kernel.hpp:513
static constexpr index_t NDimSpatial
Definition: grouped_convolution_backward_data_kernel.hpp:514
static constexpr CK_TILE_HOST auto BlockSize()
Definition: grouped_convolution_backward_data_kernel.hpp:604
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: grouped_convolution_backward_data_kernel.hpp:518
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: grouped_convolution_backward_data_kernel.hpp:904
GroupedConvBwdDataKernelArgs< GroupedConvTraitsType_, TilePartitioner > GroupedConvBwdDataKernelArgsSpecialized
Definition: grouped_convolution_backward_data_kernel.hpp:541
remove_cvref_t< typename GemmPipeline::ADataType > InDataType
Definition: grouped_convolution_backward_data_kernel.hpp:534
static constexpr index_t MaxGroupedGemmGroupsNum
Definition: grouped_convolution_backward_data_kernel.hpp:542
static constexpr auto I1
Definition: grouped_convolution_backward_data_kernel.hpp:546
CK_TILE_DEVICE void operator()(GroupedConvBwdDataKernelArgsSpecialized &kargs) const
Definition: grouped_convolution_backward_data_kernel.hpp:1077
static CK_TILE_DEVICE auto MakeDBlockWindows(const std::array< const void *, NumDTensor > &ds_ptr, const GroupedConvBwdDataKernelArgsSpecialized &kargs, const index_t group_id, const index_t i_m, const index_t i_n)
Definition: grouped_convolution_backward_data_kernel.hpp:673
static constexpr auto I3
Definition: grouped_convolution_backward_data_kernel.hpp:548
remove_cvref_t< typename GroupedConvTraitsType_::OutLayout > OutLayout
Definition: grouped_convolution_backward_data_kernel.hpp:526
static constexpr ConvolutionSpecialization ConvSpecialization
Definition: grouped_convolution_backward_data_kernel.hpp:515
static constexpr CK_TILE_HOST GroupedConvBwdDataKernelArgsSpecialized MakeKernelArgs(const GroupedConvBwdDataHostArgs &hostArgs)
Definition: grouped_convolution_backward_data_kernel.hpp:610
static CK_TILE_DEVICE void RunGemm(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, void *smem_ptr_0, const GroupedConvBwdDataKernelArgsSpecialized &kargs, const index_t splitted_k, const index_t block_idx_m, const index_t block_idx_n, const index_t block_idx_k, const index_t group_id)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_backward_data_kernel.hpp:997
static constexpr index_t NumDTensor
Definition: grouped_convolution_backward_data_kernel.hpp:530
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n, const index_t i_k)
Definition: grouped_convolution_backward_data_kernel.hpp:944
remove_cvref_t< typename GemmPipeline::BDataType > WeiDataType
Definition: grouped_convolution_backward_data_kernel.hpp:535
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: grouped_convolution_backward_data_kernel.hpp:519
remove_cvref_t< typename EpiloguePipeline::ODataType > OutDataType
Definition: grouped_convolution_backward_data_kernel.hpp:538
static CK_TILE_DEVICE auto MakeABlockWindow(const OutDataType *a_ptr, const GroupedConvBwdDataKernelArgsSpecialized &kargs, const index_t group_id, const index_t i_m, const index_t i_k)
Definition: grouped_convolution_backward_data_kernel.hpp:621
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_convolution_backward_data_kernel.hpp:517
remove_cvref_t< typename GroupedConvTraitsType_::WeiLayout > WeiLayout
Definition: grouped_convolution_backward_data_kernel.hpp:525
static CK_TILE_DEVICE auto MakeBBlockWindow(const InDataType *b_ptr, const GroupedConvBwdDataKernelArgsSpecialized &kargs, const index_t group_id, const index_t i_n, const index_t i_k)
Definition: grouped_convolution_backward_data_kernel.hpp:647
static constexpr index_t kBlockSize
Definition: grouped_convolution_backward_data_kernel.hpp:532
static CK_TILE_HOST const std::string GetTypeString()
Definition: grouped_convolution_backward_data_kernel.hpp:583
static CK_TILE_HOST bool IsSupportedArgument(const GroupedConvBwdDataKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_data_kernel.hpp:732
remove_cvref_t< typename GemmPipeline::BLayout > GemmBLayout
Definition: grouped_convolution_backward_data_kernel.hpp:521
remove_cvref_t< typename GroupedConvTraitsType_::DsLayout > DsLayout
Definition: grouped_convolution_backward_data_kernel.hpp:527
static constexpr auto I2
Definition: grouped_convolution_backward_data_kernel.hpp:547
static CK_TILE_DEVICE auto MakeGemmTensorViews(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, const GroupedConvBwdDataKernelArgsSpecialized &kargs, const index_t group_id)
Definition: grouped_convolution_backward_data_kernel.hpp:860
static CK_TILE_HOST auto GridSize(const GroupedConvBwdDataKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_data_kernel.hpp:598
remove_cvref_t< typename GemmPipeline::ALayout > GemmALayout
Definition: grouped_convolution_backward_data_kernel.hpp:520
remove_cvref_t< typename EpiloguePipeline::DsLayout > GemmDsLayout
Definition: grouped_convolution_backward_data_kernel.hpp:529
CK_TILE_DEVICE index_t FindGroupId(const GroupedConvBwdDataKernelArgsSpecialized &kargs, index_t block_id) const
Definition: grouped_convolution_backward_data_kernel.hpp:1052
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: grouped_convolution_backward_data_kernel.hpp:615
remove_cvref_t< typename GroupedConvTraitsType_::InLayout > InLayout
Definition: grouped_convolution_backward_data_kernel.hpp:524
remove_cvref_t< typename GemmPipeline::CLayout > GemmCLayout
Definition: grouped_convolution_backward_data_kernel.hpp:522
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: grouped_convolution_backward_data_kernel.hpp:536
static CK_TILE_HOST const std::string GetName()
Definition: grouped_convolution_backward_data_kernel.hpp:558
static constexpr auto I0
Definition: grouped_convolution_backward_data_kernel.hpp:545
static CK_TILE_DEVICE auto MakeCBlockWindow(WeiDataType *c_ptr, const GroupedConvBwdDataKernelArgsSpecialized &kargs, const index_t group_id, const index_t i_m, const index_t i_n)
Definition: grouped_convolution_backward_data_kernel.hpp:706
static CK_TILE_DEVICE auto GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept -> const tuple< index_t, index_t >
The function subtracts the block's start (offset) from 1D raw-indexes.
Definition: gemm_tile_partitioner.hpp:192
Definition: transform_conv_bwd_data_to_gemm.hpp:21
CK_TILE_HOST auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N([[maybe_unused]] const index_t GemmKBatch) const
Definition: transform_conv_bwd_data_to_gemm.hpp:659
constexpr CK_TILE_HOST_DEVICE auto & at(index_t i)
Definition: array.hpp:110
Definition: integral_constant.hpp:13
std::vector< ck_tile::long_index_t > input_spatial_lengths_
Definition: convolution_parameter.hpp:130
ck_tile::long_index_t K_
Definition: convolution_parameter.hpp:126
std::vector< ck_tile::long_index_t > output_spatial_lengths_
Definition: convolution_parameter.hpp:131
std::vector< ck_tile::long_index_t > input_right_pads_
Definition: convolution_parameter.hpp:137
ck_tile::long_index_t G_
Definition: convolution_parameter.hpp:124
std::vector< ck_tile::long_index_t > conv_filter_strides_
Definition: convolution_parameter.hpp:133
std::vector< ck_tile::long_index_t > filter_spatial_lengths_
Definition: convolution_parameter.hpp:129
ck_tile::long_index_t C_
Definition: convolution_parameter.hpp:127
ck_tile::long_index_t N_
Definition: convolution_parameter.hpp:125
std::vector< ck_tile::long_index_t > input_left_pads_
Definition: convolution_parameter.hpp:136
std::vector< ck_tile::long_index_t > conv_filter_dilations_
Definition: convolution_parameter.hpp:134
Definition: type_traits.hpp:115
Definition: sequence.hpp:49
#define CK_TILE_ENV(name)
Definition: env.hpp:145