/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp Source File
grouped_convolution_forward_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 #include "ck_tile/host/concat.hpp"
16 
17 namespace ck_tile {
18 
20 template <typename GroupedConvTraitsType_>
22 {
23 
25  TransformConvFwdToGemm<GroupedConvTraitsType_::NDimSpatial,
26  GroupedConvTraitsType_::ConvSpecialization,
27  GroupedConvTraitsType_::VectorSizeA,
28  GroupedConvTraitsType_::VectorSizeB,
29  GroupedConvTraitsType_::VectorSizeC,
30  true>; // Split N enabled
31  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
32 
33  template <
34  typename InLay = typename GroupedConvTraitsType_::InLayout,
35  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
36  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
37  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NWGC> &&
38  std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
39  std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
40  bool>::type = false>
42  {
43  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
44  static_cast<index_t>(args.N_),
45  static_cast<index_t>(args.C_),
46  static_cast<index_t>(args.input_spatial_lengths_[0])};
47  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
48  static_cast<index_t>(args.K_),
49  static_cast<index_t>(args.C_),
50  static_cast<index_t>(args.filter_spatial_lengths_[0])};
51  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
52  static_cast<index_t>(args.N_),
53  static_cast<index_t>(args.K_),
54  static_cast<index_t>(args.output_spatial_lengths_[0])};
55 
56  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0])};
57  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0])};
58  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0])};
59  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0])};
60 
61  k_batch = args.k_batch;
62 
63  // GemmM will be set after Split-N calculation
64  GemmN = args.K_;
65  GemmK = args.C_ * args.filter_spatial_lengths_[0];
66  GemmBatch = args.G_;
67 
68  in_ptr = args.in_ptr;
69  wei_ptr = args.wei_ptr;
70  for(index_t d = 0; d < NumDTensor; d++)
71  {
72  ds_ptr[d] = args.ds_ptr[d];
73  }
74  out_ptr = args.out_ptr;
75 
76  ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
83 
85  conv_to_gemm_transformer
86  .template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
88  conv_to_gemm_transformer
89  .template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
91  conv_to_gemm_transformer
92  .template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
93 
94  group_stride_a = args.C_;
95  group_stride_b = args.K_ * args.C_ *
96  std::accumulate(args.filter_spatial_lengths_.begin(),
97  args.filter_spatial_lengths_.end(),
98  1,
99  std::multiplies<index_t>());
100  group_stride_c = args.K_;
101 
102  // Initialize Split-N support fields for 1D convolution (NWGC layout)
103  // Get the actual split N from transformer
104  n_per_split = conv_to_gemm_transformer.GetN();
105  original_n = conv_to_gemm_transformer.GetOriginalN();
107 
108  // Calculate batch strides for NWGC layout
111 
112  // Update GemmM to use split N (not original N)
114  }
115 
116  template <
117  typename InLay = typename GroupedConvTraitsType_::InLayout,
118  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
119  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
120  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NHWGC> &&
121  std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
122  std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
123  bool>::type = false>
125  {
126  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
127  static_cast<index_t>(args.N_),
128  static_cast<index_t>(args.C_),
129  static_cast<index_t>(args.input_spatial_lengths_[0]),
130  static_cast<index_t>(args.input_spatial_lengths_[1])};
131  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
132  static_cast<index_t>(args.K_),
133  static_cast<index_t>(args.C_),
134  static_cast<index_t>(args.filter_spatial_lengths_[0]),
135  static_cast<index_t>(args.filter_spatial_lengths_[1])};
136  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
137  static_cast<index_t>(args.N_),
138  static_cast<index_t>(args.K_),
139  static_cast<index_t>(args.output_spatial_lengths_[0]),
140  static_cast<index_t>(args.output_spatial_lengths_[1])};
141 
142  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
143  static_cast<index_t>(args.conv_filter_strides_[1])};
144  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
145  static_cast<index_t>(args.conv_filter_dilations_[1])};
146  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
147  static_cast<index_t>(args.input_left_pads_[1])};
148  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
149  static_cast<index_t>(args.input_right_pads_[1])};
150 
151  k_batch = args.k_batch;
152 
153  // Note: GemmM will be set after Split-N calculation
154  GemmN = args.K_;
155  GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1];
156  GemmBatch = args.G_;
157 
158  in_ptr = args.in_ptr;
159  wei_ptr = args.wei_ptr;
160  for(index_t d = 0; d < NumDTensor; d++)
161  {
162  ds_ptr[d] = args.ds_ptr[d];
163  }
164  out_ptr = args.out_ptr;
165 
166  ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
173 
175  conv_to_gemm_transformer
176  .template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
178  conv_to_gemm_transformer
179  .template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
181  conv_to_gemm_transformer
182  .template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
183 
184  group_stride_a = args.C_;
185  group_stride_b = args.K_ * args.C_ *
186  std::accumulate(args.filter_spatial_lengths_.begin(),
187  args.filter_spatial_lengths_.end(),
188  1,
189  std::multiplies<index_t>());
190  group_stride_c = args.K_;
191 
192  // Initialize Split-N support fields for 2D convolution (NHWGC layout)
193  // Get the actual split N from transformer
194  n_per_split = conv_to_gemm_transformer.GetN();
195  original_n = conv_to_gemm_transformer.GetOriginalN();
197 
198  // Calculate batch strides for NHWGC layout
200  args.C_ * args.input_spatial_lengths_[0] * args.input_spatial_lengths_[1];
202  args.K_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
203 
204  // Update GemmM to use split N (not original N)
206  }
207 
208  template <
209  typename InLay = typename GroupedConvTraitsType_::InLayout,
210  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
211  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
212  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NDHWGC> &&
213  std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
214  std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
215  bool>::type = false>
217  {
218  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
219  static_cast<index_t>(args.N_),
220  static_cast<index_t>(args.C_),
221  static_cast<index_t>(args.input_spatial_lengths_[0]),
222  static_cast<index_t>(args.input_spatial_lengths_[1]),
223  static_cast<index_t>(args.input_spatial_lengths_[2])};
224  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
225  static_cast<index_t>(args.K_),
226  static_cast<index_t>(args.C_),
227  static_cast<index_t>(args.filter_spatial_lengths_[0]),
228  static_cast<index_t>(args.filter_spatial_lengths_[1]),
229  static_cast<index_t>(args.filter_spatial_lengths_[2])};
230  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
231  static_cast<index_t>(args.N_),
232  static_cast<index_t>(args.K_),
233  static_cast<index_t>(args.output_spatial_lengths_[0]),
234  static_cast<index_t>(args.output_spatial_lengths_[1]),
235  static_cast<index_t>(args.output_spatial_lengths_[2])};
236 
237  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
238  static_cast<index_t>(args.conv_filter_strides_[1]),
239  static_cast<index_t>(args.conv_filter_strides_[2])};
240  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
241  static_cast<index_t>(args.conv_filter_dilations_[1]),
242  static_cast<index_t>(args.conv_filter_dilations_[2])};
243  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
244  static_cast<index_t>(args.input_left_pads_[1]),
245  static_cast<index_t>(args.input_left_pads_[2])};
246  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
247  static_cast<index_t>(args.input_right_pads_[1]),
248  static_cast<index_t>(args.input_right_pads_[2])};
249 
250  k_batch = args.k_batch;
251 
252  // Note: GemmM will be set after Split-N calculation
253  GemmN = args.K_;
254  GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1] *
255  args.filter_spatial_lengths_[2];
256  GemmBatch = args.G_;
257 
258  in_ptr = args.in_ptr;
259  wei_ptr = args.wei_ptr;
260  for(index_t d = 0; d < NumDTensor; d++)
261  {
262  ds_ptr[d] = args.ds_ptr[d];
263  }
264  out_ptr = args.out_ptr;
265 
266  ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
273 
275  conv_to_gemm_transformer
276  .template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
278  conv_to_gemm_transformer
279  .template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
281  conv_to_gemm_transformer
282  .template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
283 
284  group_stride_a = args.C_;
285  group_stride_b = args.K_ * args.C_ *
286  std::accumulate(args.filter_spatial_lengths_.begin(),
287  args.filter_spatial_lengths_.end(),
288  1,
289  std::multiplies<index_t>());
290  group_stride_c = args.K_;
291 
292  // Initialize Split-N support fields for 3D convolution (NDHWGC layout)
293  // Get the actual split N from transformer
294  n_per_split = conv_to_gemm_transformer.GetN();
295  original_n = conv_to_gemm_transformer.GetOriginalN();
297 
298  // Calculate batch strides for NDHWGC layout
299  input_batch_stride = args.C_ * args.input_spatial_lengths_[0] *
303 
304  // Update GemmM to use split N (not original N)
306  args.output_spatial_lengths_[2];
307  }
308 
310  decltype(ConvToGemmFwdTransformer{}
311  .template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>())>;
313  decltype(ConvToGemmFwdTransformer{}
314  .template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>())>;
316  decltype(ConvToGemmFwdTransformer{}
317  .template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>())>;
318 
319  static constexpr index_t NonSpatialDims = 3;
323 
328 
334 
335  const void* in_ptr;
336  const void* wei_ptr;
337  std::array<const void*, NumDTensor> ds_ptr;
338  void* out_ptr;
339 
343 
347 
348  // Split-N support fields - initialize to safe defaults
349  index_t n_splits = 1; // Number of batch splits (e.g., 2 for 128→64×2)
350  index_t n_per_split = 1; // Batches per split (N_ from transformer)
351  index_t original_n = 1; // Original batch size before splitting
352  index_t input_batch_stride = 0; // Stride to next batch in input tensor
353  index_t output_batch_stride = 0; // Stride to next batch in output tensor
354 };
355 
394 template <typename GroupedConvTraitsType_,
395  typename TilePartitioner_,
396  typename GemmPipeline_,
397  typename EpiloguePipeline_>
399 {
400  static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial;
402  GroupedConvTraitsType_::ConvSpecialization;
409 
414 
416  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
417 
418  static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
419 
423  // Below type is actually accumulation data type - the output of block GEMM.
425 
427 
428  // TODO: Enable this
429  static constexpr bool IsSplitKSupported = false;
430 
431  static constexpr auto I0 = number<0>();
432  static constexpr auto I1 = number<1>();
433  static constexpr auto I2 = number<2>();
434  static constexpr auto I3 = number<3>();
435 
436  static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
437  "Not supported!");
438  static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>, "Not supported!");
439  static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
440  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
441 
442  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
443  {
444  // clang-format off
445  return concat('_', "grouped_convolution_forward", gemm_prec_str<InDataType, WeiDataType>, GemmPipeline::GetName());
446  // clang-format on
447  }
448 
450  {
451  return dim3(
452  TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.n_splits);
453  }
454 
455  CK_TILE_HOST static auto BlockSize()
456  {
457  return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
458  }
459 
462  {
463  return GroupedConvFwdKernelArgsSpecialized(hostArgs);
464  }
465 
467  {
468  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
469  }
470 
472  {
473  if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
476  {
477  if(kargs.k_batch != 1)
478  {
479  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
480  {
481  CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
482  }
483  return false;
484  }
485  }
486 
487  // Check Split-K and Split-N conflict (both use blockIdx.z)
488  if(kargs.k_batch > 1 && kargs.n_splits > 1)
489  {
490  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
491  {
493  "Cannot use both Split-K and Split-N simultaneously (both use blockIdx.z)!");
494  }
495  return false;
496  }
497 
498  const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
499  const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
500 
501  // check ConvolutionSpecialization
503  {
504  // check if it's 1x1, stride=1 conv
505  for(index_t i = 0; i < NDimSpatial; ++i)
506  {
507  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
508  const index_t ConvStride = kargs.conv_filter_strides[i];
509  const index_t LeftPad = kargs.input_left_pads[i];
510  const index_t RightPad = kargs.input_right_pads[i];
511 
512  if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
513  {
514  return false;
515  }
516  }
517  }
519  {
520  // check if it's 1x1 conv
521  for(index_t i = 0; i < NDimSpatial; ++i)
522  {
523  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
524  const index_t LeftPad = kargs.input_left_pads[i];
525  const index_t RightPad = kargs.input_right_pads[i];
526 
527  if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
528  {
529  return false;
530  }
531  }
532  }
534  {
535  if(ConvC != 1)
536  {
537  return false;
538  }
539  for(index_t i = 0; i < NDimSpatial; ++i)
540  {
541  const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3];
542 
543  if(filter_spatial_dim != I3)
544  {
545  return false;
546  }
547  }
548  }
549 
550  namespace ctc = tensor_layout::convolution;
551 
552  if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
553  std::is_same_v<InLayout, ctc::NDHWGC>)
554  {
555  // Check access per C
556  if(ConvC % GroupedConvTraitsType_::VectorSizeA != 0)
557  {
558  CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
559  return false;
560  }
561  }
562  else
563  {
564  CK_TILE_ERROR("Not supported input layout!");
565  return false;
566  }
567 
568  // check vector access of B
569  // FIXME: layout
570  if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
571  std::is_same_v<WeiLayout, ctc::GKYXC> ||
572  std::is_same_v<WeiLayout, ctc::GKZYXC>)
573  {
574  if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
575  {
576  CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
577  return false;
578  }
579  }
580  else
581  {
582  CK_TILE_ERROR("Not supported weight layout!");
583  return false;
584  }
585 
586  // check vector access of E
587  if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
588  std::is_same_v<OutLayout, ctc::NHWGK> ||
589  std::is_same_v<OutLayout, ctc::NDHWGK>)
590  {
591  if(ConvK % GroupedConvTraitsType_::VectorSizeC != 0)
592  {
593  CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!");
594  return false;
595  }
596  }
597  else
598  {
599  CK_TILE_ERROR("Not supported output layout!");
600  return false;
601  }
602 
603  return true;
604  }
605 
606  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
607  CK_TILE_DEVICE static auto
609  const WeiDataType* b_ptr,
610  const std::array<const void*, NumDTensor>& ds_ptr,
611  OutDataType* c_ptr,
613  {
614  static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
615  static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
616  const auto& a_tensor_view = [&]() {
617  return make_tensor_view<address_space_enum::global>(a_ptr, kargs.a_grid_desc_m_k);
618  }();
619 
620  const auto& b_tensor_view = [&]() {
621  return make_tensor_view<address_space_enum::global>(b_ptr, kargs.b_grid_desc_n_k);
622  }();
623 
624  // TODO: enable vector write for C in ColMajor
625  const auto& c_tensor_view = [&]() {
626  return make_tensor_view<address_space_enum::global>(c_ptr, kargs.c_grid_desc_m_n);
627  }();
628 
629  const auto& ds_tensor_view = generate_tuple(
630  [&](auto i) {
631  static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
632  "Not supported!");
633  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
634  "Not supported!");
635  static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, OutDataType>,
636  "Not supported!");
637 
638  return make_tensor_view<address_space_enum::global>(
639  static_cast<OutDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
640  },
642 
643  return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
644  }
645 
646  template <typename TensorView>
647  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
648  {
649  const auto& a_pad_view = [&]() {
650  const auto& a_tensor_view = views.at(I0);
651  return pad_tensor_view(a_tensor_view,
655  }();
656 
657  const auto& b_pad_view = [&]() {
658  const auto& b_tensor_view = views.at(I1);
659  return pad_tensor_view(b_tensor_view,
663  }();
664 
665  const auto& ds_tensor_view = views.at(I2);
666  const auto& ds_pad_view = generate_tuple(
667  [&](auto i) {
668  return pad_tensor_view(ds_tensor_view[i],
672  },
674 
675  const auto& c_pad_view = [&]() {
676  const auto& c_tensor_view = views.at(I3);
677  return pad_tensor_view(c_tensor_view,
681  }();
682 
683  return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
684  }
685 
686  template <typename PadView>
687  CK_TILE_DEVICE static auto
688  MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
689  {
690  const auto& a_pad_view = views.at(I0);
691  const auto& b_pad_view = views.at(I1);
692  const auto& ds_pad_view = views.at(I2);
693  const auto& c_pad_view = views.at(I3);
694 
695  const auto& a_block_window = [&]() {
696  return make_tile_window(a_pad_view,
699  {i_m, 0});
700  }();
701 
702  const auto& b_block_window = [&]() {
703  return make_tile_window(b_pad_view,
706  {i_n, 0});
707  }();
708 
709  const auto ds_block_window = generate_tuple(
710  [&](auto i) {
711  return make_tile_window(ds_pad_view[i],
714  {i_m, i_n});
715  },
717 
718  auto c_block_window = make_tile_window(
719  c_pad_view,
721  {i_m, i_n});
722 
723  return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
724  }
725 
738  CK_TILE_DEVICE static void RunGemm(const InDataType* a_ptr,
739  const WeiDataType* b_ptr,
740  const std::array<const void*, NumDTensor>& ds_ptr,
741  OutDataType* c_ptr,
742  void* smem_ptr_0,
744  const index_t block_idx_m,
745  const index_t block_idx_n)
746  {
747  // Create Gemm tensor views, pad views and tile windows
748  const auto& gemm_tensor_views_tuple =
749  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
750  a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
751 
752  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
753  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
754 
755  const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(kargs.GemmK));
756 
757  // Run GEMM cooperatively by whole workgroup.
758  const auto& a_block_window = gemm_tile_windows.at(I0);
759  const auto& b_block_window = gemm_tile_windows.at(I1);
760  const auto& d_block_window = gemm_tile_windows.at(I2);
761 
762  const auto& c_block_tile = GemmPipeline{}.template operator()(
763  a_block_window, b_block_window, num_loop, smem_ptr_0);
764 
765  // Run Epilogue Pipeline
766  auto& c_block_window = gemm_tile_windows.at(I3);
767 
768  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
769  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
770  }
771 
787  CK_TILE_DEVICE static void RunGemm2LDS(const InDataType* a_ptr,
788  const WeiDataType* b_ptr,
789  const std::array<const void*, NumDTensor>& ds_ptr,
790  OutDataType* c_ptr,
791  void* __restrict__ smem_ptr_0,
792  void* __restrict__ smem_ptr_1,
794  const index_t block_idx_m,
795  const index_t block_idx_n)
796  {
797  // Create Gemm tensor views, pad views and tile windows
798  const auto& gemm_tensor_views_tuple =
799  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
800  a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
801  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
802  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
803 
804  const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(kargs.GemmK));
805 
806  // Run GEMM cooperatively by whole workgroup.
807  const auto& a_block_window = gemm_tile_windows.at(I0);
808  const auto& b_block_window = gemm_tile_windows.at(I1);
809  const auto& d_block_window = gemm_tile_windows.at(I2);
810 
811  const auto& c_block_tile = GemmPipeline{}.template operator()(
812  a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
813 
814  // Run Epilogue Pipeline
815  auto& c_block_window = gemm_tile_windows.at(I3);
816 
817  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
818  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
819  }
820 
822  {
823  const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
824  const auto [iM, iN] =
825  TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX);
826  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
827  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
828 
829  const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
830  const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
831  const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
832  const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
833 
834  // Split-N handling: Get which split this workgroup handles
835  const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z);
836 
837  // Calculate batch offset for this split
838  const index_t batch_offset = amd_wave_read_first_lane(blockIdZ * kargs.n_per_split);
839 
840  // Calculate memory offsets for this split
841  const long_index_t input_batch_offset = static_cast<long_index_t>(batch_offset) *
842  static_cast<long_index_t>(kargs.input_batch_stride);
843  const long_index_t output_batch_offset =
844  static_cast<long_index_t>(batch_offset) *
845  static_cast<long_index_t>(kargs.output_batch_stride);
846 
847  // Adjust pointers: combine group offset and batch offset
848  const InDataType* a_ptr =
849  static_cast<const InDataType*>(kargs.in_ptr) + group_offset_a + input_batch_offset;
850  const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) +
851  group_offset_b; // No batch offset for weights!
852  OutDataType* c_ptr =
853  static_cast<OutDataType*>(kargs.out_ptr) + group_offset_c + output_batch_offset;
854 
855  // allocate LDS
856  __shared__ char smem_ptr_0[GetSmemSize()];
857 
858  if constexpr(GemmPipeline::DoubleSmemBuffer == true)
859  {
860  __shared__ char smem_ptr_1[GetSmemSize()];
861  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
862  GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
864  {
865  RunGemm2LDS(
866  a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, smem_ptr_1, kargs, i_m, i_n);
867  }
868  }
869  else
870  {
871  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
872  GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
874  {
875  RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n);
876  }
877  }
878  }
879 };
880 
881 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:2834
ConvolutionSpecialization
Definition: convolution_specialization.hpp:11
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
constexpr bool is_same_v
Definition: type.hpp:283
__device__ X atomic_add(X *p_dst, const X &x)
The Grouped Convolution kernel device arguments.
Definition: grouped_convolution_forward_kernel.hpp:22
index_t GemmM
Definition: grouped_convolution_forward_kernel.hpp:330
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_strides
Definition: grouped_convolution_forward_kernel.hpp:324
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_left_pads
Definition: grouped_convolution_forward_kernel.hpp:326
index_t n_per_split
Definition: grouped_convolution_forward_kernel.hpp:350
long_index_t group_stride_c
Definition: grouped_convolution_forward_kernel.hpp:346
index_t original_n
Definition: grouped_convolution_forward_kernel.hpp:351
BGridDescNK b_grid_desc_n_k
Definition: grouped_convolution_forward_kernel.hpp:341
static constexpr index_t NonSpatialDims
Definition: grouped_convolution_forward_kernel.hpp:319
static constexpr index_t NumDTensor
Definition: grouped_convolution_forward_kernel.hpp:31
remove_cvref_t< decltype(ConvToGemmFwdTransformer{} .template MakeADescriptor_M_K< typename GroupedConvTraitsType_::InLayout >())> AGridDescMK
Definition: grouped_convolution_forward_kernel.hpp:311
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_dilations
Definition: grouped_convolution_forward_kernel.hpp:325
remove_cvref_t< decltype(ConvToGemmFwdTransformer{} .template MakeBDescriptor_N_K< typename GroupedConvTraitsType_::WeiLayout >())> BGridDescNK
Definition: grouped_convolution_forward_kernel.hpp:314
const void * in_ptr
Definition: grouped_convolution_forward_kernel.hpp:335
long_index_t group_stride_b
Definition: grouped_convolution_forward_kernel.hpp:345
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > in_g_n_c_wis_lengths
Definition: grouped_convolution_forward_kernel.hpp:320
index_t output_batch_stride
Definition: grouped_convolution_forward_kernel.hpp:353
const void * wei_ptr
Definition: grouped_convolution_forward_kernel.hpp:336
index_t GemmK
Definition: grouped_convolution_forward_kernel.hpp:332
index_t GemmN
Definition: grouped_convolution_forward_kernel.hpp:331
remove_cvref_t< decltype(ConvToGemmFwdTransformer{} .template MakeCDescriptor_M_N< typename GroupedConvTraitsType_::OutLayout >())> CGridDescMN
Definition: grouped_convolution_forward_kernel.hpp:317
index_t k_batch
Definition: grouped_convolution_forward_kernel.hpp:329
CGridDescMN c_grid_desc_m_n
Definition: grouped_convolution_forward_kernel.hpp:342
index_t input_batch_stride
Definition: grouped_convolution_forward_kernel.hpp:352
std::array< const void *, NumDTensor > ds_ptr
Definition: grouped_convolution_forward_kernel.hpp:337
void * out_ptr
Definition: grouped_convolution_forward_kernel.hpp:338
AGridDescMK a_grid_desc_m_k
Definition: grouped_convolution_forward_kernel.hpp:340
long_index_t group_stride_a
Definition: grouped_convolution_forward_kernel.hpp:344
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > out_g_n_k_wos_lengths
Definition: grouped_convolution_forward_kernel.hpp:322
index_t n_splits
Definition: grouped_convolution_forward_kernel.hpp:349
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs &args)
Definition: grouped_convolution_forward_kernel.hpp:41
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > wei_g_k_c_xs_lengths
Definition: grouped_convolution_forward_kernel.hpp:321
index_t GemmBatch
Definition: grouped_convolution_forward_kernel.hpp:333
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_right_pads
Definition: grouped_convolution_forward_kernel.hpp:327
The Grouped Conv kernel host arguments.
Definition: grouped_convolution_utils.hpp:19
index_t k_batch
Definition: grouped_convolution_utils.hpp:40
InPtr in_ptr
Definition: grouped_convolution_utils.hpp:36
WeiPtr wei_ptr
Definition: grouped_convolution_utils.hpp:37
OutPtr out_ptr
Definition: grouped_convolution_utils.hpp:39
const std::vector< const void * > ds_ptr
Definition: grouped_convolution_utils.hpp:38
The Grouped Convolution Forward kernel template.
Definition: grouped_convolution_forward_kernel.hpp:399
remove_cvref_t< typename EpiloguePipeline::DsLayout > GemmDsLayout
Definition: grouped_convolution_forward_kernel.hpp:415
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: grouped_convolution_forward_kernel.hpp:404
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_convolution_forward_kernel.hpp:403
static constexpr auto I1
Definition: grouped_convolution_forward_kernel.hpp:432
static constexpr auto I2
Definition: grouped_convolution_forward_kernel.hpp:433
remove_cvref_t< typename GroupedConvTraitsType_::OutLayout > OutLayout
Definition: grouped_convolution_forward_kernel.hpp:412
static constexpr CK_TILE_HOST GroupedConvFwdKernelArgsSpecialized MakeKernelArgs(const GroupedConvFwdHostArgs &hostArgs)
Definition: grouped_convolution_forward_kernel.hpp:461
static constexpr auto I0
Definition: grouped_convolution_forward_kernel.hpp:431
CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized kargs) const
Definition: grouped_convolution_forward_kernel.hpp:821
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: grouped_convolution_forward_kernel.hpp:466
remove_cvref_t< typename GroupedConvTraitsType_::WeiLayout > WeiLayout
Definition: grouped_convolution_forward_kernel.hpp:411
remove_cvref_t< typename EpiloguePipeline::ODataType > OutDataType
Definition: grouped_convolution_forward_kernel.hpp:424
static CK_TILE_DEVICE auto MakeGemmTensorViews(const InDataType *a_ptr, const WeiDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, OutDataType *c_ptr, const GroupedConvFwdKernelArgsSpecialized &kargs)
Definition: grouped_convolution_forward_kernel.hpp:608
GroupedConvFwdKernelArgs< GroupedConvTraitsType_ > GroupedConvFwdKernelArgsSpecialized
Definition: grouped_convolution_forward_kernel.hpp:426
remove_cvref_t< typename GroupedConvTraitsType_::DsLayout > DsLayout
Definition: grouped_convolution_forward_kernel.hpp:413
static constexpr index_t kBlockSize
Definition: grouped_convolution_forward_kernel.hpp:418
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: grouped_convolution_forward_kernel.hpp:422
static CK_TILE_DEVICE void RunGemm2LDS(const InDataType *a_ptr, const WeiDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, OutDataType *c_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const GroupedConvFwdKernelArgsSpecialized &kargs, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_forward_kernel.hpp:787
remove_cvref_t< typename GemmPipeline::BLayout > GemmBLayout
Definition: grouped_convolution_forward_kernel.hpp:407
static constexpr index_t NDimSpatial
Definition: grouped_convolution_forward_kernel.hpp:400
static CK_TILE_HOST auto BlockSize()
Definition: grouped_convolution_forward_kernel.hpp:455
static constexpr auto I3
Definition: grouped_convolution_forward_kernel.hpp:434
static CK_TILE_HOST const std::string GetName()
Definition: grouped_convolution_forward_kernel.hpp:442
static CK_TILE_HOST bool IsSupportedArgument(const GroupedConvFwdKernelArgsSpecialized &kargs)
Definition: grouped_convolution_forward_kernel.hpp:471
remove_cvref_t< typename GemmPipeline::BDataType > WeiDataType
Definition: grouped_convolution_forward_kernel.hpp:421
static constexpr index_t NumDTensor
Definition: grouped_convolution_forward_kernel.hpp:416
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: grouped_convolution_forward_kernel.hpp:688
remove_cvref_t< typename GemmPipeline::ALayout > GemmALayout
Definition: grouped_convolution_forward_kernel.hpp:406
static constexpr bool IsSplitKSupported
Definition: grouped_convolution_forward_kernel.hpp:429
remove_cvref_t< typename GemmPipeline::CLayout > GemmCLayout
Definition: grouped_convolution_forward_kernel.hpp:408
static CK_TILE_DEVICE void RunGemm(const InDataType *a_ptr, const WeiDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, OutDataType *c_ptr, void *smem_ptr_0, const GroupedConvFwdKernelArgsSpecialized &kargs, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_forward_kernel.hpp:738
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: grouped_convolution_forward_kernel.hpp:647
remove_cvref_t< typename GroupedConvTraitsType_::InLayout > InLayout
Definition: grouped_convolution_forward_kernel.hpp:410
static constexpr ConvolutionSpecialization ConvSpecialization
Definition: grouped_convolution_forward_kernel.hpp:401
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: grouped_convolution_forward_kernel.hpp:405
static CK_TILE_HOST auto GridSize(const GroupedConvFwdKernelArgsSpecialized &kargs)
Definition: grouped_convolution_forward_kernel.hpp:449
remove_cvref_t< typename GemmPipeline::ADataType > InDataType
Definition: grouped_convolution_forward_kernel.hpp:420
Definition: transform_conv_fwd_to_gemm.hpp:22
Definition: integral_constant.hpp:13
std::vector< ck_tile::long_index_t > input_spatial_lengths_
Definition: convolution_parameter.hpp:130
ck_tile::long_index_t K_
Definition: convolution_parameter.hpp:126
std::vector< ck_tile::long_index_t > output_spatial_lengths_
Definition: convolution_parameter.hpp:131
std::vector< ck_tile::long_index_t > input_right_pads_
Definition: convolution_parameter.hpp:137
ck_tile::long_index_t G_
Definition: convolution_parameter.hpp:124
std::vector< ck_tile::long_index_t > conv_filter_strides_
Definition: convolution_parameter.hpp:133
std::vector< ck_tile::long_index_t > filter_spatial_lengths_
Definition: convolution_parameter.hpp:129
ck_tile::long_index_t C_
Definition: convolution_parameter.hpp:127
ck_tile::long_index_t N_
Definition: convolution_parameter.hpp:125
std::vector< ck_tile::long_index_t > input_left_pads_
Definition: convolution_parameter.hpp:136
std::vector< ck_tile::long_index_t > conv_filter_dilations_
Definition: convolution_parameter.hpp:134
Definition: type_traits.hpp:115
Definition: sequence.hpp:49
#define CK_TILE_ENV(name)
Definition: env.hpp:145