/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp Source File
grouped_convolution_forward_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 #include "ck_tile/host/concat.hpp"
16 
17 namespace ck_tile {
18 
20 template <typename GroupedConvTraitsType_>
22 {
23 
25  TransformConvFwdToGemm<GroupedConvTraitsType_::NDimSpatial,
26  GroupedConvTraitsType_::ConvSpecialization>;
27  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
28 
29  template <
30  typename InLay = typename GroupedConvTraitsType_::InLayout,
31  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
32  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
33  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NWGC> &&
34  std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
35  std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
36  bool>::type = false>
38  {
39  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
40  static_cast<index_t>(args.N_),
41  static_cast<index_t>(args.C_),
42  static_cast<index_t>(args.input_spatial_lengths_[0])};
43  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
44  static_cast<index_t>(args.K_),
45  static_cast<index_t>(args.C_),
46  static_cast<index_t>(args.filter_spatial_lengths_[0])};
47  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
48  static_cast<index_t>(args.N_),
49  static_cast<index_t>(args.K_),
50  static_cast<index_t>(args.output_spatial_lengths_[0])};
51 
52  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0])};
53  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0])};
54  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0])};
55  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0])};
56 
57  k_batch = args.k_batch;
58 
59  GemmM = args.N_ * args.output_spatial_lengths_[0];
60  GemmN = args.K_;
61  GemmK = args.C_ * args.filter_spatial_lengths_[0];
62  GemmBatch = args.G_;
63 
64  in_ptr = args.in_ptr;
65  wei_ptr = args.wei_ptr;
66  for(index_t d = 0; d < NumDTensor; d++)
67  {
68  ds_ptr[d] = args.ds_ptr[d];
69  }
70  out_ptr = args.out_ptr;
71 
72  ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
79 
81  conv_to_gemm_transformer
82  .template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
84  conv_to_gemm_transformer
85  .template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
87  conv_to_gemm_transformer
88  .template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
89 
90  group_stride_a = args.C_;
91  group_stride_b = args.K_ * args.C_ *
92  std::accumulate(args.filter_spatial_lengths_.begin(),
93  args.filter_spatial_lengths_.end(),
94  1,
95  std::multiplies<index_t>());
96  group_stride_c = args.K_;
97  }
98 
99  template <
100  typename InLay = typename GroupedConvTraitsType_::InLayout,
101  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
102  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
103  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NHWGC> &&
104  std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
105  std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
106  bool>::type = false>
108  {
109  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
110  static_cast<index_t>(args.N_),
111  static_cast<index_t>(args.C_),
112  static_cast<index_t>(args.input_spatial_lengths_[0]),
113  static_cast<index_t>(args.input_spatial_lengths_[1])};
114  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
115  static_cast<index_t>(args.K_),
116  static_cast<index_t>(args.C_),
117  static_cast<index_t>(args.filter_spatial_lengths_[0]),
118  static_cast<index_t>(args.filter_spatial_lengths_[1])};
119  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
120  static_cast<index_t>(args.N_),
121  static_cast<index_t>(args.K_),
122  static_cast<index_t>(args.output_spatial_lengths_[0]),
123  static_cast<index_t>(args.output_spatial_lengths_[1])};
124 
125  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
126  static_cast<index_t>(args.conv_filter_strides_[1])};
127  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
128  static_cast<index_t>(args.conv_filter_dilations_[1])};
129  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
130  static_cast<index_t>(args.input_left_pads_[1])};
131  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
132  static_cast<index_t>(args.input_right_pads_[1])};
133 
134  k_batch = args.k_batch;
135 
136  GemmM = args.N_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
137  GemmN = args.K_;
138  GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1];
139  GemmBatch = args.G_;
140 
141  in_ptr = args.in_ptr;
142  wei_ptr = args.wei_ptr;
143  for(index_t d = 0; d < NumDTensor; d++)
144  {
145  ds_ptr[d] = args.ds_ptr[d];
146  }
147  out_ptr = args.out_ptr;
148 
149  ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
156 
158  conv_to_gemm_transformer
159  .template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
161  conv_to_gemm_transformer
162  .template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
164  conv_to_gemm_transformer
165  .template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
166 
167  group_stride_a = args.C_;
168  group_stride_b = args.K_ * args.C_ *
169  std::accumulate(args.filter_spatial_lengths_.begin(),
170  args.filter_spatial_lengths_.end(),
171  1,
172  std::multiplies<index_t>());
173  group_stride_c = args.K_;
174  }
175 
176  template <
177  typename InLay = typename GroupedConvTraitsType_::InLayout,
178  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
179  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
180  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NDHWGC> &&
181  std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
182  std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
183  bool>::type = false>
185  {
186  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
187  static_cast<index_t>(args.N_),
188  static_cast<index_t>(args.C_),
189  static_cast<index_t>(args.input_spatial_lengths_[0]),
190  static_cast<index_t>(args.input_spatial_lengths_[1]),
191  static_cast<index_t>(args.input_spatial_lengths_[2])};
192  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
193  static_cast<index_t>(args.K_),
194  static_cast<index_t>(args.C_),
195  static_cast<index_t>(args.filter_spatial_lengths_[0]),
196  static_cast<index_t>(args.filter_spatial_lengths_[1]),
197  static_cast<index_t>(args.filter_spatial_lengths_[2])};
198  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
199  static_cast<index_t>(args.N_),
200  static_cast<index_t>(args.K_),
201  static_cast<index_t>(args.output_spatial_lengths_[0]),
202  static_cast<index_t>(args.output_spatial_lengths_[1]),
203  static_cast<index_t>(args.output_spatial_lengths_[2])};
204 
205  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
206  static_cast<index_t>(args.conv_filter_strides_[1]),
207  static_cast<index_t>(args.conv_filter_strides_[2])};
208  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
209  static_cast<index_t>(args.conv_filter_dilations_[1]),
210  static_cast<index_t>(args.conv_filter_dilations_[2])};
211  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
212  static_cast<index_t>(args.input_left_pads_[1]),
213  static_cast<index_t>(args.input_left_pads_[2])};
214  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
215  static_cast<index_t>(args.input_right_pads_[1]),
216  static_cast<index_t>(args.input_right_pads_[2])};
217 
218  k_batch = args.k_batch;
219 
220  GemmM = args.N_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1] *
221  args.output_spatial_lengths_[2];
222  GemmN = args.K_;
223  GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1] *
224  args.filter_spatial_lengths_[2];
225  GemmBatch = args.G_;
226 
227  in_ptr = args.in_ptr;
228  wei_ptr = args.wei_ptr;
229  for(index_t d = 0; d < NumDTensor; d++)
230  {
231  ds_ptr[d] = args.ds_ptr[d];
232  }
233  out_ptr = args.out_ptr;
234 
235  ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
242 
244  conv_to_gemm_transformer
245  .template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
247  conv_to_gemm_transformer
248  .template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
250  conv_to_gemm_transformer
251  .template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
252 
253  group_stride_a = args.C_;
254  group_stride_b = args.K_ * args.C_ *
255  std::accumulate(args.filter_spatial_lengths_.begin(),
256  args.filter_spatial_lengths_.end(),
257  1,
258  std::multiplies<index_t>());
259  group_stride_c = args.K_;
260  }
261 
263  decltype(ConvToGemmFwdTransformer{}
264  .template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>())>;
266  decltype(ConvToGemmFwdTransformer{}
267  .template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>())>;
269  decltype(ConvToGemmFwdTransformer{}
270  .template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>())>;
271 
272  static constexpr index_t NonSpatialDims = 3;
276 
281 
287 
288  const void* in_ptr;
289  const void* wei_ptr;
290  std::array<const void*, NumDTensor> ds_ptr;
291  void* out_ptr;
292 
296 
300 };
301 
340 template <typename GroupedConvTraitsType_,
341  typename TilePartitioner_,
342  typename GemmPipeline_,
343  typename EpiloguePipeline_>
345 {
346  static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial;
348  GroupedConvTraitsType_::ConvSpecialization;
355 
360 
362  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
363 
364  static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
365 
369  // Below type is actually accumulation data type - the output of block GEMM.
371 
373 
374  // TODO: Enable this
375  static constexpr bool IsSplitKSupported = false;
376 
377  static constexpr auto I0 = number<0>();
378  static constexpr auto I1 = number<1>();
379  static constexpr auto I2 = number<2>();
380  static constexpr auto I3 = number<3>();
381 
382  static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
383  "Not supported!");
384  static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>, "Not supported!");
385  static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
386  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
387 
388  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
389  {
390  // clang-format off
391  return concat('_', "grouped_convolution_forward", gemm_prec_str<InDataType, WeiDataType>, GemmPipeline::GetName());
392  // clang-format on
393  }
394 
395  CK_TILE_HOST static constexpr auto GridSize(const GroupedConvFwdKernelArgsSpecialized& kargs)
396  {
397  return dim3(
398  TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch);
399  }
400 
401  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
402 
405  {
406  return GroupedConvFwdKernelArgsSpecialized(hostArgs);
407  }
408 
410  {
411  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
412  }
413 
415  {
416  if constexpr((EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
419  {
420  if(kargs.k_batch != 1)
421  {
422  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
423  {
424  CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
425  }
426  return false;
427  }
428  }
429 
430  const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
431  const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
432 
433  // check ConvolutionSpecialization
435  {
436  // check if it's 1x1, stride=1 conv
437  for(index_t i = 0; i < NDimSpatial; ++i)
438  {
439  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
440  const index_t ConvStride = kargs.conv_filter_strides[i];
441  const index_t LeftPad = kargs.input_left_pads[i];
442  const index_t RightPad = kargs.input_right_pads[i];
443 
444  if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
445  {
446  return false;
447  }
448  }
449  }
451  {
452  // check if it's 1x1 conv
453  for(index_t i = 0; i < NDimSpatial; ++i)
454  {
455  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
456  const index_t LeftPad = kargs.input_left_pads[i];
457  const index_t RightPad = kargs.input_right_pads[i];
458 
459  if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
460  {
461  return false;
462  }
463  }
464  }
466  {
467  if(ConvC != 1)
468  {
469  return false;
470  }
471  for(index_t i = 0; i < NDimSpatial; ++i)
472  {
473  const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3];
474 
475  if(filter_spatial_dim != I3)
476  {
477  return false;
478  }
479  }
480  }
481 
482  namespace ctc = tensor_layout::convolution;
483 
484  if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
485  std::is_same_v<InLayout, ctc::NDHWGC>)
486  {
487  // Check access per C
488  if(ConvC % GemmPipeline::GetVectorSizeA() != 0)
489  {
490  CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
491  return false;
492  }
493  }
494  else
495  {
496  CK_TILE_ERROR("Not supported input layout!");
497  return false;
498  }
499 
500  // check vector access of B
501  // FIXME: layout
502  if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
503  std::is_same_v<WeiLayout, ctc::GKYXC> ||
504  std::is_same_v<WeiLayout, ctc::GKZYXC>)
505  {
506  if(ConvC % GemmPipeline::GetVectorSizeB() != 0)
507  {
508  CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
509  return false;
510  }
511  }
512  else
513  {
514  CK_TILE_ERROR("Not supported weight layout!");
515  return false;
516  }
517 
518  // check vector access of E
519  if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
520  std::is_same_v<OutLayout, ctc::NHWGK> ||
521  std::is_same_v<OutLayout, ctc::NDHWGK>)
522  {
523  if(ConvK % EpiloguePipeline::GetVectorSizeC() != 0)
524  {
525  CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!");
526  return false;
527  }
528  }
529  else
530  {
531  CK_TILE_ERROR("Not supported output layout!");
532  return false;
533  }
534 
535  return true;
536  }
537 
538  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
539  CK_TILE_DEVICE static auto
541  const WeiDataType* b_ptr,
542  const std::array<const void*, NumDTensor>& ds_ptr,
543  OutDataType* c_ptr,
545  {
546  static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
547  static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
548  const auto& a_tensor_view = [&]() {
549  return make_tensor_view<address_space_enum::global>(a_ptr, kargs.a_grid_desc_m_k);
550  }();
551 
552  const auto& b_tensor_view = [&]() {
553  return make_tensor_view<address_space_enum::global>(b_ptr, kargs.b_grid_desc_n_k);
554  }();
555 
556  // TODO: enable vector write for C in ColMajor
557  const auto& c_tensor_view = [&]() {
558  return make_tensor_view<address_space_enum::global>(c_ptr, kargs.c_grid_desc_m_n);
559  }();
560 
561  const auto& ds_tensor_view = generate_tuple(
562  [&](auto i) {
563  static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
564  "Not supported!");
565  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
566  "Not supported!");
567  static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, OutDataType>,
568  "Not supported!");
569 
570  return make_tensor_view<address_space_enum::global>(
571  static_cast<OutDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
572  },
574 
575  return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
576  }
577 
578  template <typename TensorView>
579  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
580  {
581  const auto& a_pad_view = [&]() {
582  const auto& a_tensor_view = views.at(I0);
583  return pad_tensor_view(a_tensor_view,
587  }();
588 
589  const auto& b_pad_view = [&]() {
590  const auto& b_tensor_view = views.at(I1);
591  return pad_tensor_view(b_tensor_view,
595  }();
596 
597  const auto& ds_tensor_view = views.at(I2);
598  const auto& ds_pad_view = generate_tuple(
599  [&](auto i) {
600  return pad_tensor_view(ds_tensor_view[i],
604  },
606 
607  const auto& c_pad_view = [&]() {
608  const auto& c_tensor_view = views.at(I3);
609  return pad_tensor_view(c_tensor_view,
613  }();
614 
615  return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
616  }
617 
618  template <typename PadView>
619  CK_TILE_DEVICE static auto
620  MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
621  {
622  const auto& a_pad_view = views.at(I0);
623  const auto& b_pad_view = views.at(I1);
624  const auto& ds_pad_view = views.at(I2);
625  const auto& c_pad_view = views.at(I3);
626 
627  const auto& a_block_window = [&]() {
628  return make_tile_window(a_pad_view,
631  {i_m, 0});
632  }();
633 
634  const auto& b_block_window = [&]() {
635  return make_tile_window(b_pad_view,
638  {i_n, 0});
639  }();
640 
641  const auto ds_block_window = generate_tuple(
642  [&](auto i) {
643  return make_tile_window(ds_pad_view[i],
646  {i_m, i_n});
647  },
649 
650  auto c_block_window = make_tile_window(
651  c_pad_view,
653  {i_m, i_n});
654 
655  return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
656  }
657 
670  CK_TILE_DEVICE static void RunGemm(const InDataType* a_ptr,
671  const WeiDataType* b_ptr,
672  const std::array<const void*, NumDTensor>& ds_ptr,
673  OutDataType* c_ptr,
674  void* smem_ptr_0,
676  const index_t block_idx_m,
677  const index_t block_idx_n)
678  {
679  // Create Gemm tensor views, pad views and tile windows
680  const auto& gemm_tensor_views_tuple =
681  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
682  a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
683 
684  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
685  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
686 
687  const index_t num_loop =
688  __builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(kargs.GemmK));
689 
690  // Run GEMM cooperatively by whole workgroup.
691  const auto& a_block_window = gemm_tile_windows.at(I0);
692  const auto& b_block_window = gemm_tile_windows.at(I1);
693  const auto& d_block_window = gemm_tile_windows.at(I2);
694 
695  const auto& c_block_tile = GemmPipeline{}.template operator()(
696  a_block_window, b_block_window, num_loop, smem_ptr_0);
697 
698  // Run Epilogue Pipeline
699  auto& c_block_window = gemm_tile_windows.at(I3);
700 
701  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
702  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
703  }
704 
720  CK_TILE_DEVICE static void RunGemm2LDS(const InDataType* a_ptr,
721  const WeiDataType* b_ptr,
722  const std::array<const void*, NumDTensor>& ds_ptr,
723  OutDataType* c_ptr,
724  void* __restrict__ smem_ptr_0,
725  void* __restrict__ smem_ptr_1,
727  const index_t block_idx_m,
728  const index_t block_idx_n)
729  {
730  // Create Gemm tensor views, pad views and tile windows
731  const auto& gemm_tensor_views_tuple =
732  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
733  a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
734  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
735  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
736 
737  const index_t num_loop =
738  __builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(kargs.GemmK));
739 
740  // Run GEMM cooperatively by whole workgroup.
741  const auto& a_block_window = gemm_tile_windows.at(I0);
742  const auto& b_block_window = gemm_tile_windows.at(I1);
743  const auto& d_block_window = gemm_tile_windows.at(I2);
744 
745  const auto& c_block_tile = GemmPipeline{}.template operator()(
746  a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
747 
748  // Run Epilogue Pipeline
749  auto& c_block_window = gemm_tile_windows.at(I3);
750 
751  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
752  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
753  }
754 
756  {
757  const auto blockIdX = __builtin_amdgcn_readfirstlane(blockIdx.x);
758  const auto [iM, iN] =
759  TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX);
760  const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
761  const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
762 
763  const auto blockIdY = __builtin_amdgcn_readfirstlane(blockIdx.y);
764  const auto group_offset_a = __builtin_amdgcn_readfirstlane(kargs.group_stride_a * blockIdY);
765  const auto group_offset_b = __builtin_amdgcn_readfirstlane(kargs.group_stride_b * blockIdY);
766  const auto group_offset_c = __builtin_amdgcn_readfirstlane(kargs.group_stride_c * blockIdY);
767 
768  // options
769  const InDataType* a_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_a;
770  const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) + group_offset_b;
771  OutDataType* c_ptr = static_cast<OutDataType*>(kargs.out_ptr) + group_offset_c;
772 
773  // allocate LDS
774  __shared__ char smem_ptr_0[GetSmemSize()];
775 
776  if constexpr(GemmPipeline::DoubleSmemBuffer == true)
777  {
778  __shared__ char smem_ptr_1[GetSmemSize()];
779  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
780  EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
782  {
783  RunGemm2LDS(
784  a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, smem_ptr_1, kargs, i_m, i_n);
785  }
786  }
787  else
788  {
789  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
790  EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
792  {
793  RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n);
794  }
795  }
796  }
797 };
798 
799 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
ConvolutionSpecialization
Definition: convolution_specialization.hpp:11
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
constexpr bool is_same_v
Definition: type.hpp:283
__device__ X atomic_add(X *p_dst, const X &x)
The Grouped Convolution kernel device arguments.
Definition: grouped_convolution_forward_kernel.hpp:22
index_t GemmM
Definition: grouped_convolution_forward_kernel.hpp:283
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_strides
Definition: grouped_convolution_forward_kernel.hpp:277
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_left_pads
Definition: grouped_convolution_forward_kernel.hpp:279
long_index_t group_stride_c
Definition: grouped_convolution_forward_kernel.hpp:299
BGridDescNK b_grid_desc_n_k
Definition: grouped_convolution_forward_kernel.hpp:294
static constexpr index_t NonSpatialDims
Definition: grouped_convolution_forward_kernel.hpp:272
static constexpr index_t NumDTensor
Definition: grouped_convolution_forward_kernel.hpp:27
remove_cvref_t< decltype(ConvToGemmFwdTransformer{} .template MakeADescriptor_M_K< typename GroupedConvTraitsType_::InLayout >())> AGridDescMK
Definition: grouped_convolution_forward_kernel.hpp:264
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_dilations
Definition: grouped_convolution_forward_kernel.hpp:278
remove_cvref_t< decltype(ConvToGemmFwdTransformer{} .template MakeBDescriptor_N_K< typename GroupedConvTraitsType_::WeiLayout >())> BGridDescNK
Definition: grouped_convolution_forward_kernel.hpp:267
const void * in_ptr
Definition: grouped_convolution_forward_kernel.hpp:288
long_index_t group_stride_b
Definition: grouped_convolution_forward_kernel.hpp:298
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > in_g_n_c_wis_lengths
Definition: grouped_convolution_forward_kernel.hpp:273
const void * wei_ptr
Definition: grouped_convolution_forward_kernel.hpp:289
index_t GemmK
Definition: grouped_convolution_forward_kernel.hpp:285
index_t GemmN
Definition: grouped_convolution_forward_kernel.hpp:284
remove_cvref_t< decltype(ConvToGemmFwdTransformer{} .template MakeCDescriptor_M_N< typename GroupedConvTraitsType_::OutLayout >())> CGridDescMN
Definition: grouped_convolution_forward_kernel.hpp:270
index_t k_batch
Definition: grouped_convolution_forward_kernel.hpp:282
CGridDescMN c_grid_desc_m_n
Definition: grouped_convolution_forward_kernel.hpp:295
std::array< const void *, NumDTensor > ds_ptr
Definition: grouped_convolution_forward_kernel.hpp:290
void * out_ptr
Definition: grouped_convolution_forward_kernel.hpp:291
AGridDescMK a_grid_desc_m_k
Definition: grouped_convolution_forward_kernel.hpp:293
long_index_t group_stride_a
Definition: grouped_convolution_forward_kernel.hpp:297
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > out_g_n_k_wos_lengths
Definition: grouped_convolution_forward_kernel.hpp:275
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs &args)
Definition: grouped_convolution_forward_kernel.hpp:37
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > wei_g_k_c_xs_lengths
Definition: grouped_convolution_forward_kernel.hpp:274
index_t GemmBatch
Definition: grouped_convolution_forward_kernel.hpp:286
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_right_pads
Definition: grouped_convolution_forward_kernel.hpp:280
The Grouped Conv kernel host arguments.
Definition: grouped_convolution_utils.hpp:19
index_t k_batch
Definition: grouped_convolution_utils.hpp:40
InPtr in_ptr
Definition: grouped_convolution_utils.hpp:36
WeiPtr wei_ptr
Definition: grouped_convolution_utils.hpp:37
OutPtr out_ptr
Definition: grouped_convolution_utils.hpp:39
const std::vector< const void * > ds_ptr
Definition: grouped_convolution_utils.hpp:38
The Grouped Convolution Forward kernel template.
Definition: grouped_convolution_forward_kernel.hpp:345
remove_cvref_t< typename EpiloguePipeline::DsLayout > GemmDsLayout
Definition: grouped_convolution_forward_kernel.hpp:361
static constexpr CK_TILE_HOST auto GridSize(const GroupedConvFwdKernelArgsSpecialized &kargs)
Definition: grouped_convolution_forward_kernel.hpp:395
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: grouped_convolution_forward_kernel.hpp:350
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_convolution_forward_kernel.hpp:349
static constexpr auto I1
Definition: grouped_convolution_forward_kernel.hpp:378
static constexpr auto I2
Definition: grouped_convolution_forward_kernel.hpp:379
remove_cvref_t< typename GroupedConvTraitsType_::OutLayout > OutLayout
Definition: grouped_convolution_forward_kernel.hpp:358
static constexpr CK_TILE_HOST GroupedConvFwdKernelArgsSpecialized MakeKernelArgs(const GroupedConvFwdHostArgs &hostArgs)
Definition: grouped_convolution_forward_kernel.hpp:404
static constexpr auto I0
Definition: grouped_convolution_forward_kernel.hpp:377
CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized kargs) const
Definition: grouped_convolution_forward_kernel.hpp:755
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: grouped_convolution_forward_kernel.hpp:409
remove_cvref_t< typename GroupedConvTraitsType_::WeiLayout > WeiLayout
Definition: grouped_convolution_forward_kernel.hpp:357
remove_cvref_t< typename EpiloguePipeline::ODataType > OutDataType
Definition: grouped_convolution_forward_kernel.hpp:370
static CK_TILE_DEVICE auto MakeGemmTensorViews(const InDataType *a_ptr, const WeiDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, OutDataType *c_ptr, const GroupedConvFwdKernelArgsSpecialized &kargs)
Definition: grouped_convolution_forward_kernel.hpp:540
GroupedConvFwdKernelArgs< GroupedConvTraitsType_ > GroupedConvFwdKernelArgsSpecialized
Definition: grouped_convolution_forward_kernel.hpp:372
remove_cvref_t< typename GroupedConvTraitsType_::DsLayout > DsLayout
Definition: grouped_convolution_forward_kernel.hpp:359
static constexpr index_t kBlockSize
Definition: grouped_convolution_forward_kernel.hpp:364
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: grouped_convolution_forward_kernel.hpp:368
static CK_TILE_DEVICE void RunGemm2LDS(const InDataType *a_ptr, const WeiDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, OutDataType *c_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const GroupedConvFwdKernelArgsSpecialized &kargs, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_forward_kernel.hpp:720
remove_cvref_t< typename GemmPipeline::BLayout > GemmBLayout
Definition: grouped_convolution_forward_kernel.hpp:353
static constexpr index_t NDimSpatial
Definition: grouped_convolution_forward_kernel.hpp:346
static constexpr auto I3
Definition: grouped_convolution_forward_kernel.hpp:380
static CK_TILE_HOST const std::string GetName()
Definition: grouped_convolution_forward_kernel.hpp:388
static CK_TILE_HOST bool IsSupportedArgument(const GroupedConvFwdKernelArgsSpecialized &kargs)
Definition: grouped_convolution_forward_kernel.hpp:414
remove_cvref_t< typename GemmPipeline::BDataType > WeiDataType
Definition: grouped_convolution_forward_kernel.hpp:367
static constexpr index_t NumDTensor
Definition: grouped_convolution_forward_kernel.hpp:362
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: grouped_convolution_forward_kernel.hpp:620
remove_cvref_t< typename GemmPipeline::ALayout > GemmALayout
Definition: grouped_convolution_forward_kernel.hpp:352
static constexpr bool IsSplitKSupported
Definition: grouped_convolution_forward_kernel.hpp:375
remove_cvref_t< typename GemmPipeline::CLayout > GemmCLayout
Definition: grouped_convolution_forward_kernel.hpp:354
static CK_TILE_DEVICE void RunGemm(const InDataType *a_ptr, const WeiDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, OutDataType *c_ptr, void *smem_ptr_0, const GroupedConvFwdKernelArgsSpecialized &kargs, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_forward_kernel.hpp:670
static constexpr CK_TILE_HOST auto BlockSize()
Definition: grouped_convolution_forward_kernel.hpp:401
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: grouped_convolution_forward_kernel.hpp:579
remove_cvref_t< typename GroupedConvTraitsType_::InLayout > InLayout
Definition: grouped_convolution_forward_kernel.hpp:356
static constexpr ConvolutionSpecialization ConvSpecialization
Definition: grouped_convolution_forward_kernel.hpp:347
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: grouped_convolution_forward_kernel.hpp:351
remove_cvref_t< typename GemmPipeline::ADataType > InDataType
Definition: grouped_convolution_forward_kernel.hpp:366
Definition: transform_conv_fwd_to_gemm.hpp:19
Definition: integral_constant.hpp:13
std::vector< ck_tile::long_index_t > input_spatial_lengths_
Definition: convolution_parameter.hpp:130
ck_tile::long_index_t K_
Definition: convolution_parameter.hpp:126
std::vector< ck_tile::long_index_t > output_spatial_lengths_
Definition: convolution_parameter.hpp:131
std::vector< ck_tile::long_index_t > input_right_pads_
Definition: convolution_parameter.hpp:137
ck_tile::long_index_t G_
Definition: convolution_parameter.hpp:124
std::vector< ck_tile::long_index_t > conv_filter_strides_
Definition: convolution_parameter.hpp:133
std::vector< ck_tile::long_index_t > filter_spatial_lengths_
Definition: convolution_parameter.hpp:129
ck_tile::long_index_t C_
Definition: convolution_parameter.hpp:127
ck_tile::long_index_t N_
Definition: convolution_parameter.hpp:125
std::vector< ck_tile::long_index_t > input_left_pads_
Definition: convolution_parameter.hpp:136
std::vector< ck_tile::long_index_t > conv_filter_dilations_
Definition: convolution_parameter.hpp:134
Definition: type_traits.hpp:115
Definition: sequence.hpp:49
#define CK_TILE_ENV(name)
Definition: env.hpp:145