/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp Source File
grouped_convolution_backward_weight_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 #include "ck_tile/host/concat.hpp"
16 
17 namespace ck_tile {
18 
20 template <typename GroupedConvTraitsType_>
22 {
23 
25  TransformConvBwdWeightToGemm<GroupedConvTraitsType_::NDimSpatial,
26  GroupedConvTraitsType_::ConvSpecialization,
27  GroupedConvTraitsType_::VectorSizeA,
28  GroupedConvTraitsType_::VectorSizeB,
29  GroupedConvTraitsType_::VectorSizeC>;
30  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
31 
32  template <
33  typename InLay = typename GroupedConvTraitsType_::InLayout,
34  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
35  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
36  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NWGC> &&
37  std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
38  std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
39  bool>::type = false>
41  {
42  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
43  static_cast<index_t>(args.N_),
44  static_cast<index_t>(args.C_),
45  static_cast<index_t>(args.input_spatial_lengths_[0])};
46  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
47  static_cast<index_t>(args.K_),
48  static_cast<index_t>(args.C_),
49  static_cast<index_t>(args.filter_spatial_lengths_[0])};
50  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
51  static_cast<index_t>(args.N_),
52  static_cast<index_t>(args.K_),
53  static_cast<index_t>(args.output_spatial_lengths_[0])};
54 
55  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0])};
56  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0])};
57  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0])};
58  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0])};
59 
60  k_batch = args.k_batch;
61 
62  in_ptr = args.in_ptr;
63  wei_ptr = args.wei_ptr;
64  for(index_t d = 0; d < NumDTensor; d++)
65  {
66  ds_ptr[d] = args.ds_ptr[d];
67  }
68  out_ptr = args.out_ptr;
69 
70  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
77 
78  // tuple
79  auto grid_descs =
80  conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
81  GroupedConvTraitsType_::NDimSpatial>();
82 
83  a_grid_desc_m_k = grid_descs.at(number<0>{});
84  b_grid_desc_n_k = grid_descs.at(number<1>{});
85  c_grid_desc_m_n = grid_descs.at(number<2>{});
86 
87  group_stride_a = args.K_; // A: Out NWGK
88  group_stride_b = args.C_; // B: In NWGC
89  group_stride_c = args.K_ * args.C_ * // C: Wei GKXC
90  std::accumulate(args.filter_spatial_lengths_.begin(),
91  args.filter_spatial_lengths_.end(),
92  1,
93  std::multiplies<index_t>());
94 
95  GemmM = a_grid_desc_m_k.get_length(number<0>{});
96  GemmN = b_grid_desc_n_k.get_length(number<0>{});
97  GemmK = a_grid_desc_m_k.get_length(number<1>{});
98  GemmBatch = args.G_;
99  }
100 
101  template <
102  typename InLay = typename GroupedConvTraitsType_::InLayout,
103  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
104  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
105  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NHWGC> &&
106  std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
107  std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
108  bool>::type = false>
110  {
111  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
112  static_cast<index_t>(args.N_),
113  static_cast<index_t>(args.C_),
114  static_cast<index_t>(args.input_spatial_lengths_[0]),
115  static_cast<index_t>(args.input_spatial_lengths_[1])};
116  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
117  static_cast<index_t>(args.K_),
118  static_cast<index_t>(args.C_),
119  static_cast<index_t>(args.filter_spatial_lengths_[0]),
120  static_cast<index_t>(args.filter_spatial_lengths_[1])};
121  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
122  static_cast<index_t>(args.N_),
123  static_cast<index_t>(args.K_),
124  static_cast<index_t>(args.output_spatial_lengths_[0]),
125  static_cast<index_t>(args.output_spatial_lengths_[1])};
126 
127  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
128  static_cast<index_t>(args.conv_filter_strides_[1])};
129  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
130  static_cast<index_t>(args.conv_filter_dilations_[1])};
131  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
132  static_cast<index_t>(args.input_left_pads_[1])};
133  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
134  static_cast<index_t>(args.input_right_pads_[1])};
135 
136  k_batch = args.k_batch;
137 
138  in_ptr = args.in_ptr;
139  wei_ptr = args.wei_ptr;
140  for(index_t d = 0; d < NumDTensor; d++)
141  {
142  ds_ptr[d] = args.ds_ptr[d];
143  }
144  out_ptr = args.out_ptr;
145 
146  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
153 
154  // tuple
155  auto grid_descs =
156  conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
157  GroupedConvTraitsType_::NDimSpatial>();
158 
159  a_grid_desc_m_k = grid_descs.at(number<0>{});
160  b_grid_desc_n_k = grid_descs.at(number<1>{});
161  c_grid_desc_m_n = grid_descs.at(number<2>{});
162 
163  group_stride_a = args.K_; // A: Out NHWGK
164  group_stride_b = args.C_; // B: In NHWGC
165  group_stride_c = args.K_ * args.C_ * // C: Wei GKYXC
166  std::accumulate(args.filter_spatial_lengths_.begin(),
167  args.filter_spatial_lengths_.end(),
168  1,
169  std::multiplies<index_t>());
170 
171  GemmM = a_grid_desc_m_k.get_length(number<0>{});
172  GemmN = b_grid_desc_n_k.get_length(number<0>{});
173  GemmK = a_grid_desc_m_k.get_length(number<1>{});
174  GemmBatch = args.G_;
175  }
176 
177  template <
178  typename InLay = typename GroupedConvTraitsType_::InLayout,
179  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
180  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
181  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NDHWGC> &&
182  std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
183  std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
184  bool>::type = false>
186  {
187  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
188  static_cast<index_t>(args.N_),
189  static_cast<index_t>(args.C_),
190  static_cast<index_t>(args.input_spatial_lengths_[0]),
191  static_cast<index_t>(args.input_spatial_lengths_[1]),
192  static_cast<index_t>(args.input_spatial_lengths_[2])};
193  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
194  static_cast<index_t>(args.K_),
195  static_cast<index_t>(args.C_),
196  static_cast<index_t>(args.filter_spatial_lengths_[0]),
197  static_cast<index_t>(args.filter_spatial_lengths_[1]),
198  static_cast<index_t>(args.filter_spatial_lengths_[2])};
199  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
200  static_cast<index_t>(args.N_),
201  static_cast<index_t>(args.K_),
202  static_cast<index_t>(args.output_spatial_lengths_[0]),
203  static_cast<index_t>(args.output_spatial_lengths_[1]),
204  static_cast<index_t>(args.output_spatial_lengths_[2])};
205 
206  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
207  static_cast<index_t>(args.conv_filter_strides_[1]),
208  static_cast<index_t>(args.conv_filter_strides_[2])};
209  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
210  static_cast<index_t>(args.conv_filter_dilations_[1]),
211  static_cast<index_t>(args.conv_filter_dilations_[2])};
212  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
213  static_cast<index_t>(args.input_left_pads_[1]),
214  static_cast<index_t>(args.input_left_pads_[2])};
215  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
216  static_cast<index_t>(args.input_right_pads_[1]),
217  static_cast<index_t>(args.input_right_pads_[2])};
218 
219  k_batch = args.k_batch;
220 
221  in_ptr = args.in_ptr;
222  wei_ptr = args.wei_ptr;
223  for(index_t d = 0; d < NumDTensor; d++)
224  {
225  ds_ptr[d] = args.ds_ptr[d];
226  }
227  out_ptr = args.out_ptr;
228 
229  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
236 
237  // tuple
238  auto grid_descs =
239  conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
240  GroupedConvTraitsType_::NDimSpatial>();
241 
242  a_grid_desc_m_k = grid_descs.at(number<0>{});
243  b_grid_desc_n_k = grid_descs.at(number<1>{});
244  c_grid_desc_m_n = grid_descs.at(number<2>{});
245 
246  group_stride_a = args.K_; // A: Out NDHWGK
247  group_stride_b = args.C_; // B: In NDHWGC
248  group_stride_c = args.K_ * args.C_ * // C: wEI GKZYXC
249  std::accumulate(args.filter_spatial_lengths_.begin(),
250  args.filter_spatial_lengths_.end(),
251  1,
252  std::multiplies<index_t>());
253 
254  GemmM = a_grid_desc_m_k.get_length(number<0>{});
255  GemmN = b_grid_desc_n_k.get_length(number<0>{});
256  GemmK = a_grid_desc_m_k.get_length(number<1>{});
257  GemmBatch = args.G_;
258  }
259 
262 
266 
267  static constexpr index_t NonSpatialDims = 3;
271 
276 
282 
283  const void* out_ptr;
284  const void* in_ptr;
285  std::array<const void*, NumDTensor> ds_ptr;
286  void* wei_ptr;
287 
291 
295 };
296 
335 template <typename GroupedConvTraitsType_,
336  typename TilePartitioner_,
337  typename GemmPipeline_,
338  typename EpiloguePipeline_>
340 {
341  // Todo: Enable Vector Load Size > 1
342  static_assert(GroupedConvTraitsType_::VectorSizeA == 1 &&
343  GroupedConvTraitsType_::VectorSizeB == 1);
344 
345  static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial_;
347  GroupedConvTraitsType_::ConvSpecialization;
354 
359 
361  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
362 
363  static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
364 
369 
372 
373  // TODO: Enable this
374  static constexpr bool IsSplitKSupported = true;
375 
376  static constexpr auto I0 = number<0>();
377  static constexpr auto I1 = number<1>();
378  static constexpr auto I2 = number<2>();
379  static constexpr auto I3 = number<3>();
380 
381  static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
382  "Not supported!");
383  static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>, "Not supported!");
384  static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
385  // TODO: Change to and enable vector load
386  // static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::ColumnMajor>, "Not
387  // supported!"); static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>, "Not
388  // supported!");
389  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
390 
391  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
392  {
393  // clang-format off
394  return concat('_', "grouped_convolution_backward_weight", gemm_prec_str<InDataType, WeiDataType>, GemmPipeline::GetName());
395  // clang-format on
396  }
397 
398  CK_TILE_HOST static constexpr auto
400  {
401  return dim3(
402  TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch);
403  }
404 
405  CK_TILE_HOST static constexpr auto BlockSize()
406  {
407  return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
408  }
409 
412  {
414  }
415 
417  {
418  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
419  }
420 
422  {
424  const std::size_t k_id = blockIdx.z)
425  {
426  constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
427  const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
428  const index_t KRead = amd_wave_read_first_lane((kargs.GemmK + K_t - 1) / K_t * K1);
429 
432 
433  if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
434  {
436  }
437  else
438  {
439  splitted_k = amd_wave_read_first_lane(kargs.GemmK - KRead * (kargs.k_batch - 1));
440  }
441  }
442 
446  };
447 
449  const stream_config& s)
450  {
451  return [&]() {
452  if(kargs.k_batch > 1)
453  hipGetErrorString(hipMemsetAsync(kargs.wei_ptr,
454  0,
455  kargs.GemmBatch * kargs.GemmM * kargs.GemmN *
456  sizeof(WeiDataType),
457  s.stream_id_));
458  };
459  }
460 
461  CK_TILE_HOST static bool
463  {
464  if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
467  {
468  if(kargs.k_batch != 1)
469  {
470  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
471  {
472  CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
473  }
474  return false;
475  }
476  }
477 
478  const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
479  const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
480 
481  // check ConvSpecialization
483  {
484  // check if it's 1x1, stride=1 conv
485  for(index_t i = 0; i < NDimSpatial; ++i)
486  {
487  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
488  const index_t ConvStride = kargs.conv_filter_strides[i];
489  const index_t LeftPad = kargs.input_left_pads[i];
490  const index_t RightPad = kargs.input_right_pads[i];
491 
492  if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
493  {
494  return false;
495  }
496  }
497  }
499  {
500  // check if it's 1x1 conv
501  for(index_t i = 0; i < NDimSpatial; ++i)
502  {
503  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
504  const index_t LeftPad = kargs.input_left_pads[i];
505  const index_t RightPad = kargs.input_right_pads[i];
506 
507  if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
508  {
509  return false;
510  }
511  }
512  }
514  {
515  if(ConvC != 1)
516  {
517  return false;
518  }
519  for(index_t i = 0; i < NDimSpatial; ++i)
520  {
521  const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3];
522 
523  if(filter_spatial_dim != I3)
524  {
525  return false;
526  }
527  }
528  }
529 
530  namespace ctc = tensor_layout::convolution;
531 
532  if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
533  std::is_same_v<InLayout, ctc::NDHWGC>)
534  {
535  // Check access per C
536  if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
537  {
538  CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
539  return false;
540  }
541  }
542  else
543  {
544  CK_TILE_ERROR("Not supported input layout!");
545  return false;
546  }
547 
548  if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
549  std::is_same_v<WeiLayout, ctc::GKYXC> ||
550  std::is_same_v<WeiLayout, ctc::GKZYXC>)
551  {
552  if(ConvC % GroupedConvTraitsType_::VectorSizeC != 0)
553  {
554  CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
555  return false;
556  }
557  }
558  else
559  {
560  CK_TILE_ERROR("Not supported weight layout!");
561  return false;
562  }
563 
564  if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
565  std::is_same_v<OutLayout, ctc::NHWGK> ||
566  std::is_same_v<OutLayout, ctc::NDHWGK>)
567  {
568  if(ConvK % GroupedConvTraitsType_::VectorSizeA != 0)
569  {
570  CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!");
571  return false;
572  }
573  }
574  else
575  {
576  CK_TILE_ERROR("Not supported output layout!");
577  return false;
578  }
579 
580  return true;
581  }
582 
583  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
584  CK_TILE_DEVICE static auto
586  const InDataType* b_ptr,
587  const std::array<const void*, NumDTensor>& ds_ptr,
588  WeiDataType* c_ptr,
590  {
591  static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
592  static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
593  const auto& a_tensor_view = [&]() {
594  return make_tensor_view<address_space_enum::global>(a_ptr,
595  kargs.a_grid_desc_m_k); // A: out
596  }();
597 
598  const auto& b_tensor_view = [&]() {
599  return make_tensor_view<address_space_enum::global>(b_ptr,
600  kargs.b_grid_desc_n_k); // B: in
601  }();
602 
603  const auto& c_tensor_view = [&]() {
604  return make_tensor_view<address_space_enum::global, DstInMemOp>(c_ptr,
605  kargs.c_grid_desc_m_n);
606  }();
607 
608  const auto& ds_tensor_view = generate_tuple(
609  [&](auto i) {
610  static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
611  "Not supported!");
612  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
613  "Not supported!");
614  static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, WeiDataType>,
615  "Not supported!");
616 
617  return make_tensor_view<address_space_enum::global>(
618  static_cast<WeiDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
619  },
621 
622  return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
623  }
624 
625  template <typename TensorView>
626  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views, const index_t k_batch)
627  {
628  const auto& a_pad_view = [&]() {
629  const auto& a_tensor_view = views.at(I0);
630  return pad_tensor_view(a_tensor_view,
634  }();
635 
636  const auto& b_pad_view = [&]() {
637  const auto& b_tensor_view = views.at(I1);
638  return pad_tensor_view(b_tensor_view,
642  }();
643 
644  const auto& ds_tensor_view = views.at(I2);
645  const auto& ds_pad_view = generate_tuple(
646  [&](auto i) {
647  return pad_tensor_view(ds_tensor_view[i],
651  },
653 
654  const auto& c_pad_view = [&]() {
655  const auto& c_tensor_view = views.at(I3);
656  return pad_tensor_view(c_tensor_view,
660  }();
661 
662  return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
663  }
664 
665  template <typename PadView>
666  CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
667  const index_t i_m,
668  const index_t i_n,
669  const index_t i_k)
670  {
671  const auto& a_pad_view = views.at(I0);
672  const auto& b_pad_view = views.at(I1);
673  const auto& ds_pad_view = views.at(I2);
674  const auto& c_pad_view = views.at(I3);
675 
676  const auto& a_block_window = [&]() {
677  return make_tile_window(a_pad_view,
680  {i_m, i_k});
681  }();
682 
683  const auto& b_block_window = [&]() {
684  return make_tile_window(b_pad_view,
687  {i_n, i_k});
688  }();
689 
690  const auto ds_block_window = generate_tuple(
691  [&](auto i) {
692  return make_tile_window(ds_pad_view[i],
695  {i_m, i_n});
696  },
698 
699  auto c_block_window = make_tile_window(
700  c_pad_view,
702  {i_m, i_n});
703 
704  return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
705  }
706 
719  CK_TILE_DEVICE static void RunGemm(const OutDataType* a_ptr,
720  const InDataType* b_ptr,
721  const std::array<const void*, NumDTensor>& ds_ptr,
722  WeiDataType* c_ptr,
723  void* smem_ptr_0,
725  const index_t num_loop,
726  const index_t block_idx_m,
727  const index_t block_idx_n,
728  const index_t block_idx_k)
729  {
730  // Create Gemm tensor views, pad views and tile windows
731  const auto& gemm_tensor_views_tuple =
732  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
733  a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
734 
735  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple, kargs.k_batch);
736  auto gemm_tile_windows =
737  MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
738 
739  // Run GEMM cooperatively by whole workgroup.
740  const auto& a_block_window = gemm_tile_windows.at(I0);
741  const auto& b_block_window = gemm_tile_windows.at(I1);
742  const auto& d_block_window = gemm_tile_windows.at(I2);
743 
744  const auto& c_block_tile = GemmPipeline{}.template operator()(
745  a_block_window, b_block_window, num_loop, smem_ptr_0);
746 
747  // Run Epilogue Pipeline
748  auto& c_block_window = gemm_tile_windows.at(I3);
749 
750  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
751  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
752  }
753 
769  CK_TILE_DEVICE static void RunGemm2LDS(const OutDataType* a_ptr,
770  const InDataType* b_ptr,
771  const std::array<const void*, NumDTensor>& ds_ptr,
772  WeiDataType* c_ptr,
773  void* __restrict__ smem_ptr_0,
774  void* __restrict__ smem_ptr_1,
776  const index_t num_loop,
777  const index_t block_idx_m,
778  const index_t block_idx_n,
779  const index_t block_idx_k)
780  {
781  // Create Gemm tensor views, pad views and tile windows
782  const auto& gemm_tensor_views_tuple =
783  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
784  a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
785  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple, kargs.k_batch);
786  auto gemm_tile_windows =
787  MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
788 
789  // Run GEMM cooperatively by whole workgroup.
790  const auto& a_block_window = gemm_tile_windows.at(I0);
791  const auto& b_block_window = gemm_tile_windows.at(I1);
792  const auto& d_block_window = gemm_tile_windows.at(I2);
793 
794  const auto& c_block_tile = GemmPipeline{}.template operator()(
795  a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
796 
797  // Run Epilogue Pipeline
798  auto& c_block_window = gemm_tile_windows.at(I3);
799 
800  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
801  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
802  }
803 
805  {
806  const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
807  const auto [iM, iN] =
808  TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX);
809  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
810  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
811 
812  const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z);
813  const index_t num_loop = amd_wave_read_first_lane(
814  ck_tile::integer_divide_ceil(kargs.GemmK, kargs.k_batch * TilePartitioner::KPerBlock));
815  const index_t i_k =
816  amd_wave_read_first_lane(blockIdZ * num_loop * TilePartitioner::KPerBlock);
817 
818  const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
819  const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
820  const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
821  const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
822 
823  // options
824  // conv_bwd_weight = Out * In = Weight
825  const OutDataType* a_ptr = static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a;
826  const InDataType* b_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_b;
827  WeiDataType* c_ptr = static_cast<WeiDataType*>(kargs.wei_ptr) + group_offset_c;
828 
829  // allocate LDS
830  __shared__ char smem_ptr_0[GetSmemSize()];
831 
832  if constexpr(GemmPipeline::DoubleSmemBuffer == true)
833  {
834  __shared__ char smem_ptr_1[GetSmemSize()];
835  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
836  GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
838  {
839  RunGemm2LDS(a_ptr,
840  b_ptr,
841  kargs.ds_ptr,
842  c_ptr,
843  smem_ptr_0,
844  smem_ptr_1,
845  kargs,
846  num_loop,
847  i_m,
848  i_n,
849  i_k);
850  }
851  }
852  else
853  {
854  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
855  GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
857  {
858  RunGemm(
859  a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, num_loop, i_m, i_n, i_k);
860  }
861  }
862  }
863 };
864 
865 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:2834
ConvolutionSpecialization
Definition: convolution_specialization.hpp:11
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
constexpr bool is_same_v
Definition: type.hpp:283
__device__ X atomic_add(X *p_dst, const X &x)
unsigned int uint32_t
Definition: stdint.h:126
The Grouped Convolution kernel device arguments.
Definition: grouped_convolution_backward_weight_kernel.hpp:22
long_index_t group_stride_a
Definition: grouped_convolution_backward_weight_kernel.hpp:292
remove_cvref_t< decltype(ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N())> ABCGridDescs
Definition: grouped_convolution_backward_weight_kernel.hpp:261
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_strides
Definition: grouped_convolution_backward_weight_kernel.hpp:272
remove_cvref_t< decltype(ABCGridDescs{}[number< 1 >{}])> BGridDescNK
Definition: grouped_convolution_backward_weight_kernel.hpp:264
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > wei_g_k_c_xs_lengths
Definition: grouped_convolution_backward_weight_kernel.hpp:269
void * wei_ptr
Definition: grouped_convolution_backward_weight_kernel.hpp:286
long_index_t group_stride_b
Definition: grouped_convolution_backward_weight_kernel.hpp:293
CGridDescMN c_grid_desc_m_n
Definition: grouped_convolution_backward_weight_kernel.hpp:290
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > in_g_n_c_wis_lengths
Definition: grouped_convolution_backward_weight_kernel.hpp:268
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_dilations
Definition: grouped_convolution_backward_weight_kernel.hpp:273
index_t GemmN
Definition: grouped_convolution_backward_weight_kernel.hpp:279
index_t GemmBatch
Definition: grouped_convolution_backward_weight_kernel.hpp:281
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > out_g_n_k_wos_lengths
Definition: grouped_convolution_backward_weight_kernel.hpp:270
remove_cvref_t< decltype(ABCGridDescs{}[number< 0 >{}])> AGridDescMK
Definition: grouped_convolution_backward_weight_kernel.hpp:263
BGridDescNK b_grid_desc_n_k
Definition: grouped_convolution_backward_weight_kernel.hpp:289
CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs &args)
Definition: grouped_convolution_backward_weight_kernel.hpp:40
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_left_pads
Definition: grouped_convolution_backward_weight_kernel.hpp:274
std::array< const void *, NumDTensor > ds_ptr
Definition: grouped_convolution_backward_weight_kernel.hpp:285
index_t GemmM
Definition: grouped_convolution_backward_weight_kernel.hpp:278
remove_cvref_t< decltype(ABCGridDescs{}[number< 2 >{}])> CGridDescMN
Definition: grouped_convolution_backward_weight_kernel.hpp:265
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_right_pads
Definition: grouped_convolution_backward_weight_kernel.hpp:275
index_t GemmK
Definition: grouped_convolution_backward_weight_kernel.hpp:280
const void * in_ptr
Definition: grouped_convolution_backward_weight_kernel.hpp:284
index_t k_batch
Definition: grouped_convolution_backward_weight_kernel.hpp:277
static constexpr index_t NonSpatialDims
Definition: grouped_convolution_backward_weight_kernel.hpp:267
const void * out_ptr
Definition: grouped_convolution_backward_weight_kernel.hpp:283
AGridDescMK a_grid_desc_m_k
Definition: grouped_convolution_backward_weight_kernel.hpp:288
static constexpr index_t NumDTensor
Definition: grouped_convolution_backward_weight_kernel.hpp:30
long_index_t group_stride_c
Definition: grouped_convolution_backward_weight_kernel.hpp:294
The Grouped Conv kernel host arguments.
Definition: grouped_convolution_utils.hpp:19
index_t k_batch
Definition: grouped_convolution_utils.hpp:40
InPtr in_ptr
Definition: grouped_convolution_utils.hpp:36
WeiPtr wei_ptr
Definition: grouped_convolution_utils.hpp:37
OutPtr out_ptr
Definition: grouped_convolution_utils.hpp:39
const std::vector< const void * > ds_ptr
Definition: grouped_convolution_utils.hpp:38
Definition: grouped_convolution_backward_weight_kernel.hpp:422
index_t b_k_split_offset
Definition: grouped_convolution_backward_weight_kernel.hpp:444
index_t splitted_k
Definition: grouped_convolution_backward_weight_kernel.hpp:445
__device__ SplitKBatchOffset(const GroupedConvBwdWeightKernelArgsSpecialized &kargs, const std::size_t k_id=blockIdx.z)
Definition: grouped_convolution_backward_weight_kernel.hpp:423
index_t a_k_split_offset
Definition: grouped_convolution_backward_weight_kernel.hpp:443
The Grouped Convolution Backward Weight kernel template.
Definition: grouped_convolution_backward_weight_kernel.hpp:340
remove_cvref_t< typename EpiloguePipeline::DsLayout > GemmDsLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:360
static constexpr index_t kBlockSize
Definition: grouped_convolution_backward_weight_kernel.hpp:363
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views, const index_t k_batch)
Definition: grouped_convolution_backward_weight_kernel.hpp:626
static constexpr CK_TILE_HOST auto BlockSize()
Definition: grouped_convolution_backward_weight_kernel.hpp:405
remove_cvref_t< typename GroupedConvTraitsType_::OutLayout > OutLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:357
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_convolution_backward_weight_kernel.hpp:348
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: grouped_convolution_backward_weight_kernel.hpp:349
GroupedConvBwdWeightKernelArgs< GroupedConvTraitsType_ > GroupedConvBwdWeightKernelArgsSpecialized
Definition: grouped_convolution_backward_weight_kernel.hpp:371
static CK_TILE_HOST const std::string GetName()
Definition: grouped_convolution_backward_weight_kernel.hpp:391
static CK_TILE_DEVICE void RunGemm2LDS(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const GroupedConvBwdWeightKernelArgsSpecialized &kargs, const index_t num_loop, const index_t block_idx_m, const index_t block_idx_n, const index_t block_idx_k)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_backward_weight_kernel.hpp:769
remove_cvref_t< typename GemmPipeline::CLayout > GemmCLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:353
static constexpr auto I2
Definition: grouped_convolution_backward_weight_kernel.hpp:378
static constexpr CK_TILE_HOST GroupedConvBwdWeightKernelArgsSpecialized MakeKernelArgs(const GroupedConvBwdWeightHostArgs &hostArgs)
Definition: grouped_convolution_backward_weight_kernel.hpp:411
remove_cvref_t< typename GemmPipeline::ALayout > GemmALayout
Definition: grouped_convolution_backward_weight_kernel.hpp:351
static CK_TILE_HOST bool IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_weight_kernel.hpp:462
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: grouped_convolution_backward_weight_kernel.hpp:350
static constexpr ConvolutionSpecialization ConvSpecialization
Definition: grouped_convolution_backward_weight_kernel.hpp:346
remove_cvref_t< typename GroupedConvTraitsType_::WeiLayout > WeiLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:356
static constexpr bool IsSplitKSupported
Definition: grouped_convolution_backward_weight_kernel.hpp:374
static CK_TILE_DEVICE void RunGemm(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, void *smem_ptr_0, const GroupedConvBwdWeightKernelArgsSpecialized &kargs, const index_t num_loop, const index_t block_idx_m, const index_t block_idx_n, const index_t block_idx_k)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_backward_weight_kernel.hpp:719
static constexpr index_t NDimSpatial
Definition: grouped_convolution_backward_weight_kernel.hpp:345
remove_cvref_t< typename GroupedConvTraitsType_::DsLayout > DsLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:358
remove_cvref_t< typename GroupedConvTraitsType_::InLayout > InLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:355
static CK_TILE_HOST auto Preprocess(const GroupedConvBwdWeightKernelArgsSpecialized &kargs, const stream_config &s)
Definition: grouped_convolution_backward_weight_kernel.hpp:448
remove_cvref_t< typename EpiloguePipeline::ODataType > WeiDataType
Definition: grouped_convolution_backward_weight_kernel.hpp:368
static constexpr auto I3
Definition: grouped_convolution_backward_weight_kernel.hpp:379
static constexpr auto I0
Definition: grouped_convolution_backward_weight_kernel.hpp:376
CK_TILE_DEVICE void operator()(GroupedConvBwdWeightKernelArgsSpecialized kargs) const
Definition: grouped_convolution_backward_weight_kernel.hpp:804
static constexpr auto I1
Definition: grouped_convolution_backward_weight_kernel.hpp:377
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: grouped_convolution_backward_weight_kernel.hpp:367
static constexpr index_t NumDTensor
Definition: grouped_convolution_backward_weight_kernel.hpp:361
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: grouped_convolution_backward_weight_kernel.hpp:416
remove_cvref_t< typename GemmPipeline::ADataType > OutDataType
Definition: grouped_convolution_backward_weight_kernel.hpp:365
static CK_TILE_DEVICE auto MakeGemmTensorViews(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, const GroupedConvBwdWeightKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_weight_kernel.hpp:585
remove_cvref_t< typename GemmPipeline::BDataType > InDataType
Definition: grouped_convolution_backward_weight_kernel.hpp:366
remove_cvref_t< typename GemmPipeline::BLayout > GemmBLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:352
static constexpr CK_TILE_HOST auto GridSize(const GroupedConvBwdWeightKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_weight_kernel.hpp:399
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n, const index_t i_k)
Definition: grouped_convolution_backward_weight_kernel.hpp:666
Definition: transform_conv_bwd_weight_to_gemm.hpp:22
CK_TILE_HOST auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N() const
Definition: transform_conv_bwd_weight_to_gemm.hpp:551
Definition: integral_constant.hpp:13
std::vector< ck_tile::long_index_t > input_spatial_lengths_
Definition: convolution_parameter.hpp:130
ck_tile::long_index_t K_
Definition: convolution_parameter.hpp:126
std::vector< ck_tile::long_index_t > output_spatial_lengths_
Definition: convolution_parameter.hpp:131
std::vector< ck_tile::long_index_t > input_right_pads_
Definition: convolution_parameter.hpp:137
ck_tile::long_index_t G_
Definition: convolution_parameter.hpp:124
std::vector< ck_tile::long_index_t > conv_filter_strides_
Definition: convolution_parameter.hpp:133
std::vector< ck_tile::long_index_t > filter_spatial_lengths_
Definition: convolution_parameter.hpp:129
ck_tile::long_index_t C_
Definition: convolution_parameter.hpp:127
ck_tile::long_index_t N_
Definition: convolution_parameter.hpp:125
std::vector< ck_tile::long_index_t > input_left_pads_
Definition: convolution_parameter.hpp:136
std::vector< ck_tile::long_index_t > conv_filter_dilations_
Definition: convolution_parameter.hpp:134
Definition: type_traits.hpp:115
Definition: sequence.hpp:49
Definition: stream_config.hpp:30
hipStream_t stream_id_
Definition: stream_config.hpp:31
#define CK_TILE_ENV(name)
Definition: env.hpp:145