/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp Source File
grouped_convolution_backward_weight_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 #include "ck_tile/host/concat.hpp"
16 
17 namespace ck_tile {
18 
20 template <typename GroupedConvTraitsType_>
22 {
23 
25  TransformConvBwdWeightToGemm<GroupedConvTraitsType_::NDimSpatial,
26  GroupedConvTraitsType_::ConvSpecialization>;
27  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
28 
29  template <
30  typename InLay = typename GroupedConvTraitsType_::InLayout,
31  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
32  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
33  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NWGC> &&
34  std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
35  std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
36  bool>::type = false>
38  {
39  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
40  static_cast<index_t>(args.N_),
41  static_cast<index_t>(args.C_),
42  static_cast<index_t>(args.input_spatial_lengths_[0])};
43  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
44  static_cast<index_t>(args.K_),
45  static_cast<index_t>(args.C_),
46  static_cast<index_t>(args.filter_spatial_lengths_[0])};
47  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
48  static_cast<index_t>(args.N_),
49  static_cast<index_t>(args.K_),
50  static_cast<index_t>(args.output_spatial_lengths_[0])};
51 
52  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0])};
53  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0])};
54  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0])};
55  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0])};
56 
57  k_batch = args.k_batch;
58 
59  in_ptr = args.in_ptr;
60  wei_ptr = args.wei_ptr;
61  for(index_t d = 0; d < NumDTensor; d++)
62  {
63  ds_ptr[d] = args.ds_ptr[d];
64  }
65  out_ptr = args.out_ptr;
66 
67  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
74 
75  // tuple
76  auto grid_descs =
77  conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
78  GroupedConvTraitsType_::NDimSpatial>();
79 
80  a_grid_desc_m_k = grid_descs.at(number<0>{});
81  b_grid_desc_n_k = grid_descs.at(number<1>{});
82  c_grid_desc_m_n = grid_descs.at(number<2>{});
83 
84  group_stride_a = args.K_; // A: Out NWGK
85  group_stride_b = args.C_; // B: In NWGC
86  group_stride_c = args.K_ * args.C_ * // C: Wei GKXC
87  std::accumulate(args.filter_spatial_lengths_.begin(),
88  args.filter_spatial_lengths_.end(),
89  1,
90  std::multiplies<index_t>());
91 
92  GemmM = a_grid_desc_m_k.get_length(number<0>{});
93  GemmN = b_grid_desc_n_k.get_length(number<0>{});
94  GemmK = a_grid_desc_m_k.get_length(number<1>{});
95  GemmBatch = args.G_;
96  }
97 
98  template <
99  typename InLay = typename GroupedConvTraitsType_::InLayout,
100  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
101  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
102  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NHWGC> &&
103  std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
104  std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
105  bool>::type = false>
107  {
108  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
109  static_cast<index_t>(args.N_),
110  static_cast<index_t>(args.C_),
111  static_cast<index_t>(args.input_spatial_lengths_[0]),
112  static_cast<index_t>(args.input_spatial_lengths_[1])};
113  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
114  static_cast<index_t>(args.K_),
115  static_cast<index_t>(args.C_),
116  static_cast<index_t>(args.filter_spatial_lengths_[0]),
117  static_cast<index_t>(args.filter_spatial_lengths_[1])};
118  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
119  static_cast<index_t>(args.N_),
120  static_cast<index_t>(args.K_),
121  static_cast<index_t>(args.output_spatial_lengths_[0]),
122  static_cast<index_t>(args.output_spatial_lengths_[1])};
123 
124  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
125  static_cast<index_t>(args.conv_filter_strides_[1])};
126  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
127  static_cast<index_t>(args.conv_filter_dilations_[1])};
128  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
129  static_cast<index_t>(args.input_left_pads_[1])};
130  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
131  static_cast<index_t>(args.input_right_pads_[1])};
132 
133  k_batch = args.k_batch;
134 
135  in_ptr = args.in_ptr;
136  wei_ptr = args.wei_ptr;
137  for(index_t d = 0; d < NumDTensor; d++)
138  {
139  ds_ptr[d] = args.ds_ptr[d];
140  }
141  out_ptr = args.out_ptr;
142 
143  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
150 
151  // tuple
152  auto grid_descs =
153  conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
154  GroupedConvTraitsType_::NDimSpatial>();
155 
156  a_grid_desc_m_k = grid_descs.at(number<0>{});
157  b_grid_desc_n_k = grid_descs.at(number<1>{});
158  c_grid_desc_m_n = grid_descs.at(number<2>{});
159 
160  group_stride_a = args.K_; // A: Out NHWGK
161  group_stride_b = args.C_; // B: In NHWGC
162  group_stride_c = args.K_ * args.C_ * // C: Wei GKYXC
163  std::accumulate(args.filter_spatial_lengths_.begin(),
164  args.filter_spatial_lengths_.end(),
165  1,
166  std::multiplies<index_t>());
167 
168  GemmM = a_grid_desc_m_k.get_length(number<0>{});
169  GemmN = b_grid_desc_n_k.get_length(number<0>{});
170  GemmK = a_grid_desc_m_k.get_length(number<1>{});
171  GemmBatch = args.G_;
172  }
173 
174  template <
175  typename InLay = typename GroupedConvTraitsType_::InLayout,
176  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
177  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
178  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NDHWGC> &&
179  std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
180  std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
181  bool>::type = false>
183  {
184  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
185  static_cast<index_t>(args.N_),
186  static_cast<index_t>(args.C_),
187  static_cast<index_t>(args.input_spatial_lengths_[0]),
188  static_cast<index_t>(args.input_spatial_lengths_[1]),
189  static_cast<index_t>(args.input_spatial_lengths_[2])};
190  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
191  static_cast<index_t>(args.K_),
192  static_cast<index_t>(args.C_),
193  static_cast<index_t>(args.filter_spatial_lengths_[0]),
194  static_cast<index_t>(args.filter_spatial_lengths_[1]),
195  static_cast<index_t>(args.filter_spatial_lengths_[2])};
196  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
197  static_cast<index_t>(args.N_),
198  static_cast<index_t>(args.K_),
199  static_cast<index_t>(args.output_spatial_lengths_[0]),
200  static_cast<index_t>(args.output_spatial_lengths_[1]),
201  static_cast<index_t>(args.output_spatial_lengths_[2])};
202 
203  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
204  static_cast<index_t>(args.conv_filter_strides_[1]),
205  static_cast<index_t>(args.conv_filter_strides_[2])};
206  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
207  static_cast<index_t>(args.conv_filter_dilations_[1]),
208  static_cast<index_t>(args.conv_filter_dilations_[2])};
209  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
210  static_cast<index_t>(args.input_left_pads_[1]),
211  static_cast<index_t>(args.input_left_pads_[2])};
212  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
213  static_cast<index_t>(args.input_right_pads_[1]),
214  static_cast<index_t>(args.input_right_pads_[2])};
215 
216  k_batch = args.k_batch;
217 
218  in_ptr = args.in_ptr;
219  wei_ptr = args.wei_ptr;
220  for(index_t d = 0; d < NumDTensor; d++)
221  {
222  ds_ptr[d] = args.ds_ptr[d];
223  }
224  out_ptr = args.out_ptr;
225 
226  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
233 
234  // tuple
235  auto grid_descs =
236  conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
237  GroupedConvTraitsType_::NDimSpatial>();
238 
239  a_grid_desc_m_k = grid_descs.at(number<0>{});
240  b_grid_desc_n_k = grid_descs.at(number<1>{});
241  c_grid_desc_m_n = grid_descs.at(number<2>{});
242 
243  group_stride_a = args.K_; // A: Out NDHWGK
244  group_stride_b = args.C_; // B: In NDHWGC
245  group_stride_c = args.K_ * args.C_ * // C: wEI GKZYXC
246  std::accumulate(args.filter_spatial_lengths_.begin(),
247  args.filter_spatial_lengths_.end(),
248  1,
249  std::multiplies<index_t>());
250 
251  GemmM = a_grid_desc_m_k.get_length(number<0>{});
252  GemmN = b_grid_desc_n_k.get_length(number<0>{});
253  GemmK = a_grid_desc_m_k.get_length(number<1>{});
254  GemmBatch = args.G_;
255  }
256 
259 
263 
264  static constexpr index_t NonSpatialDims = 3;
268 
273 
279 
280  const void* out_ptr;
281  const void* in_ptr;
282  std::array<const void*, NumDTensor> ds_ptr;
283  void* wei_ptr;
284 
288 
292 };
293 
332 template <typename GroupedConvTraitsType_,
333  typename TilePartitioner_,
334  typename GemmPipeline_,
335  typename EpiloguePipeline_>
337 {
338  static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial_;
340  GroupedConvTraitsType_::ConvSpecialization;
347 
352 
354  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
355 
356  static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
357 
361  // Below type is actually accumulation data type - the output of block GEMM.
363 
366 
367  // TODO: Enable this
368  static constexpr bool IsSplitKSupported = true;
369 
370  static constexpr auto I0 = number<0>();
371  static constexpr auto I1 = number<1>();
372  static constexpr auto I2 = number<2>();
373  static constexpr auto I3 = number<3>();
374 
375  static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
376  "Not supported!");
377  static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>, "Not supported!");
378  static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
379  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
380 
381  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
382  {
383  // clang-format off
384  return concat('_', "grouped_convolution_backward_weight", gemm_prec_str<InDataType, WeiDataType>, GemmPipeline::GetName());
385  // clang-format on
386  }
387 
388  CK_TILE_HOST static constexpr auto
390  {
391  return dim3(
392  TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch);
393  }
394 
395  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
396 
399  {
401  }
402 
404  {
405  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
406  }
407 
409  {
411  const std::size_t k_id = blockIdx.z)
412  {
413  constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
414  const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1);
415  const index_t KRead =
416  __builtin_amdgcn_readfirstlane((kargs.GemmK + K_t - 1) / K_t * K1);
417 
418  a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead);
419  b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead);
420 
421  if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
422  {
423  splitted_k = __builtin_amdgcn_readfirstlane(KRead);
424  }
425  else
426  {
427  splitted_k =
428  __builtin_amdgcn_readfirstlane(kargs.GemmK - KRead * (kargs.k_batch - 1));
429  }
430  }
431 
435  };
436 
438  const stream_config& s)
439  {
440  return [&]() {
441  if(kargs.k_batch > 1)
442  hipGetErrorString(hipMemsetAsync(kargs.wei_ptr,
443  0,
444  kargs.GemmBatch * kargs.GemmM * kargs.GemmN *
445  sizeof(WeiDataType),
446  s.stream_id_));
447  };
448  }
449 
450  CK_TILE_HOST static bool
452  {
453  if constexpr((EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
456  {
457  if(kargs.k_batch != 1)
458  {
459  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
460  {
461  CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
462  }
463  return false;
464  }
465  }
466 
467  const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
468  const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
469 
470  // check ConvSpecialization
472  {
473  // check if it's 1x1, stride=1 conv
474  for(index_t i = 0; i < NDimSpatial; ++i)
475  {
476  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
477  const index_t ConvStride = kargs.conv_filter_strides[i];
478  const index_t LeftPad = kargs.input_left_pads[i];
479  const index_t RightPad = kargs.input_right_pads[i];
480 
481  if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
482  {
483  return false;
484  }
485  }
486  }
488  {
489  // check if it's 1x1 conv
490  for(index_t i = 0; i < NDimSpatial; ++i)
491  {
492  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
493  const index_t LeftPad = kargs.input_left_pads[i];
494  const index_t RightPad = kargs.input_right_pads[i];
495 
496  if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
497  {
498  return false;
499  }
500  }
501  }
503  {
504  if(ConvC != 1)
505  {
506  return false;
507  }
508  for(index_t i = 0; i < NDimSpatial; ++i)
509  {
510  const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3];
511 
512  if(filter_spatial_dim != I3)
513  {
514  return false;
515  }
516  }
517  }
518 
519  namespace ctc = tensor_layout::convolution;
520 
521  if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
522  std::is_same_v<InLayout, ctc::NDHWGC>)
523  {
524  // Check access per C
525  if(ConvC % GemmPipeline::GetVectorSizeB() != 0)
526  {
527  CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
528  return false;
529  }
530  }
531  else
532  {
533  CK_TILE_ERROR("Not supported input layout!");
534  return false;
535  }
536 
537  // check vector access of B
538  // FIXME: layout
539  if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
540  std::is_same_v<WeiLayout, ctc::GKYXC> ||
541  std::is_same_v<WeiLayout, ctc::GKZYXC>)
542  {
543  if(ConvC % EpiloguePipeline::GetVectorSizeC() != 0)
544  {
545  CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
546  return false;
547  }
548  }
549  else
550  {
551  CK_TILE_ERROR("Not supported weight layout!");
552  return false;
553  }
554 
555  // check vector access of E
556  if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
557  std::is_same_v<OutLayout, ctc::NHWGK> ||
558  std::is_same_v<OutLayout, ctc::NDHWGK>)
559  {
560  if(ConvK % GemmPipeline::GetVectorSizeA() != 0)
561  {
562  CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!");
563  return false;
564  }
565  }
566  else
567  {
568  CK_TILE_ERROR("Not supported output layout!");
569  return false;
570  }
571 
572  return true;
573  }
574 
575  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
576  CK_TILE_DEVICE static auto
578  const InDataType* b_ptr,
579  const std::array<const void*, NumDTensor>& ds_ptr,
580  WeiDataType* c_ptr,
582  {
583  static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
584  static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
585  const auto& a_tensor_view = [&]() {
586  return make_tensor_view<address_space_enum::global>(a_ptr,
587  kargs.a_grid_desc_m_k); // A: out
588  }();
589 
590  const auto& b_tensor_view = [&]() {
591  return make_tensor_view<address_space_enum::global>(b_ptr,
592  kargs.b_grid_desc_n_k); // B: in
593  }();
594 
595  const auto& c_tensor_view = [&]() {
596  return make_tensor_view<address_space_enum::global, DstInMemOp>(
597  c_ptr,
598  kargs.c_grid_desc_m_n); // B: in
599  }();
600 
601  const auto& ds_tensor_view = generate_tuple(
602  [&](auto i) {
603  static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
604  "Not supported!");
605  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
606  "Not supported!");
607  static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, OutDataType>,
608  "Not supported!");
609 
610  return make_tensor_view<address_space_enum::global>(
611  static_cast<OutDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
612  },
614 
615  return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
616  }
617 
618  template <typename TensorView>
619  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views, const index_t k_batch)
620  {
621  const auto& a_pad_view = [&]() {
622  const auto& a_tensor_view = views.at(I0);
623  return pad_tensor_view(a_tensor_view,
627  }();
628 
629  const auto& b_pad_view = [&]() {
630  const auto& b_tensor_view = views.at(I1);
631  return pad_tensor_view(b_tensor_view,
635  }();
636 
637  const auto& ds_tensor_view = views.at(I2);
638  const auto& ds_pad_view = generate_tuple(
639  [&](auto i) {
640  return pad_tensor_view(ds_tensor_view[i],
644  },
646 
647  const auto& c_pad_view = [&]() {
648  const auto& c_tensor_view = views.at(I3);
649  return pad_tensor_view(c_tensor_view,
653  }();
654 
655  return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
656  }
657 
658  template <typename PadView>
659  CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
660  const index_t i_m,
661  const index_t i_n,
662  const index_t i_k)
663  {
664  const auto& a_pad_view = views.at(I0);
665  const auto& b_pad_view = views.at(I1);
666  const auto& ds_pad_view = views.at(I2);
667  const auto& c_pad_view = views.at(I3);
668 
669  const auto& a_block_window = [&]() {
670  return make_tile_window(a_pad_view,
673  {i_m, i_k});
674  }();
675 
676  const auto& b_block_window = [&]() {
677  return make_tile_window(b_pad_view,
680  {i_n, i_k});
681  }();
682 
683  const auto ds_block_window = generate_tuple(
684  [&](auto i) {
685  return make_tile_window(ds_pad_view[i],
688  {i_m, i_n});
689  },
691 
692  auto c_block_window = make_tile_window(
693  c_pad_view,
695  {i_m, i_n});
696 
697  return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
698  }
699 
712  CK_TILE_DEVICE static void RunGemm(const OutDataType* a_ptr,
713  const InDataType* b_ptr,
714  const std::array<const void*, NumDTensor>& ds_ptr,
715  WeiDataType* c_ptr,
716  void* smem_ptr_0,
718  const index_t num_loop,
719  const index_t block_idx_m,
720  const index_t block_idx_n,
721  const index_t block_idx_k)
722  {
723  // Create Gemm tensor views, pad views and tile windows
724  const auto& gemm_tensor_views_tuple =
725  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
726  a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
727 
728  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple, kargs.k_batch);
729  auto gemm_tile_windows =
730  MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
731 
732  // Run GEMM cooperatively by whole workgroup.
733  const auto& a_block_window = gemm_tile_windows.at(I0);
734  const auto& b_block_window = gemm_tile_windows.at(I1);
735  const auto& d_block_window = gemm_tile_windows.at(I2);
736 
737  const auto& c_block_tile = GemmPipeline{}.template operator()(
738  a_block_window, b_block_window, num_loop, smem_ptr_0);
739 
740  // Run Epilogue Pipeline
741  auto& c_block_window = gemm_tile_windows.at(I3);
742 
743  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
744  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
745  }
746 
762  CK_TILE_DEVICE static void RunGemm2LDS(const OutDataType* a_ptr,
763  const InDataType* b_ptr,
764  const std::array<const void*, NumDTensor>& ds_ptr,
765  WeiDataType* c_ptr,
766  void* __restrict__ smem_ptr_0,
767  void* __restrict__ smem_ptr_1,
769  const index_t num_loop,
770  const index_t block_idx_m,
771  const index_t block_idx_n,
772  const index_t block_idx_k)
773  {
774  // Create Gemm tensor views, pad views and tile windows
775  const auto& gemm_tensor_views_tuple =
776  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
777  a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
778  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple, kargs.k_batch);
779  auto gemm_tile_windows =
780  MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
781 
782  // Run GEMM cooperatively by whole workgroup.
783  const auto& a_block_window = gemm_tile_windows.at(I0);
784  const auto& b_block_window = gemm_tile_windows.at(I1);
785  const auto& d_block_window = gemm_tile_windows.at(I2);
786 
787  const auto& c_block_tile = GemmPipeline{}.template operator()(
788  a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
789 
790  // Run Epilogue Pipeline
791  auto& c_block_window = gemm_tile_windows.at(I3);
792 
793  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
794  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
795  }
796 
798  {
799  const auto blockIdX = __builtin_amdgcn_readfirstlane(blockIdx.x);
800  const auto [iM, iN] =
801  TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX);
802  const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
803  const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
804 
805  const auto blockIdZ = __builtin_amdgcn_readfirstlane(blockIdx.z);
806  const index_t num_loop = __builtin_amdgcn_readfirstlane(
807  ck_tile::integer_divide_ceil(kargs.GemmK, kargs.k_batch * TilePartitioner::KPerBlock));
808  const index_t i_k =
809  __builtin_amdgcn_readfirstlane(blockIdZ * num_loop * TilePartitioner::KPerBlock);
810 
811  const auto blockIdY = __builtin_amdgcn_readfirstlane(blockIdx.y);
812  const auto group_offset_a = __builtin_amdgcn_readfirstlane(kargs.group_stride_a * blockIdY);
813  const auto group_offset_b = __builtin_amdgcn_readfirstlane(kargs.group_stride_b * blockIdY);
814  const auto group_offset_c = __builtin_amdgcn_readfirstlane(kargs.group_stride_c * blockIdY);
815 
816  // options
817  // conv_bwd_weight = Out * In = Weight
818  const OutDataType* a_ptr = static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a;
819  const InDataType* b_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_b;
820  WeiDataType* c_ptr = static_cast<WeiDataType*>(kargs.wei_ptr) + group_offset_c;
821 
822  // allocate LDS
823  __shared__ char smem_ptr_0[GetSmemSize()];
824 
825  if constexpr(GemmPipeline::DoubleSmemBuffer == true)
826  {
827  __shared__ char smem_ptr_1[GetSmemSize()];
828  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
829  EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
831  {
832  RunGemm2LDS(a_ptr,
833  b_ptr,
834  kargs.ds_ptr,
835  c_ptr,
836  smem_ptr_0,
837  smem_ptr_1,
838  kargs,
839  num_loop,
840  i_m,
841  i_n,
842  i_k);
843  }
844  }
845  else
846  {
847  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
848  EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
850  {
851  RunGemm(
852  a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, num_loop, i_m, i_n, i_k);
853  }
854  }
855  }
856 };
857 
858 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
ConvolutionSpecialization
Definition: convolution_specialization.hpp:11
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
constexpr bool is_same_v
Definition: type.hpp:283
__device__ X atomic_add(X *p_dst, const X &x)
unsigned int uint32_t
Definition: stdint.h:126
The Grouped Convolution kernel device arguments.
Definition: grouped_convolution_backward_weight_kernel.hpp:22
long_index_t group_stride_a
Definition: grouped_convolution_backward_weight_kernel.hpp:289
remove_cvref_t< decltype(ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N())> ABCGridDescs
Definition: grouped_convolution_backward_weight_kernel.hpp:258
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_strides
Definition: grouped_convolution_backward_weight_kernel.hpp:269
remove_cvref_t< decltype(ABCGridDescs{}[number< 1 >{}])> BGridDescNK
Definition: grouped_convolution_backward_weight_kernel.hpp:261
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > wei_g_k_c_xs_lengths
Definition: grouped_convolution_backward_weight_kernel.hpp:266
void * wei_ptr
Definition: grouped_convolution_backward_weight_kernel.hpp:283
long_index_t group_stride_b
Definition: grouped_convolution_backward_weight_kernel.hpp:290
CGridDescMN c_grid_desc_m_n
Definition: grouped_convolution_backward_weight_kernel.hpp:287
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > in_g_n_c_wis_lengths
Definition: grouped_convolution_backward_weight_kernel.hpp:265
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_dilations
Definition: grouped_convolution_backward_weight_kernel.hpp:270
index_t GemmN
Definition: grouped_convolution_backward_weight_kernel.hpp:276
index_t GemmBatch
Definition: grouped_convolution_backward_weight_kernel.hpp:278
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > out_g_n_k_wos_lengths
Definition: grouped_convolution_backward_weight_kernel.hpp:267
remove_cvref_t< decltype(ABCGridDescs{}[number< 0 >{}])> AGridDescMK
Definition: grouped_convolution_backward_weight_kernel.hpp:260
BGridDescNK b_grid_desc_n_k
Definition: grouped_convolution_backward_weight_kernel.hpp:286
CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs &args)
Definition: grouped_convolution_backward_weight_kernel.hpp:37
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_left_pads
Definition: grouped_convolution_backward_weight_kernel.hpp:271
std::array< const void *, NumDTensor > ds_ptr
Definition: grouped_convolution_backward_weight_kernel.hpp:282
index_t GemmM
Definition: grouped_convolution_backward_weight_kernel.hpp:275
remove_cvref_t< decltype(ABCGridDescs{}[number< 2 >{}])> CGridDescMN
Definition: grouped_convolution_backward_weight_kernel.hpp:262
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_right_pads
Definition: grouped_convolution_backward_weight_kernel.hpp:272
index_t GemmK
Definition: grouped_convolution_backward_weight_kernel.hpp:277
const void * in_ptr
Definition: grouped_convolution_backward_weight_kernel.hpp:281
index_t k_batch
Definition: grouped_convolution_backward_weight_kernel.hpp:274
static constexpr index_t NonSpatialDims
Definition: grouped_convolution_backward_weight_kernel.hpp:264
const void * out_ptr
Definition: grouped_convolution_backward_weight_kernel.hpp:280
AGridDescMK a_grid_desc_m_k
Definition: grouped_convolution_backward_weight_kernel.hpp:285
static constexpr index_t NumDTensor
Definition: grouped_convolution_backward_weight_kernel.hpp:27
long_index_t group_stride_c
Definition: grouped_convolution_backward_weight_kernel.hpp:291
The Grouped Conv kernel host arguments.
Definition: grouped_convolution_utils.hpp:19
index_t k_batch
Definition: grouped_convolution_utils.hpp:40
InPtr in_ptr
Definition: grouped_convolution_utils.hpp:36
WeiPtr wei_ptr
Definition: grouped_convolution_utils.hpp:37
OutPtr out_ptr
Definition: grouped_convolution_utils.hpp:39
const std::vector< const void * > ds_ptr
Definition: grouped_convolution_utils.hpp:38
Definition: grouped_convolution_backward_weight_kernel.hpp:409
index_t b_k_split_offset
Definition: grouped_convolution_backward_weight_kernel.hpp:433
index_t splitted_k
Definition: grouped_convolution_backward_weight_kernel.hpp:434
__device__ SplitKBatchOffset(const GroupedConvBwdWeightKernelArgsSpecialized &kargs, const std::size_t k_id=blockIdx.z)
Definition: grouped_convolution_backward_weight_kernel.hpp:410
index_t a_k_split_offset
Definition: grouped_convolution_backward_weight_kernel.hpp:432
The Grouped Convolution Backward Weight kernel template.
Definition: grouped_convolution_backward_weight_kernel.hpp:337
remove_cvref_t< typename EpiloguePipeline::DsLayout > GemmDsLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:353
static constexpr index_t kBlockSize
Definition: grouped_convolution_backward_weight_kernel.hpp:356
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views, const index_t k_batch)
Definition: grouped_convolution_backward_weight_kernel.hpp:619
static constexpr CK_TILE_HOST auto BlockSize()
Definition: grouped_convolution_backward_weight_kernel.hpp:395
remove_cvref_t< typename GroupedConvTraitsType_::OutLayout > OutLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:350
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_convolution_backward_weight_kernel.hpp:341
remove_cvref_t< typename GemmPipeline::ADataType > InDataType
Definition: grouped_convolution_backward_weight_kernel.hpp:358
remove_cvref_t< typename EpiloguePipeline::ODataType > OutDataType
Definition: grouped_convolution_backward_weight_kernel.hpp:362
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: grouped_convolution_backward_weight_kernel.hpp:342
GroupedConvBwdWeightKernelArgs< GroupedConvTraitsType_ > GroupedConvBwdWeightKernelArgsSpecialized
Definition: grouped_convolution_backward_weight_kernel.hpp:365
static CK_TILE_HOST const std::string GetName()
Definition: grouped_convolution_backward_weight_kernel.hpp:381
static CK_TILE_DEVICE void RunGemm2LDS(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const GroupedConvBwdWeightKernelArgsSpecialized &kargs, const index_t num_loop, const index_t block_idx_m, const index_t block_idx_n, const index_t block_idx_k)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_backward_weight_kernel.hpp:762
remove_cvref_t< typename GemmPipeline::CLayout > GemmCLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:346
static constexpr auto I2
Definition: grouped_convolution_backward_weight_kernel.hpp:372
static constexpr CK_TILE_HOST GroupedConvBwdWeightKernelArgsSpecialized MakeKernelArgs(const GroupedConvBwdWeightHostArgs &hostArgs)
Definition: grouped_convolution_backward_weight_kernel.hpp:398
remove_cvref_t< typename GemmPipeline::ALayout > GemmALayout
Definition: grouped_convolution_backward_weight_kernel.hpp:344
static CK_TILE_HOST bool IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_weight_kernel.hpp:451
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: grouped_convolution_backward_weight_kernel.hpp:343
static constexpr ConvolutionSpecialization ConvSpecialization
Definition: grouped_convolution_backward_weight_kernel.hpp:339
remove_cvref_t< typename GroupedConvTraitsType_::WeiLayout > WeiLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:349
static constexpr bool IsSplitKSupported
Definition: grouped_convolution_backward_weight_kernel.hpp:368
static CK_TILE_DEVICE void RunGemm(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, void *smem_ptr_0, const GroupedConvBwdWeightKernelArgsSpecialized &kargs, const index_t num_loop, const index_t block_idx_m, const index_t block_idx_n, const index_t block_idx_k)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_backward_weight_kernel.hpp:712
static constexpr index_t NDimSpatial
Definition: grouped_convolution_backward_weight_kernel.hpp:338
remove_cvref_t< typename GroupedConvTraitsType_::DsLayout > DsLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:351
remove_cvref_t< typename GroupedConvTraitsType_::InLayout > InLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:348
static CK_TILE_HOST auto Preprocess(const GroupedConvBwdWeightKernelArgsSpecialized &kargs, const stream_config &s)
Definition: grouped_convolution_backward_weight_kernel.hpp:437
static constexpr auto I3
Definition: grouped_convolution_backward_weight_kernel.hpp:373
static constexpr auto I0
Definition: grouped_convolution_backward_weight_kernel.hpp:370
CK_TILE_DEVICE void operator()(GroupedConvBwdWeightKernelArgsSpecialized kargs) const
Definition: grouped_convolution_backward_weight_kernel.hpp:797
static constexpr auto I1
Definition: grouped_convolution_backward_weight_kernel.hpp:371
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: grouped_convolution_backward_weight_kernel.hpp:360
remove_cvref_t< typename GemmPipeline::BDataType > WeiDataType
Definition: grouped_convolution_backward_weight_kernel.hpp:359
static constexpr index_t NumDTensor
Definition: grouped_convolution_backward_weight_kernel.hpp:354
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: grouped_convolution_backward_weight_kernel.hpp:403
static CK_TILE_DEVICE auto MakeGemmTensorViews(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, const GroupedConvBwdWeightKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_weight_kernel.hpp:577
remove_cvref_t< typename GemmPipeline::BLayout > GemmBLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:345
static constexpr CK_TILE_HOST auto GridSize(const GroupedConvBwdWeightKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_weight_kernel.hpp:389
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n, const index_t i_k)
Definition: grouped_convolution_backward_weight_kernel.hpp:659
Definition: transform_conv_bwd_weight_to_gemm.hpp:19
CK_TILE_HOST auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N() const
Definition: transform_conv_bwd_weight_to_gemm.hpp:533
Definition: integral_constant.hpp:13
std::vector< ck_tile::long_index_t > input_spatial_lengths_
Definition: convolution_parameter.hpp:130
ck_tile::long_index_t K_
Definition: convolution_parameter.hpp:126
std::vector< ck_tile::long_index_t > output_spatial_lengths_
Definition: convolution_parameter.hpp:131
std::vector< ck_tile::long_index_t > input_right_pads_
Definition: convolution_parameter.hpp:137
ck_tile::long_index_t G_
Definition: convolution_parameter.hpp:124
std::vector< ck_tile::long_index_t > conv_filter_strides_
Definition: convolution_parameter.hpp:133
std::vector< ck_tile::long_index_t > filter_spatial_lengths_
Definition: convolution_parameter.hpp:129
ck_tile::long_index_t C_
Definition: convolution_parameter.hpp:127
ck_tile::long_index_t N_
Definition: convolution_parameter.hpp:125
std::vector< ck_tile::long_index_t > input_left_pads_
Definition: convolution_parameter.hpp:136
std::vector< ck_tile::long_index_t > conv_filter_dilations_
Definition: convolution_parameter.hpp:134
Definition: type_traits.hpp:115
Definition: sequence.hpp:49
Definition: stream_config.hpp:30
hipStream_t stream_id_
Definition: stream_config.hpp:31
#define CK_TILE_ENV(name)
Definition: env.hpp:145