/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp Source File
grouped_convolution_forward_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 #include <tuple>
9 
10 #include "ck_tile/core.hpp"
13 #include "ck_tile/ops/common.hpp"
14 #include "ck_tile/host/concat.hpp"
20 
21 #ifdef CK_EXPERIMENTAL_BUILDER
22 #include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp"
23 #endif
24 
25 namespace ck_tile {
26 
28 template <typename GroupedConvTraitsType_, typename CDElementwise_>
30 {
32  TransformConvFwdToGemm<GroupedConvTraitsType_::NDimSpatial,
33  GroupedConvTraitsType_::ConvSpecialization,
34  GroupedConvTraitsType_::VectorSizeA,
35  GroupedConvTraitsType_::VectorSizeB,
36  GroupedConvTraitsType_::VectorSizeC,
37  GroupedConvTraitsType_::NumGroupsToMerge,
38  true>; // Split N enabled
39  using CDElementwise = CDElementwise_;
40  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
41 
42  static_assert(!GroupedConvTraitsType_::ExplicitGemm ||
43  GroupedConvTraitsType_::NumGroupsToMerge == 1,
44  "Explicit GEMM does not support merging convolution groups!");
45 
46  template <
47  typename InLay = typename GroupedConvTraitsType_::InLayout,
48  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
49  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
50  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NWGC> &&
51  std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
52  std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
53  bool>::type = false>
55  : elfunc(args.elfunc)
56  {
57  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
58  static_cast<index_t>(args.N_),
59  static_cast<index_t>(args.C_),
60  static_cast<index_t>(args.input_spatial_lengths_[0])};
61  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
62  static_cast<index_t>(args.K_),
63  static_cast<index_t>(args.C_),
64  static_cast<index_t>(args.filter_spatial_lengths_[0])};
65  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
66  static_cast<index_t>(args.N_),
67  static_cast<index_t>(args.K_),
68  static_cast<index_t>(args.output_spatial_lengths_[0])};
69 
70  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0])};
71  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0])};
72  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0])};
73  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0])};
74 
75  k_batch = args.k_batch;
76 
77  in_ptr = args.in_ptr;
78  wei_ptr = args.wei_ptr;
79  for(index_t d = 0; d < NumDTensor; d++)
80  {
81  ds_ptr[d] = args.ds_ptr[d];
82  }
83  out_ptr = args.out_ptr;
84 
85  // Create and STORE transformer (for split-image support)
93 
95  transformer_.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
97  transformer_.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
99  transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
100 
101  NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
103  group_stride_b = args.K_ * args.C_ * NumGroupsToMerge *
104  std::accumulate(args.filter_spatial_lengths_.begin(),
105  args.filter_spatial_lengths_.end(),
106  1,
107  std::multiplies<index_t>());
109 
110  // Initialize Split-N support fields for 1D convolution (NWGC layout)
111  // Get the actual split N from transformer
115 
116  // Calculate batch strides using the original argument dimensions.
117  // These are the original dimensions passed to the constructor, not modified by the invoker
118  // yet. (The invoker modifies args after calling MakeKernelArgs.) VERIFIED: G_ MUST be
119  // included - NWGC layout has all groups within each batch
120  input_batch_stride = args.G_ * args.C_ * args.input_spatial_lengths_[0];
121  output_batch_stride = args.G_ * args.K_ * args.output_spatial_lengths_[0];
122 
123  GemmM = a_grid_desc_m_k.get_length(number<0>{});
124  GemmN = b_grid_desc_n_k.get_length(number<0>{});
125  GemmK = a_grid_desc_m_k.get_length(number<1>{});
127 
128  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
129  {
130  std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
131  << ", GemmBatch: " << GemmBatch << ", N per split: " << n_per_split
132  << ", number of N splits: " << n_splits
133  << ", input_batch_stride: " << input_batch_stride
134  << ", output_batch_stride: " << output_batch_stride
135  << ", NumGroupsToMerge: " << NumGroupsToMerge << std::endl;
136  }
137  }
138 
139  template <
140  typename InLay = typename GroupedConvTraitsType_::InLayout,
141  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
142  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
143  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NHWGC> &&
144  std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
145  std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
146  bool>::type = false>
148  : elfunc(args.elfunc)
149  {
150  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
151  static_cast<index_t>(args.N_),
152  static_cast<index_t>(args.C_),
153  static_cast<index_t>(args.input_spatial_lengths_[0]),
154  static_cast<index_t>(args.input_spatial_lengths_[1])};
155  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
156  static_cast<index_t>(args.K_),
157  static_cast<index_t>(args.C_),
158  static_cast<index_t>(args.filter_spatial_lengths_[0]),
159  static_cast<index_t>(args.filter_spatial_lengths_[1])};
160  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
161  static_cast<index_t>(args.N_),
162  static_cast<index_t>(args.K_),
163  static_cast<index_t>(args.output_spatial_lengths_[0]),
164  static_cast<index_t>(args.output_spatial_lengths_[1])};
165 
166  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
167  static_cast<index_t>(args.conv_filter_strides_[1])};
168  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
169  static_cast<index_t>(args.conv_filter_dilations_[1])};
170  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
171  static_cast<index_t>(args.input_left_pads_[1])};
172  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
173  static_cast<index_t>(args.input_right_pads_[1])};
174 
175  k_batch = args.k_batch;
176 
177  in_ptr = args.in_ptr;
178  wei_ptr = args.wei_ptr;
179  for(index_t d = 0; d < NumDTensor; d++)
180  {
181  ds_ptr[d] = args.ds_ptr[d];
182  }
183  out_ptr = args.out_ptr;
184 
185  // Create and STORE transformer (for split-image support)
193 
195  transformer_.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
197  transformer_.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
199  transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
200 
201  NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
203  group_stride_b = args.K_ * args.C_ * NumGroupsToMerge *
204  std::accumulate(args.filter_spatial_lengths_.begin(),
205  args.filter_spatial_lengths_.end(),
206  1,
207  std::multiplies<index_t>());
209 
210  // Initialize Split-N support fields for 2D convolution (NHWGC layout)
211  // Get the actual split N from transformer
215 
216  // Calculate batch strides for NHWGC layout
217  // VERIFIED: G_ MUST be included - NHWGC layout has all groups within each batch
219  args.G_ * args.C_ * args.input_spatial_lengths_[0] * args.input_spatial_lengths_[1];
221  args.G_ * args.K_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
222 
223  GemmM = a_grid_desc_m_k.get_length(number<0>{});
224  GemmN = b_grid_desc_n_k.get_length(number<0>{});
225  GemmK = a_grid_desc_m_k.get_length(number<1>{});
227 
228  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
229  {
230  std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
231  << ", GemmBatch: " << GemmBatch << ", N per split: " << n_per_split
232  << ", number of N splits: " << n_splits
233  << ", input_batch_stride: " << input_batch_stride
234  << ", output_batch_stride: " << output_batch_stride
235  << ", NumGroupsToMerge: " << NumGroupsToMerge << std::endl;
236  }
237  }
238 
239  template <
240  typename InLay = typename GroupedConvTraitsType_::InLayout,
241  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
242  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
243  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NDHWGC> &&
244  std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
245  std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
246  bool>::type = false>
248  : elfunc(args.elfunc)
249  {
250  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
251  static_cast<index_t>(args.N_),
252  static_cast<index_t>(args.C_),
253  static_cast<index_t>(args.input_spatial_lengths_[0]),
254  static_cast<index_t>(args.input_spatial_lengths_[1]),
255  static_cast<index_t>(args.input_spatial_lengths_[2])};
256  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
257  static_cast<index_t>(args.K_),
258  static_cast<index_t>(args.C_),
259  static_cast<index_t>(args.filter_spatial_lengths_[0]),
260  static_cast<index_t>(args.filter_spatial_lengths_[1]),
261  static_cast<index_t>(args.filter_spatial_lengths_[2])};
262  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
263  static_cast<index_t>(args.N_),
264  static_cast<index_t>(args.K_),
265  static_cast<index_t>(args.output_spatial_lengths_[0]),
266  static_cast<index_t>(args.output_spatial_lengths_[1]),
267  static_cast<index_t>(args.output_spatial_lengths_[2])};
268 
269  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
270  static_cast<index_t>(args.conv_filter_strides_[1]),
271  static_cast<index_t>(args.conv_filter_strides_[2])};
272  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
273  static_cast<index_t>(args.conv_filter_dilations_[1]),
274  static_cast<index_t>(args.conv_filter_dilations_[2])};
275  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
276  static_cast<index_t>(args.input_left_pads_[1]),
277  static_cast<index_t>(args.input_left_pads_[2])};
278  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
279  static_cast<index_t>(args.input_right_pads_[1]),
280  static_cast<index_t>(args.input_right_pads_[2])};
281 
282  k_batch = args.k_batch;
283 
284  in_ptr = args.in_ptr;
285  wei_ptr = args.wei_ptr;
286  for(index_t d = 0; d < NumDTensor; d++)
287  {
288  ds_ptr[d] = args.ds_ptr[d];
289  }
290  out_ptr = args.out_ptr;
291 
292  // Create and STORE transformer (for split-image support)
300 
302  transformer_.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
304  transformer_.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
306  transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
307 
308  NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
310  group_stride_b = args.K_ * args.C_ * NumGroupsToMerge *
311  std::accumulate(args.filter_spatial_lengths_.begin(),
312  args.filter_spatial_lengths_.end(),
313  1,
314  std::multiplies<index_t>());
316 
317  // Initialize Split-N support fields for 3D convolution (NDHWGC layout)
318  // Get the actual split N from transformer
322 
323  // Calculate batch strides for NDHWGC layout
324  // VERIFIED: G_ MUST be included - NDHWGC layout has all groups within each batch
325  input_batch_stride = args.G_ * args.C_ * args.input_spatial_lengths_[0] *
327  output_batch_stride = args.G_ * args.K_ * args.output_spatial_lengths_[0] *
329 
330  GemmM = a_grid_desc_m_k.get_length(number<0>{});
331  GemmN = b_grid_desc_n_k.get_length(number<0>{});
332  GemmK = a_grid_desc_m_k.get_length(number<1>{});
334 
335  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
336  {
337  std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
338  << ", GemmBatch: " << GemmBatch << ", N per split: " << n_per_split
339  << ", number of N splits: " << n_splits
340  << ", input_batch_stride: " << input_batch_stride
341  << ", output_batch_stride: " << output_batch_stride
342  << ", NumGroupsToMerge: " << NumGroupsToMerge << std::endl;
343  }
344  }
346  decltype(ConvToGemmFwdTransformer{}
347  .template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>())>;
349  decltype(ConvToGemmFwdTransformer{}
350  .template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>())>;
352  decltype(ConvToGemmFwdTransformer{}
353  .template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>())>;
354 
355  static constexpr index_t NonSpatialDims = 3;
359 
364 
371 
372  const void* in_ptr;
373  const void* wei_ptr;
374  std::array<const void*, NumDTensor> ds_ptr;
376  void* out_ptr;
377 
381 
385 
386  // Split-N support fields - initialize to safe defaults
387  index_t n_splits = 1; // Number of batch splits (e.g., 2 for 128→64×2)
388  index_t n_per_split = 1; // Batches per split (N_ from transformer)
389  index_t original_n = 1; // Original batch size before splitting
390  index_t input_batch_stride = 0; // Stride to next batch in input tensor
391  index_t output_batch_stride = 0; // Stride to next batch in output tensor
392 
393  // Split-image support - spatial offsets (applied per-batch in operator())
394  long_index_t spatial_offset_in = 0; // Spatial offset for input (e.g., W/2 for 1D split)
395  long_index_t spatial_offset_out = 0; // Spatial offset for output (e.g., W/2 for 1D split)
396 
397  // Split-image support - transformer instance
399 
400  // Forward declare descriptor types (will be defined after using declarations)
404 
405  // Split-image support: Common data for all pieces
407  {
408  // Common dimensions (same for all pieces)
409  index_t total_d = 1, total_h = 1, total_w = 1; // Total tensor dimensions
410  index_t total_spatial = 1; // Pre-calculated: total_d * total_h * total_w
411  index_t num_d_pieces = 1, num_h_pieces = 1, num_w_pieces = 1; // Split factors
412 
413  // Minimal per-piece data (only unique values)
414  struct PieceInfo
415  {
416  index_t block_start; // Starting block index for this piece
417  index_t block_end; // Ending block index (exclusive)
418  index_t d_start, h_start, w_start; // Piece starting position in OUTPUT space
419  index_t d_size, h_size, w_size; // Piece size in OUTPUT space
420  };
421 
422  static constexpr index_t MaxPieces = 64; // Max pieces: 4 (1D), 16 (2D), 64 (3D)
423  std::array<PieceInfo, MaxPieces> pieces; // Array of minimal piece descriptors
424  };
425 
426  index_t num_spatial_pieces = 1; // Number of spatial pieces (1 = no split)
427  SplitImageInfo split_image; // Nested structure with common + per-piece data
428 };
429 
468 template <typename GroupedConvTraitsType_,
469  typename TilePartitioner_,
470  typename GemmPipeline_,
471  typename EpiloguePipeline_>
473 {
474  static constexpr bool EnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
475  static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial;
477  GroupedConvTraitsType_::ConvSpecialization;
484 
489 
491  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
492 
493  static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
494 
498  // Below type is actually accumulation data type - the output of block GEMM.
500 
501  using CDElementwise = typename EpiloguePipeline::CDElementwise;
502 
505 
506  static constexpr bool IsSplitKSupported = false;
507 
508  static constexpr auto I0 = number<0>();
509  static constexpr auto I1 = number<1>();
510  static constexpr auto I2 = number<2>();
511  static constexpr auto I3 = number<3>();
512 
513  static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
514  "Not supported!");
515  static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>, "Not supported!");
516  static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
517  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
518  static_assert(GroupedConvTraitsType_::ExplicitGemm == false ||
519  GroupedConvTraitsType_::NumGroupsToMerge == 1,
520  "Not supported!");
521 
522  // Helper struct for spatial coordinates
524  {
526  };
527 
528  // Helper: Convert flat spatial index to (d,h,w) coordinates
530  UnflattenSpatial(index_t flat, index_t h_size, index_t w_size)
531  {
532  if constexpr(NDimSpatial == 1)
533  {
534  return SpatialCoords{0, 0, flat};
535  }
536  else if constexpr(NDimSpatial == 2)
537  {
538  return SpatialCoords{0, flat / w_size, flat % w_size};
539  }
540  else // NDimSpatial == 3
541  {
542  const index_t hw = h_size * w_size;
543  const index_t d = flat / hw;
544  const index_t remainder = flat % hw;
545  return SpatialCoords{d, remainder / w_size, remainder % w_size};
546  }
547  }
548 
549  // Helper: Convert (d,h,w) to flat spatial index
550  CK_TILE_DEVICE static index_t
552  {
553  if constexpr(NDimSpatial == 1)
554  {
555  return w;
556  }
557  else if constexpr(NDimSpatial == 2)
558  {
559  return h * total_w + w;
560  }
561  else // NDimSpatial == 3
562  {
563  return (d * total_h + h) * total_w + w;
564  }
565  }
566 
567  // Helper: Find which piece owns a block using binary search
568  template <typename SplitImageInfo>
569  CK_TILE_DEVICE static index_t
570  FindPieceId(index_t block_id, const SplitImageInfo& split_info, index_t num_pieces)
571  {
572  index_t left = 0;
573  index_t right = num_pieces - 1;
574  index_t piece_id = (left + right) / 2;
575 
576  while(!(block_id >= split_info.pieces[piece_id].block_start &&
577  block_id < split_info.pieces[piece_id].block_end) &&
578  left <= right)
579  {
580  if(block_id < split_info.pieces[piece_id].block_start)
581  {
582  right = piece_id - 1;
583  }
584  else
585  {
586  left = piece_id + 1;
587  }
588  piece_id = (left + right) / 2;
589  }
590  return piece_id;
591  }
592 
593  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
594  {
595  constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
596  // clang-format off
597  return concat('_', "grouped_convolution_forward",
598  gemm_prec_str<InDataType, WeiDataType>(),
599  InLayout::name,
600  WeiLayout::name,
601  OutLayout::name,
602  "gemm",
603  GemmPipeline::GetName(),
604  "epilogue",
605  EpiloguePipeline::GetName(),
607  "MergedGroups",
608  NumGroupsToMerge,
609  "SplitImage",
611  "ExplicitGemm",
612  GroupedConvTraitsType_::ExplicitGemm
613  );
614  // clang-format on
615  }
616 
617  [[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); }
618 
619 #ifdef CK_EXPERIMENTAL_BUILDER
620  CK_TILE_HOST std::string GetInstanceString() const
621  {
622  static_assert(ck_tile::reflect::HasInstanceTraits<GroupedConvolutionForwardKernel>,
623  "Specialization of instance_traits not found. Please check that a "
624  "specialization exists in file "
625  "ck_tile/builder/reflect/"
626  "instance_traits_tile_grouped_convolution_forward.hpp "
627  "for the given template parameters.");
628  return ck_tile::reflect::instance_string<GroupedConvolutionForwardKernel>();
629  }
630 #endif
631 
633  {
634  return dim3(
635  TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.n_splits);
636  }
637 
638  CK_TILE_HOST static auto BlockSize()
639  {
640  return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
641  }
642 
645  {
646  auto kargs = GroupedConvFwdKernelArgsSpecialized(hostArgs);
647  return kargs;
648  }
649 
651  {
652  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
653  }
654 
656  {
657  if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
660  {
661  if(kargs.k_batch != 1)
662  {
663  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
664  {
665  CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
666  }
667  return false;
668  }
669  }
670 
671  const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
672  const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
673 
674  // check ConvolutionSpecialization
676  {
677  // check if it's 1x1, stride=1 conv
678  for(index_t i = 0; i < NDimSpatial; ++i)
679  {
680  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
681  const index_t ConvStride = kargs.conv_filter_strides[i];
682  const index_t LeftPad = kargs.input_left_pads[i];
683  const index_t RightPad = kargs.input_right_pads[i];
684 
685  if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
686  {
687  return false;
688  }
689  }
690  }
692  {
693  // check if it's 1x1 conv
694  for(index_t i = 0; i < NDimSpatial; ++i)
695  {
696  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
697  const index_t LeftPad = kargs.input_left_pads[i];
698  const index_t RightPad = kargs.input_right_pads[i];
699 
700  if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
701  {
702  return false;
703  }
704  }
705  }
707  {
708  if(ConvC != 1)
709  {
710  return false;
711  }
712  for(index_t i = 0; i < NDimSpatial; ++i)
713  {
714  const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3];
715 
716  if(filter_spatial_dim != I3)
717  {
718  return false;
719  }
720  }
721  }
722 
723  if constexpr(GroupedConvTraitsType_::ExplicitGemm &&
725  {
727  "Explicit Gemm is supported only for Filter1x1Stride1Pad0 specialization!");
728  return false;
729  }
730 
731  namespace ctc = tensor_layout::convolution;
732 
733  if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
734  std::is_same_v<InLayout, ctc::NDHWGC>)
735  {
736  // Check access per C
737  if(ConvC % GroupedConvTraitsType_::VectorSizeA != 0)
738  {
739  CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
740  return false;
741  }
742  }
743  else
744  {
745  CK_TILE_ERROR("Not supported input layout!");
746  return false;
747  }
748 
749  // check vector access of B
750  // FIXME: layout
751  if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
752  std::is_same_v<WeiLayout, ctc::GKYXC> ||
753  std::is_same_v<WeiLayout, ctc::GKZYXC>)
754  {
755  if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
756  {
757  CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
758  return false;
759  }
760  }
761  else
762  {
763  CK_TILE_ERROR("Not supported weight layout!");
764  return false;
765  }
766 
767  // check vector access of E
768  if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
769  std::is_same_v<OutLayout, ctc::NHWGK> ||
770  std::is_same_v<OutLayout, ctc::NDHWGK>)
771  {
772  if(ConvK % GroupedConvTraitsType_::VectorSizeC != 0)
773  {
774  CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!");
775  return false;
776  }
777  }
778  else
779  {
780  CK_TILE_ERROR("Not supported output layout!");
781  return false;
782  }
783 
784  if constexpr(GroupedConvTraitsType_::NumGroupsToMerge > 1)
785  {
786  const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}];
787  if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0)
788  {
789  CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge!");
790  return false;
791  }
792  }
793 
794  return true;
795  }
796 
797  template <memory_operation_enum DstInMemOp = memory_operation_enum::set,
798  typename ADescType,
799  typename BDescType,
800  typename CDescType>
801  CK_TILE_DEVICE static auto
803  const WeiDataType* b_ptr,
804  const std::array<const void*, NumDTensor>& ds_ptr,
805  OutDataType* c_ptr,
806  const ADescType& a_desc,
807  const BDescType& b_desc,
808  const CDescType& c_desc)
809  {
810  static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
811  static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!");
812  const auto& a_tensor_view = [&]() {
813  return make_tensor_view<address_space_enum::global>(a_ptr, a_desc);
814  }();
815 
816  const auto& b_tensor_view = [&]() {
817  return make_tensor_view<address_space_enum::global>(b_ptr, b_desc);
818  }();
819 
820  // TODO: enable vector write for C in ColMajor
821  const auto& c_tensor_view = [&]() {
822  return make_tensor_view<address_space_enum::global>(c_ptr, c_desc);
823  }();
824 
825  const auto& ds_tensor_view = generate_tuple(
826  [&](auto i) {
827  static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
828  "Not supported!");
829  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
830  "Not supported!");
831  static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, OutDataType>,
832  "Not supported!");
833 
834  return make_tensor_view<address_space_enum::global>(
835  static_cast<const OutDataType*>(ds_ptr[i]), c_desc);
836  },
838 
839  return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
840  }
841 
842  template <typename TensorView>
843  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
844  {
845  const auto& a_pad_view = [&]() {
846  const auto& a_tensor_view = views.at(I0);
847  return pad_tensor_view(a_tensor_view,
851  }();
852 
853  const auto& b_pad_view = [&]() {
854  const auto& b_tensor_view = views.at(I1);
855  return pad_tensor_view(b_tensor_view,
859  }();
860 
861  const auto& ds_tensor_view = views.at(I2);
862  const auto& ds_pad_view = generate_tuple(
863  [&](auto i) {
864  return pad_tensor_view(ds_tensor_view[i],
868  },
870 
871  const auto& c_pad_view = [&]() {
872  const auto& c_tensor_view = views.at(I3);
873  return pad_tensor_view(c_tensor_view,
877  }();
878 
879  return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
880  }
881 
882  template <typename PadView>
883  CK_TILE_DEVICE static auto
884  MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
885  {
886  const auto& a_pad_view = views.at(I0);
887  const auto& b_pad_view = views.at(I1);
888  const auto& ds_pad_view = views.at(I2);
889  const auto& c_pad_view = views.at(I3);
890 
891  const auto& a_block_window = [&]() {
892  return make_tile_window(a_pad_view,
895  {i_m, 0});
896  }();
897 
898  const auto& b_block_window = [&]() {
899  return make_tile_window(b_pad_view,
902  {i_n, 0});
903  }();
904 
905  const auto ds_block_window = generate_tuple(
906  [&](auto i) {
907  return make_tile_window(ds_pad_view[i],
910  {i_m, i_n});
911  },
913 
914  auto c_block_window = make_tile_window(
915  c_pad_view,
917  {i_m, i_n});
918 
919  return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
920  }
921 
938  template <typename ADescType, typename BDescType, typename CDescType>
939  CK_TILE_DEVICE static void RunGemm(const InDataType* a_ptr,
940  const WeiDataType* b_ptr,
941  const std::array<const void*, NumDTensor>& ds_ptr,
942  OutDataType* c_ptr,
943  void* smem_ptr_0,
944  const ADescType& a_desc,
945  const BDescType& b_desc,
946  const CDescType& c_desc,
947  const index_t gemm_k,
948  const index_t block_idx_m,
949  const index_t block_idx_n,
950  const CDElementwise& elfunc)
951  {
952  // Create Gemm tensor views, pad views and tile windows
953  const auto& gemm_tensor_views_tuple =
954  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
955  a_ptr, b_ptr, ds_ptr, c_ptr, a_desc, b_desc, c_desc);
956 
957  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
958  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
959 
960  const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(gemm_k));
961 
962  // Run GEMM cooperatively by whole workgroup.
963  const auto& a_block_window = gemm_tile_windows.at(I0);
964  const auto& b_block_window = gemm_tile_windows.at(I1);
965  const auto& d_block_window = gemm_tile_windows.at(I2);
966 
967  const auto& c_block_tile = GemmPipeline{}.template operator()(
968  a_block_window, b_block_window, num_loop, smem_ptr_0);
969 
970  // Run Epilogue Pipeline
971  auto& c_block_window = gemm_tile_windows.at(I3);
972 
973  EpiloguePipeline{elfunc}
974  .template operator()<decltype(c_block_window), decltype(c_block_tile)>(
975  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
976  }
977 
997  template <typename ADescType, typename BDescType, typename CDescType>
998  CK_TILE_DEVICE static void RunGemm2LDS(const InDataType* a_ptr,
999  const WeiDataType* b_ptr,
1000  const std::array<const void*, NumDTensor>& ds_ptr,
1001  OutDataType* c_ptr,
1002  void* __restrict__ smem_ptr_0,
1003  void* __restrict__ smem_ptr_1,
1004  const ADescType& a_desc,
1005  const BDescType& b_desc,
1006  const CDescType& c_desc,
1007  const index_t gemm_k,
1008  const index_t block_idx_m,
1009  const index_t block_idx_n,
1010  const CDElementwise& elfunc)
1011  {
1012  // Create Gemm tensor views, pad views and tile windows
1013  const auto& gemm_tensor_views_tuple =
1014  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
1015  a_ptr, b_ptr, ds_ptr, c_ptr, a_desc, b_desc, c_desc);
1016  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
1017  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
1018 
1019  const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(gemm_k));
1020 
1021  // Run GEMM cooperatively by whole workgroup.
1022  const auto& a_block_window = gemm_tile_windows.at(I0);
1023  const auto& b_block_window = gemm_tile_windows.at(I1);
1024  const auto& d_block_window = gemm_tile_windows.at(I2);
1025 
1026  const auto& c_block_tile = GemmPipeline{}.template operator()(
1027  a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
1028 
1029  // Run Epilogue Pipeline
1030  auto& c_block_window = gemm_tile_windows.at(I3);
1031 
1032  EpiloguePipeline{elfunc}
1033  .template operator()<decltype(c_block_window), decltype(c_block_tile)>(
1034  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
1035  }
1036 
1038  {
1039  static_assert(NumDTensor == 0, "Not supported!");
1040  using ExplicitBatchedGemmKernel =
1042  const auto batched_gemm_kargs = typename ExplicitBatchedGemmKernel::BatchedGemmKernelArgs{
1043  {{kargs.in_ptr},
1044  {kargs.wei_ptr},
1045  {},
1046  kargs.out_ptr,
1047  kargs.GemmM,
1048  kargs.GemmN,
1049  kargs.GemmK,
1050  {kargs.GemmK * kargs.GemmBatch},
1051  {kargs.GemmK},
1052  {},
1053  kargs.GemmBatch * kargs.GemmN,
1054  kargs.k_batch},
1055  kargs.GemmK,
1056  kargs.GemmN * kargs.GemmK,
1057  kargs.GemmN,
1058  kargs.GemmBatch};
1059  ExplicitBatchedGemmKernel{}(batched_gemm_kargs);
1060  }
1061 
1063  {
1064  if constexpr(GroupedConvTraitsType_::ExplicitGemm)
1065  {
1066  CallExplicitGemm(kargs);
1067  }
1068  else
1069  {
1070  const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
1071  const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
1072 
1073  const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
1074  const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
1075  const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
1076 
1077  // Split-N handling: Get which split this workgroup handles
1078  const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z);
1079 
1080  // Calculate batch offset for this split
1081  const index_t batch_offset = amd_wave_read_first_lane(blockIdZ * kargs.n_per_split);
1082 
1083  // Calculate memory offsets for this split
1084  const long_index_t input_batch_offset =
1085  static_cast<long_index_t>(batch_offset) *
1086  static_cast<long_index_t>(kargs.input_batch_stride);
1087  const long_index_t output_batch_offset =
1088  static_cast<long_index_t>(batch_offset) *
1089  static_cast<long_index_t>(kargs.output_batch_stride);
1090 
1091  // Calculate base pointers with group and batch offsets
1092  const InDataType* base_a_ptr =
1093  static_cast<const InDataType*>(kargs.in_ptr) + group_offset_a + input_batch_offset;
1094  const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) +
1095  group_offset_b; // No batch offset for weights!
1096  OutDataType* base_c_ptr =
1097  static_cast<OutDataType*>(kargs.out_ptr) + group_offset_c + output_batch_offset;
1098 
1099  // Apply group offsets to D tensors
1100  std::array<const void*, NumDTensor> ds_ptr_with_offsets;
1101  static_for<0, NumDTensor, 1>{}([&](auto d) {
1102  using DType = std::tuple_element_t<d, DsDataType>;
1103  ds_ptr_with_offsets[d] = static_cast<const DType*>(kargs.ds_ptr[d]) +
1104  group_offset_c + output_batch_offset;
1105  });
1106 
1107  // =====================================================================
1108  // Split-image: Map local block to global tile index (if enabled)
1109  // =====================================================================
1110  const InDataType* a_ptr;
1111  OutDataType* c_ptr;
1112  index_t i_m = 0;
1113  index_t i_n = 0;
1114 
1115  // Pre-calculate block_id (used in both split-image and non-split paths)
1116  const index_t block_id = static_cast<index_t>(blockIdX);
1117 
1118  if constexpr(EnableSplitImage)
1119  {
1120  // Add spatial offsets for split-image (constexpr optimization)
1121  a_ptr = base_a_ptr + kargs.spatial_offset_in;
1122  c_ptr = base_c_ptr + kargs.spatial_offset_out;
1123 
1124  // Find which piece owns this block using binary search
1125  // Reference: device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
1126  const index_t piece_id =
1127  FindPieceId(block_id, kargs.split_image, kargs.num_spatial_pieces);
1128  const auto& piece = kargs.split_image.pieces[piece_id];
1129  const auto& split_info = kargs.split_image;
1130 
1131  // Calculate local block ID and tile indices
1132  const index_t local_block_id = block_id - piece.block_start;
1133  const index_t local_gemm_m =
1134  kargs.n_per_split * piece.d_size * piece.h_size * piece.w_size;
1135  const auto [local_tile_m, local_tile_n] =
1136  TilePartitioner{local_gemm_m, kargs.GemmN}.GetOutputTileIndex(local_block_id);
1137 
1138  // Extract batch and spatial coordinates from local tile
1139  const index_t local_m_start = local_tile_m * TilePartitioner::MPerBlock;
1140  const index_t spatial_per_batch = piece.d_size * piece.h_size * piece.w_size;
1141  const index_t local_n = local_m_start / spatial_per_batch;
1142  const index_t local_spatial_flat = local_m_start % spatial_per_batch;
1143 
1144  // Convert to local spatial coordinates
1145  const auto local_coords =
1146  UnflattenSpatial(local_spatial_flat, piece.h_size, piece.w_size);
1147 
1148  // Convert to global spatial coordinates
1149  const index_t global_n = local_n;
1150  const index_t global_d = piece.d_start + local_coords.d;
1151  const index_t global_h = piece.h_start + local_coords.h;
1152  const index_t global_w = piece.w_start + local_coords.w;
1153 
1154  // Convert to global M index
1155  const index_t global_spatial_per_batch = split_info.total_spatial; // Pre-calculated
1156  const index_t global_spatial_flat = FlattenSpatial(
1157  global_d, global_h, global_w, split_info.total_h, split_info.total_w);
1158  const index_t global_m = global_n * global_spatial_per_batch + global_spatial_flat;
1159 
1160  // Set tile indices for GEMM operation
1161  i_m = amd_wave_read_first_lane(global_m);
1162  i_n = amd_wave_read_first_lane(local_tile_n * TilePartitioner::NPerBlock);
1163  }
1164  else
1165  {
1166  // No spatial offsets needed for regular path
1167  a_ptr = base_a_ptr;
1168  c_ptr = base_c_ptr;
1169 
1170  // No split-image: use standard tile partitioning
1171  const auto [iM, iN] =
1172  TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(block_id);
1173  i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
1174  i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
1175  }
1176 
1177  // Use global descriptors for all cases
1178  const auto& a_desc = kargs.a_grid_desc_m_k;
1179  const auto& b_desc = kargs.b_grid_desc_n_k;
1180  const auto& c_desc = kargs.c_grid_desc_m_n;
1181 
1182  // allocate LDS
1183  __shared__ char smem_ptr_0[GetSmemSize()];
1184 
1185  if constexpr(GemmPipeline::DoubleSmemBuffer == true)
1186  {
1187  __shared__ char smem_ptr_1[GetSmemSize()];
1188  if constexpr(!(EpiloguePipeline::MemoryOperation ==
1190  GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
1192  {
1193  RunGemm2LDS(a_ptr,
1194  b_ptr,
1195  ds_ptr_with_offsets,
1196  c_ptr,
1197  smem_ptr_0,
1198  smem_ptr_1,
1199  a_desc,
1200  b_desc,
1201  c_desc,
1202  kargs.GemmK,
1203  i_m,
1204  i_n,
1205  kargs.elfunc);
1206  }
1207  }
1208  else
1209  {
1210  if constexpr(!(EpiloguePipeline::MemoryOperation ==
1212  GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
1214  {
1215  RunGemm(a_ptr,
1216  b_ptr,
1217  ds_ptr_with_offsets,
1218  c_ptr,
1219  smem_ptr_0,
1220  a_desc,
1221  b_desc,
1222  c_desc,
1223  kargs.GemmK,
1224  i_m,
1225  i_n,
1226  kargs.elfunc);
1227  }
1228  }
1229  }
1230  }
1231 };
1232 
1233 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:49
#define CK_TILE_HOST
Definition: config.hpp:48
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:50
Definition: cluster_descriptor.hpp:13
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:145
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
ConvolutionSpecialization
Definition: convolution_specialization.hpp:11
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
CK_TILE_HOST std::string getConvSpecializationString(const ConvolutionSpecialization &s)
Definition: convolution_specialization.hpp:18
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:157
constexpr bool is_same_v
Definition: type.hpp:283
__device__ X atomic_add(X *p_dst, const X &x)
Definition: batched_gemm_kernel.hpp:62
Definition: grouped_convolution_forward_kernel.hpp:415
index_t w_size
Definition: grouped_convolution_forward_kernel.hpp:419
index_t h_start
Definition: grouped_convolution_forward_kernel.hpp:418
index_t w_start
Definition: grouped_convolution_forward_kernel.hpp:418
index_t d_size
Definition: grouped_convolution_forward_kernel.hpp:419
index_t h_size
Definition: grouped_convolution_forward_kernel.hpp:419
index_t block_start
Definition: grouped_convolution_forward_kernel.hpp:416
index_t block_end
Definition: grouped_convolution_forward_kernel.hpp:417
index_t d_start
Definition: grouped_convolution_forward_kernel.hpp:418
Definition: grouped_convolution_forward_kernel.hpp:407
index_t num_d_pieces
Definition: grouped_convolution_forward_kernel.hpp:411
index_t total_w
Definition: grouped_convolution_forward_kernel.hpp:409
index_t total_d
Definition: grouped_convolution_forward_kernel.hpp:409
std::array< PieceInfo, MaxPieces > pieces
Definition: grouped_convolution_forward_kernel.hpp:423
static constexpr index_t MaxPieces
Definition: grouped_convolution_forward_kernel.hpp:422
index_t total_spatial
Definition: grouped_convolution_forward_kernel.hpp:410
index_t num_w_pieces
Definition: grouped_convolution_forward_kernel.hpp:411
index_t total_h
Definition: grouped_convolution_forward_kernel.hpp:409
index_t num_h_pieces
Definition: grouped_convolution_forward_kernel.hpp:411
The Grouped Convolution kernel device arguments.
Definition: grouped_convolution_forward_kernel.hpp:30
long_index_t group_stride_c
Definition: grouped_convolution_forward_kernel.hpp:384
index_t input_batch_stride
Definition: grouped_convolution_forward_kernel.hpp:390
static constexpr index_t NonSpatialDims
Definition: grouped_convolution_forward_kernel.hpp:355
remove_cvref_t< decltype(ConvToGemmFwdTransformer{} .template MakeADescriptor_M_K< typename GroupedConvTraitsType_::InLayout >())> AGridDescMK
Definition: grouped_convolution_forward_kernel.hpp:347
index_t n_per_split
Definition: grouped_convolution_forward_kernel.hpp:388
const CDElementwise elfunc
Definition: grouped_convolution_forward_kernel.hpp:375
AGridDescMK a_grid_desc_m_k
Definition: grouped_convolution_forward_kernel.hpp:378
CGridDescMN CGridDescMN_t
Definition: grouped_convolution_forward_kernel.hpp:403
const void * in_ptr
Definition: grouped_convolution_forward_kernel.hpp:372
index_t GemmM
Definition: grouped_convolution_forward_kernel.hpp:366
remove_cvref_t< decltype(ConvToGemmFwdTransformer{} .template MakeCDescriptor_M_N< typename GroupedConvTraitsType_::OutLayout >())> CGridDescMN
Definition: grouped_convolution_forward_kernel.hpp:353
index_t original_n
Definition: grouped_convolution_forward_kernel.hpp:389
long_index_t group_stride_b
Definition: grouped_convolution_forward_kernel.hpp:383
CGridDescMN c_grid_desc_m_n
Definition: grouped_convolution_forward_kernel.hpp:380
CDElementwise_ CDElementwise
Definition: grouped_convolution_forward_kernel.hpp:39
index_t n_splits
Definition: grouped_convolution_forward_kernel.hpp:387
std::array< const void *, NumDTensor > ds_ptr
Definition: grouped_convolution_forward_kernel.hpp:374
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_left_pads
Definition: grouped_convolution_forward_kernel.hpp:362
AGridDescMK AGridDescMK_t
Definition: grouped_convolution_forward_kernel.hpp:402
const void * wei_ptr
Definition: grouped_convolution_forward_kernel.hpp:373
BGridDescNK b_grid_desc_n_k
Definition: grouped_convolution_forward_kernel.hpp:379
index_t num_spatial_pieces
Definition: grouped_convolution_forward_kernel.hpp:426
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > out_g_n_k_wos_lengths
Definition: grouped_convolution_forward_kernel.hpp:358
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > wei_g_k_c_xs_lengths
Definition: grouped_convolution_forward_kernel.hpp:357
index_t GemmN
Definition: grouped_convolution_forward_kernel.hpp:367
index_t NumGroupsToMerge
Definition: grouped_convolution_forward_kernel.hpp:370
long_index_t spatial_offset_in
Definition: grouped_convolution_forward_kernel.hpp:394
SplitImageInfo split_image
Definition: grouped_convolution_forward_kernel.hpp:427
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs< CDElementwise > &args)
Definition: grouped_convolution_forward_kernel.hpp:54
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_right_pads
Definition: grouped_convolution_forward_kernel.hpp:363
index_t output_batch_stride
Definition: grouped_convolution_forward_kernel.hpp:391
long_index_t group_stride_a
Definition: grouped_convolution_forward_kernel.hpp:382
index_t GemmK
Definition: grouped_convolution_forward_kernel.hpp:368
void * out_ptr
Definition: grouped_convolution_forward_kernel.hpp:376
ConvToGemmFwdTransformer transformer_
Definition: grouped_convolution_forward_kernel.hpp:398
index_t GemmBatch
Definition: grouped_convolution_forward_kernel.hpp:369
long_index_t spatial_offset_out
Definition: grouped_convolution_forward_kernel.hpp:395
TransformConvFwdToGemm< GroupedConvTraitsType_::NDimSpatial, GroupedConvTraitsType_::ConvSpecialization, GroupedConvTraitsType_::VectorSizeA, GroupedConvTraitsType_::VectorSizeB, GroupedConvTraitsType_::VectorSizeC, GroupedConvTraitsType_::NumGroupsToMerge, true > ConvToGemmFwdTransformer
Definition: grouped_convolution_forward_kernel.hpp:38
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > in_g_n_c_wis_lengths
Definition: grouped_convolution_forward_kernel.hpp:356
static constexpr index_t NumDTensor
Definition: grouped_convolution_forward_kernel.hpp:40
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_dilations
Definition: grouped_convolution_forward_kernel.hpp:361
index_t k_batch
Definition: grouped_convolution_forward_kernel.hpp:365
remove_cvref_t< decltype(ConvToGemmFwdTransformer{} .template MakeBDescriptor_N_K< typename GroupedConvTraitsType_::WeiLayout >())> BGridDescNK
Definition: grouped_convolution_forward_kernel.hpp:350
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_strides
Definition: grouped_convolution_forward_kernel.hpp:360
The Grouped Conv kernel host arguments.
Definition: grouped_convolution_utils.hpp:27
InPtr in_ptr
Definition: grouped_convolution_utils.hpp:46
OutPtr out_ptr
Definition: grouped_convolution_utils.hpp:49
WeiPtr wei_ptr
Definition: grouped_convolution_utils.hpp:47
index_t k_batch
Definition: grouped_convolution_utils.hpp:50
const std::vector< const void * > ds_ptr
Definition: grouped_convolution_utils.hpp:48
Definition: grouped_convolution_forward_kernel.hpp:524
index_t h
Definition: grouped_convolution_forward_kernel.hpp:525
index_t d
Definition: grouped_convolution_forward_kernel.hpp:525
index_t w
Definition: grouped_convolution_forward_kernel.hpp:525
The Grouped Convolution Forward kernel template.
Definition: grouped_convolution_forward_kernel.hpp:473
static CK_TILE_DEVICE index_t FindPieceId(index_t block_id, const SplitImageInfo &split_info, index_t num_pieces)
Definition: grouped_convolution_forward_kernel.hpp:570
remove_cvref_t< typename EpiloguePipeline::DsLayout > GemmDsLayout
Definition: grouped_convolution_forward_kernel.hpp:490
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: grouped_convolution_forward_kernel.hpp:479
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_convolution_forward_kernel.hpp:478
typename EpiloguePipeline::CDElementwise CDElementwise
Definition: grouped_convolution_forward_kernel.hpp:501
static constexpr auto I1
Definition: grouped_convolution_forward_kernel.hpp:509
static constexpr auto I2
Definition: grouped_convolution_forward_kernel.hpp:510
static CK_TILE_DEVICE index_t FlattenSpatial(index_t d, index_t h, index_t w, index_t total_h, index_t total_w)
Definition: grouped_convolution_forward_kernel.hpp:551
remove_cvref_t< typename GroupedConvTraitsType_::OutLayout > OutLayout
Definition: grouped_convolution_forward_kernel.hpp:487
GroupedConvFwdKernelArgs< GroupedConvTraitsType_, CDElementwise > GroupedConvFwdKernelArgsSpecialized
Definition: grouped_convolution_forward_kernel.hpp:504
static CK_TILE_DEVICE auto MakeGemmTensorViews(const InDataType *a_ptr, const WeiDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, OutDataType *c_ptr, const ADescType &a_desc, const BDescType &b_desc, const CDescType &c_desc)
Definition: grouped_convolution_forward_kernel.hpp:802
CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized &kargs) const
Definition: grouped_convolution_forward_kernel.hpp:1062
static CK_TILE_HOST const std::string GetTypeString()
Definition: grouped_convolution_forward_kernel.hpp:617
static constexpr auto I0
Definition: grouped_convolution_forward_kernel.hpp:508
static constexpr bool EnableSplitImage
Definition: grouped_convolution_forward_kernel.hpp:474
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: grouped_convolution_forward_kernel.hpp:650
remove_cvref_t< typename GroupedConvTraitsType_::WeiLayout > WeiLayout
Definition: grouped_convolution_forward_kernel.hpp:486
remove_cvref_t< typename EpiloguePipeline::ODataType > OutDataType
Definition: grouped_convolution_forward_kernel.hpp:499
static CK_TILE_DEVICE void RunGemm2LDS(const InDataType *a_ptr, const WeiDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, OutDataType *c_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const ADescType &a_desc, const BDescType &b_desc, const CDescType &c_desc, const index_t gemm_k, const index_t block_idx_m, const index_t block_idx_n, const CDElementwise &elfunc)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_forward_kernel.hpp:998
remove_cvref_t< typename GroupedConvTraitsType_::DsLayout > DsLayout
Definition: grouped_convolution_forward_kernel.hpp:488
static constexpr index_t kBlockSize
Definition: grouped_convolution_forward_kernel.hpp:493
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: grouped_convolution_forward_kernel.hpp:497
CK_TILE_DEVICE void CallExplicitGemm(GroupedConvFwdKernelArgsSpecialized &kargs) const
Definition: grouped_convolution_forward_kernel.hpp:1037
remove_cvref_t< typename GemmPipeline::BLayout > GemmBLayout
Definition: grouped_convolution_forward_kernel.hpp:482
static constexpr index_t NDimSpatial
Definition: grouped_convolution_forward_kernel.hpp:475
static CK_TILE_HOST auto BlockSize()
Definition: grouped_convolution_forward_kernel.hpp:638
static constexpr auto I3
Definition: grouped_convolution_forward_kernel.hpp:511
static CK_TILE_HOST const std::string GetName()
Definition: grouped_convolution_forward_kernel.hpp:593
static CK_TILE_HOST bool IsSupportedArgument(const GroupedConvFwdKernelArgsSpecialized &kargs)
Definition: grouped_convolution_forward_kernel.hpp:655
remove_cvref_t< typename GemmPipeline::BDataType > WeiDataType
Definition: grouped_convolution_forward_kernel.hpp:496
static constexpr CK_TILE_HOST GroupedConvFwdKernelArgsSpecialized MakeKernelArgs(const GroupedConvFwdHostArgs< CDElementwise > &hostArgs)
Definition: grouped_convolution_forward_kernel.hpp:644
static constexpr index_t NumDTensor
Definition: grouped_convolution_forward_kernel.hpp:491
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: grouped_convolution_forward_kernel.hpp:884
remove_cvref_t< typename GemmPipeline::ALayout > GemmALayout
Definition: grouped_convolution_forward_kernel.hpp:481
static constexpr bool IsSplitKSupported
Definition: grouped_convolution_forward_kernel.hpp:506
remove_cvref_t< typename GemmPipeline::CLayout > GemmCLayout
Definition: grouped_convolution_forward_kernel.hpp:483
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: grouped_convolution_forward_kernel.hpp:843
static CK_TILE_DEVICE void RunGemm(const InDataType *a_ptr, const WeiDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, OutDataType *c_ptr, void *smem_ptr_0, const ADescType &a_desc, const BDescType &b_desc, const CDescType &c_desc, const index_t gemm_k, const index_t block_idx_m, const index_t block_idx_n, const CDElementwise &elfunc)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_forward_kernel.hpp:939
remove_cvref_t< typename GroupedConvTraitsType_::InLayout > InLayout
Definition: grouped_convolution_forward_kernel.hpp:485
static constexpr ConvolutionSpecialization ConvSpecialization
Definition: grouped_convolution_forward_kernel.hpp:476
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: grouped_convolution_forward_kernel.hpp:480
static CK_TILE_DEVICE SpatialCoords UnflattenSpatial(index_t flat, index_t h_size, index_t w_size)
Definition: grouped_convolution_forward_kernel.hpp:530
static CK_TILE_HOST auto GridSize(const GroupedConvFwdKernelArgsSpecialized &kargs)
Definition: grouped_convolution_forward_kernel.hpp:632
remove_cvref_t< typename GemmPipeline::ADataType > InDataType
Definition: grouped_convolution_forward_kernel.hpp:495
constexpr CK_TILE_HOST IndexType GetOriginalN() const
Definition: transform_conv_fwd_to_gemm.hpp:264
constexpr CK_TILE_HOST IndexType GetN() const
Definition: transform_conv_fwd_to_gemm.hpp:263
Definition: integral_constant.hpp:13
std::vector< ck_tile::long_index_t > input_spatial_lengths_
Definition: convolution_parameter.hpp:130
ck_tile::long_index_t K_
Definition: convolution_parameter.hpp:126
std::vector< ck_tile::long_index_t > output_spatial_lengths_
Definition: convolution_parameter.hpp:131
std::vector< ck_tile::long_index_t > input_right_pads_
Definition: convolution_parameter.hpp:137
ck_tile::long_index_t G_
Definition: convolution_parameter.hpp:124
std::vector< ck_tile::long_index_t > conv_filter_strides_
Definition: convolution_parameter.hpp:133
std::vector< ck_tile::long_index_t > filter_spatial_lengths_
Definition: convolution_parameter.hpp:129
ck_tile::long_index_t C_
Definition: convolution_parameter.hpp:127
ck_tile::long_index_t N_
Definition: convolution_parameter.hpp:125
std::vector< ck_tile::long_index_t > input_left_pads_
Definition: convolution_parameter.hpp:136
std::vector< ck_tile::long_index_t > conv_filter_dilations_
Definition: convolution_parameter.hpp:134
Definition: type_traits.hpp:115
Definition: sequence.hpp:49
Definition: functional.hpp:43
#define CK_TILE_ENV(name)
Definition: env.hpp:145