/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/matrix_padder.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/matrix_padder.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/matrix_padder.hpp Source File
matrix_padder.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
10 
11 namespace ck {
12 namespace tensor_operation {
13 namespace device {
14 
15 template <typename TensorDesc,
16  typename TileLengths, // Tuple<...>
17  typename DoPads> // Sequence<bool, bool, ...>
18 __host__ __device__ constexpr auto
19 PadTensorDescriptor(const TensorDesc& desc, const TileLengths& tile_lengths, DoPads)
20 {
21  constexpr index_t num_dim = DoPads::Size();
22 
23  static_assert(num_dim == TileLengths::Size() && num_dim == TensorDesc::GetNumOfDimension(),
24  "wrong! inconsistent # of dimensions");
25 
26  // transforms
27  const auto transforms = generate_tuple(
28  [&](auto idim) {
29  const auto MRaw = desc.GetLength(idim);
30 
31  const auto MPerTile = tile_lengths[idim];
32 
33  const auto M = math::integer_divide_ceil(MRaw, MPerTile) * MPerTile;
34 
35  const auto MPad = M - MRaw;
36 
37  const bool DoPadM = DoPads::At(idim);
38 
39  const auto MTransform = conditional_expr<DoPadM>(make_right_pad_transform(MRaw, MPad),
41 
42  return MTransform;
43  },
44  Number<num_dim>{});
45 
46  // lower dimension Id
47  const auto lower_dimss =
48  generate_tuple([&](auto idim) { return Sequence<idim.value>{}; }, Number<num_dim>{});
49 
50  // upper dimension Id
51  const auto upper_dimss = lower_dimss;
52 
53  return transform_tensor_descriptor(desc, transforms, lower_dimss, upper_dimss);
54 }
55 
56 // M/N/K/OPerTileType could be index_t or Number<>
57 template <GemmSpecialization GemmSpec,
58  typename MPerTileType,
59  typename NPerTileType,
60  typename KPerTileType,
61  typename OPerTileType>
63 {
64  // TODO: hard to scale; use mask instead
65  static constexpr bool PadM =
70  static constexpr bool PadN =
75  static constexpr bool PadK =
80  static constexpr bool PadO =
85 
86  // A[M, K]
87  template <typename ADesc_MRaw_KRaw>
88  __host__ __device__ constexpr auto
89  PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const
90  {
91  return PadTensorDescriptor(
92  a_desc_mraw_kraw, make_tuple(MPerTile_, KPerTile_), Sequence<PadM, PadK>{});
93  }
94 
95  // B[K, N]
96  template <typename BDesc_NRaw_KRaw>
97  __host__ __device__ constexpr auto
98  PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const
99  {
100  return PadTensorDescriptor(
101  b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence<PadN, PadK>{});
102  }
103 
104  // B1[Gemm1N, Gemm1K] = B1[O, N]
105  template <typename B1Desc_NRaw_KRaw>
106  __host__ __device__ constexpr auto
107  PadB1Descriptor_N_K(const B1Desc_NRaw_KRaw& b1_desc_nraw_kraw) const
108  {
109  return PadTensorDescriptor(
110  b1_desc_nraw_kraw, make_tuple(OPerTile_, NPerTile_), Sequence<PadO, PadN>{});
111  }
112 
113  // C[M, Gemm1N] = C[M, O]
114  template <typename CDesc_MRaw_NRaw>
115  __host__ __device__ constexpr auto
116  PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const
117  {
118  return PadTensorDescriptor(
119  c_desc_mraw_nraw, make_tuple(MPerTile_, OPerTile_), Sequence<PadM, PadO>{});
120  }
121 
122  MPerTileType MPerTile_;
123  NPerTileType NPerTile_;
124  KPerTileType KPerTile_;
125  OPerTileType OPerTile_;
126 };
127 
128 // M/N/KPerTileType could be index_t or Number<>
129 template <GemmSpecialization GemmSpec,
130  typename MPerTileType,
131  typename NPerTileType,
132  typename KPerTileType>
134 {
135  static constexpr bool PadM =
136  (GemmSpec == GemmSpecialization::MPadding || GemmSpec == GemmSpecialization::MNPadding ||
138  static constexpr bool PadN =
139  (GemmSpec == GemmSpecialization::NPadding || GemmSpec == GemmSpecialization::MNPadding ||
141  static constexpr bool PadK =
142  (GemmSpec == GemmSpecialization::KPadding || GemmSpec == GemmSpecialization::MKPadding ||
144 
145  template <typename ADesc_MRaw_KRaw>
146  __host__ __device__ constexpr auto
147  PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const
148  {
149  return PadTensorDescriptor(
150  a_desc_mraw_kraw, make_tuple(MPerTile_, KPerTile_), Sequence<PadM, PadK>{});
151  }
152 
153  template <typename BDesc_NRaw_KRaw>
154  __host__ __device__ constexpr auto
155  PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const
156  {
157  return PadTensorDescriptor(
158  b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence<PadN, PadK>{});
159  }
160 
161  template <typename CDesc_MRaw_NRaw>
162  __host__ __device__ constexpr auto
163  PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const
164  {
165  return PadTensorDescriptor(
166  c_desc_mraw_nraw, make_tuple(MPerTile_, NPerTile_), Sequence<PadM, PadN>{});
167  }
168 
169  MPerTileType MPerTile_;
170  NPerTileType NPerTile_;
171  KPerTileType KPerTile_;
172 };
173 
174 // Alias of GemmPadder; to deprecate
175 template <GemmSpecialization GemmSpec,
176  typename MPerTileType,
177  typename NPerTileType,
178  typename KPerTileType>
179 struct MatrixPadder : public GemmPadder<GemmSpec, MPerTileType, NPerTileType, KPerTileType>
180 {
181 };
182 
183 // function to take in a struct of type MatrixPadder and call the appropriate function to get
184 // the output descriptor at runtime for codegen
185 template <GemmSpecialization GemmSpec,
186  typename MPerTileType,
187  typename NPerTileType,
188  typename KPerTileType,
189  typename CDesc_MRaw_NRaw>
191  CDesc_MRaw_NRaw conv_desc)
192 {
193  auto res = matrix_padder.PadCDescriptor_M_N(conv_desc);
194  return res;
195 }
196 // M/N/KPerTileType could be index_t or Number<>
197 template <bool PadM,
198  bool PadN,
199  bool PadK,
200  typename MPerTileType,
201  typename NPerTileType,
202  typename KPerTileType>
204 {
205  template <typename ADesc_MRaw_KRaw>
206  __host__ __device__ constexpr auto
207  PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const
208  {
209  return PadTensorDescriptor(
210  a_desc_mraw_kraw, make_tuple(MPerTile_, KPerTile_), Sequence<PadM, PadK>{});
211  }
212 
213  template <typename BDesc_NRaw_KRaw>
214  __host__ __device__ constexpr auto
215  PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const
216  {
217  return PadTensorDescriptor(
218  b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence<PadN, PadK>{});
219  }
220 
221  template <typename CDesc_MRaw_NRaw>
222  __host__ __device__ constexpr auto
223  PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const
224  {
225  return PadTensorDescriptor(
226  c_desc_mraw_nraw, make_tuple(MPerTile_, NPerTile_), Sequence<PadM, PadN>{});
227  }
228 
229  MPerTileType MPerTile_;
230  NPerTileType NPerTile_;
231  KPerTileType KPerTile_;
232 };
233 
234 // M/N/KPerTileType could be index_t or Number<>
235 template <bool PadM,
236  bool PadN,
237  bool PadK,
238  typename MPerTileType,
239  typename NPerTileType,
240  typename KPerTileType>
242 {
243  static constexpr auto I0 = Number<0>{};
244  static constexpr auto I1 = Number<1>{};
245  static constexpr auto I2 = Number<2>{};
246  static constexpr auto I3 = Number<3>{};
247 
248  template <typename ADesc_MRaw_KRaw>
249  __host__ __device__ constexpr auto
250  PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const
251  {
252  const auto MRaw = a_desc_mraw_kraw.GetLength(I0);
253  const auto KRaw = a_desc_mraw_kraw.GetLength(I1);
254 
255  const auto M = math::integer_divide_ceil(MRaw, MPerTile_) * MPerTile_;
256  const auto K = math::integer_divide_ceil(KRaw, KPerTile_) * KPerTile_;
257 
258  const auto MPad = M - MRaw;
259  const auto KPad = K - KRaw;
260 
261  if constexpr(PadM && PadK)
262  {
263  // pad both M and K
264  return transform_tensor_descriptor(a_desc_mraw_kraw,
266  make_right_pad_transform(KRaw, KPad)),
269  }
270  else if constexpr(PadM && (!PadK))
271  {
272  // pad M, but not K
274  a_desc_mraw_kraw,
278  }
279  else if constexpr((!PadM) && PadK)
280  {
281  // pad K, but not M
283  a_desc_mraw_kraw,
287  }
288  else
289  {
290  // not pad M or K
291  return a_desc_mraw_kraw;
292  }
293  }
294 
295  template <typename BDesc_NRaw_KRaw>
296  __host__ __device__ constexpr auto
297  PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const
298  {
299  const auto NRaw = b_desc_nraw_kraw.GetLength(I0);
300  const auto KRaw = b_desc_nraw_kraw.GetLength(I1);
301 
302  const auto N = math::integer_divide_ceil(NRaw, NPerTile_) * NPerTile_;
303  const auto K = math::integer_divide_ceil(KRaw, KPerTile_) * KPerTile_;
304 
305  const auto NPad = N - NRaw;
306  const auto KPad = K - KRaw;
307 
308  if constexpr(PadN && PadK)
309  {
310  // pad both N and K
311  return transform_tensor_descriptor(b_desc_nraw_kraw,
313  make_right_pad_transform(KRaw, KPad)),
316  }
317  else if constexpr(PadN && (!PadK))
318  {
319  // pad N, but not K
321  b_desc_nraw_kraw,
325  }
326  else if constexpr((!PadN) && PadK)
327  {
328  // pad K, but not N
330  b_desc_nraw_kraw,
334  }
335  else
336  {
337  // not pad N or K
338  return b_desc_nraw_kraw;
339  }
340  }
341 
342  template <typename CDesc_MRaw_NRaw>
343  __host__ __device__ constexpr auto
344  PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const
345  {
346  const auto MRaw = c_desc_mraw_nraw.GetLength(I0);
347  const auto NRaw = c_desc_mraw_nraw.GetLength(I1);
348 
349  const auto M = math::integer_divide_ceil(MRaw, MPerTile_) * MPerTile_;
350  const auto N = math::integer_divide_ceil(NRaw, NPerTile_) * NPerTile_;
351 
352  const auto MPad = M - MRaw;
353  const auto NPad = N - NRaw;
354 
355  if constexpr(PadM && PadN)
356  {
357  // pad M and N
358  return transform_tensor_descriptor(c_desc_mraw_nraw,
360  make_right_pad_transform(NRaw, NPad)),
363  }
364  else if constexpr(PadM && (!PadN))
365  {
366  // pad M, but not N
368  c_desc_mraw_nraw,
372  }
373  else if constexpr((!PadM) && PadN)
374  {
375  // pad N, but not M
377  c_desc_mraw_nraw,
381  }
382  else
383  {
384  // not pad M or N
385  return c_desc_mraw_nraw;
386  }
387  }
388 
389  MPerTileType MPerTile_;
390  NPerTileType NPerTile_;
391  KPerTileType KPerTile_;
392 };
393 } // namespace device
394 } // namespace tensor_operation
395 } // namespace ck
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
auto grid_desc(MatrixPadder< GemmSpec, MPerTileType, NPerTileType, KPerTileType > matrix_padder, CDesc_MRaw_NRaw conv_desc)
Definition: matrix_padder.hpp:190
__host__ constexpr __device__ auto PadTensorDescriptor(const TensorDesc &desc, const TileLengths &tile_lengths, DoPads)
Definition: matrix_padder.hpp:19
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:266
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:297
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
Definition: sequence.hpp:43
Definition: integral_constant.hpp:20
Definition: matrix_padder.hpp:63
KPerTileType KPerTile_
Definition: matrix_padder.hpp:124
OPerTileType OPerTile_
Definition: matrix_padder.hpp:125
static constexpr bool PadM
Definition: matrix_padder.hpp:65
MPerTileType MPerTile_
Definition: matrix_padder.hpp:122
static constexpr bool PadN
Definition: matrix_padder.hpp:70
__host__ constexpr __device__ auto PadCDescriptor_M_N(const CDesc_MRaw_NRaw &c_desc_mraw_nraw) const
Definition: matrix_padder.hpp:116
__host__ constexpr __device__ auto PadADescriptor_M_K(const ADesc_MRaw_KRaw &a_desc_mraw_kraw) const
Definition: matrix_padder.hpp:89
NPerTileType NPerTile_
Definition: matrix_padder.hpp:123
static constexpr bool PadO
Definition: matrix_padder.hpp:80
__host__ constexpr __device__ auto PadB1Descriptor_N_K(const B1Desc_NRaw_KRaw &b1_desc_nraw_kraw) const
Definition: matrix_padder.hpp:107
static constexpr bool PadK
Definition: matrix_padder.hpp:75
__host__ constexpr __device__ auto PadBDescriptor_N_K(const BDesc_NRaw_KRaw &b_desc_nraw_kraw) const
Definition: matrix_padder.hpp:98
Definition: matrix_padder.hpp:204
__host__ constexpr __device__ auto PadADescriptor_M_K(const ADesc_MRaw_KRaw &a_desc_mraw_kraw) const
Definition: matrix_padder.hpp:207
MPerTileType MPerTile_
Definition: matrix_padder.hpp:229
NPerTileType NPerTile_
Definition: matrix_padder.hpp:230
__host__ constexpr __device__ auto PadBDescriptor_N_K(const BDesc_NRaw_KRaw &b_desc_nraw_kraw) const
Definition: matrix_padder.hpp:215
__host__ constexpr __device__ auto PadCDescriptor_M_N(const CDesc_MRaw_NRaw &c_desc_mraw_nraw) const
Definition: matrix_padder.hpp:223
KPerTileType KPerTile_
Definition: matrix_padder.hpp:231
Definition: matrix_padder.hpp:134
NPerTileType NPerTile_
Definition: matrix_padder.hpp:170
MPerTileType MPerTile_
Definition: matrix_padder.hpp:169
static constexpr bool PadK
Definition: matrix_padder.hpp:141
KPerTileType KPerTile_
Definition: matrix_padder.hpp:171
__host__ constexpr __device__ auto PadADescriptor_M_K(const ADesc_MRaw_KRaw &a_desc_mraw_kraw) const
Definition: matrix_padder.hpp:147
__host__ constexpr __device__ auto PadBDescriptor_N_K(const BDesc_NRaw_KRaw &b_desc_nraw_kraw) const
Definition: matrix_padder.hpp:155
__host__ constexpr __device__ auto PadCDescriptor_M_N(const CDesc_MRaw_NRaw &c_desc_mraw_nraw) const
Definition: matrix_padder.hpp:163
static constexpr bool PadM
Definition: matrix_padder.hpp:135
static constexpr bool PadN
Definition: matrix_padder.hpp:138
Definition: matrix_padder.hpp:242
KPerTileType KPerTile_
Definition: matrix_padder.hpp:391
__host__ constexpr __device__ auto PadBDescriptor_N_K(const BDesc_NRaw_KRaw &b_desc_nraw_kraw) const
Definition: matrix_padder.hpp:297
static constexpr auto I2
Definition: matrix_padder.hpp:245
MPerTileType MPerTile_
Definition: matrix_padder.hpp:389
__host__ constexpr __device__ auto PadCDescriptor_M_N(const CDesc_MRaw_NRaw &c_desc_mraw_nraw) const
Definition: matrix_padder.hpp:344
static constexpr auto I3
Definition: matrix_padder.hpp:246
static constexpr auto I0
Definition: matrix_padder.hpp:243
NPerTileType NPerTile_
Definition: matrix_padder.hpp:390
__host__ constexpr __device__ auto PadADescriptor_M_K(const ADesc_MRaw_KRaw &a_desc_mraw_kraw) const
Definition: matrix_padder.hpp:250
static constexpr auto I1
Definition: matrix_padder.hpp:244
Definition: matrix_padder.hpp:180