/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp Source File
transform_conv_bwd_weight_to_gemm.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
12 
13 namespace ck {
14 namespace tensor_operation {
15 
16 template <index_t NDimSpatial,
17  index_t MPerBlock,
18  index_t NPerBlock,
19  index_t GemmK1Number,
20  index_t K0PerBlock,
21  device::ConvolutionBackwardWeightSpecialization ConvBackwardWeightSpecialization>
23 {
24  static constexpr auto I0 = Number<0>{};
25  static constexpr auto I1 = Number<1>{};
26 
27  template <index_t NDim, typename enable_if<NDim == 2, bool>::type = false>
28  constexpr static auto
30  const index_t Ho,
31  const index_t Wo,
32  const index_t K,
33  const std::array<index_t, NDimSpatial + 3>& output_strides)
34  {
35  const index_t WoStride = output_strides[4];
36  const auto KStride = Number<1>{};
37  return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K),
38  make_tuple(WoStride, KStride));
39  }
40 
41  template <index_t NDim, typename enable_if<NDim == 2, bool>::type = false>
42  constexpr static auto
44  const index_t Hi,
45  const index_t Wi,
46  const index_t C,
47  const std::array<index_t, NDimSpatial + 3>& input_strides)
48  {
49  const index_t NStride = input_strides[1];
50  const index_t HiStride = input_strides[3];
51  const index_t WiStride = input_strides[4];
52  const auto CStride = input_strides[2];
53  if constexpr(ConvBackwardWeightSpecialization ==
55  {
56  return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, C),
57  make_tuple(WiStride, CStride));
58  }
59  else
60  {
61  return make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C),
62  make_tuple(NStride, HiStride, WiStride, CStride));
63  }
64  }
65 
66  template <index_t NDim, typename enable_if<NDim == 2, bool>::type = false>
67  constexpr static auto
69  const index_t Y,
70  const index_t X,
71  const index_t C,
72  const std::array<index_t, NDimSpatial + 3>& weights_strides)
73  {
74  const auto CStride = Number<1>{};
75  const auto KStride = weights_strides[1];
76  return make_naive_tensor_descriptor(make_tuple(K, Y * X * C), make_tuple(KStride, CStride));
77  }
78 
79  template <index_t NDim, typename enable_if<NDim == 3, bool>::type = false>
80  constexpr static auto
82  const index_t Do,
83  const index_t Ho,
84  const index_t Wo,
85  const index_t K,
86  const std::array<index_t, NDimSpatial + 3>& output_strides)
87  {
88  const index_t WoStride = output_strides[5];
89  const auto KStride = Number<1>{};
90  return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K),
91  make_tuple(WoStride, KStride));
92  }
93 
94  template <index_t NDim, typename enable_if<NDim == 3, bool>::type = false>
95  constexpr static auto
97  const index_t Di,
98  const index_t Hi,
99  const index_t Wi,
100  const index_t C,
101  const std::array<index_t, NDimSpatial + 3>& input_strides)
102  {
103  const index_t NStride = input_strides[1];
104  const index_t DiStride = input_strides[3];
105  const index_t HiStride = input_strides[4];
106  const index_t WiStride = input_strides[5];
107  const auto CStride = input_strides[2];
108  if constexpr(ConvBackwardWeightSpecialization ==
110  {
111  return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, C),
112  make_tuple(WiStride, CStride));
113  }
114  else
115  {
117  make_tuple(N, Di, Hi, Wi, C),
118  make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
119  }
120  }
121 
122  template <index_t NDim, typename enable_if<NDim == 3, bool>::type = false>
123  constexpr static auto
125  const index_t Z,
126  const index_t Y,
127  const index_t X,
128  const index_t C,
129  const std::array<index_t, NDimSpatial + 3>& weights_strides)
130  {
131  const auto CStride = Number<1>{};
132  const auto KStride = weights_strides[1];
133  return make_naive_tensor_descriptor(make_tuple(K, Z * Y * X * C),
134  make_tuple(KStride, CStride));
135  }
136 
137  template <index_t NDim, typename enable_if<NDim == 1, bool>::type = false>
139  const index_t N,
140  const index_t K,
141  const index_t C,
142  const std::array<index_t, NDimSpatial>& input_spatial_lengths,
143  const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
144  const std::array<index_t, NDimSpatial>& output_spatial_lengths,
145  const std::array<index_t, NDimSpatial + 3>& /* input_strides */,
146  const std::array<index_t, NDimSpatial + 3>& /* weights_strides */,
147  const std::array<index_t, NDimSpatial + 3>& /* output_strides */,
148  const std::array<index_t, NDimSpatial>& conv_filter_strides,
149  const std::array<index_t, NDimSpatial>& conv_filter_dilations,
150  const std::array<index_t, NDimSpatial>& input_left_pads,
151  const std::array<index_t, NDimSpatial>& input_right_pads,
152  const index_t batch_k)
153  {
154  using namespace ck;
155 
156  const index_t Wi = input_spatial_lengths[0];
157  const index_t Wo = output_spatial_lengths[0];
158  const index_t X = filter_spatial_lengths[0];
159  const index_t ConvStrideW = conv_filter_strides[0];
160  const index_t ConvDilationW = conv_filter_dilations[0];
161  const index_t InLeftPadW = input_left_pads[0];
162  const index_t InRightPadW = input_right_pads[0];
163 
164  const index_t GemmKTotal = N * Wo;
165  const index_t GemmM = K;
166  const index_t GemmN = C * X;
167 
168  const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
169  const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
170 
171  const index_t GemmKBatch = batch_k;
172  const index_t GemmK0 =
173  math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
174  K0PerBlock;
175  const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
176 
177  if constexpr(ConvBackwardWeightSpecialization ==
179  {
180  // A: output tensor
181  const auto out_gemmktotal_gemmm_grid_desc =
183 
184  const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
185  out_gemmktotal_gemmm_grid_desc,
186  make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
190 
191  const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
192  out_gemmkpad_gemmm_grid_desc,
193  make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
194  make_right_pad_transform(GemmM, PadGemmM)),
197 
198  // B: input tensor
199  const auto in_gemmktotal_gemmn_grid_desc =
201 
202  const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
203  in_gemmktotal_gemmn_grid_desc,
204  make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
208 
209  const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
210  in_gemmkpad_gemmn_grid_desc,
211  make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
212  make_right_pad_transform(GemmN, PadGemmN)),
215 
216  // C: weight tensor
217  const auto wei_gemmm_gemmn_grid_desc =
219 
220  // Padd
221  const auto wei_gemmm_gemmn_pad_grid_desc =
222  transform_tensor_descriptor(wei_gemmm_gemmn_grid_desc,
223  make_tuple(make_right_pad_transform(GemmM, PadGemmM),
224  make_right_pad_transform(GemmN, PadGemmN)),
227 
228  return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
229  in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
230  wei_gemmm_gemmn_pad_grid_desc);
231  }
232  else
233  {
234  const auto out_gemmktotal_gemmm_grid_desc =
236  const auto in_n_wi_c_grid_desc =
238 
239  // A: output tensor
240  const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
241  out_gemmktotal_gemmm_grid_desc,
242  make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
246 
247  const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
248  out_gemmkpad_gemmm_grid_desc,
249  make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
250  make_right_pad_transform(GemmM, PadGemmM)),
253 
254  // B: input tensor
255  const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
256  in_n_wi_c_grid_desc,
258  make_pad_transform(Wi, InLeftPadW, InRightPadW),
262 
263  const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
264  in_n_wip_c_grid_desc,
265  make_tuple(
267  make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
271 
272  const auto in_gemmktotal_gemmn_grid_desc =
273  transform_tensor_descriptor(in_n_x_wo_c_grid_desc,
278 
279  const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
280  in_gemmktotal_gemmn_grid_desc,
281  make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
285 
286  const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
287  in_gemmkpad_gemmn_grid_desc,
288  make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
289  make_right_pad_transform(GemmN, PadGemmN)),
292 
293  // C: weight tensor
294  const auto wei_gemmm_gemmn_grid_desc =
296 
297  // Padd
298  const auto wei_gemmm_gemmn_pad_grid_desc =
299  transform_tensor_descriptor(wei_gemmm_gemmn_grid_desc,
300  make_tuple(make_right_pad_transform(GemmM, PadGemmM),
301  make_right_pad_transform(GemmN, PadGemmN)),
304 
305  return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
306  in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
307  wei_gemmm_gemmn_pad_grid_desc);
308  }
309  }
310 
311  template <index_t NDim, typename enable_if<NDim == 2, bool>::type = false>
313  const index_t N,
314  const index_t K,
315  const index_t C,
316  const std::array<index_t, NDimSpatial>& input_spatial_lengths,
317  const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
318  const std::array<index_t, NDimSpatial>& output_spatial_lengths,
319  const std::array<index_t, NDimSpatial + 3>& input_strides,
320  const std::array<index_t, NDimSpatial + 3>& weights_strides,
321  const std::array<index_t, NDimSpatial + 3>& output_strides,
322  const std::array<index_t, NDimSpatial>& conv_filter_strides,
323  const std::array<index_t, NDimSpatial>& conv_filter_dilations,
324  const std::array<index_t, NDimSpatial>& input_left_pads,
325  const std::array<index_t, NDimSpatial>& input_right_pads,
326  const index_t batch_k)
327  {
328  using namespace ck;
329 
330  const index_t Hi = input_spatial_lengths[0];
331  const index_t Wi = input_spatial_lengths[1];
332 
333  const index_t Ho = output_spatial_lengths[0];
334  const index_t Wo = output_spatial_lengths[1];
335 
336  const index_t Y = filter_spatial_lengths[0];
337  const index_t X = filter_spatial_lengths[1];
338 
339  const index_t ConvStrideH = conv_filter_strides[0];
340  const index_t ConvStrideW = conv_filter_strides[1];
341 
342  const index_t ConvDilationH = conv_filter_dilations[0];
343  const index_t ConvDilationW = conv_filter_dilations[1];
344 
345  const index_t InLeftPadH = input_left_pads[0];
346  const index_t InLeftPadW = input_left_pads[1];
347 
348  const index_t InRightPadH = input_right_pads[0];
349  const index_t InRightPadW = input_right_pads[1];
350 
351  const index_t GemmKTotal = N * Ho * Wo;
352  const index_t GemmM = K;
353  const index_t GemmN = C * X * Y;
354 
355  const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
356  const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
357 
358  const index_t GemmKBatch = batch_k;
359  const index_t GemmK0 =
360  math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
361  K0PerBlock;
362  const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
363 
364  const auto out_grid_desc = make_out_grid_desc<NDim>(N, Ho, Wo, K, output_strides);
365  const auto in_grid_desc = make_in_grid_desc<NDim>(N, Hi, Wi, C, input_strides);
366  const auto wei_grid_desc = make_wei_grid_desc<NDim>(K, Y, X, C, weights_strides);
367 
368  if constexpr(ConvBackwardWeightSpecialization ==
370  {
371  // A: output tensor
372  const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
373  out_grid_desc,
374  make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
378 
379  const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
380  out_gemmkpad_gemmm_grid_desc,
381  make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
382  make_right_pad_transform(GemmM, PadGemmM)),
385 
386  // B: input tensor
387  const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
388  in_grid_desc,
389  make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
393 
394  const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
395  in_gemmkpad_gemmn_grid_desc,
396  make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
397  make_right_pad_transform(GemmN, PadGemmN)),
400 
401  // Padd
402  const auto wei_gemmm_gemmn_pad_grid_desc =
403  transform_tensor_descriptor(wei_grid_desc,
404  make_tuple(make_right_pad_transform(GemmM, PadGemmM),
405  make_right_pad_transform(GemmN, PadGemmN)),
408 
409  return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
410  in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
411  wei_gemmm_gemmn_pad_grid_desc);
412  }
413  else
414  {
415  // A: output tensor
416  const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
417  out_grid_desc,
418  make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
422 
423  const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
424  out_gemmkpad_gemmm_grid_desc,
425  make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
426  make_right_pad_transform(GemmM, PadGemmM)),
429 
430  // B: input tensor
431  const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
432  in_grid_desc,
434  make_pad_transform(Hi, InLeftPadH, InRightPadH),
435  make_pad_transform(Wi, InLeftPadW, InRightPadW),
439 
440  const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
441  in_n_hip_wip_c_grid_desc,
442  make_tuple(
444  make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
445  make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
449 
450  const auto in_gemmktotal_gemmn_grid_desc =
451  transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
453  make_merge_transform(make_tuple(N, Ho, Wo))),
456 
457  const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
458  in_gemmktotal_gemmn_grid_desc,
459  make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
463 
464  const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
465  in_gemmkpad_gemmn_grid_desc,
466  make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
467  make_right_pad_transform(GemmN, PadGemmN)),
470 
471  // Padd
472  const auto wei_gemmm_gemmn_pad_grid_desc =
473  transform_tensor_descriptor(wei_grid_desc,
474  make_tuple(make_right_pad_transform(GemmM, PadGemmM),
475  make_right_pad_transform(GemmN, PadGemmN)),
478 
479  return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
480  in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
481  wei_gemmm_gemmn_pad_grid_desc);
482  }
483  }
484 
485  template <index_t NDim, typename enable_if<NDim == 3, bool>::type = false>
487  const index_t N,
488  const index_t K,
489  const index_t C,
490  const std::array<index_t, NDimSpatial>& input_spatial_lengths,
491  const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
492  const std::array<index_t, NDimSpatial>& output_spatial_lengths,
493  const std::array<index_t, NDimSpatial + 3>& input_strides,
494  const std::array<index_t, NDimSpatial + 3>& weights_strides,
495  const std::array<index_t, NDimSpatial + 3>& output_strides,
496  const std::array<index_t, NDimSpatial>& conv_filter_strides,
497  const std::array<index_t, NDimSpatial>& conv_filter_dilations,
498  const std::array<index_t, NDimSpatial>& input_left_pads,
499  const std::array<index_t, NDimSpatial>& input_right_pads,
500  const index_t batch_k)
501  {
502  using namespace ck;
503 
504  const index_t Di = input_spatial_lengths[0];
505  const index_t Hi = input_spatial_lengths[1];
506  const index_t Wi = input_spatial_lengths[2];
507 
508  const index_t Do = output_spatial_lengths[0];
509  const index_t Ho = output_spatial_lengths[1];
510  const index_t Wo = output_spatial_lengths[2];
511 
512  const index_t Z = filter_spatial_lengths[0];
513  const index_t Y = filter_spatial_lengths[1];
514  const index_t X = filter_spatial_lengths[2];
515 
516  const index_t ConvStrideD = conv_filter_strides[0];
517  const index_t ConvStrideH = conv_filter_strides[1];
518  const index_t ConvStrideW = conv_filter_strides[2];
519 
520  const index_t ConvDilationD = conv_filter_dilations[0];
521  const index_t ConvDilationH = conv_filter_dilations[1];
522  const index_t ConvDilationW = conv_filter_dilations[2];
523 
524  const index_t InLeftPadD = input_left_pads[0];
525  const index_t InLeftPadH = input_left_pads[1];
526  const index_t InLeftPadW = input_left_pads[2];
527 
528  const index_t InRightPadD = input_right_pads[0];
529  const index_t InRightPadH = input_right_pads[1];
530  const index_t InRightPadW = input_right_pads[2];
531 
532  const index_t GemmKTotal = N * Do * Ho * Wo;
533  const index_t GemmM = K;
534  const index_t GemmN = C * Z * X * Y;
535 
536  const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
537  const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
538 
539  const index_t GemmKBatch = batch_k;
540  const index_t GemmK0 =
541  math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
542  K0PerBlock;
543  const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
544 
545  const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides);
546  const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides);
547  const auto wei_grid_desc = make_wei_grid_desc<NDim>(K, Z, Y, X, C, weights_strides);
548 
549  if constexpr(ConvBackwardWeightSpecialization ==
551  {
552  // A: output tensor
553  const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
554  out_grid_desc,
555  make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
559 
560  const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
561  out_gemmkpad_gemmm_grid_desc,
562  make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
563  make_right_pad_transform(GemmM, PadGemmM)),
566 
567  // B: input tensor
568  const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
569  in_grid_desc,
570  make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
574 
575  const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
576  in_gemmkpad_gemmn_grid_desc,
577  make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
578  make_right_pad_transform(GemmN, PadGemmN)),
581 
582  // Padd
583  const auto wei_gemmm_gemmn_pad_grid_desc =
584  transform_tensor_descriptor(wei_grid_desc,
585  make_tuple(make_right_pad_transform(GemmM, PadGemmM),
586  make_right_pad_transform(GemmN, PadGemmN)),
589 
590  return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
591  in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
592  wei_gemmm_gemmn_pad_grid_desc);
593  }
594  else
595  {
596  // A: output tensor
597  const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
598  out_grid_desc,
599  make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
603 
604  const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
605  out_gemmkpad_gemmm_grid_desc,
606  make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
607  make_right_pad_transform(GemmM, PadGemmM)),
610 
611  // B: input tensor
612  const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
613  in_grid_desc,
615  make_pad_transform(Di, InLeftPadD, InRightPadD),
616  make_pad_transform(Hi, InLeftPadH, InRightPadH),
617  make_pad_transform(Wi, InLeftPadW, InRightPadW),
619  make_tuple(
621  make_tuple(
623 
624  const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
625  in_n_dip_hip_wip_c_grid_desc,
626  make_tuple(
628  make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
629  make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
630  make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
632  make_tuple(
635  Sequence<1, 2>{},
636  Sequence<3, 4>{},
637  Sequence<5, 6>{},
638  Sequence<7>{}));
639 
640  const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor(
641  in_n_z_do_y_ho_x_wo_c_grid_desc,
643  make_merge_transform(make_tuple(N, Do, Ho, Wo))),
646 
647  const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
648  in_gemmktotal_gemmn_grid_desc,
649  make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
653 
654  const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
655  in_gemmkpad_gemmn_grid_desc,
656  make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
657  make_right_pad_transform(GemmN, PadGemmN)),
660 
661  // Padd
662  const auto wei_gemmm_gemmn_pad_grid_desc =
663  transform_tensor_descriptor(wei_grid_desc,
664  make_tuple(make_right_pad_transform(GemmM, PadGemmM),
665  make_right_pad_transform(GemmN, PadGemmN)),
668 
669  return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
670  in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
671  wei_gemmm_gemmn_pad_grid_desc);
672  }
673  } // function end
674 };
675 
676 } // namespace tensor_operation
677 } // namespace ck
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
ConvolutionBackwardWeightSpecialization
Definition: convolution_backward_weight_specialization.hpp:13
Definition: ck.hpp:270
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition: multi_index_transform_helper.hpp:48
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:301
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:19
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
Definition: sequence.hpp:43
Definition: integral_constant.hpp:20
Definition: transform_conv_bwd_weight_to_gemm.hpp:23
static constexpr auto I0
Definition: transform_conv_bwd_weight_to_gemm.hpp:24
constexpr static auto make_out_grid_desc(const index_t N, const index_t Do, const index_t Ho, const index_t Wo, const index_t K, const std::array< index_t, NDimSpatial+3 > &output_strides)
Definition: transform_conv_bwd_weight_to_gemm.hpp:81
constexpr static auto make_out_grid_desc(const index_t N, const index_t Ho, const index_t Wo, const index_t K, const std::array< index_t, NDimSpatial+3 > &output_strides)
Definition: transform_conv_bwd_weight_to_gemm.hpp:29
static constexpr auto I1
Definition: transform_conv_bwd_weight_to_gemm.hpp:25
constexpr static auto make_in_grid_desc(const index_t N, const index_t Di, const index_t Hi, const index_t Wi, const index_t C, const std::array< index_t, NDimSpatial+3 > &input_strides)
Definition: transform_conv_bwd_weight_to_gemm.hpp:96
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(const index_t N, const index_t K, const index_t C, const std::array< index_t, NDimSpatial > &input_spatial_lengths, const std::array< index_t, NDimSpatial > &filter_spatial_lengths, const std::array< index_t, NDimSpatial > &output_spatial_lengths, const std::array< index_t, NDimSpatial+3 > &input_strides, const std::array< index_t, NDimSpatial+3 > &weights_strides, const std::array< index_t, NDimSpatial+3 > &output_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const index_t batch_k)
Definition: transform_conv_bwd_weight_to_gemm.hpp:312
constexpr static auto make_wei_grid_desc(const index_t K, const index_t Y, const index_t X, const index_t C, const std::array< index_t, NDimSpatial+3 > &weights_strides)
Definition: transform_conv_bwd_weight_to_gemm.hpp:68
constexpr static auto make_in_grid_desc(const index_t N, const index_t Hi, const index_t Wi, const index_t C, const std::array< index_t, NDimSpatial+3 > &input_strides)
Definition: transform_conv_bwd_weight_to_gemm.hpp:43
constexpr static auto make_wei_grid_desc(const index_t K, const index_t Z, const index_t Y, const index_t X, const index_t C, const std::array< index_t, NDimSpatial+3 > &weights_strides)
Definition: transform_conv_bwd_weight_to_gemm.hpp:124
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(const index_t N, const index_t K, const index_t C, const std::array< index_t, NDimSpatial > &input_spatial_lengths, const std::array< index_t, NDimSpatial > &filter_spatial_lengths, const std::array< index_t, NDimSpatial > &output_spatial_lengths, const std::array< index_t, NDimSpatial+3 > &, const std::array< index_t, NDimSpatial+3 > &, const std::array< index_t, NDimSpatial+3 > &, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const index_t batch_k)
Definition: transform_conv_bwd_weight_to_gemm.hpp:138