/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp Source File
grouped_convolution_backward_data_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 #include "ck_tile/host/concat.hpp"
16 
17 namespace ck_tile {
18 
20 template <typename GroupedConvTraitsType_, typename TilePartitioner_>
22 {
24 
26  TransformConvBwdDataToGemm<GroupedConvTraitsType_::NDimSpatial,
27  GroupedConvTraitsType_::ConvSpecialization,
28  GroupedConvTraitsType_::VectorSizeA,
29  GroupedConvTraitsType_::VectorSizeB,
30  GroupedConvTraitsType_::VectorSizeC>;
31  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
32 
33  static constexpr auto I0 = number<0>();
34  static constexpr auto I1 = number<1>();
35 
36  template <
37  typename InLay = typename GroupedConvTraitsType_::InLayout,
38  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
39  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
40  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NWGC> &&
41  std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
42  std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
43  bool>::type = false>
45  {
46  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
47  static_cast<index_t>(args.N_),
48  static_cast<index_t>(args.C_),
49  static_cast<index_t>(args.input_spatial_lengths_[0])};
50  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
51  static_cast<index_t>(args.K_),
52  static_cast<index_t>(args.C_),
53  static_cast<index_t>(args.filter_spatial_lengths_[0])};
54  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
55  static_cast<index_t>(args.N_),
56  static_cast<index_t>(args.K_),
57  static_cast<index_t>(args.output_spatial_lengths_[0])};
58 
59  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0])};
60  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0])};
61  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0])};
62  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0])};
63 
64  k_batch = args.k_batch;
65 
66  in_ptr = args.in_ptr;
67  wei_ptr = args.wei_ptr;
68  for(index_t d = 0; d < NumDTensor; d++)
69  {
70  ds_ptr[d] = args.ds_ptr[d];
71  }
72  out_ptr = args.out_ptr;
73 
74  const index_t X = wei_g_k_c_xs_lengths[3];
75  const index_t ConvStrideW = conv_filter_strides[0];
76  const index_t ConvDilationW = conv_filter_dilations[0];
77  const auto GcdStrideDilationW = gcd(ConvStrideW, ConvDilationW);
78  const auto XTilde = ConvStrideW / GcdStrideDilationW;
79 
80  for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
81  {
82  const auto XDotSlice = integer_divide_ceil(X - i_xtilde, XTilde);
83 
84  if(XDotSlice <= 0)
85  {
86  continue;
87  }
88 
90  {
91  gemm_count++;
92  // Avoid array segfault
93  continue;
94  }
95 
96  tildes = {i_xtilde};
97 
98  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
105  tildes};
106 
107  auto grid_descs =
108  conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
109  GroupedConvTraitsType_::NDimSpatial>(1);
110 
111  a_grid_descs_m_k[gemm_count] = grid_descs.at(number<0>{});
112  b_grid_descs_n_k[gemm_count] = grid_descs.at(number<1>{});
113  c_grid_descs_m_n[gemm_count] = grid_descs.at(number<2>{});
114 
115  const index_t grid_size_grp =
116  TilePartitioner::GridSize(c_grid_descs_m_n[gemm_count].get_length(I0),
117  c_grid_descs_m_n[gemm_count].get_length(I1));
118 
120  block_ends[gemm_count] = grid_size_ + grid_size_grp;
121 
122  grid_size_ += grid_size_grp;
123 
124  ++gemm_count;
125  }
126  group_stride_a = args.K_; // A: Out NWGK
127  group_stride_b = args.K_ * args.C_ *
128  std::accumulate(args.filter_spatial_lengths_.begin(),
129  args.filter_spatial_lengths_.end(),
130  1,
131  std::multiplies<index_t>()); // B: Wei GKXC
132  group_stride_c = args.C_; // C: In NWGC
133 
134  GemmBatch = args.G_;
135  }
136 
137  template <
138  typename InLay = typename GroupedConvTraitsType_::InLayout,
139  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
140  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
141  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NHWGC> &&
142  std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
143  std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
144  bool>::type = false>
146  {
147  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
148  static_cast<index_t>(args.N_),
149  static_cast<index_t>(args.C_),
150  static_cast<index_t>(args.input_spatial_lengths_[0]),
151  static_cast<index_t>(args.input_spatial_lengths_[1])};
152  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
153  static_cast<index_t>(args.K_),
154  static_cast<index_t>(args.C_),
155  static_cast<index_t>(args.filter_spatial_lengths_[0]),
156  static_cast<index_t>(args.filter_spatial_lengths_[1])};
157  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
158  static_cast<index_t>(args.N_),
159  static_cast<index_t>(args.K_),
160  static_cast<index_t>(args.output_spatial_lengths_[0]),
161  static_cast<index_t>(args.output_spatial_lengths_[1])};
162 
163  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
164  static_cast<index_t>(args.conv_filter_strides_[1])};
165  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
166  static_cast<index_t>(args.conv_filter_dilations_[1])};
167  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
168  static_cast<index_t>(args.input_left_pads_[1])};
169  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
170  static_cast<index_t>(args.input_right_pads_[1])};
171 
172  k_batch = args.k_batch;
173 
174  in_ptr = args.in_ptr;
175  wei_ptr = args.wei_ptr;
176  for(index_t d = 0; d < NumDTensor; d++)
177  {
178  ds_ptr[d] = args.ds_ptr[d];
179  }
180  out_ptr = args.out_ptr;
181 
182  const index_t Y = wei_g_k_c_xs_lengths[3];
183  const index_t X = wei_g_k_c_xs_lengths[4];
184  const index_t ConvStrideH = conv_filter_strides[0];
185  const index_t ConvStrideW = conv_filter_strides[1];
186  const index_t ConvDilationH = conv_filter_dilations[0];
187  const index_t ConvDilationW = conv_filter_dilations[1];
188  const auto GcdStrideDilationH = gcd(ConvStrideH, ConvDilationH);
189  const auto GcdStrideDilationW = gcd(ConvStrideW, ConvDilationW);
190  const auto YTilde = ConvStrideH / GcdStrideDilationH;
191  const auto XTilde = ConvStrideW / GcdStrideDilationW;
192 
193  for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
194  {
195  for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
196  {
197  const auto YDotSlice = integer_divide_ceil(Y - i_ytilde, YTilde);
198  const auto XDotSlice = integer_divide_ceil(X - i_xtilde, XTilde);
199 
200  if(XDotSlice * YDotSlice <= 0)
201  {
202  continue;
203  }
204 
206  {
207  gemm_count++;
208  // Avoid array segfault
209  continue;
210  }
211 
212  tildes = {i_ytilde, i_xtilde};
213 
214  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
221  tildes};
222 
223  auto grid_descs = conv_to_gemm_transformer
224  .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
225  GroupedConvTraitsType_::NDimSpatial>(1);
226 
227  a_grid_descs_m_k[gemm_count] = grid_descs.at(number<0>{});
228  b_grid_descs_n_k[gemm_count] = grid_descs.at(number<1>{});
229  c_grid_descs_m_n[gemm_count] = grid_descs.at(number<2>{});
230 
231  const index_t grid_size_grp =
232  TilePartitioner::GridSize(c_grid_descs_m_n[gemm_count].get_length(I0),
233  c_grid_descs_m_n[gemm_count].get_length(I1));
234 
236  block_ends[gemm_count] = grid_size_ + grid_size_grp;
237 
238  grid_size_ += grid_size_grp;
239 
240  ++gemm_count;
241  }
242  }
243  group_stride_a = args.K_; // A: Out NWGK
244  group_stride_b = args.K_ * args.C_ *
245  std::accumulate(args.filter_spatial_lengths_.begin(),
246  args.filter_spatial_lengths_.end(),
247  1,
248  std::multiplies<index_t>()); // B: Wei GKXC
249  group_stride_c = args.C_; // C: In NWGC
250 
251  GemmBatch = args.G_;
252  }
253 
254  template <
255  typename InLay = typename GroupedConvTraitsType_::InLayout,
256  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
257  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
258  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NDHWGC> &&
259  std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
260  std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
261  bool>::type = false>
263  {
264  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
265  static_cast<index_t>(args.N_),
266  static_cast<index_t>(args.C_),
267  static_cast<index_t>(args.input_spatial_lengths_[0]),
268  static_cast<index_t>(args.input_spatial_lengths_[1]),
269  static_cast<index_t>(args.input_spatial_lengths_[2])};
270  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
271  static_cast<index_t>(args.K_),
272  static_cast<index_t>(args.C_),
273  static_cast<index_t>(args.filter_spatial_lengths_[0]),
274  static_cast<index_t>(args.filter_spatial_lengths_[1]),
275  static_cast<index_t>(args.filter_spatial_lengths_[2])};
276  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
277  static_cast<index_t>(args.N_),
278  static_cast<index_t>(args.K_),
279  static_cast<index_t>(args.output_spatial_lengths_[0]),
280  static_cast<index_t>(args.output_spatial_lengths_[1]),
281  static_cast<index_t>(args.output_spatial_lengths_[2])};
282 
283  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
284  static_cast<index_t>(args.conv_filter_strides_[1]),
285  static_cast<index_t>(args.conv_filter_strides_[2])};
286  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
287  static_cast<index_t>(args.conv_filter_dilations_[1]),
288  static_cast<index_t>(args.conv_filter_dilations_[2])};
289  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
290  static_cast<index_t>(args.input_left_pads_[1]),
291  static_cast<index_t>(args.input_left_pads_[2])};
292  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
293  static_cast<index_t>(args.input_right_pads_[1]),
294  static_cast<index_t>(args.input_right_pads_[2])};
295 
296  k_batch = args.k_batch;
297 
298  in_ptr = args.in_ptr;
299  wei_ptr = args.wei_ptr;
300  for(index_t d = 0; d < NumDTensor; d++)
301  {
302  ds_ptr[d] = args.ds_ptr[d];
303  }
304  out_ptr = args.out_ptr;
305 
306  const index_t Z = wei_g_k_c_xs_lengths[3];
307  const index_t Y = wei_g_k_c_xs_lengths[4];
308  const index_t X = wei_g_k_c_xs_lengths[5];
309  const index_t ConvStrideD = conv_filter_strides[0];
310  const index_t ConvStrideH = conv_filter_strides[1];
311  const index_t ConvStrideW = conv_filter_strides[2];
312  const index_t ConvDilationD = conv_filter_dilations[0];
313  const index_t ConvDilationH = conv_filter_dilations[1];
314  const index_t ConvDilationW = conv_filter_dilations[2];
315  const auto GcdStrideDilationD = gcd(ConvStrideD, ConvDilationD);
316  const auto GcdStrideDilationH = gcd(ConvStrideH, ConvDilationH);
317  const auto GcdStrideDilationW = gcd(ConvStrideW, ConvDilationW);
318  const auto ZTilde = ConvStrideD / GcdStrideDilationD;
319  const auto YTilde = ConvStrideH / GcdStrideDilationH;
320  const auto XTilde = ConvStrideW / GcdStrideDilationW;
321 
322  for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde)
323  {
324  for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
325  {
326  for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
327  {
328  const auto ZDotSlice = integer_divide_ceil(Z - i_ztilde, ZTilde);
329  const auto YDotSlice = integer_divide_ceil(Y - i_ytilde, YTilde);
330  const auto XDotSlice = integer_divide_ceil(X - i_xtilde, XTilde);
331 
332  if(ZDotSlice * XDotSlice * YDotSlice <= 0)
333  {
334  continue;
335  }
336 
338  {
339  gemm_count++;
340  // Avoid array segfault
341  continue;
342  }
343 
344  tildes = {i_ztilde, i_ytilde, i_xtilde};
345 
346  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
353  tildes};
354 
355  auto grid_descs = conv_to_gemm_transformer
356  .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
357  GroupedConvTraitsType_::NDimSpatial>(1);
358 
359  a_grid_descs_m_k[gemm_count] = grid_descs.at(number<0>{});
360  b_grid_descs_n_k[gemm_count] = grid_descs.at(number<1>{});
361  c_grid_descs_m_n[gemm_count] = grid_descs.at(number<2>{});
362 
363  const index_t grid_size_grp =
364  TilePartitioner::GridSize(c_grid_descs_m_n[gemm_count].get_length(I0),
365  c_grid_descs_m_n[gemm_count].get_length(I1));
366 
368  block_ends[gemm_count] = grid_size_ + grid_size_grp;
369 
370  grid_size_ += grid_size_grp;
371 
372  ++gemm_count;
373  }
374  }
375  }
376 
377  group_stride_a = args.K_; // A: Out NWGK
378  group_stride_b = args.K_ * args.C_ *
379  std::accumulate(args.filter_spatial_lengths_.begin(),
380  args.filter_spatial_lengths_.end(),
381  1,
382  std::multiplies<index_t>()); // B: Wei GKXC
383  group_stride_c = args.C_; // C: In NWGC
384 
385  GemmBatch = args.G_; // C: In NWGC
386  }
387 
388  static constexpr index_t MaxGroupedGemmGroupsNum = 128;
389 
392 
396 
397  static constexpr index_t NonSpatialDims = 3;
401 
407 
412 
413  const void* out_ptr;
414  void* in_ptr;
415  std::array<const void*, NumDTensor> ds_ptr;
416  const void* wei_ptr;
417 
421 
424 
428 };
429 
468 template <typename GroupedConvTraitsType_,
469  typename TilePartitioner_,
470  typename GemmPipeline_,
471  typename EpiloguePipeline_>
473 {
474  // Todo: Enable Vector Load Size > 1
475  static_assert(GroupedConvTraitsType_::VectorSizeA == 1 &&
476  GroupedConvTraitsType_::VectorSizeB == 1);
477 
478  static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial_;
480  GroupedConvTraitsType_::ConvSpecialization;
487 
492 
494  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
495 
496  static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
497 
501 
503 
506  static constexpr index_t MaxGroupedGemmGroupsNum =
508 
509  // TODO: Enable this
510  static constexpr bool IsSplitKSupported = false;
511 
512  static constexpr auto I0 = number<0>();
513  static constexpr auto I1 = number<1>();
514  static constexpr auto I2 = number<2>();
515  static constexpr auto I3 = number<3>();
516 
517  static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
518  "Not supported!");
519  static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>, "Not supported!");
520  static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
521  // TODO: Change to and enable vector load
522  // static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>,
523  // "Not supported A GEMM layout!");
524  // static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>,
525  // "Not supported B GEMM layout!");
526  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
527  "Not supported C GEMM layout!");
528 
529  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
530  {
531  // clang-format off
532  return concat('_', "grouped_convolution_backward_data", gemm_prec_str<InDataType, WeiDataType>, GemmPipeline::GetName());
533  // clang-format on
534  }
535 
537  {
538  // enable batched grouped gemm
539  return dim3(kargs.grid_size_, kargs.GemmBatch, kargs.k_batch);
540  }
541 
542  CK_TILE_HOST static constexpr auto BlockSize()
543  {
544  return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
545  }
546 
549  {
551  }
552 
554  {
555  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
556  }
557 
558  CK_TILE_HOST static bool
560  {
561  if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
564  {
565  if(kargs.k_batch != 1)
566  {
567  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
568  {
569  CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
570  }
571  return false;
572  }
573  }
574 
576  {
577  return false;
578  }
579 
580  const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
581  const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
582 
583  // check ConvSpecialization
585  {
586  // check if it's 1x1, stride=1 conv
587  for(index_t i = 0; i < NDimSpatial; ++i)
588  {
589  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
590  const index_t ConvStride = kargs.conv_filter_strides[i];
591  const index_t LeftPad = kargs.input_left_pads[i];
592  const index_t RightPad = kargs.input_right_pads[i];
593 
594  if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
595  {
596  return false;
597  }
598  }
599  }
601  {
602  // check if it's 1x1 conv
603  for(index_t i = 0; i < NDimSpatial; ++i)
604  {
605  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
606  const index_t LeftPad = kargs.input_left_pads[i];
607  const index_t RightPad = kargs.input_right_pads[i];
608 
609  if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
610  {
611  return false;
612  }
613  }
614  }
616  {
617  if(ConvC != 1)
618  {
619  return false;
620  }
621  for(index_t i = 0; i < NDimSpatial; ++i)
622  {
623  const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3];
624 
625  if(filter_spatial_dim != I3)
626  {
627  return false;
628  }
629  }
630  }
631 
632  namespace ctc = tensor_layout::convolution;
633 
634  if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
635  std::is_same_v<InLayout, ctc::NDHWGC>)
636  {
637  // Check access per C
638  if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
639  {
640  CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
641  return false;
642  }
643  }
644  else
645  {
646  CK_TILE_ERROR("Not supported input layout!");
647  return false;
648  }
649 
650  // FIXME: layout
651  if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
652  std::is_same_v<WeiLayout, ctc::GKYXC> ||
653  std::is_same_v<WeiLayout, ctc::GKZYXC>)
654  {
655  if(ConvC % GroupedConvTraitsType_::VectorSizeC != 0)
656  {
657  CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
658  return false;
659  }
660  }
661  else
662  {
663  CK_TILE_ERROR("Not supported weight layout!");
664  return false;
665  }
666 
667  if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
668  std::is_same_v<OutLayout, ctc::NHWGK> ||
669  std::is_same_v<OutLayout, ctc::NDHWGK>)
670  {
671  if(ConvK % GroupedConvTraitsType_::VectorSizeA != 0)
672  {
673  CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!");
674  return false;
675  }
676  }
677  else
678  {
679  CK_TILE_ERROR("Not supported output layout!");
680  return false;
681  }
682 
683  return true;
684  }
685 
686  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
687  CK_TILE_DEVICE static auto
689  const InDataType* b_ptr,
690  const std::array<const void*, NumDTensor>& ds_ptr,
691  WeiDataType* c_ptr,
693  const index_t group_id)
694  {
695  static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
696  static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
697  const auto& a_tensor_view = [&]() {
698  return make_tensor_view<address_space_enum::global>(
699  a_ptr,
700  kargs.a_grid_descs_m_k[group_id]); // A: out
701  }();
702 
703  const auto& b_tensor_view = [&]() {
704  return make_tensor_view<address_space_enum::global>(
705  b_ptr,
706  kargs.b_grid_descs_n_k[group_id]); // B: weight
707  }();
708 
709  const auto& c_tensor_view = [&]() {
710  return make_tensor_view<address_space_enum::global>(c_ptr,
711  kargs.c_grid_descs_m_n[group_id]);
712  }();
713 
714  const auto& ds_tensor_view = generate_tuple(
715  [&](auto i) {
716  static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
717  "Not supported!");
718  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
719  "Not supported!");
720  static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, OutDataType>,
721  "Not supported!");
722 
723  return make_tensor_view<address_space_enum::global>(
724  static_cast<OutDataType*>(ds_ptr[i]), kargs.c_grid_descs_m_n[group_id]);
725  },
727 
728  return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
729  }
730 
731  template <typename TensorView>
732  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
733  {
734  const auto& a_pad_view = [&]() {
735  const auto& a_tensor_view = views.at(I0);
736  return pad_tensor_view(a_tensor_view,
740  }();
741 
742  const auto& b_pad_view = [&]() {
743  const auto& b_tensor_view = views.at(I1);
744  return pad_tensor_view(b_tensor_view,
748  }();
749 
750  const auto& ds_tensor_view = views.at(I2);
751  const auto& ds_pad_view = generate_tuple(
752  [&](auto i) {
753  return pad_tensor_view(ds_tensor_view[i],
757  },
759 
760  const auto& c_pad_view = [&]() {
761  const auto& c_tensor_view = views.at(I3);
762  return pad_tensor_view(c_tensor_view,
766  }();
767 
768  return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
769  }
770 
771  template <typename PadView>
772  CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
773  const index_t i_m,
774  const index_t i_n,
775  const index_t i_k = 0)
776  {
777  const auto& a_pad_view = views.at(I0);
778  const auto& b_pad_view = views.at(I1);
779  const auto& ds_pad_view = views.at(I2);
780  const auto& c_pad_view = views.at(I3);
781 
782  const auto& a_block_window = [&]() {
783  return make_tile_window(a_pad_view,
786  {i_m, i_k});
787  }();
788 
789  const auto& b_block_window = [&]() {
790  return make_tile_window(b_pad_view,
793  {i_n, i_k});
794  }();
795 
796  const auto ds_block_window = generate_tuple(
797  [&](auto i) {
798  return make_tile_window(ds_pad_view[i],
801  {i_m, i_n});
802  },
804 
805  auto c_block_window = make_tile_window(
806  c_pad_view,
808  {i_m, i_n});
809 
810  return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
811  }
812 
825  CK_TILE_DEVICE static void RunGemm(const OutDataType* a_ptr,
826  const InDataType* b_ptr,
827  const std::array<const void*, NumDTensor>& ds_ptr,
828  WeiDataType* c_ptr,
829  void* smem_ptr_0,
831  const index_t block_idx_m,
832  const index_t block_idx_n,
833  const index_t group_id)
834  {
835  // Create Gemm tensor views, pad views and tile windows
836  const auto& gemm_tensor_views_tuple =
837  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
838  a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id);
839 
840  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
841  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
842 
843  const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(
844  gemm_pad_views.at(I0).get_tensor_descriptor().get_length(I1)));
845 
846  // Run GEMM cooperatively by whole workgroup.
847  const auto& a_block_window = gemm_tile_windows.at(I0);
848  const auto& b_block_window = gemm_tile_windows.at(I1);
849  const auto& d_block_window = gemm_tile_windows.at(I2);
850 
851  const auto& c_block_tile = GemmPipeline{}.template operator()(
852  a_block_window, b_block_window, num_loop, smem_ptr_0);
853 
854  // Run Epilogue Pipeline
855  auto& c_block_window = gemm_tile_windows.at(I3);
856 
857  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
858  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
859  }
860 
876  CK_TILE_DEVICE static void RunGemm2LDS(const OutDataType* a_ptr,
877  const InDataType* b_ptr,
878  const std::array<const void*, NumDTensor>& ds_ptr,
879  WeiDataType* c_ptr,
880  void* __restrict__ smem_ptr_0,
881  void* __restrict__ smem_ptr_1,
883  const index_t block_idx_m,
884  const index_t block_idx_n,
885  const index_t group_id)
886  {
887  // Create Gemm tensor views, pad views and tile windows
888  const auto& gemm_tensor_views_tuple =
889  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
890  a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id);
891  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
892  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
893 
894  const index_t num_loop = amd_wave_read_first_lane(
895  TilePartitioner::GetLoopNum(gemm_tile_windows.at(I0).get_length(I1)));
896 
897  // Run GEMM cooperatively by whole workgroup.
898  const auto& a_block_window = gemm_tile_windows.at(I0);
899  const auto& b_block_window = gemm_tile_windows.at(I1);
900  const auto& d_block_window = gemm_tile_windows.at(I2);
901 
902  const auto& c_block_tile = GemmPipeline{}.template operator()(
903  a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
904 
905  // Run Epilogue Pipeline
906  auto& c_block_window = gemm_tile_windows.at(I3);
907 
908  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
909  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
910  }
911 
913  index_t block_id) const
914  {
915  index_t left = 0;
916  index_t right = kargs.gemm_count;
917  index_t group_id = index_t((left + right) >> 1);
918 
919  while((!(block_id >= kargs.block_starts[group_id] &&
920  block_id < kargs.block_ends[group_id])) &&
921  left <= right)
922  {
923  if(block_id < kargs.block_starts[group_id])
924  {
925  right = group_id;
926  }
927  else
928  {
929  left = group_id;
930  }
931  group_id = index_t((left + right) >> 1);
932  }
933 
934  return group_id;
935  }
936 
938  {
939  const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
940  const index_t group_id = FindGroupId(kargs, blockIdX);
941 
943  kargs.block_starts[group_id],
944  kargs.c_grid_descs_m_n[group_id].get_length(I0),
945  kargs.c_grid_descs_m_n[group_id].get_length(I1));
946 
947  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
948  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
949 
950  const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
951  const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
952  const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
953  const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
954 
955  // options
956  // conv_bwd_data = Out * Weight = In
957  const OutDataType* a_ptr = static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a;
958  const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) + group_offset_b;
959  InDataType* c_ptr = static_cast<InDataType*>(kargs.in_ptr) + group_offset_c;
960 
961  // allocate LDS
962  __shared__ char smem_ptr_0[GetSmemSize()];
963 
964  if constexpr(GemmPipeline::DoubleSmemBuffer == true)
965  {
966  __shared__ char smem_ptr_1[GetSmemSize()];
967  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
968  GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
970  {
971  RunGemm2LDS(a_ptr,
972  b_ptr,
973  kargs.ds_ptr,
974  c_ptr,
975  smem_ptr_0,
976  smem_ptr_1,
977  kargs,
978  i_m,
979  i_n,
980  group_id);
981  }
982  }
983  else
984  {
985  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
986  GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
988  {
989  RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n, group_id);
990  }
991  }
992  }
993 };
994 
995 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:2834
ConvolutionSpecialization
Definition: convolution_specialization.hpp:11
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE index_t gcd(index_t x, index_t y)
Definition: math.hpp:268
int64_t long_index_t
Definition: integer.hpp:11
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
constexpr bool is_same_v
Definition: type.hpp:283
__device__ X atomic_add(X *p_dst, const X &x)
The Grouped Convolution kernel device arguments.
Definition: grouped_convolution_backward_data_kernel.hpp:22
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > wei_g_k_c_xs_lengths
Definition: grouped_convolution_backward_data_kernel.hpp:399
static constexpr auto I1
Definition: grouped_convolution_backward_data_kernel.hpp:34
CK_TILE_HOST GroupedConvBwdDataKernelArgs(const GroupedConvBwdDataHostArgs &args)
Definition: grouped_convolution_backward_data_kernel.hpp:44
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_dilations
Definition: grouped_convolution_backward_data_kernel.hpp:403
std::array< const void *, NumDTensor > ds_ptr
Definition: grouped_convolution_backward_data_kernel.hpp:415
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_strides
Definition: grouped_convolution_backward_data_kernel.hpp:402
array< index_t, MaxGroupedGemmGroupsNum > block_starts
Definition: grouped_convolution_backward_data_kernel.hpp:422
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_left_pads
Definition: grouped_convolution_backward_data_kernel.hpp:404
long_index_t group_stride_b
Definition: grouped_convolution_backward_data_kernel.hpp:426
long_index_t group_stride_c
Definition: grouped_convolution_backward_data_kernel.hpp:427
array< index_t, MaxGroupedGemmGroupsNum > block_ends
Definition: grouped_convolution_backward_data_kernel.hpp:423
const void * out_ptr
Definition: grouped_convolution_backward_data_kernel.hpp:413
remove_cvref_t< decltype(ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(1))> ABCGridDescs
Definition: grouped_convolution_backward_data_kernel.hpp:391
remove_cvref_t< decltype(ABCGridDescs{}[number< 1 >{}])> BGridDescNK
Definition: grouped_convolution_backward_data_kernel.hpp:394
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_convolution_backward_data_kernel.hpp:23
array< index_t, GroupedConvTraitsType_::NDimSpatial > tildes
Definition: grouped_convolution_backward_data_kernel.hpp:406
remove_cvref_t< decltype(ABCGridDescs{}[number< 0 >{}])> AGridDescMK
Definition: grouped_convolution_backward_data_kernel.hpp:393
const void * wei_ptr
Definition: grouped_convolution_backward_data_kernel.hpp:416
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > out_g_n_k_wos_lengths
Definition: grouped_convolution_backward_data_kernel.hpp:400
long_index_t group_stride_a
Definition: grouped_convolution_backward_data_kernel.hpp:425
index_t GemmBatch
Definition: grouped_convolution_backward_data_kernel.hpp:409
void * in_ptr
Definition: grouped_convolution_backward_data_kernel.hpp:414
index_t gemm_count
Definition: grouped_convolution_backward_data_kernel.hpp:411
array< CGridDescMN, MaxGroupedGemmGroupsNum > c_grid_descs_m_n
Definition: grouped_convolution_backward_data_kernel.hpp:420
index_t grid_size_
Definition: grouped_convolution_backward_data_kernel.hpp:410
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_right_pads
Definition: grouped_convolution_backward_data_kernel.hpp:405
array< BGridDescNK, MaxGroupedGemmGroupsNum > b_grid_descs_n_k
Definition: grouped_convolution_backward_data_kernel.hpp:419
index_t k_batch
Definition: grouped_convolution_backward_data_kernel.hpp:408
static constexpr auto I0
Definition: grouped_convolution_backward_data_kernel.hpp:33
static constexpr index_t MaxGroupedGemmGroupsNum
Definition: grouped_convolution_backward_data_kernel.hpp:388
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > in_g_n_c_wis_lengths
Definition: grouped_convolution_backward_data_kernel.hpp:398
static constexpr index_t NumDTensor
Definition: grouped_convolution_backward_data_kernel.hpp:31
array< AGridDescMK, MaxGroupedGemmGroupsNum > a_grid_descs_m_k
Definition: grouped_convolution_backward_data_kernel.hpp:418
remove_cvref_t< decltype(ABCGridDescs{}[number< 2 >{}])> CGridDescMN
Definition: grouped_convolution_backward_data_kernel.hpp:395
static constexpr index_t NonSpatialDims
Definition: grouped_convolution_backward_data_kernel.hpp:397
The Grouped Conv kernel host arguments.
Definition: grouped_convolution_utils.hpp:19
index_t k_batch
Definition: grouped_convolution_utils.hpp:40
InPtr in_ptr
Definition: grouped_convolution_utils.hpp:36
WeiPtr wei_ptr
Definition: grouped_convolution_utils.hpp:37
OutPtr out_ptr
Definition: grouped_convolution_utils.hpp:39
const std::vector< const void * > ds_ptr
Definition: grouped_convolution_utils.hpp:38
The Grouped Convolution Backward Data kernel template.
Definition: grouped_convolution_backward_data_kernel.hpp:473
static constexpr index_t NDimSpatial
Definition: grouped_convolution_backward_data_kernel.hpp:478
static constexpr CK_TILE_HOST auto BlockSize()
Definition: grouped_convolution_backward_data_kernel.hpp:542
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: grouped_convolution_backward_data_kernel.hpp:482
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n, const index_t i_k=0)
Definition: grouped_convolution_backward_data_kernel.hpp:772
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: grouped_convolution_backward_data_kernel.hpp:732
GroupedConvBwdDataKernelArgs< GroupedConvTraitsType_, TilePartitioner > GroupedConvBwdDataKernelArgsSpecialized
Definition: grouped_convolution_backward_data_kernel.hpp:505
remove_cvref_t< typename GemmPipeline::ADataType > InDataType
Definition: grouped_convolution_backward_data_kernel.hpp:498
static constexpr index_t MaxGroupedGemmGroupsNum
Definition: grouped_convolution_backward_data_kernel.hpp:506
static constexpr auto I1
Definition: grouped_convolution_backward_data_kernel.hpp:513
static constexpr auto I3
Definition: grouped_convolution_backward_data_kernel.hpp:515
remove_cvref_t< typename GroupedConvTraitsType_::OutLayout > OutLayout
Definition: grouped_convolution_backward_data_kernel.hpp:490
static constexpr ConvolutionSpecialization ConvSpecialization
Definition: grouped_convolution_backward_data_kernel.hpp:479
static constexpr CK_TILE_HOST GroupedConvBwdDataKernelArgsSpecialized MakeKernelArgs(const GroupedConvBwdDataHostArgs &hostArgs)
Definition: grouped_convolution_backward_data_kernel.hpp:548
static constexpr index_t NumDTensor
Definition: grouped_convolution_backward_data_kernel.hpp:494
remove_cvref_t< typename GemmPipeline::BDataType > WeiDataType
Definition: grouped_convolution_backward_data_kernel.hpp:499
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: grouped_convolution_backward_data_kernel.hpp:483
remove_cvref_t< typename EpiloguePipeline::ODataType > OutDataType
Definition: grouped_convolution_backward_data_kernel.hpp:502
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_convolution_backward_data_kernel.hpp:481
remove_cvref_t< typename GroupedConvTraitsType_::WeiLayout > WeiLayout
Definition: grouped_convolution_backward_data_kernel.hpp:489
static constexpr index_t kBlockSize
Definition: grouped_convolution_backward_data_kernel.hpp:496
static CK_TILE_HOST bool IsSupportedArgument(const GroupedConvBwdDataKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_data_kernel.hpp:559
remove_cvref_t< typename GemmPipeline::BLayout > GemmBLayout
Definition: grouped_convolution_backward_data_kernel.hpp:485
remove_cvref_t< typename GroupedConvTraitsType_::DsLayout > DsLayout
Definition: grouped_convolution_backward_data_kernel.hpp:491
static constexpr auto I2
Definition: grouped_convolution_backward_data_kernel.hpp:514
static CK_TILE_DEVICE auto MakeGemmTensorViews(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, const GroupedConvBwdDataKernelArgsSpecialized &kargs, const index_t group_id)
Definition: grouped_convolution_backward_data_kernel.hpp:688
static CK_TILE_HOST auto GridSize(const GroupedConvBwdDataKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_data_kernel.hpp:536
remove_cvref_t< typename GemmPipeline::ALayout > GemmALayout
Definition: grouped_convolution_backward_data_kernel.hpp:484
remove_cvref_t< typename EpiloguePipeline::DsLayout > GemmDsLayout
Definition: grouped_convolution_backward_data_kernel.hpp:493
CK_TILE_DEVICE index_t FindGroupId(const GroupedConvBwdDataKernelArgsSpecialized &kargs, index_t block_id) const
Definition: grouped_convolution_backward_data_kernel.hpp:912
static CK_TILE_DEVICE void RunGemm(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, void *smem_ptr_0, const GroupedConvBwdDataKernelArgsSpecialized &kargs, const index_t block_idx_m, const index_t block_idx_n, const index_t group_id)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_backward_data_kernel.hpp:825
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: grouped_convolution_backward_data_kernel.hpp:553
static CK_TILE_DEVICE void RunGemm2LDS(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const GroupedConvBwdDataKernelArgsSpecialized &kargs, const index_t block_idx_m, const index_t block_idx_n, const index_t group_id)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_backward_data_kernel.hpp:876
CK_TILE_DEVICE void operator()(GroupedConvBwdDataKernelArgsSpecialized kargs) const
Definition: grouped_convolution_backward_data_kernel.hpp:937
static constexpr bool IsSplitKSupported
Definition: grouped_convolution_backward_data_kernel.hpp:510
remove_cvref_t< typename GroupedConvTraitsType_::InLayout > InLayout
Definition: grouped_convolution_backward_data_kernel.hpp:488
remove_cvref_t< typename GemmPipeline::CLayout > GemmCLayout
Definition: grouped_convolution_backward_data_kernel.hpp:486
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: grouped_convolution_backward_data_kernel.hpp:500
static CK_TILE_HOST const std::string GetName()
Definition: grouped_convolution_backward_data_kernel.hpp:529
static constexpr auto I0
Definition: grouped_convolution_backward_data_kernel.hpp:512
static CK_TILE_DEVICE auto GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept -> const tuple< index_t, index_t >
The function subtracts the block's start (offset) from 1D raw-indexes.
Definition: gemm_tile_partitioner.hpp:192
Definition: transform_conv_bwd_data_to_gemm.hpp:22
CK_TILE_HOST auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N([[maybe_unused]] const index_t GemmKBatch) const
Definition: transform_conv_bwd_data_to_gemm.hpp:569
constexpr CK_TILE_HOST_DEVICE auto & at(index_t i)
Definition: array.hpp:110
Definition: integral_constant.hpp:13
std::vector< ck_tile::long_index_t > input_spatial_lengths_
Definition: convolution_parameter.hpp:130
ck_tile::long_index_t K_
Definition: convolution_parameter.hpp:126
std::vector< ck_tile::long_index_t > output_spatial_lengths_
Definition: convolution_parameter.hpp:131
std::vector< ck_tile::long_index_t > input_right_pads_
Definition: convolution_parameter.hpp:137
ck_tile::long_index_t G_
Definition: convolution_parameter.hpp:124
std::vector< ck_tile::long_index_t > conv_filter_strides_
Definition: convolution_parameter.hpp:133
std::vector< ck_tile::long_index_t > filter_spatial_lengths_
Definition: convolution_parameter.hpp:129
ck_tile::long_index_t C_
Definition: convolution_parameter.hpp:127
ck_tile::long_index_t N_
Definition: convolution_parameter.hpp:125
std::vector< ck_tile::long_index_t > input_left_pads_
Definition: convolution_parameter.hpp:136
std::vector< ck_tile::long_index_t > conv_filter_dilations_
Definition: convolution_parameter.hpp:134
Definition: type_traits.hpp:115
Definition: sequence.hpp:49
#define CK_TILE_ENV(name)
Definition: env.hpp:145