/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp Source File
transform_contraction_to_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
10 
11 namespace ck {
12 namespace tensor_operation {
13 
14 // assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
15 template <index_t NumDimG,
16  index_t NumDimM,
17  index_t NumDimN,
19 static auto MakeGridDescriptorPair(const std::vector<index_t>& gs_ms_ns_lengths_vec,
20  const std::vector<index_t>& gs_ms_ns_strides_vec)
21 {
22  if(!(gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
23  gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN))
24  {
25  throw std::runtime_error("wrong! dimension must match input lengths");
26  }
27 
28  const auto to_tuple = [&](auto& vec, auto start, auto end) {
29  return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
30  };
31 
32  const auto gs_ms_ns_lengths =
33  to_tuple(gs_ms_ns_lengths_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
34  const auto gs_ms_ns_strides =
35  to_tuple(gs_ms_ns_strides_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
36 
37  // dimension Ids for G0, G1, ...
38  constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{};
39 
40  // dimension Ids for M0, M1, ...
41  constexpr auto mDimIds =
43 
44  // dimension Ids for N0, N1, ...
45  constexpr auto nDimIds =
47 
48  // lengths for G0, G1, ...
49  const auto gLengths = get_container_subset(gs_ms_ns_lengths, gDimIds);
50 
51  // lengths for M0, M1, ...
52  const auto mLengths = get_container_subset(gs_ms_ns_lengths, mDimIds);
53 
54  // lengths for N0, N1, ...
55  const auto nLengths = get_container_subset(gs_ms_ns_lengths, nDimIds);
56 
57  if constexpr(TensorSpec == device::TensorSpecialization::Packed)
58  {
59  auto G = container_reduce(gLengths, math::multiplies{}, Number<1>{});
60  auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
61  auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
62  const auto grid_desc_g_mraw_nraw = make_naive_tensor_descriptor(
63  make_tuple(G, M, N),
64  make_tuple(gs_ms_ns_strides[Number<NumDimG - 1>{}],
65  gs_ms_ns_strides[Number<NumDimG + NumDimM - 1>{}],
66  gs_ms_ns_strides[Number<NumDimG + NumDimM + NumDimN - 1>{}]));
67 
68  const auto grid_desc_mraw_nraw = make_naive_tensor_descriptor(
69  make_tuple(M, N),
70  make_tuple(gs_ms_ns_strides[Number<NumDimG + NumDimM - 1>{}],
71  gs_ms_ns_strides[Number<NumDimG + NumDimM + NumDimN - 1>{}]));
72 
73  return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw);
74  }
75  else
76  {
77  // naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
78  const auto grid_desc_gs_ms_ns =
79  make_naive_tensor_descriptor(gs_ms_ns_lengths, gs_ms_ns_strides);
80 
81  // transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
82  // N2 * ...]
83  // Note: This does not require padding as it only provides G offset calculation. Technically
84  // descriptor for only G is needed. Here we opt for backward compatibility purpose to return
85  // G_M_N
86  const auto grid_desc_g_mraw_nraw =
87  transform_tensor_descriptor(grid_desc_gs_ms_ns,
89  make_merge_transform(mLengths),
90  make_merge_transform(nLengths)),
91  make_tuple(gDimIds, mDimIds, nDimIds),
92  make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
93 
94  const auto c_ms_ns_lengths = to_tuple(
95  gs_ms_ns_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
96  const auto c_ms_ns_strides = to_tuple(
97  gs_ms_ns_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
98 
99  // transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
100  // N2 * ...]
101  const auto grid_desc_ms_ns = make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides);
102 
103  const auto grid_desc_mraw_nraw = transform_tensor_descriptor(
104  grid_desc_ms_ns,
106  make_tuple(mDimIds - Number<NumDimG>{}, nDimIds - Number<NumDimG>{}),
107  make_tuple(Sequence<0>{}, Sequence<1>{}));
108 
109  return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw);
110  }
111 }
112 
113 template <typename NumDims_G_M_N_K_O, // Sequence<>
114  typename PerBlock_M_N_K_O, // Sequence<>
121 {
122  static constexpr auto I0 = Number<0>{};
123  static constexpr auto I1 = Number<1>{};
124  static constexpr auto I2 = Number<2>{};
125  static constexpr auto I3 = Number<3>{};
126  static constexpr auto I4 = Number<4>{};
127 
128  static constexpr index_t NumDimG = NumDims_G_M_N_K_O::At(I0);
129  static constexpr index_t NumDimM = NumDims_G_M_N_K_O::At(I1);
130  static constexpr index_t NumDimN = NumDims_G_M_N_K_O::At(I2);
131  static constexpr index_t NumDimK = NumDims_G_M_N_K_O::At(I3);
132  static constexpr index_t NumDimO = NumDims_G_M_N_K_O::At(I4);
133 
134  static constexpr index_t MPerBlock = PerBlock_M_N_K_O::At(I0);
135  static constexpr index_t NPerBlock = PerBlock_M_N_K_O::At(I1);
136  static constexpr index_t KPerBlock = PerBlock_M_N_K_O::At(I2);
137  static constexpr index_t OPerBlock = PerBlock_M_N_K_O::At(I3);
138 
139  static constexpr auto matrix_padder =
142 
143  //
144  // A
145  //
146  static auto MakeAGridDescriptorPair(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
147  const std::vector<index_t>& a_gs_ms_ks_strides_vec)
148  {
149  return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimK, ASpec>(a_gs_ms_ks_lengths_vec,
150  a_gs_ms_ks_strides_vec);
151  }
152 
153  // TODO: rename to G_MRaw_KRaw
154  static auto MakeAGridDescriptor_G_M_K(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
155  const std::vector<index_t>& a_gs_ms_ks_strides_vec)
156  {
157  return MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).first;
158  }
159  static auto MakeAGridDescriptor_M_K(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
160  const std::vector<index_t>& a_gs_ms_ks_strides_vec)
161  {
162  return matrix_padder.PadADescriptor_M_K(
163  MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).second);
164  }
165 
166  template <typename AGridDesc_M_K, typename Number>
167  __host__ __device__ static constexpr auto
168  MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k, const Number& AK1)
169  {
170  const auto M = a_grid_desc_m_k.GetLength(I0);
171  const auto K = a_grid_desc_m_k.GetLength(I1);
172 
173  const auto AK0 = K / AK1;
174 
175  return transform_tensor_descriptor(a_grid_desc_m_k,
180  }
181 
182  //
183  // B (alias of B0)
184  //
185  static auto MakeB0GridDescriptorPair(const std::vector<index_t>& b0_gs_ns_ks_lengths_vec,
186  const std::vector<index_t>& b0_gs_ns_ks_strides_vec)
187  {
188  return MakeGridDescriptorPair<NumDimG, NumDimN, NumDimK, B0Spec>(b0_gs_ns_ks_lengths_vec,
189  b0_gs_ns_ks_strides_vec);
190  }
191 
192  // TODO: rename to G_MRaw_NRaw
193  static auto MakeB0GridDescriptor_G_N_K(const std::vector<index_t>& b0_gs_ns_ks_lengths_vec,
194  const std::vector<index_t>& b0_gs_ns_ks_strides_vec)
195  {
196  return MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).first;
197  }
198  static auto MakeB0GridDescriptor_N_K(const std::vector<index_t>& b0_gs_ns_ks_lengths_vec,
199  const std::vector<index_t>& b0_gs_ns_ks_strides_vec)
200  {
201  // alias of matrix_padder.PadB0Descriptor_N_K
202  return matrix_padder.PadBDescriptor_N_K(
203  MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).second);
204  }
205 
206  template <typename BGridDesc_N_K, typename Number>
207  __host__ __device__ static constexpr auto
208  MakeB0GridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k, const Number& BK1)
209  {
210  const auto N = b_grid_desc_n_k.GetLength(I0);
211  const auto K = b_grid_desc_n_k.GetLength(I1);
212 
213  const auto BK0 = K / BK1;
214 
215  return transform_tensor_descriptor(b_grid_desc_n_k,
220  }
221 
222  //
223  // B1
224  //
225  static auto MakeB1GridDescriptorPair(const std::vector<index_t>& b1_gs_os_ns_lengths_vec,
226  const std::vector<index_t>& b1_gs_os_ns_strides_vec)
227  {
228  return MakeGridDescriptorPair<NumDimG, NumDimO, NumDimN, B1Spec>(b1_gs_os_ns_lengths_vec,
229  b1_gs_os_ns_strides_vec);
230  }
231 
232  // TODO: rename to G_NRaw_KRaw
233  static auto MakeB1GridDescriptor_G_N_K(const std::vector<index_t>& b1_gs_os_ns_lengths_vec,
234  const std::vector<index_t>& b1_gs_os_ns_strides_vec)
235  {
236  return MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).first;
237  }
238  static auto MakeB1GridDescriptor_N_K(const std::vector<index_t>& b1_gs_os_ns_lengths_vec,
239  const std::vector<index_t>& b1_gs_os_ns_strides_vec)
240  {
241  // alias of matrix_padder.PadB1Descriptor_O_N
242  return matrix_padder.PadB1Descriptor_N_K(
243  MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).second);
244  }
245 
246  template <typename B1GridDesc_N_K, typename Number>
247  __host__ __device__ static constexpr auto
248  MakeB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K& b1_grid_desc_n_k, const Number& B1K1)
249  {
250  const auto N = b1_grid_desc_n_k.GetLength(I0);
251  const auto K = b1_grid_desc_n_k.GetLength(I1);
252 
253  const auto B1K0 = K / B1K1;
254 
256  b1_grid_desc_n_k,
261  }
262 
263  //
264  // C
265  //
266  static auto MakeCGridDescriptorPair(const std::vector<index_t>& c_gs_ms_os_lengths_vec,
267  const std::vector<index_t>& c_gs_ms_os_strides_vec)
268  {
269  return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimO, CSpec>(c_gs_ms_os_lengths_vec,
270  c_gs_ms_os_strides_vec);
271  }
272 
273  // TODO: rename to G_MRaw_NRaw
274  static auto MakeCGridDescriptor_G_M_N(const std::vector<index_t>& c_gs_ms_os_lengths_vec,
275  const std::vector<index_t>& c_gs_ms_os_strides_vec)
276  {
277  return MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).first;
278  }
279  static auto MakeCGridDescriptor_M_N(const std::vector<index_t>& c_gs_ms_os_lengths_vec,
280  const std::vector<index_t>& c_gs_ms_os_strides_vec)
281  {
282  return matrix_padder.PadCDescriptor_M_N(
283  MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).second);
284  }
285 };
286 
287 } // namespace tensor_operation
288 } // namespace ck
TensorSpecialization
Definition: tensor_specialization.hpp:11
GemmSpecialization
Definition: gemm_specialization.hpp:11
__host__ __device__ multiplies() -> multiplies< void, void >
FIXME: create macro to replace 'host device' and nothing more.
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__host__ constexpr __device__ auto container_reduce(const Container &x, Reduce reduce, Init init, Number< IBegin >=Number< 0 >{}, Number< IEnd >=Number< Container::Size()>{}, Number< IStep >=Number< 1 >{})
Definition: container_helper.hpp:111
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
integral_constant< index_t, N > Number
Definition: number.hpp:12
__host__ constexpr __device__ auto get_container_subset(const Array< T, N > &arr, Sequence< Is... >)
Definition: container_helper.hpp:346
Definition: sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:271
Definition: integral_constant.hpp:20
static constexpr index_t NPerBlock
Definition: transform_contraction_to_gemm.hpp:135
static constexpr index_t NumDimM
Definition: transform_contraction_to_gemm.hpp:129
__host__ static constexpr __device__ auto MakeB0GridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k, const Number &BK1)
Definition: transform_contraction_to_gemm.hpp:208
__host__ static constexpr __device__ auto MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k, const Number &AK1)
Definition: transform_contraction_to_gemm.hpp:168
static constexpr auto I4
Definition: transform_contraction_to_gemm.hpp:126
static auto MakeB0GridDescriptor_N_K(const std::vector< index_t > &b0_gs_ns_ks_lengths_vec, const std::vector< index_t > &b0_gs_ns_ks_strides_vec)
Definition: transform_contraction_to_gemm.hpp:198
static constexpr index_t MPerBlock
Definition: transform_contraction_to_gemm.hpp:134
static constexpr auto I2
Definition: transform_contraction_to_gemm.hpp:124
static auto MakeAGridDescriptor_G_M_K(const std::vector< index_t > &a_gs_ms_ks_lengths_vec, const std::vector< index_t > &a_gs_ms_ks_strides_vec)
Definition: transform_contraction_to_gemm.hpp:154
static constexpr index_t OPerBlock
Definition: transform_contraction_to_gemm.hpp:137
__host__ static constexpr __device__ auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K &b1_grid_desc_n_k, const Number &B1K1)
Definition: transform_contraction_to_gemm.hpp:248
static auto MakeB0GridDescriptor_G_N_K(const std::vector< index_t > &b0_gs_ns_ks_lengths_vec, const std::vector< index_t > &b0_gs_ns_ks_strides_vec)
Definition: transform_contraction_to_gemm.hpp:193
static constexpr index_t KPerBlock
Definition: transform_contraction_to_gemm.hpp:136
static auto MakeB1GridDescriptorPair(const std::vector< index_t > &b1_gs_os_ns_lengths_vec, const std::vector< index_t > &b1_gs_os_ns_strides_vec)
Definition: transform_contraction_to_gemm.hpp:225
static constexpr index_t NumDimO
Definition: transform_contraction_to_gemm.hpp:132
static constexpr auto I3
Definition: transform_contraction_to_gemm.hpp:125
static constexpr auto I0
Definition: transform_contraction_to_gemm.hpp:122
static auto MakeB0GridDescriptorPair(const std::vector< index_t > &b0_gs_ns_ks_lengths_vec, const std::vector< index_t > &b0_gs_ns_ks_strides_vec)
Definition: transform_contraction_to_gemm.hpp:185
static auto MakeAGridDescriptor_M_K(const std::vector< index_t > &a_gs_ms_ks_lengths_vec, const std::vector< index_t > &a_gs_ms_ks_strides_vec)
Definition: transform_contraction_to_gemm.hpp:159
static constexpr auto matrix_padder
Definition: transform_contraction_to_gemm.hpp:139
static constexpr index_t NumDimK
Definition: transform_contraction_to_gemm.hpp:131
static auto MakeAGridDescriptorPair(const std::vector< index_t > &a_gs_ms_ks_lengths_vec, const std::vector< index_t > &a_gs_ms_ks_strides_vec)
Definition: transform_contraction_to_gemm.hpp:146
static auto MakeCGridDescriptor_G_M_N(const std::vector< index_t > &c_gs_ms_os_lengths_vec, const std::vector< index_t > &c_gs_ms_os_strides_vec)
Definition: transform_contraction_to_gemm.hpp:274
static constexpr index_t NumDimG
Definition: transform_contraction_to_gemm.hpp:128
static auto MakeB1GridDescriptor_G_N_K(const std::vector< index_t > &b1_gs_os_ns_lengths_vec, const std::vector< index_t > &b1_gs_os_ns_strides_vec)
Definition: transform_contraction_to_gemm.hpp:233
static auto MakeB1GridDescriptor_N_K(const std::vector< index_t > &b1_gs_os_ns_lengths_vec, const std::vector< index_t > &b1_gs_os_ns_strides_vec)
Definition: transform_contraction_to_gemm.hpp:238
static auto MakeCGridDescriptorPair(const std::vector< index_t > &c_gs_ms_os_lengths_vec, const std::vector< index_t > &c_gs_ms_os_strides_vec)
Definition: transform_contraction_to_gemm.hpp:266
static auto MakeCGridDescriptor_M_N(const std::vector< index_t > &c_gs_ms_os_lengths_vec, const std::vector< index_t > &c_gs_ms_os_strides_vec)
Definition: transform_contraction_to_gemm.hpp:279
static constexpr auto I1
Definition: transform_contraction_to_gemm.hpp:123
static constexpr index_t NumDimN
Definition: transform_contraction_to_gemm.hpp:130
Definition: matrix_padder.hpp:63