/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp Source File
grouped_convolution_forward_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
11 #include "ck_tile/ops/common.hpp"
12 #include "ck_tile/host/concat.hpp"
18 
19 namespace ck_tile {
20 
22 template <typename GroupedConvTraitsType_, typename CDElementwise_>
24 {
25 
27  TransformConvFwdToGemm<GroupedConvTraitsType_::NDimSpatial,
28  GroupedConvTraitsType_::ConvSpecialization,
29  GroupedConvTraitsType_::VectorSizeA,
30  GroupedConvTraitsType_::VectorSizeB,
31  GroupedConvTraitsType_::VectorSizeC,
32  GroupedConvTraitsType_::NumGroupsToMerge,
33  true>; // Split N enabled
34  using CDElementwise = CDElementwise_;
35  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
36 
37  template <
38  typename InLay = typename GroupedConvTraitsType_::InLayout,
39  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
40  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
41  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NWGC> &&
42  std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
43  std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
44  bool>::type = false>
46  : elfunc(args.elfunc)
47  {
48  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
49  static_cast<index_t>(args.N_),
50  static_cast<index_t>(args.C_),
51  static_cast<index_t>(args.input_spatial_lengths_[0])};
52  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
53  static_cast<index_t>(args.K_),
54  static_cast<index_t>(args.C_),
55  static_cast<index_t>(args.filter_spatial_lengths_[0])};
56  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
57  static_cast<index_t>(args.N_),
58  static_cast<index_t>(args.K_),
59  static_cast<index_t>(args.output_spatial_lengths_[0])};
60 
61  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0])};
62  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0])};
63  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0])};
64  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0])};
65 
66  k_batch = args.k_batch;
67 
68  // GemmM will be set after Split-N calculation
69  GemmN = args.K_;
70  GemmK = args.C_ * args.filter_spatial_lengths_[0];
71  GemmBatch = args.G_;
72 
73  in_ptr = args.in_ptr;
74  wei_ptr = args.wei_ptr;
75  for(index_t d = 0; d < NumDTensor; d++)
76  {
77  ds_ptr[d] = args.ds_ptr[d];
78  }
79  out_ptr = args.out_ptr;
80 
81  // Create and STORE transformer (for split-image support)
89 
91  transformer_.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
93  transformer_.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
95  transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
96 
97  group_stride_a = args.C_;
98  group_stride_b = args.K_ * args.C_ *
99  std::accumulate(args.filter_spatial_lengths_.begin(),
100  args.filter_spatial_lengths_.end(),
101  1,
102  std::multiplies<index_t>());
103  group_stride_c = args.K_;
104 
105  // Initialize Split-N support fields for 1D convolution (NWGC layout)
106  // Get the actual split N from transformer
110 
111  // Calculate batch strides using the original argument dimensions.
112  // These are the original dimensions passed to the constructor, not modified by the invoker
113  // yet. (The invoker modifies args after calling MakeKernelArgs.) VERIFIED: G_ MUST be
114  // included - NWGC layout has all groups within each batch
115  input_batch_stride = args.G_ * args.C_ * args.input_spatial_lengths_[0];
116  output_batch_stride = args.G_ * args.K_ * args.output_spatial_lengths_[0];
117 
118  // Update GemmM to use split N (not original N)
120  }
121 
122  template <
123  typename InLay = typename GroupedConvTraitsType_::InLayout,
124  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
125  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
126  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NHWGC> &&
127  std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
128  std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
129  bool>::type = false>
131  : elfunc(args.elfunc)
132  {
133  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
134  static_cast<index_t>(args.N_),
135  static_cast<index_t>(args.C_),
136  static_cast<index_t>(args.input_spatial_lengths_[0]),
137  static_cast<index_t>(args.input_spatial_lengths_[1])};
138  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
139  static_cast<index_t>(args.K_),
140  static_cast<index_t>(args.C_),
141  static_cast<index_t>(args.filter_spatial_lengths_[0]),
142  static_cast<index_t>(args.filter_spatial_lengths_[1])};
143  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
144  static_cast<index_t>(args.N_),
145  static_cast<index_t>(args.K_),
146  static_cast<index_t>(args.output_spatial_lengths_[0]),
147  static_cast<index_t>(args.output_spatial_lengths_[1])};
148 
149  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
150  static_cast<index_t>(args.conv_filter_strides_[1])};
151  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
152  static_cast<index_t>(args.conv_filter_dilations_[1])};
153  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
154  static_cast<index_t>(args.input_left_pads_[1])};
155  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
156  static_cast<index_t>(args.input_right_pads_[1])};
157 
158  k_batch = args.k_batch;
159 
160  // Note: GemmM will be set after Split-N calculation
161  GemmN = args.K_;
162  GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1];
163  GemmBatch = args.G_;
164 
165  in_ptr = args.in_ptr;
166  wei_ptr = args.wei_ptr;
167  for(index_t d = 0; d < NumDTensor; d++)
168  {
169  ds_ptr[d] = args.ds_ptr[d];
170  }
171  out_ptr = args.out_ptr;
172 
173  // Create and STORE transformer (for split-image support)
181 
183  transformer_.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
185  transformer_.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
187  transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
188 
189  group_stride_a = args.C_;
190  group_stride_b = args.K_ * args.C_ *
191  std::accumulate(args.filter_spatial_lengths_.begin(),
192  args.filter_spatial_lengths_.end(),
193  1,
194  std::multiplies<index_t>());
195  group_stride_c = args.K_;
196 
197  // Initialize Split-N support fields for 2D convolution (NHWGC layout)
198  // Get the actual split N from transformer
202 
203  // Calculate batch strides for NHWGC layout
204  // VERIFIED: G_ MUST be included - NHWGC layout has all groups within each batch
206  args.G_ * args.C_ * args.input_spatial_lengths_[0] * args.input_spatial_lengths_[1];
208  args.G_ * args.K_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
209 
210  // Update GemmM to use split N (not original N)
212  }
213 
214  template <
215  typename InLay = typename GroupedConvTraitsType_::InLayout,
216  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
217  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
218  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NDHWGC> &&
219  std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
220  std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
221  bool>::type = false>
223  : elfunc(args.elfunc)
224  {
225  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
226  static_cast<index_t>(args.N_),
227  static_cast<index_t>(args.C_),
228  static_cast<index_t>(args.input_spatial_lengths_[0]),
229  static_cast<index_t>(args.input_spatial_lengths_[1]),
230  static_cast<index_t>(args.input_spatial_lengths_[2])};
231  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
232  static_cast<index_t>(args.K_),
233  static_cast<index_t>(args.C_),
234  static_cast<index_t>(args.filter_spatial_lengths_[0]),
235  static_cast<index_t>(args.filter_spatial_lengths_[1]),
236  static_cast<index_t>(args.filter_spatial_lengths_[2])};
237  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
238  static_cast<index_t>(args.N_),
239  static_cast<index_t>(args.K_),
240  static_cast<index_t>(args.output_spatial_lengths_[0]),
241  static_cast<index_t>(args.output_spatial_lengths_[1]),
242  static_cast<index_t>(args.output_spatial_lengths_[2])};
243 
244  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
245  static_cast<index_t>(args.conv_filter_strides_[1]),
246  static_cast<index_t>(args.conv_filter_strides_[2])};
247  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
248  static_cast<index_t>(args.conv_filter_dilations_[1]),
249  static_cast<index_t>(args.conv_filter_dilations_[2])};
250  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
251  static_cast<index_t>(args.input_left_pads_[1]),
252  static_cast<index_t>(args.input_left_pads_[2])};
253  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
254  static_cast<index_t>(args.input_right_pads_[1]),
255  static_cast<index_t>(args.input_right_pads_[2])};
256 
257  k_batch = args.k_batch;
258 
259  // Note: GemmM will be set after Split-N calculation
260  GemmN = args.K_;
261  GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1] *
262  args.filter_spatial_lengths_[2];
263  GemmBatch = args.G_;
264 
265  in_ptr = args.in_ptr;
266  wei_ptr = args.wei_ptr;
267  for(index_t d = 0; d < NumDTensor; d++)
268  {
269  ds_ptr[d] = args.ds_ptr[d];
270  }
271  out_ptr = args.out_ptr;
272 
273  // Create and STORE transformer (for split-image support)
281 
283  transformer_.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
285  transformer_.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
287  transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
288 
289  group_stride_a = args.C_;
290  group_stride_b = args.K_ * args.C_ *
291  std::accumulate(args.filter_spatial_lengths_.begin(),
292  args.filter_spatial_lengths_.end(),
293  1,
294  std::multiplies<index_t>());
295  group_stride_c = args.K_;
296 
297  // Initialize Split-N support fields for 3D convolution (NDHWGC layout)
298  // Get the actual split N from transformer
302 
303  // Calculate batch strides for NDHWGC layout
304  // VERIFIED: G_ MUST be included - NDHWGC layout has all groups within each batch
305  input_batch_stride = args.G_ * args.C_ * args.input_spatial_lengths_[0] *
307  output_batch_stride = args.G_ * args.K_ * args.output_spatial_lengths_[0] *
309 
310  // Update GemmM to use split N (not original N)
312  args.output_spatial_lengths_[2];
313  }
314 
316  decltype(ConvToGemmFwdTransformer{}
317  .template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>())>;
319  decltype(ConvToGemmFwdTransformer{}
320  .template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>())>;
322  decltype(ConvToGemmFwdTransformer{}
323  .template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>())>;
324 
325  static constexpr index_t NonSpatialDims = 3;
329 
334 
340 
341  const void* in_ptr;
342  const void* wei_ptr;
343  std::array<const void*, NumDTensor> ds_ptr;
345  void* out_ptr;
346 
350 
354 
355  // Split-N support fields - initialize to safe defaults
356  index_t n_splits = 1; // Number of batch splits (e.g., 2 for 128→64×2)
357  index_t n_per_split = 1; // Batches per split (N_ from transformer)
358  index_t original_n = 1; // Original batch size before splitting
359  index_t input_batch_stride = 0; // Stride to next batch in input tensor
360  index_t output_batch_stride = 0; // Stride to next batch in output tensor
361 
362  // Split-image support - spatial offsets (applied per-batch in operator())
363  long_index_t spatial_offset_in = 0; // Spatial offset for input (e.g., W/2 for 1D split)
364  long_index_t spatial_offset_out = 0; // Spatial offset for output (e.g., W/2 for 1D split)
365 
366  // Split-image support - transformer instance
368 
369  // Forward declare descriptor types (will be defined after using declarations)
373 
374  // Split-image support: Common data for all pieces
376  {
377  // Common dimensions (same for all pieces)
378  index_t total_d = 1, total_h = 1, total_w = 1; // Total tensor dimensions
379  index_t total_spatial = 1; // Pre-calculated: total_d * total_h * total_w
380  index_t num_d_pieces = 1, num_h_pieces = 1, num_w_pieces = 1; // Split factors
381 
382  // Minimal per-piece data (only unique values)
383  struct PieceInfo
384  {
385  index_t block_start; // Starting block index for this piece
386  index_t block_end; // Ending block index (exclusive)
387  index_t d_start, h_start, w_start; // Piece starting position in OUTPUT space
388  index_t d_size, h_size, w_size; // Piece size in OUTPUT space
389  };
390 
391  static constexpr index_t MaxPieces = 64; // Max pieces: 4 (1D), 16 (2D), 64 (3D)
392  std::array<PieceInfo, MaxPieces> pieces; // Array of minimal piece descriptors
393  };
394 
395  index_t num_spatial_pieces = 1; // Number of spatial pieces (1 = no split)
396  SplitImageInfo split_image; // Nested structure with common + per-piece data
397 };
398 
437 template <typename GroupedConvTraitsType_,
438  typename TilePartitioner_,
439  typename GemmPipeline_,
440  typename EpiloguePipeline_>
442 {
443  static constexpr bool EnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
444  static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial;
446  GroupedConvTraitsType_::ConvSpecialization;
453 
458 
460  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
461 
462  static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
463 
467  // Below type is actually accumulation data type - the output of block GEMM.
469 
470  using CDElementwise = typename EpiloguePipeline::CDElementwise;
471 
474 
475  static constexpr bool IsSplitKSupported = false;
476 
477  static constexpr auto I0 = number<0>();
478  static constexpr auto I1 = number<1>();
479  static constexpr auto I2 = number<2>();
480  static constexpr auto I3 = number<3>();
481 
482  static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
483  "Not supported!");
484  static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>, "Not supported!");
485  static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
486  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
487 
488  // Helper struct for spatial coordinates
490  {
492  };
493 
494  // Helper: Convert flat spatial index to (d,h,w) coordinates
496  UnflattenSpatial(index_t flat, index_t h_size, index_t w_size)
497  {
498  if constexpr(NDimSpatial == 1)
499  {
500  return SpatialCoords{0, 0, flat};
501  }
502  else if constexpr(NDimSpatial == 2)
503  {
504  return SpatialCoords{0, flat / w_size, flat % w_size};
505  }
506  else // NDimSpatial == 3
507  {
508  const index_t hw = h_size * w_size;
509  const index_t d = flat / hw;
510  const index_t remainder = flat % hw;
511  return SpatialCoords{d, remainder / w_size, remainder % w_size};
512  }
513  }
514 
515  // Helper: Convert (d,h,w) to flat spatial index
516  CK_TILE_DEVICE static index_t
518  {
519  if constexpr(NDimSpatial == 1)
520  {
521  return w;
522  }
523  else if constexpr(NDimSpatial == 2)
524  {
525  return h * total_w + w;
526  }
527  else // NDimSpatial == 3
528  {
529  return (d * total_h + h) * total_w + w;
530  }
531  }
532 
533  // Helper: Find which piece owns a block using binary search
534  template <typename SplitImageInfo>
535  CK_TILE_DEVICE static index_t
536  FindPieceId(index_t block_id, const SplitImageInfo& split_info, index_t num_pieces)
537  {
538  index_t left = 0;
539  index_t right = num_pieces - 1;
540  index_t piece_id = (left + right) / 2;
541 
542  while(!(block_id >= split_info.pieces[piece_id].block_start &&
543  block_id < split_info.pieces[piece_id].block_end) &&
544  left <= right)
545  {
546  if(block_id < split_info.pieces[piece_id].block_start)
547  {
548  right = piece_id - 1;
549  }
550  else
551  {
552  left = piece_id + 1;
553  }
554  piece_id = (left + right) / 2;
555  }
556  return piece_id;
557  }
558 
559  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
560  {
561  // clang-format off
562  return concat('_', "grouped_convolution_forward",
563  gemm_prec_str<InDataType, WeiDataType>(),
564  "gemm",
565  GemmPipeline::GetName(),
566  "epilogue",
567  EpiloguePipeline::GetName());
568  // clang-format on
569  }
570 
572  {
573  return dim3(
574  TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.n_splits);
575  }
576 
577  CK_TILE_HOST static auto BlockSize()
578  {
579  return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
580  }
581 
584  {
585  auto kargs = GroupedConvFwdKernelArgsSpecialized(hostArgs);
586  return kargs;
587  }
588 
590  {
591  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
592  }
593 
595  {
596  if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
599  {
600  if(kargs.k_batch != 1)
601  {
602  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
603  {
604  CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
605  }
606  return false;
607  }
608  }
609 
610  const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
611  const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
612 
613  // check ConvolutionSpecialization
615  {
616  // check if it's 1x1, stride=1 conv
617  for(index_t i = 0; i < NDimSpatial; ++i)
618  {
619  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
620  const index_t ConvStride = kargs.conv_filter_strides[i];
621  const index_t LeftPad = kargs.input_left_pads[i];
622  const index_t RightPad = kargs.input_right_pads[i];
623 
624  if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
625  {
626  return false;
627  }
628  }
629  }
631  {
632  // check if it's 1x1 conv
633  for(index_t i = 0; i < NDimSpatial; ++i)
634  {
635  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
636  const index_t LeftPad = kargs.input_left_pads[i];
637  const index_t RightPad = kargs.input_right_pads[i];
638 
639  if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
640  {
641  return false;
642  }
643  }
644  }
646  {
647  if(ConvC != 1)
648  {
649  return false;
650  }
651  for(index_t i = 0; i < NDimSpatial; ++i)
652  {
653  const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3];
654 
655  if(filter_spatial_dim != I3)
656  {
657  return false;
658  }
659  }
660  }
661 
662  namespace ctc = tensor_layout::convolution;
663 
664  if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
665  std::is_same_v<InLayout, ctc::NDHWGC>)
666  {
667  // Check access per C
668  if(ConvC % GroupedConvTraitsType_::VectorSizeA != 0)
669  {
670  CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
671  return false;
672  }
673  }
674  else
675  {
676  CK_TILE_ERROR("Not supported input layout!");
677  return false;
678  }
679 
680  // check vector access of B
681  // FIXME: layout
682  if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
683  std::is_same_v<WeiLayout, ctc::GKYXC> ||
684  std::is_same_v<WeiLayout, ctc::GKZYXC>)
685  {
686  if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
687  {
688  CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
689  return false;
690  }
691  }
692  else
693  {
694  CK_TILE_ERROR("Not supported weight layout!");
695  return false;
696  }
697 
698  // check vector access of E
699  if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
700  std::is_same_v<OutLayout, ctc::NHWGK> ||
701  std::is_same_v<OutLayout, ctc::NDHWGK>)
702  {
703  if(ConvK % GroupedConvTraitsType_::VectorSizeC != 0)
704  {
705  CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!");
706  return false;
707  }
708  }
709  else
710  {
711  CK_TILE_ERROR("Not supported output layout!");
712  return false;
713  }
714 
715  return true;
716  }
717 
718  template <memory_operation_enum DstInMemOp = memory_operation_enum::set,
719  typename ADescType,
720  typename BDescType,
721  typename CDescType>
722  CK_TILE_DEVICE static auto
724  const WeiDataType* b_ptr,
725  const std::array<const void*, NumDTensor>& ds_ptr,
726  OutDataType* c_ptr,
727  const ADescType& a_desc,
728  const BDescType& b_desc,
729  const CDescType& c_desc)
730  {
731  static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
732  static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
733  const auto& a_tensor_view = [&]() {
734  return make_tensor_view<address_space_enum::global>(a_ptr, a_desc);
735  }();
736 
737  const auto& b_tensor_view = [&]() {
738  return make_tensor_view<address_space_enum::global>(b_ptr, b_desc);
739  }();
740 
741  // TODO: enable vector write for C in ColMajor
742  const auto& c_tensor_view = [&]() {
743  return make_tensor_view<address_space_enum::global>(c_ptr, c_desc);
744  }();
745 
746  const auto& ds_tensor_view = generate_tuple(
747  [&](auto i) {
748  static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
749  "Not supported!");
750  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
751  "Not supported!");
752  static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, OutDataType>,
753  "Not supported!");
754 
755  return make_tensor_view<address_space_enum::global>(
756  static_cast<const OutDataType*>(ds_ptr[i]), c_desc);
757  },
759 
760  return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
761  }
762 
763  template <typename TensorView>
764  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
765  {
766  const auto& a_pad_view = [&]() {
767  const auto& a_tensor_view = views.at(I0);
768  return pad_tensor_view(a_tensor_view,
772  }();
773 
774  const auto& b_pad_view = [&]() {
775  const auto& b_tensor_view = views.at(I1);
776  return pad_tensor_view(b_tensor_view,
780  }();
781 
782  const auto& ds_tensor_view = views.at(I2);
783  const auto& ds_pad_view = generate_tuple(
784  [&](auto i) {
785  return pad_tensor_view(ds_tensor_view[i],
789  },
791 
792  const auto& c_pad_view = [&]() {
793  const auto& c_tensor_view = views.at(I3);
794  return pad_tensor_view(c_tensor_view,
798  }();
799 
800  return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
801  }
802 
803  template <typename PadView>
804  CK_TILE_DEVICE static auto
805  MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
806  {
807  const auto& a_pad_view = views.at(I0);
808  const auto& b_pad_view = views.at(I1);
809  const auto& ds_pad_view = views.at(I2);
810  const auto& c_pad_view = views.at(I3);
811 
812  const auto& a_block_window = [&]() {
813  return make_tile_window(a_pad_view,
816  {i_m, 0});
817  }();
818 
819  const auto& b_block_window = [&]() {
820  return make_tile_window(b_pad_view,
823  {i_n, 0});
824  }();
825 
826  const auto ds_block_window = generate_tuple(
827  [&](auto i) {
828  return make_tile_window(ds_pad_view[i],
831  {i_m, i_n});
832  },
834 
835  auto c_block_window = make_tile_window(
836  c_pad_view,
838  {i_m, i_n});
839 
840  return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
841  }
842 
859  template <typename ADescType, typename BDescType, typename CDescType>
860  CK_TILE_DEVICE static void RunGemm(const InDataType* a_ptr,
861  const WeiDataType* b_ptr,
862  const std::array<const void*, NumDTensor>& ds_ptr,
863  OutDataType* c_ptr,
864  void* smem_ptr_0,
865  const ADescType& a_desc,
866  const BDescType& b_desc,
867  const CDescType& c_desc,
868  const index_t gemm_k,
869  const index_t block_idx_m,
870  const index_t block_idx_n)
871  {
872  // Create Gemm tensor views, pad views and tile windows
873  const auto& gemm_tensor_views_tuple =
874  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
875  a_ptr, b_ptr, ds_ptr, c_ptr, a_desc, b_desc, c_desc);
876 
877  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
878  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
879 
880  const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(gemm_k));
881 
882  // Run GEMM cooperatively by whole workgroup.
883  const auto& a_block_window = gemm_tile_windows.at(I0);
884  const auto& b_block_window = gemm_tile_windows.at(I1);
885  const auto& d_block_window = gemm_tile_windows.at(I2);
886 
887  const auto& c_block_tile = GemmPipeline{}.template operator()(
888  a_block_window, b_block_window, num_loop, smem_ptr_0);
889 
890  // Run Epilogue Pipeline
891  auto& c_block_window = gemm_tile_windows.at(I3);
892 
893  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
894  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
895  }
896 
916  template <typename ADescType, typename BDescType, typename CDescType>
917  CK_TILE_DEVICE static void RunGemm2LDS(const InDataType* a_ptr,
918  const WeiDataType* b_ptr,
919  const std::array<const void*, NumDTensor>& ds_ptr,
920  OutDataType* c_ptr,
921  void* __restrict__ smem_ptr_0,
922  void* __restrict__ smem_ptr_1,
923  const ADescType& a_desc,
924  const BDescType& b_desc,
925  const CDescType& c_desc,
926  const index_t gemm_k,
927  const index_t block_idx_m,
928  const index_t block_idx_n)
929  {
930  // Create Gemm tensor views, pad views and tile windows
931  const auto& gemm_tensor_views_tuple =
932  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
933  a_ptr, b_ptr, ds_ptr, c_ptr, a_desc, b_desc, c_desc);
934  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
935  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
936 
937  const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(gemm_k));
938 
939  // Run GEMM cooperatively by whole workgroup.
940  const auto& a_block_window = gemm_tile_windows.at(I0);
941  const auto& b_block_window = gemm_tile_windows.at(I1);
942  const auto& d_block_window = gemm_tile_windows.at(I2);
943 
944  const auto& c_block_tile = GemmPipeline{}.template operator()(
945  a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
946 
947  // Run Epilogue Pipeline
948  auto& c_block_window = gemm_tile_windows.at(I3);
949 
950  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
951  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
952  }
953 
955  {
956  const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
957  const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
958 
959  const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
960  const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
961  const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
962 
963  // Split-N handling: Get which split this workgroup handles
964  const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z);
965 
966  // Calculate batch offset for this split
967  const index_t batch_offset = amd_wave_read_first_lane(blockIdZ * kargs.n_per_split);
968 
969  // Calculate memory offsets for this split
970  const long_index_t input_batch_offset = static_cast<long_index_t>(batch_offset) *
971  static_cast<long_index_t>(kargs.input_batch_stride);
972  const long_index_t output_batch_offset =
973  static_cast<long_index_t>(batch_offset) *
974  static_cast<long_index_t>(kargs.output_batch_stride);
975 
976  // Calculate base pointers with group and batch offsets
977  const InDataType* base_a_ptr =
978  static_cast<const InDataType*>(kargs.in_ptr) + group_offset_a + input_batch_offset;
979  const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) +
980  group_offset_b; // No batch offset for weights!
981  OutDataType* base_c_ptr =
982  static_cast<OutDataType*>(kargs.out_ptr) + group_offset_c + output_batch_offset;
983 
984  // =====================================================================
985  // Split-image: Map local block to global tile index (if enabled)
986  // =====================================================================
987  const InDataType* a_ptr;
988  OutDataType* c_ptr;
989  index_t i_m = 0;
990  index_t i_n = 0;
991 
992  // Pre-calculate block_id (used in both split-image and non-split paths)
993  const index_t block_id = static_cast<index_t>(blockIdX);
994 
995  if constexpr(EnableSplitImage)
996  {
997  // Add spatial offsets for split-image (constexpr optimization)
998  a_ptr = base_a_ptr + kargs.spatial_offset_in;
999  c_ptr = base_c_ptr + kargs.spatial_offset_out;
1000 
1001  // Find which piece owns this block using binary search
1002  // Reference: device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
1003  const index_t piece_id =
1004  FindPieceId(block_id, kargs.split_image, kargs.num_spatial_pieces);
1005  const auto& piece = kargs.split_image.pieces[piece_id];
1006  const auto& split_info = kargs.split_image;
1007 
1008  // Calculate local block ID and tile indices
1009  const index_t local_block_id = block_id - piece.block_start;
1010  const index_t local_gemm_m =
1011  kargs.n_per_split * piece.d_size * piece.h_size * piece.w_size;
1012  const auto [local_tile_m, local_tile_n] =
1013  TilePartitioner{local_gemm_m, kargs.GemmN}.GetOutputTileIndex(local_block_id);
1014 
1015  // Extract batch and spatial coordinates from local tile
1016  const index_t local_m_start = local_tile_m * TilePartitioner::MPerBlock;
1017  const index_t spatial_per_batch = piece.d_size * piece.h_size * piece.w_size;
1018  const index_t local_n = local_m_start / spatial_per_batch;
1019  const index_t local_spatial_flat = local_m_start % spatial_per_batch;
1020 
1021  // Convert to local spatial coordinates
1022  const auto local_coords =
1023  UnflattenSpatial(local_spatial_flat, piece.h_size, piece.w_size);
1024 
1025  // Convert to global spatial coordinates
1026  const index_t global_n = local_n;
1027  const index_t global_d = piece.d_start + local_coords.d;
1028  const index_t global_h = piece.h_start + local_coords.h;
1029  const index_t global_w = piece.w_start + local_coords.w;
1030 
1031  // Convert to global M index
1032  const index_t global_spatial_per_batch = split_info.total_spatial; // Pre-calculated
1033  const index_t global_spatial_flat = FlattenSpatial(
1034  global_d, global_h, global_w, split_info.total_h, split_info.total_w);
1035  const index_t global_m = global_n * global_spatial_per_batch + global_spatial_flat;
1036 
1037  // Set tile indices for GEMM operation
1038  i_m = amd_wave_read_first_lane(global_m);
1039  i_n = amd_wave_read_first_lane(local_tile_n * TilePartitioner::NPerBlock);
1040  }
1041  else
1042  {
1043  // No spatial offsets needed for regular path
1044  a_ptr = base_a_ptr;
1045  c_ptr = base_c_ptr;
1046 
1047  // No split-image: use standard tile partitioning
1048  const auto [iM, iN] =
1049  TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(block_id);
1050  i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
1051  i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
1052  }
1053 
1054  // Use global descriptors for all cases
1055  const auto& a_desc = kargs.a_grid_desc_m_k;
1056  const auto& b_desc = kargs.b_grid_desc_n_k;
1057  const auto& c_desc = kargs.c_grid_desc_m_n;
1058 
1059  // allocate LDS
1060  __shared__ char smem_ptr_0[GetSmemSize()];
1061 
1062  if constexpr(GemmPipeline::DoubleSmemBuffer == true)
1063  {
1064  __shared__ char smem_ptr_1[GetSmemSize()];
1065  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
1066  GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
1068  {
1069  RunGemm2LDS(a_ptr,
1070  b_ptr,
1071  kargs.ds_ptr,
1072  c_ptr,
1073  smem_ptr_0,
1074  smem_ptr_1,
1075  a_desc,
1076  b_desc,
1077  c_desc,
1078  kargs.GemmK,
1079  i_m,
1080  i_n);
1081  }
1082  }
1083  else
1084  {
1085  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
1086  GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
1088  {
1089  RunGemm(a_ptr,
1090  b_ptr,
1091  kargs.ds_ptr,
1092  c_ptr,
1093  smem_ptr_0,
1094  a_desc,
1095  b_desc,
1096  c_desc,
1097  kargs.GemmK,
1098  i_m,
1099  i_n);
1100  }
1101  }
1102  }
1103 };
1104 
1105 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
ConvolutionSpecialization
Definition: convolution_specialization.hpp:11
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
constexpr bool is_same_v
Definition: type.hpp:283
__device__ X atomic_add(X *p_dst, const X &x)
Definition: grouped_convolution_forward_kernel.hpp:384
index_t w_size
Definition: grouped_convolution_forward_kernel.hpp:388
index_t h_start
Definition: grouped_convolution_forward_kernel.hpp:387
index_t w_start
Definition: grouped_convolution_forward_kernel.hpp:387
index_t d_size
Definition: grouped_convolution_forward_kernel.hpp:388
index_t h_size
Definition: grouped_convolution_forward_kernel.hpp:388
index_t block_start
Definition: grouped_convolution_forward_kernel.hpp:385
index_t block_end
Definition: grouped_convolution_forward_kernel.hpp:386
index_t d_start
Definition: grouped_convolution_forward_kernel.hpp:387
Definition: grouped_convolution_forward_kernel.hpp:376
index_t num_d_pieces
Definition: grouped_convolution_forward_kernel.hpp:380
index_t total_w
Definition: grouped_convolution_forward_kernel.hpp:378
index_t total_d
Definition: grouped_convolution_forward_kernel.hpp:378
std::array< PieceInfo, MaxPieces > pieces
Definition: grouped_convolution_forward_kernel.hpp:392
static constexpr index_t MaxPieces
Definition: grouped_convolution_forward_kernel.hpp:391
index_t total_spatial
Definition: grouped_convolution_forward_kernel.hpp:379
index_t num_w_pieces
Definition: grouped_convolution_forward_kernel.hpp:380
index_t total_h
Definition: grouped_convolution_forward_kernel.hpp:378
index_t num_h_pieces
Definition: grouped_convolution_forward_kernel.hpp:380
The Grouped Convolution kernel device arguments.
Definition: grouped_convolution_forward_kernel.hpp:24
long_index_t group_stride_c
Definition: grouped_convolution_forward_kernel.hpp:353
index_t input_batch_stride
Definition: grouped_convolution_forward_kernel.hpp:359
static constexpr index_t NonSpatialDims
Definition: grouped_convolution_forward_kernel.hpp:325
remove_cvref_t< decltype(ConvToGemmFwdTransformer{} .template MakeADescriptor_M_K< typename GroupedConvTraitsType_::InLayout >())> AGridDescMK
Definition: grouped_convolution_forward_kernel.hpp:317
index_t n_per_split
Definition: grouped_convolution_forward_kernel.hpp:357
const CDElementwise elfunc
Definition: grouped_convolution_forward_kernel.hpp:344
AGridDescMK a_grid_desc_m_k
Definition: grouped_convolution_forward_kernel.hpp:347
CGridDescMN CGridDescMN_t
Definition: grouped_convolution_forward_kernel.hpp:372
const void * in_ptr
Definition: grouped_convolution_forward_kernel.hpp:341
index_t GemmM
Definition: grouped_convolution_forward_kernel.hpp:336
remove_cvref_t< decltype(ConvToGemmFwdTransformer{} .template MakeCDescriptor_M_N< typename GroupedConvTraitsType_::OutLayout >())> CGridDescMN
Definition: grouped_convolution_forward_kernel.hpp:323
index_t original_n
Definition: grouped_convolution_forward_kernel.hpp:358
long_index_t group_stride_b
Definition: grouped_convolution_forward_kernel.hpp:352
CGridDescMN c_grid_desc_m_n
Definition: grouped_convolution_forward_kernel.hpp:349
CDElementwise_ CDElementwise
Definition: grouped_convolution_forward_kernel.hpp:34
index_t n_splits
Definition: grouped_convolution_forward_kernel.hpp:356
std::array< const void *, NumDTensor > ds_ptr
Definition: grouped_convolution_forward_kernel.hpp:343
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_left_pads
Definition: grouped_convolution_forward_kernel.hpp:332
AGridDescMK AGridDescMK_t
Definition: grouped_convolution_forward_kernel.hpp:371
const void * wei_ptr
Definition: grouped_convolution_forward_kernel.hpp:342
BGridDescNK b_grid_desc_n_k
Definition: grouped_convolution_forward_kernel.hpp:348
index_t num_spatial_pieces
Definition: grouped_convolution_forward_kernel.hpp:395
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > out_g_n_k_wos_lengths
Definition: grouped_convolution_forward_kernel.hpp:328
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > wei_g_k_c_xs_lengths
Definition: grouped_convolution_forward_kernel.hpp:327
index_t GemmN
Definition: grouped_convolution_forward_kernel.hpp:337
long_index_t spatial_offset_in
Definition: grouped_convolution_forward_kernel.hpp:363
SplitImageInfo split_image
Definition: grouped_convolution_forward_kernel.hpp:396
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs< CDElementwise > &args)
Definition: grouped_convolution_forward_kernel.hpp:45
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_right_pads
Definition: grouped_convolution_forward_kernel.hpp:333
index_t output_batch_stride
Definition: grouped_convolution_forward_kernel.hpp:360
long_index_t group_stride_a
Definition: grouped_convolution_forward_kernel.hpp:351
index_t GemmK
Definition: grouped_convolution_forward_kernel.hpp:338
void * out_ptr
Definition: grouped_convolution_forward_kernel.hpp:345
ConvToGemmFwdTransformer transformer_
Definition: grouped_convolution_forward_kernel.hpp:367
index_t GemmBatch
Definition: grouped_convolution_forward_kernel.hpp:339
long_index_t spatial_offset_out
Definition: grouped_convolution_forward_kernel.hpp:364
TransformConvFwdToGemm< GroupedConvTraitsType_::NDimSpatial, GroupedConvTraitsType_::ConvSpecialization, GroupedConvTraitsType_::VectorSizeA, GroupedConvTraitsType_::VectorSizeB, GroupedConvTraitsType_::VectorSizeC, GroupedConvTraitsType_::NumGroupsToMerge, true > ConvToGemmFwdTransformer
Definition: grouped_convolution_forward_kernel.hpp:33
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > in_g_n_c_wis_lengths
Definition: grouped_convolution_forward_kernel.hpp:326
static constexpr index_t NumDTensor
Definition: grouped_convolution_forward_kernel.hpp:35
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_dilations
Definition: grouped_convolution_forward_kernel.hpp:331
index_t k_batch
Definition: grouped_convolution_forward_kernel.hpp:335
remove_cvref_t< decltype(ConvToGemmFwdTransformer{} .template MakeBDescriptor_N_K< typename GroupedConvTraitsType_::WeiLayout >())> BGridDescNK
Definition: grouped_convolution_forward_kernel.hpp:320
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_strides
Definition: grouped_convolution_forward_kernel.hpp:330
The Grouped Conv kernel host arguments.
Definition: grouped_convolution_utils.hpp:20
InPtr in_ptr
Definition: grouped_convolution_utils.hpp:39
OutPtr out_ptr
Definition: grouped_convolution_utils.hpp:42
WeiPtr wei_ptr
Definition: grouped_convolution_utils.hpp:40
index_t k_batch
Definition: grouped_convolution_utils.hpp:43
const std::vector< const void * > ds_ptr
Definition: grouped_convolution_utils.hpp:41
Definition: grouped_convolution_forward_kernel.hpp:490
index_t h
Definition: grouped_convolution_forward_kernel.hpp:491
index_t d
Definition: grouped_convolution_forward_kernel.hpp:491
index_t w
Definition: grouped_convolution_forward_kernel.hpp:491
The Grouped Convolution Forward kernel template.
Definition: grouped_convolution_forward_kernel.hpp:442
static CK_TILE_DEVICE index_t FindPieceId(index_t block_id, const SplitImageInfo &split_info, index_t num_pieces)
Definition: grouped_convolution_forward_kernel.hpp:536
remove_cvref_t< typename EpiloguePipeline::DsLayout > GemmDsLayout
Definition: grouped_convolution_forward_kernel.hpp:459
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: grouped_convolution_forward_kernel.hpp:448
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_convolution_forward_kernel.hpp:447
typename EpiloguePipeline::CDElementwise CDElementwise
Definition: grouped_convolution_forward_kernel.hpp:470
static constexpr auto I1
Definition: grouped_convolution_forward_kernel.hpp:478
static constexpr auto I2
Definition: grouped_convolution_forward_kernel.hpp:479
static CK_TILE_DEVICE index_t FlattenSpatial(index_t d, index_t h, index_t w, index_t total_h, index_t total_w)
Definition: grouped_convolution_forward_kernel.hpp:517
static CK_TILE_DEVICE void RunGemm(const InDataType *a_ptr, const WeiDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, OutDataType *c_ptr, void *smem_ptr_0, const ADescType &a_desc, const BDescType &b_desc, const CDescType &c_desc, const index_t gemm_k, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_forward_kernel.hpp:860
remove_cvref_t< typename GroupedConvTraitsType_::OutLayout > OutLayout
Definition: grouped_convolution_forward_kernel.hpp:456
GroupedConvFwdKernelArgs< GroupedConvTraitsType_, CDElementwise > GroupedConvFwdKernelArgsSpecialized
Definition: grouped_convolution_forward_kernel.hpp:473
static CK_TILE_DEVICE auto MakeGemmTensorViews(const InDataType *a_ptr, const WeiDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, OutDataType *c_ptr, const ADescType &a_desc, const BDescType &b_desc, const CDescType &c_desc)
Definition: grouped_convolution_forward_kernel.hpp:723
static constexpr auto I0
Definition: grouped_convolution_forward_kernel.hpp:477
CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized kargs) const
Definition: grouped_convolution_forward_kernel.hpp:954
static constexpr bool EnableSplitImage
Definition: grouped_convolution_forward_kernel.hpp:443
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: grouped_convolution_forward_kernel.hpp:589
remove_cvref_t< typename GroupedConvTraitsType_::WeiLayout > WeiLayout
Definition: grouped_convolution_forward_kernel.hpp:455
remove_cvref_t< typename EpiloguePipeline::ODataType > OutDataType
Definition: grouped_convolution_forward_kernel.hpp:468
remove_cvref_t< typename GroupedConvTraitsType_::DsLayout > DsLayout
Definition: grouped_convolution_forward_kernel.hpp:457
static constexpr index_t kBlockSize
Definition: grouped_convolution_forward_kernel.hpp:462
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: grouped_convolution_forward_kernel.hpp:466
remove_cvref_t< typename GemmPipeline::BLayout > GemmBLayout
Definition: grouped_convolution_forward_kernel.hpp:451
static constexpr index_t NDimSpatial
Definition: grouped_convolution_forward_kernel.hpp:444
static CK_TILE_HOST auto BlockSize()
Definition: grouped_convolution_forward_kernel.hpp:577
static constexpr auto I3
Definition: grouped_convolution_forward_kernel.hpp:480
static CK_TILE_HOST const std::string GetName()
Definition: grouped_convolution_forward_kernel.hpp:559
static CK_TILE_HOST bool IsSupportedArgument(const GroupedConvFwdKernelArgsSpecialized &kargs)
Definition: grouped_convolution_forward_kernel.hpp:594
remove_cvref_t< typename GemmPipeline::BDataType > WeiDataType
Definition: grouped_convolution_forward_kernel.hpp:465
static constexpr CK_TILE_HOST GroupedConvFwdKernelArgsSpecialized MakeKernelArgs(const GroupedConvFwdHostArgs< CDElementwise > &hostArgs)
Definition: grouped_convolution_forward_kernel.hpp:583
static constexpr index_t NumDTensor
Definition: grouped_convolution_forward_kernel.hpp:460
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: grouped_convolution_forward_kernel.hpp:805
remove_cvref_t< typename GemmPipeline::ALayout > GemmALayout
Definition: grouped_convolution_forward_kernel.hpp:450
static constexpr bool IsSplitKSupported
Definition: grouped_convolution_forward_kernel.hpp:475
remove_cvref_t< typename GemmPipeline::CLayout > GemmCLayout
Definition: grouped_convolution_forward_kernel.hpp:452
static CK_TILE_DEVICE void RunGemm2LDS(const InDataType *a_ptr, const WeiDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, OutDataType *c_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const ADescType &a_desc, const BDescType &b_desc, const CDescType &c_desc, const index_t gemm_k, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_forward_kernel.hpp:917
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: grouped_convolution_forward_kernel.hpp:764
remove_cvref_t< typename GroupedConvTraitsType_::InLayout > InLayout
Definition: grouped_convolution_forward_kernel.hpp:454
static constexpr ConvolutionSpecialization ConvSpecialization
Definition: grouped_convolution_forward_kernel.hpp:445
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: grouped_convolution_forward_kernel.hpp:449
static CK_TILE_DEVICE SpatialCoords UnflattenSpatial(index_t flat, index_t h_size, index_t w_size)
Definition: grouped_convolution_forward_kernel.hpp:496
static CK_TILE_HOST auto GridSize(const GroupedConvFwdKernelArgsSpecialized &kargs)
Definition: grouped_convolution_forward_kernel.hpp:571
remove_cvref_t< typename GemmPipeline::ADataType > InDataType
Definition: grouped_convolution_forward_kernel.hpp:464
constexpr CK_TILE_HOST IndexType GetOriginalN() const
Definition: transform_conv_fwd_to_gemm.hpp:265
constexpr CK_TILE_HOST IndexType GetN() const
Definition: transform_conv_fwd_to_gemm.hpp:264
Definition: integral_constant.hpp:13
std::vector< ck_tile::long_index_t > input_spatial_lengths_
Definition: convolution_parameter.hpp:130
ck_tile::long_index_t K_
Definition: convolution_parameter.hpp:126
std::vector< ck_tile::long_index_t > output_spatial_lengths_
Definition: convolution_parameter.hpp:131
std::vector< ck_tile::long_index_t > input_right_pads_
Definition: convolution_parameter.hpp:137
ck_tile::long_index_t G_
Definition: convolution_parameter.hpp:124
std::vector< ck_tile::long_index_t > conv_filter_strides_
Definition: convolution_parameter.hpp:133
std::vector< ck_tile::long_index_t > filter_spatial_lengths_
Definition: convolution_parameter.hpp:129
ck_tile::long_index_t C_
Definition: convolution_parameter.hpp:127
ck_tile::long_index_t N_
Definition: convolution_parameter.hpp:125
std::vector< ck_tile::long_index_t > input_left_pads_
Definition: convolution_parameter.hpp:136
std::vector< ck_tile::long_index_t > conv_filter_dilations_
Definition: convolution_parameter.hpp:134
Definition: type_traits.hpp:115
Definition: sequence.hpp:49
#define CK_TILE_ENV(name)
Definition: env.hpp:145