/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp Source File
grouped_convolution_backward_data_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 #include "ck_tile/host/concat.hpp"
16 
17 namespace ck_tile {
18 
20 template <typename GroupedConvTraitsType_, typename TilePartitioner_>
22 {
24 
26  TransformConvBwdDataToGemm<GroupedConvTraitsType_::NDimSpatial,
27  GroupedConvTraitsType_::ConvSpecialization>;
28  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
29 
30  static constexpr auto I0 = number<0>();
31  static constexpr auto I1 = number<1>();
32 
33  template <
34  typename InLay = typename GroupedConvTraitsType_::InLayout,
35  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
36  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
37  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NWGC> &&
38  std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
39  std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
40  bool>::type = false>
42  {
43  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
44  static_cast<index_t>(args.N_),
45  static_cast<index_t>(args.C_),
46  static_cast<index_t>(args.input_spatial_lengths_[0])};
47  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
48  static_cast<index_t>(args.K_),
49  static_cast<index_t>(args.C_),
50  static_cast<index_t>(args.filter_spatial_lengths_[0])};
51  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
52  static_cast<index_t>(args.N_),
53  static_cast<index_t>(args.K_),
54  static_cast<index_t>(args.output_spatial_lengths_[0])};
55 
56  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0])};
57  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0])};
58  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0])};
59  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0])};
60 
61  k_batch = args.k_batch;
62 
63  in_ptr = args.in_ptr;
64  wei_ptr = args.wei_ptr;
65  for(index_t d = 0; d < NumDTensor; d++)
66  {
67  ds_ptr[d] = args.ds_ptr[d];
68  }
69  out_ptr = args.out_ptr;
70 
71  const index_t X = wei_g_k_c_xs_lengths[3];
72  const index_t ConvStrideW = conv_filter_strides[0];
73  const index_t ConvDilationW = conv_filter_dilations[0];
74  const auto GcdStrideDilationW = gcd(ConvStrideW, ConvDilationW);
75  const auto XTilde = ConvStrideW / GcdStrideDilationW;
76 
77  for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
78  {
79  const auto XDotSlice = integer_divide_ceil(X - i_xtilde, XTilde);
80 
81  if(XDotSlice <= 0)
82  {
83  continue;
84  }
85 
87  {
88  gemm_count++;
89  // Avoid array segfault
90  continue;
91  }
92 
93  tildes = {i_xtilde};
94 
95  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
102  tildes};
103 
104  auto grid_descs =
105  conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
106  GroupedConvTraitsType_::NDimSpatial>(1);
107 
108  a_grid_descs_m_k[gemm_count] = grid_descs.at(number<0>{});
109  b_grid_descs_n_k[gemm_count] = grid_descs.at(number<1>{});
110  c_grid_descs_m_n[gemm_count] = grid_descs.at(number<2>{});
111 
112  const index_t grid_size_grp =
113  TilePartitioner::GridSize(c_grid_descs_m_n[gemm_count].get_length(I0),
114  c_grid_descs_m_n[gemm_count].get_length(I1));
115 
117  block_ends[gemm_count] = grid_size_ + grid_size_grp;
118 
119  grid_size_ += grid_size_grp;
120 
121  ++gemm_count;
122  }
123  group_stride_a = args.K_; // A: Out NWGK
124  group_stride_b = args.K_ * args.C_ *
125  std::accumulate(args.filter_spatial_lengths_.begin(),
126  args.filter_spatial_lengths_.end(),
127  1,
128  std::multiplies<index_t>()); // B: Wei GKXC
129  group_stride_c = args.C_; // C: In NWGC
130 
131  GemmBatch = args.G_;
132  }
133 
134  template <
135  typename InLay = typename GroupedConvTraitsType_::InLayout,
136  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
137  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
138  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NHWGC> &&
139  std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
140  std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
141  bool>::type = false>
143  {
144  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
145  static_cast<index_t>(args.N_),
146  static_cast<index_t>(args.C_),
147  static_cast<index_t>(args.input_spatial_lengths_[0]),
148  static_cast<index_t>(args.input_spatial_lengths_[1])};
149  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
150  static_cast<index_t>(args.K_),
151  static_cast<index_t>(args.C_),
152  static_cast<index_t>(args.filter_spatial_lengths_[0]),
153  static_cast<index_t>(args.filter_spatial_lengths_[1])};
154  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
155  static_cast<index_t>(args.N_),
156  static_cast<index_t>(args.K_),
157  static_cast<index_t>(args.output_spatial_lengths_[0]),
158  static_cast<index_t>(args.output_spatial_lengths_[1])};
159 
160  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
161  static_cast<index_t>(args.conv_filter_strides_[1])};
162  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
163  static_cast<index_t>(args.conv_filter_dilations_[1])};
164  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
165  static_cast<index_t>(args.input_left_pads_[1])};
166  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
167  static_cast<index_t>(args.input_right_pads_[1])};
168 
169  k_batch = args.k_batch;
170 
171  in_ptr = args.in_ptr;
172  wei_ptr = args.wei_ptr;
173  for(index_t d = 0; d < NumDTensor; d++)
174  {
175  ds_ptr[d] = args.ds_ptr[d];
176  }
177  out_ptr = args.out_ptr;
178 
179  const index_t Y = wei_g_k_c_xs_lengths[3];
180  const index_t X = wei_g_k_c_xs_lengths[4];
181  const index_t ConvStrideH = conv_filter_strides[0];
182  const index_t ConvStrideW = conv_filter_strides[1];
183  const index_t ConvDilationH = conv_filter_dilations[0];
184  const index_t ConvDilationW = conv_filter_dilations[1];
185  const auto GcdStrideDilationH = gcd(ConvStrideH, ConvDilationH);
186  const auto GcdStrideDilationW = gcd(ConvStrideW, ConvDilationW);
187  const auto YTilde = ConvStrideH / GcdStrideDilationH;
188  const auto XTilde = ConvStrideW / GcdStrideDilationW;
189 
190  for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
191  {
192  for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
193  {
194  const auto YDotSlice = integer_divide_ceil(Y - i_ytilde, YTilde);
195  const auto XDotSlice = integer_divide_ceil(X - i_xtilde, XTilde);
196 
197  if(XDotSlice * YDotSlice <= 0)
198  {
199  continue;
200  }
201 
203  {
204  gemm_count++;
205  // Avoid array segfault
206  continue;
207  }
208 
209  tildes = {i_ytilde, i_xtilde};
210 
211  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
218  tildes};
219 
220  auto grid_descs = conv_to_gemm_transformer
221  .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
222  GroupedConvTraitsType_::NDimSpatial>(1);
223 
224  a_grid_descs_m_k[gemm_count] = grid_descs.at(number<0>{});
225  b_grid_descs_n_k[gemm_count] = grid_descs.at(number<1>{});
226  c_grid_descs_m_n[gemm_count] = grid_descs.at(number<2>{});
227 
228  const index_t grid_size_grp =
229  TilePartitioner::GridSize(c_grid_descs_m_n[gemm_count].get_length(I0),
230  c_grid_descs_m_n[gemm_count].get_length(I1));
231 
233  block_ends[gemm_count] = grid_size_ + grid_size_grp;
234 
235  grid_size_ += grid_size_grp;
236 
237  ++gemm_count;
238  }
239  }
240  group_stride_a = args.K_; // A: Out NWGK
241  group_stride_b = args.K_ * args.C_ *
242  std::accumulate(args.filter_spatial_lengths_.begin(),
243  args.filter_spatial_lengths_.end(),
244  1,
245  std::multiplies<index_t>()); // B: Wei GKXC
246  group_stride_c = args.C_; // C: In NWGC
247 
248  GemmBatch = args.G_;
249  }
250 
251  template <
252  typename InLay = typename GroupedConvTraitsType_::InLayout,
253  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
254  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
255  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NDHWGC> &&
256  std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
257  std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
258  bool>::type = false>
260  {
261  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
262  static_cast<index_t>(args.N_),
263  static_cast<index_t>(args.C_),
264  static_cast<index_t>(args.input_spatial_lengths_[0]),
265  static_cast<index_t>(args.input_spatial_lengths_[1]),
266  static_cast<index_t>(args.input_spatial_lengths_[2])};
267  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
268  static_cast<index_t>(args.K_),
269  static_cast<index_t>(args.C_),
270  static_cast<index_t>(args.filter_spatial_lengths_[0]),
271  static_cast<index_t>(args.filter_spatial_lengths_[1]),
272  static_cast<index_t>(args.filter_spatial_lengths_[2])};
273  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
274  static_cast<index_t>(args.N_),
275  static_cast<index_t>(args.K_),
276  static_cast<index_t>(args.output_spatial_lengths_[0]),
277  static_cast<index_t>(args.output_spatial_lengths_[1]),
278  static_cast<index_t>(args.output_spatial_lengths_[2])};
279 
280  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
281  static_cast<index_t>(args.conv_filter_strides_[1]),
282  static_cast<index_t>(args.conv_filter_strides_[2])};
283  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
284  static_cast<index_t>(args.conv_filter_dilations_[1]),
285  static_cast<index_t>(args.conv_filter_dilations_[2])};
286  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
287  static_cast<index_t>(args.input_left_pads_[1]),
288  static_cast<index_t>(args.input_left_pads_[2])};
289  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
290  static_cast<index_t>(args.input_right_pads_[1]),
291  static_cast<index_t>(args.input_right_pads_[2])};
292 
293  k_batch = args.k_batch;
294 
295  in_ptr = args.in_ptr;
296  wei_ptr = args.wei_ptr;
297  for(index_t d = 0; d < NumDTensor; d++)
298  {
299  ds_ptr[d] = args.ds_ptr[d];
300  }
301  out_ptr = args.out_ptr;
302 
303  const index_t Z = wei_g_k_c_xs_lengths[3];
304  const index_t Y = wei_g_k_c_xs_lengths[4];
305  const index_t X = wei_g_k_c_xs_lengths[5];
306  const index_t ConvStrideD = conv_filter_strides[0];
307  const index_t ConvStrideH = conv_filter_strides[1];
308  const index_t ConvStrideW = conv_filter_strides[2];
309  const index_t ConvDilationD = conv_filter_dilations[0];
310  const index_t ConvDilationH = conv_filter_dilations[1];
311  const index_t ConvDilationW = conv_filter_dilations[2];
312  const auto GcdStrideDilationD = gcd(ConvStrideD, ConvDilationD);
313  const auto GcdStrideDilationH = gcd(ConvStrideH, ConvDilationH);
314  const auto GcdStrideDilationW = gcd(ConvStrideW, ConvDilationW);
315  const auto ZTilde = ConvStrideD / GcdStrideDilationD;
316  const auto YTilde = ConvStrideH / GcdStrideDilationH;
317  const auto XTilde = ConvStrideW / GcdStrideDilationW;
318 
319  for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde)
320  {
321  for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
322  {
323  for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
324  {
325  const auto ZDotSlice = integer_divide_ceil(Z - i_ztilde, ZTilde);
326  const auto YDotSlice = integer_divide_ceil(Y - i_ytilde, YTilde);
327  const auto XDotSlice = integer_divide_ceil(X - i_xtilde, XTilde);
328 
329  if(ZDotSlice * XDotSlice * YDotSlice <= 0)
330  {
331  continue;
332  }
333 
335  {
336  gemm_count++;
337  // Avoid array segfault
338  continue;
339  }
340 
341  tildes = {i_ztilde, i_ytilde, i_xtilde};
342 
343  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
350  tildes};
351 
352  auto grid_descs = conv_to_gemm_transformer
353  .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
354  GroupedConvTraitsType_::NDimSpatial>(1);
355 
356  a_grid_descs_m_k[gemm_count] = grid_descs.at(number<0>{});
357  b_grid_descs_n_k[gemm_count] = grid_descs.at(number<1>{});
358  c_grid_descs_m_n[gemm_count] = grid_descs.at(number<2>{});
359 
360  const index_t grid_size_grp =
361  TilePartitioner::GridSize(c_grid_descs_m_n[gemm_count].get_length(I0),
362  c_grid_descs_m_n[gemm_count].get_length(I1));
363 
365  block_ends[gemm_count] = grid_size_ + grid_size_grp;
366 
367  grid_size_ += grid_size_grp;
368 
369  ++gemm_count;
370  }
371  }
372  }
373 
374  group_stride_a = args.K_; // A: Out NWGK
375  group_stride_b = args.K_ * args.C_ *
376  std::accumulate(args.filter_spatial_lengths_.begin(),
377  args.filter_spatial_lengths_.end(),
378  1,
379  std::multiplies<index_t>()); // B: Wei GKXC
380  group_stride_c = args.C_; // C: In NWGC
381 
382  GemmBatch = args.G_; // C: In NWGC
383  }
384 
385  static constexpr index_t MaxGroupedGemmGroupsNum = 128;
386 
389 
393 
394  static constexpr index_t NonSpatialDims = 3;
398 
404 
409 
410  const void* out_ptr;
411  void* in_ptr;
412  std::array<const void*, NumDTensor> ds_ptr;
413  const void* wei_ptr;
414 
418 
421 
425 };
426 
465 template <typename GroupedConvTraitsType_,
466  typename TilePartitioner_,
467  typename GemmPipeline_,
468  typename EpiloguePipeline_>
470 {
471  static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial_;
473  GroupedConvTraitsType_::ConvSpecialization;
480 
485 
487  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
488 
489  static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
490 
494 
496 
499  static constexpr index_t MaxGroupedGemmGroupsNum =
501 
502  // TODO: Enable this
503  static constexpr bool IsSplitKSupported = false;
504 
505  static constexpr auto I0 = number<0>();
506  static constexpr auto I1 = number<1>();
507  static constexpr auto I2 = number<2>();
508  static constexpr auto I3 = number<3>();
509 
510  static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
511  "Not supported!");
512  static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>,
513  "Not supported A GEMM layout!");
514  static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>,
515  "Not supported B GEMM layout!");
516  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
517  "Not supported C GEMM layout!");
518 
519  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
520  {
521  // clang-format off
522  return concat('_', "grouped_convolution_backward_data", gemm_prec_str<InDataType, WeiDataType>, GemmPipeline::GetName());
523  // clang-format on
524  }
525 
527  {
528  // enable batched grouped gemm
529  return dim3(kargs.grid_size_, kargs.GemmBatch, kargs.k_batch);
530  }
531 
532  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
533 
536  {
538  }
539 
541  {
542  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
543  }
544 
545  CK_TILE_HOST static bool
547  {
548  if constexpr((EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
551  {
552  if(kargs.k_batch != 1)
553  {
554  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
555  {
556  CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
557  }
558  return false;
559  }
560  }
561 
563  {
564  return false;
565  }
566 
567  const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
568  const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
569 
570  // check ConvSpecialization
572  {
573  // check if it's 1x1, stride=1 conv
574  for(index_t i = 0; i < NDimSpatial; ++i)
575  {
576  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
577  const index_t ConvStride = kargs.conv_filter_strides[i];
578  const index_t LeftPad = kargs.input_left_pads[i];
579  const index_t RightPad = kargs.input_right_pads[i];
580 
581  if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
582  {
583  return false;
584  }
585  }
586  }
588  {
589  // check if it's 1x1 conv
590  for(index_t i = 0; i < NDimSpatial; ++i)
591  {
592  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
593  const index_t LeftPad = kargs.input_left_pads[i];
594  const index_t RightPad = kargs.input_right_pads[i];
595 
596  if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
597  {
598  return false;
599  }
600  }
601  }
603  {
604  if(ConvC != 1)
605  {
606  return false;
607  }
608  for(index_t i = 0; i < NDimSpatial; ++i)
609  {
610  const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3];
611 
612  if(filter_spatial_dim != I3)
613  {
614  return false;
615  }
616  }
617  }
618 
619  namespace ctc = tensor_layout::convolution;
620 
621  if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
622  std::is_same_v<InLayout, ctc::NDHWGC>)
623  {
624  // Check access per C
625  if(ConvC % GemmPipeline::GetVectorSizeB() != 0)
626  {
627  CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
628  return false;
629  }
630  }
631  else
632  {
633  CK_TILE_ERROR("Not supported input layout!");
634  return false;
635  }
636 
637  // check vector access of B
638  // FIXME: layout
639  if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
640  std::is_same_v<WeiLayout, ctc::GKYXC> ||
641  std::is_same_v<WeiLayout, ctc::GKZYXC>)
642  {
643  if(ConvC % EpiloguePipeline::GetVectorSizeC() != 0)
644  {
645  CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
646  return false;
647  }
648  }
649  else
650  {
651  CK_TILE_ERROR("Not supported weight layout!");
652  return false;
653  }
654 
655  // check vector access of E
656  if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
657  std::is_same_v<OutLayout, ctc::NHWGK> ||
658  std::is_same_v<OutLayout, ctc::NDHWGK>)
659  {
660  if(ConvK % GemmPipeline::GetVectorSizeA() != 0)
661  {
662  CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!");
663  return false;
664  }
665  }
666  else
667  {
668  CK_TILE_ERROR("Not supported output layout!");
669  return false;
670  }
671 
672  return true;
673  }
674 
675  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
676  CK_TILE_DEVICE static auto
678  const InDataType* b_ptr,
679  const std::array<const void*, NumDTensor>& ds_ptr,
680  WeiDataType* c_ptr,
682  const index_t group_id)
683  {
684  static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
685  static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
686  const auto& a_tensor_view = [&]() {
687  return make_tensor_view<address_space_enum::global>(
688  a_ptr,
689  kargs.a_grid_descs_m_k[group_id]); // A: out
690  }();
691 
692  const auto& b_tensor_view = [&]() {
693  return make_tensor_view<address_space_enum::global>(
694  b_ptr,
695  kargs.b_grid_descs_n_k[group_id]); // B: weight
696  }();
697 
698  const auto& c_tensor_view = [&]() {
699  return make_tensor_view<address_space_enum::global>(c_ptr,
700  kargs.c_grid_descs_m_n[group_id]);
701  }();
702 
703  const auto& ds_tensor_view = generate_tuple(
704  [&](auto i) {
705  static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
706  "Not supported!");
707  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
708  "Not supported!");
709  static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, OutDataType>,
710  "Not supported!");
711 
712  return make_tensor_view<address_space_enum::global>(
713  static_cast<OutDataType*>(ds_ptr[i]), kargs.c_grid_descs_m_n[group_id]);
714  },
716 
717  return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
718  }
719 
720  template <typename TensorView>
721  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
722  {
723  const auto& a_pad_view = [&]() {
724  const auto& a_tensor_view = views.at(I0);
725  return pad_tensor_view(a_tensor_view,
729  }();
730 
731  const auto& b_pad_view = [&]() {
732  const auto& b_tensor_view = views.at(I1);
733  return pad_tensor_view(b_tensor_view,
737  }();
738 
739  const auto& ds_tensor_view = views.at(I2);
740  const auto& ds_pad_view = generate_tuple(
741  [&](auto i) {
742  return pad_tensor_view(ds_tensor_view[i],
746  },
748 
749  const auto& c_pad_view = [&]() {
750  const auto& c_tensor_view = views.at(I3);
751  return pad_tensor_view(c_tensor_view,
755  }();
756 
757  return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
758  }
759 
760  template <typename PadView>
761  CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
762  const index_t i_m,
763  const index_t i_n,
764  const index_t i_k = 0)
765  {
766  const auto& a_pad_view = views.at(I0);
767  const auto& b_pad_view = views.at(I1);
768  const auto& ds_pad_view = views.at(I2);
769  const auto& c_pad_view = views.at(I3);
770 
771  const auto& a_block_window = [&]() {
772  return make_tile_window(a_pad_view,
775  {i_m, i_k});
776  }();
777 
778  const auto& b_block_window = [&]() {
779  return make_tile_window(b_pad_view,
782  {i_n, i_k});
783  }();
784 
785  const auto ds_block_window = generate_tuple(
786  [&](auto i) {
787  return make_tile_window(ds_pad_view[i],
790  {i_m, i_n});
791  },
793 
794  auto c_block_window = make_tile_window(
795  c_pad_view,
797  {i_m, i_n});
798 
799  return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
800  }
801 
814  CK_TILE_DEVICE static void RunGemm(const OutDataType* a_ptr,
815  const InDataType* b_ptr,
816  const std::array<const void*, NumDTensor>& ds_ptr,
817  WeiDataType* c_ptr,
818  void* smem_ptr_0,
820  const index_t block_idx_m,
821  const index_t block_idx_n,
822  const index_t group_id)
823  {
824  // Create Gemm tensor views, pad views and tile windows
825  const auto& gemm_tensor_views_tuple =
826  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
827  a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id);
828 
829  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
830  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
831 
832  const index_t num_loop = __builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(
833  gemm_pad_views.at(I0).get_tensor_descriptor().get_length(I1)));
834 
835  // Run GEMM cooperatively by whole workgroup.
836  const auto& a_block_window = gemm_tile_windows.at(I0);
837  const auto& b_block_window = gemm_tile_windows.at(I1);
838  const auto& d_block_window = gemm_tile_windows.at(I2);
839 
840  const auto& c_block_tile = GemmPipeline{}.template operator()(
841  a_block_window, b_block_window, num_loop, smem_ptr_0);
842 
843  // Run Epilogue Pipeline
844  auto& c_block_window = gemm_tile_windows.at(I3);
845 
846  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
847  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
848  }
849 
865  CK_TILE_DEVICE static void RunGemm2LDS(const OutDataType* a_ptr,
866  const InDataType* b_ptr,
867  const std::array<const void*, NumDTensor>& ds_ptr,
868  WeiDataType* c_ptr,
869  void* __restrict__ smem_ptr_0,
870  void* __restrict__ smem_ptr_1,
872  const index_t block_idx_m,
873  const index_t block_idx_n,
874  const index_t group_id)
875  {
876  // Create Gemm tensor views, pad views and tile windows
877  const auto& gemm_tensor_views_tuple =
878  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
879  a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id);
880  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
881  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
882 
883  const index_t num_loop = __builtin_amdgcn_readfirstlane(
884  TilePartitioner::GetLoopNum(gemm_tile_windows.at(I0).get_length(I1)));
885 
886  // Run GEMM cooperatively by whole workgroup.
887  const auto& a_block_window = gemm_tile_windows.at(I0);
888  const auto& b_block_window = gemm_tile_windows.at(I1);
889  const auto& d_block_window = gemm_tile_windows.at(I2);
890 
891  const auto& c_block_tile = GemmPipeline{}.template operator()(
892  a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
893 
894  // Run Epilogue Pipeline
895  auto& c_block_window = gemm_tile_windows.at(I3);
896 
897  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
898  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
899  }
900 
902  index_t block_id) const
903  {
904  index_t left = 0;
905  index_t right = kargs.gemm_count;
906  index_t group_id = index_t((left + right) >> 1);
907 
908  while((!(block_id >= kargs.block_starts[group_id] &&
909  block_id < kargs.block_ends[group_id])) &&
910  left <= right)
911  {
912  if(block_id < kargs.block_starts[group_id])
913  {
914  right = group_id;
915  }
916  else
917  {
918  left = group_id;
919  }
920  group_id = index_t((left + right) >> 1);
921  }
922 
923  return group_id;
924  }
925 
927  {
928  const auto blockIdX = __builtin_amdgcn_readfirstlane(blockIdx.x);
929  const index_t group_id = FindGroupId(kargs, blockIdX);
930 
932  kargs.block_starts[group_id],
933  kargs.c_grid_descs_m_n[group_id].get_length(I0),
934  kargs.c_grid_descs_m_n[group_id].get_length(I1));
935 
936  const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
937  const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
938 
939  const auto blockIdY = __builtin_amdgcn_readfirstlane(blockIdx.y);
940  const auto group_offset_a = __builtin_amdgcn_readfirstlane(kargs.group_stride_a * blockIdY);
941  const auto group_offset_b = __builtin_amdgcn_readfirstlane(kargs.group_stride_b * blockIdY);
942  const auto group_offset_c = __builtin_amdgcn_readfirstlane(kargs.group_stride_c * blockIdY);
943 
944  // options
945  // conv_bwd_data = Out * Weight = In
946  const OutDataType* a_ptr = static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a;
947  const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) + group_offset_b;
948  InDataType* c_ptr = static_cast<InDataType*>(kargs.in_ptr) + group_offset_c;
949 
950  // allocate LDS
951  __shared__ char smem_ptr_0[GetSmemSize()];
952 
953  if constexpr(GemmPipeline::DoubleSmemBuffer == true)
954  {
955  __shared__ char smem_ptr_1[GetSmemSize()];
956  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
957  EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
959  {
960  RunGemm2LDS(a_ptr,
961  b_ptr,
962  kargs.ds_ptr,
963  c_ptr,
964  smem_ptr_0,
965  smem_ptr_1,
966  kargs,
967  i_m,
968  i_n,
969  group_id);
970  }
971  }
972  else
973  {
974  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
975  EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
977  {
978  RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n, group_id);
979  }
980  }
981  }
982 };
983 
984 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
ConvolutionSpecialization
Definition: convolution_specialization.hpp:11
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE index_t gcd(index_t x, index_t y)
Definition: math.hpp:268
int64_t long_index_t
Definition: integer.hpp:11
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
constexpr bool is_same_v
Definition: type.hpp:283
__device__ X atomic_add(X *p_dst, const X &x)
The Grouped Convolution kernel device arguments.
Definition: grouped_convolution_backward_data_kernel.hpp:22
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > wei_g_k_c_xs_lengths
Definition: grouped_convolution_backward_data_kernel.hpp:396
static constexpr auto I1
Definition: grouped_convolution_backward_data_kernel.hpp:31
CK_TILE_HOST GroupedConvBwdDataKernelArgs(const GroupedConvBwdDataHostArgs &args)
Definition: grouped_convolution_backward_data_kernel.hpp:41
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_dilations
Definition: grouped_convolution_backward_data_kernel.hpp:400
std::array< const void *, NumDTensor > ds_ptr
Definition: grouped_convolution_backward_data_kernel.hpp:412
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_strides
Definition: grouped_convolution_backward_data_kernel.hpp:399
array< index_t, MaxGroupedGemmGroupsNum > block_starts
Definition: grouped_convolution_backward_data_kernel.hpp:419
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_left_pads
Definition: grouped_convolution_backward_data_kernel.hpp:401
long_index_t group_stride_b
Definition: grouped_convolution_backward_data_kernel.hpp:423
long_index_t group_stride_c
Definition: grouped_convolution_backward_data_kernel.hpp:424
array< index_t, MaxGroupedGemmGroupsNum > block_ends
Definition: grouped_convolution_backward_data_kernel.hpp:420
const void * out_ptr
Definition: grouped_convolution_backward_data_kernel.hpp:410
remove_cvref_t< decltype(ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(1))> ABCGridDescs
Definition: grouped_convolution_backward_data_kernel.hpp:388
remove_cvref_t< decltype(ABCGridDescs{}[number< 1 >{}])> BGridDescNK
Definition: grouped_convolution_backward_data_kernel.hpp:391
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_convolution_backward_data_kernel.hpp:23
array< index_t, GroupedConvTraitsType_::NDimSpatial > tildes
Definition: grouped_convolution_backward_data_kernel.hpp:403
remove_cvref_t< decltype(ABCGridDescs{}[number< 0 >{}])> AGridDescMK
Definition: grouped_convolution_backward_data_kernel.hpp:390
const void * wei_ptr
Definition: grouped_convolution_backward_data_kernel.hpp:413
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > out_g_n_k_wos_lengths
Definition: grouped_convolution_backward_data_kernel.hpp:397
long_index_t group_stride_a
Definition: grouped_convolution_backward_data_kernel.hpp:422
index_t GemmBatch
Definition: grouped_convolution_backward_data_kernel.hpp:406
void * in_ptr
Definition: grouped_convolution_backward_data_kernel.hpp:411
index_t gemm_count
Definition: grouped_convolution_backward_data_kernel.hpp:408
array< CGridDescMN, MaxGroupedGemmGroupsNum > c_grid_descs_m_n
Definition: grouped_convolution_backward_data_kernel.hpp:417
index_t grid_size_
Definition: grouped_convolution_backward_data_kernel.hpp:407
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_right_pads
Definition: grouped_convolution_backward_data_kernel.hpp:402
array< BGridDescNK, MaxGroupedGemmGroupsNum > b_grid_descs_n_k
Definition: grouped_convolution_backward_data_kernel.hpp:416
index_t k_batch
Definition: grouped_convolution_backward_data_kernel.hpp:405
static constexpr auto I0
Definition: grouped_convolution_backward_data_kernel.hpp:30
static constexpr index_t MaxGroupedGemmGroupsNum
Definition: grouped_convolution_backward_data_kernel.hpp:385
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > in_g_n_c_wis_lengths
Definition: grouped_convolution_backward_data_kernel.hpp:395
static constexpr index_t NumDTensor
Definition: grouped_convolution_backward_data_kernel.hpp:28
array< AGridDescMK, MaxGroupedGemmGroupsNum > a_grid_descs_m_k
Definition: grouped_convolution_backward_data_kernel.hpp:415
remove_cvref_t< decltype(ABCGridDescs{}[number< 2 >{}])> CGridDescMN
Definition: grouped_convolution_backward_data_kernel.hpp:392
static constexpr index_t NonSpatialDims
Definition: grouped_convolution_backward_data_kernel.hpp:394
The Grouped Conv kernel host arguments.
Definition: grouped_convolution_utils.hpp:19
index_t k_batch
Definition: grouped_convolution_utils.hpp:40
InPtr in_ptr
Definition: grouped_convolution_utils.hpp:36
WeiPtr wei_ptr
Definition: grouped_convolution_utils.hpp:37
OutPtr out_ptr
Definition: grouped_convolution_utils.hpp:39
const std::vector< const void * > ds_ptr
Definition: grouped_convolution_utils.hpp:38
The Grouped Convolution Backward Data kernel template.
Definition: grouped_convolution_backward_data_kernel.hpp:470
static constexpr index_t NDimSpatial
Definition: grouped_convolution_backward_data_kernel.hpp:471
static constexpr CK_TILE_HOST auto BlockSize()
Definition: grouped_convolution_backward_data_kernel.hpp:532
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: grouped_convolution_backward_data_kernel.hpp:475
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n, const index_t i_k=0)
Definition: grouped_convolution_backward_data_kernel.hpp:761
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: grouped_convolution_backward_data_kernel.hpp:721
GroupedConvBwdDataKernelArgs< GroupedConvTraitsType_, TilePartitioner > GroupedConvBwdDataKernelArgsSpecialized
Definition: grouped_convolution_backward_data_kernel.hpp:498
remove_cvref_t< typename GemmPipeline::ADataType > InDataType
Definition: grouped_convolution_backward_data_kernel.hpp:491
static constexpr index_t MaxGroupedGemmGroupsNum
Definition: grouped_convolution_backward_data_kernel.hpp:499
static constexpr auto I1
Definition: grouped_convolution_backward_data_kernel.hpp:506
static constexpr auto I3
Definition: grouped_convolution_backward_data_kernel.hpp:508
remove_cvref_t< typename GroupedConvTraitsType_::OutLayout > OutLayout
Definition: grouped_convolution_backward_data_kernel.hpp:483
static constexpr ConvolutionSpecialization ConvSpecialization
Definition: grouped_convolution_backward_data_kernel.hpp:472
static constexpr CK_TILE_HOST GroupedConvBwdDataKernelArgsSpecialized MakeKernelArgs(const GroupedConvBwdDataHostArgs &hostArgs)
Definition: grouped_convolution_backward_data_kernel.hpp:535
static constexpr index_t NumDTensor
Definition: grouped_convolution_backward_data_kernel.hpp:487
remove_cvref_t< typename GemmPipeline::BDataType > WeiDataType
Definition: grouped_convolution_backward_data_kernel.hpp:492
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: grouped_convolution_backward_data_kernel.hpp:476
remove_cvref_t< typename EpiloguePipeline::ODataType > OutDataType
Definition: grouped_convolution_backward_data_kernel.hpp:495
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_convolution_backward_data_kernel.hpp:474
remove_cvref_t< typename GroupedConvTraitsType_::WeiLayout > WeiLayout
Definition: grouped_convolution_backward_data_kernel.hpp:482
static constexpr index_t kBlockSize
Definition: grouped_convolution_backward_data_kernel.hpp:489
static CK_TILE_HOST bool IsSupportedArgument(const GroupedConvBwdDataKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_data_kernel.hpp:546
remove_cvref_t< typename GemmPipeline::BLayout > GemmBLayout
Definition: grouped_convolution_backward_data_kernel.hpp:478
remove_cvref_t< typename GroupedConvTraitsType_::DsLayout > DsLayout
Definition: grouped_convolution_backward_data_kernel.hpp:484
static constexpr auto I2
Definition: grouped_convolution_backward_data_kernel.hpp:507
static CK_TILE_DEVICE auto MakeGemmTensorViews(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, const GroupedConvBwdDataKernelArgsSpecialized &kargs, const index_t group_id)
Definition: grouped_convolution_backward_data_kernel.hpp:677
static CK_TILE_HOST auto GridSize(const GroupedConvBwdDataKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_data_kernel.hpp:526
remove_cvref_t< typename GemmPipeline::ALayout > GemmALayout
Definition: grouped_convolution_backward_data_kernel.hpp:477
remove_cvref_t< typename EpiloguePipeline::DsLayout > GemmDsLayout
Definition: grouped_convolution_backward_data_kernel.hpp:486
CK_TILE_DEVICE index_t FindGroupId(const GroupedConvBwdDataKernelArgsSpecialized &kargs, index_t block_id) const
Definition: grouped_convolution_backward_data_kernel.hpp:901
static CK_TILE_DEVICE void RunGemm(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, void *smem_ptr_0, const GroupedConvBwdDataKernelArgsSpecialized &kargs, const index_t block_idx_m, const index_t block_idx_n, const index_t group_id)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_backward_data_kernel.hpp:814
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: grouped_convolution_backward_data_kernel.hpp:540
static CK_TILE_DEVICE void RunGemm2LDS(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const GroupedConvBwdDataKernelArgsSpecialized &kargs, const index_t block_idx_m, const index_t block_idx_n, const index_t group_id)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_backward_data_kernel.hpp:865
CK_TILE_DEVICE void operator()(GroupedConvBwdDataKernelArgsSpecialized kargs) const
Definition: grouped_convolution_backward_data_kernel.hpp:926
static constexpr bool IsSplitKSupported
Definition: grouped_convolution_backward_data_kernel.hpp:503
remove_cvref_t< typename GroupedConvTraitsType_::InLayout > InLayout
Definition: grouped_convolution_backward_data_kernel.hpp:481
remove_cvref_t< typename GemmPipeline::CLayout > GemmCLayout
Definition: grouped_convolution_backward_data_kernel.hpp:479
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: grouped_convolution_backward_data_kernel.hpp:493
static CK_TILE_HOST const std::string GetName()
Definition: grouped_convolution_backward_data_kernel.hpp:519
static constexpr auto I0
Definition: grouped_convolution_backward_data_kernel.hpp:505
static CK_TILE_DEVICE auto GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept -> const tuple< index_t, index_t >
The function subtracts the block's start (offset) from 1D raw-indexes.
Definition: gemm_tile_partitioner.hpp:192
Definition: transform_conv_bwd_data_to_gemm.hpp:19
CK_TILE_HOST auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N([[maybe_unused]] const index_t GemmKBatch) const
Definition: transform_conv_bwd_data_to_gemm.hpp:546
constexpr CK_TILE_HOST_DEVICE auto & at(index_t i)
Definition: array.hpp:110
Definition: integral_constant.hpp:13
std::vector< ck_tile::long_index_t > input_spatial_lengths_
Definition: convolution_parameter.hpp:130
ck_tile::long_index_t K_
Definition: convolution_parameter.hpp:126
std::vector< ck_tile::long_index_t > output_spatial_lengths_
Definition: convolution_parameter.hpp:131
std::vector< ck_tile::long_index_t > input_right_pads_
Definition: convolution_parameter.hpp:137
ck_tile::long_index_t G_
Definition: convolution_parameter.hpp:124
std::vector< ck_tile::long_index_t > conv_filter_strides_
Definition: convolution_parameter.hpp:133
std::vector< ck_tile::long_index_t > filter_spatial_lengths_
Definition: convolution_parameter.hpp:129
ck_tile::long_index_t C_
Definition: convolution_parameter.hpp:127
ck_tile::long_index_t N_
Definition: convolution_parameter.hpp:125
std::vector< ck_tile::long_index_t > input_left_pads_
Definition: convolution_parameter.hpp:136
std::vector< ck_tile::long_index_t > conv_filter_dilations_
Definition: convolution_parameter.hpp:134
Definition: type_traits.hpp:115
Definition: sequence.hpp:49
#define CK_TILE_ENV(name)
Definition: env.hpp:145