/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/wrapper/operations/gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/wrapper/operations/gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/wrapper/operations/gemm.hpp Source File
gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
8 
11 
12 // Disable from doxygen docs generation
14 namespace ck {
15 namespace wrapper {
17 
18 // Disable from doxygen docs generation
20 namespace {
21 namespace detail {
31 template <index_t K1, typename TileLayout>
32 __device__ constexpr auto GetBlockDescriptor()
33 {
34  using TileLayoutShape = typename TileLayout::LayoutShape;
35  using TileLayoutDescriptor = typename TileLayout::LayoutUnrolledDescriptorType;
36 
37  constexpr auto K0PerBlock = Number<size<1>(TileLayoutShape{})>{} / Number<K1>{};
38  // MPerBlock or NPerBlock
39  constexpr auto Dim0 = Number<size<0>(TileLayoutShape{})>{};
40 
41  constexpr auto a_block_desc_k0_m_k1 = transform_tensor_descriptor(
42  TileLayoutDescriptor{},
43  make_tuple(make_unmerge_transform(make_tuple(K0PerBlock, Number<K1>{})),
45  make_tuple(Sequence<1>{}, Sequence<0>{}),
46  make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
47 
48  return a_block_desc_k0_m_k1;
49 }
50 
51 } // namespace detail
52 } // namespace
54 
86 template <typename DataType,
87  index_t BlockSize,
88  typename GemmTraits,
89  typename ATensorType,
90  typename BTensorType,
91  typename CTensorType>
92 __device__ void blockwise_gemm_xdl(const ATensorType& a_local_tile_tensor,
93  const BTensorType& b_local_tile_tensor,
94  CTensorType& c_reg_tensor)
95 {
96  constexpr auto I3 = Number<3>{};
97 
98  static_assert(ATensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds);
99  static_assert(BTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds);
100  static_assert(CTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Vgpr);
101  static_assert(is_same_v<DataType, typename ATensorType::TensorElementType>);
102  static_assert(is_same_v<DataType, typename BTensorType::TensorElementType>);
103 
104  constexpr bool is_integer =
105  is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
106  using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
107 
108  using ATileLayout = remove_cvref_t<decltype(layout(a_local_tile_tensor))>;
109  using BTileLayout = remove_cvref_t<decltype(layout(b_local_tile_tensor))>;
110 
111  static_assert(typename ATileLayout::LayoutShape{}.Size() ==
112  typename BTileLayout::LayoutShape{}.Size());
113  constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3;
114 
115  using ABlockDesc_K0_M_K1_Type =
116  conditional_t<is_3d_desc,
117  typename ATileLayout::LayoutUnrolledDescriptorType,
118  decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
119  using BBlockDesc_K0_N_K1_Type =
120  conditional_t<is_3d_desc,
121  typename BTileLayout::LayoutUnrolledDescriptorType,
122  decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
123 
124  BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
125  DataType,
126  DataType,
127  GemmAccDataType,
128  ABlockDesc_K0_M_K1_Type,
129  BBlockDesc_K0_N_K1_Type,
130  GemmTraits::MPerXDL,
131  GemmTraits::NPerXDL,
132  GemmTraits::MXdlPerWave,
133  GemmTraits::NXdlPerWave,
134  GemmTraits::K1>
135  blockwise_gemm_xdl_op{};
136 
137  blockwise_gemm_xdl_op.Run(
138  a_local_tile_tensor.GetBuffer(), b_local_tile_tensor.GetBuffer(), c_reg_tensor.GetBuffer());
139 }
140 
170 template <typename DataType,
171  typename ATileLayout,
172  typename BTileLayout,
173  index_t BlockSize,
174  typename GemmTraits,
175  typename CTensorType>
176 __host__ __device__ constexpr auto
177 make_blockwise_gemm_xdl_c_local_partition(CTensorType& c_local_tile_tensor)
178 {
179  constexpr auto I0 = Number<0>{};
180  constexpr auto I1 = Number<1>{};
181  constexpr auto I2 = Number<2>{};
182  constexpr auto I3 = Number<3>{};
183  constexpr auto I4 = Number<4>{};
184  constexpr auto I5 = Number<5>{};
185  constexpr auto I6 = Number<6>{};
186  constexpr auto I7 = Number<7>{};
187 
188  static_assert(typename ATileLayout::LayoutShape{}.Size() ==
189  typename BTileLayout::LayoutShape{}.Size());
190 
191  constexpr bool is_integer =
192  is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
193  using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
194 
195  constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3;
196  using ABlockDesc_K0_M_K1_Type =
197  conditional_t<is_3d_desc,
198  typename ATileLayout::LayoutUnrolledDescriptorType,
199  decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
200  using BBlockDesc_K0_N_K1_Type =
201  conditional_t<is_3d_desc,
202  typename BTileLayout::LayoutUnrolledDescriptorType,
203  decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
204 
205  using BlockwiseGemmXdlops =
206  BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
207  DataType,
208  DataType,
209  GemmAccDataType,
210  ABlockDesc_K0_M_K1_Type,
211  BBlockDesc_K0_N_K1_Type,
212  GemmTraits::MPerXDL,
213  GemmTraits::NPerXDL,
214  GemmTraits::MXdlPerWave,
215  GemmTraits::NXdlPerWave,
216  GemmTraits::K1>;
217 
218  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
219  BlockwiseGemmXdlops::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
220  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0);
221  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
222  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2);
223  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
224  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4);
225  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
226  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
227  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
228 
229  // Calculate offset on grid
230  const auto c_thread_mtx_on_block =
231  BlockwiseGemmXdlops::CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
232 
233  const index_t m_thread_data_on_grid =
234  c_local_tile_tensor.GetMultiIdxOffsets()[I0] + c_thread_mtx_on_block[I0];
235 
236  const index_t n_thread_data_on_grid =
237  c_local_tile_tensor.GetMultiIdxOffsets()[I1] + c_thread_mtx_on_block[I1];
238 
239  const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = make_single_stage_tensor_adaptor(
240  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
241  make_tuple(Sequence<0, 1, 2, 3, 4>{}),
242  make_tuple(Sequence<0>{}));
243 
244  const auto m_thread_data_on_grid_idx =
245  m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
246  make_multi_index(m_thread_data_on_grid));
247 
248  const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor =
250  make_tuple(Sequence<0, 1, 2>{}),
251  make_tuple(Sequence<0>{}));
252 
253  const auto n_thread_data_on_grid_idx =
254  n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
255  make_multi_index(n_thread_data_on_grid));
256  // Create partition shape based on descriptor dims.
257  const auto partition_shape = make_tuple(M0, N0, I1, I1, M2, I1, M4, I1);
258 
259  const auto partition_desc = BlockwiseGemmXdlops::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(
260  layout(c_local_tile_tensor).GetUnrolledDescriptor());
261 
262  const auto lower_upper_dims =
263  generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<8>{});
264 
265  auto sliced_desc = transform_tensor_descriptor(
266  partition_desc,
267  make_tuple(
268  make_slice_transform(partition_shape.At(Number<0>{}),
269  m_thread_data_on_grid_idx[I0],
270  partition_shape.At(Number<0>{}) + m_thread_data_on_grid_idx[I0]),
271  make_slice_transform(partition_shape.At(Number<1>{}),
272  n_thread_data_on_grid_idx[I0],
273  partition_shape.At(Number<1>{}) + n_thread_data_on_grid_idx[I0]),
274  make_slice_transform(partition_shape.At(Number<2>{}),
275  m_thread_data_on_grid_idx[I1],
276  partition_shape.At(Number<2>{}) + m_thread_data_on_grid_idx[I1]),
277  make_slice_transform(partition_shape.At(Number<3>{}),
278  n_thread_data_on_grid_idx[I1],
279  partition_shape.At(Number<3>{}) + n_thread_data_on_grid_idx[I1]),
280  make_slice_transform(partition_shape.At(Number<4>{}),
281  m_thread_data_on_grid_idx[I2],
282  partition_shape.At(Number<4>{}) + m_thread_data_on_grid_idx[I2]),
283  make_slice_transform(partition_shape.At(Number<5>{}),
284  m_thread_data_on_grid_idx[I3],
285  partition_shape.At(Number<5>{}) + m_thread_data_on_grid_idx[I3]),
286  make_slice_transform(partition_shape.At(Number<6>{}),
287  m_thread_data_on_grid_idx[I4],
288  partition_shape.At(Number<6>{}) + m_thread_data_on_grid_idx[I4]),
289  make_slice_transform(partition_shape.At(Number<7>{}),
290  n_thread_data_on_grid_idx[I2],
291  partition_shape.At(Number<7>{}) + n_thread_data_on_grid_idx[I2])),
292  lower_upper_dims,
293  lower_upper_dims);
294 
295  const auto partition_layout =
296  Layout<remove_reference_t<decltype(partition_shape)>, decltype(sliced_desc)>(
297  partition_shape, sliced_desc);
298  auto partition_tensor = make_tensor<CTensorType::TensorBufferAddressSpace>(
299  c_local_tile_tensor.GetPointer(), partition_layout);
300  return partition_tensor;
301 }
302 
330 template <typename DataType,
331  typename ATileLayout,
332  typename BTileLayout,
333  index_t BlockSize,
334  typename GemmTraits>
335 __host__ __device__ constexpr auto make_blockwise_gemm_xdl_c_vgpr()
336 {
337  constexpr auto I0 = Number<0>{};
338  constexpr auto I1 = Number<1>{};
339  constexpr auto I2 = Number<2>{};
340  constexpr auto I3 = Number<3>{};
341  constexpr auto I4 = Number<4>{};
342  constexpr auto I5 = Number<5>{};
343  constexpr auto I6 = Number<6>{};
344  constexpr auto I7 = Number<7>{};
345 
346  static_assert(typename ATileLayout::LayoutShape{}.Size() ==
347  typename BTileLayout::LayoutShape{}.Size());
348 
349  constexpr bool is_integer =
350  is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
351  using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
352 
353  constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3;
354  using ABlockDesc_K0_M_K1_Type =
355  conditional_t<is_3d_desc,
356  typename ATileLayout::LayoutUnrolledDescriptorType,
357  decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
358  using BBlockDesc_K0_N_K1_Type =
359  conditional_t<is_3d_desc,
360  typename BTileLayout::LayoutUnrolledDescriptorType,
361  decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
362 
363  using BlockwiseGemmXdlops =
364  BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
365  DataType,
366  DataType,
367  GemmAccDataType,
368  ABlockDesc_K0_M_K1_Type,
369  BBlockDesc_K0_N_K1_Type,
370  GemmTraits::MPerXDL,
371  GemmTraits::NPerXDL,
372  GemmTraits::MXdlPerWave,
373  GemmTraits::NXdlPerWave,
374  GemmTraits::K1>;
375  // Calcualte descriptor, shape and layout
376  constexpr auto vgpr_desc = BlockwiseGemmXdlops::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
377  const auto vgpr_shape = make_tuple(vgpr_desc.GetLengths()[I0],
378  vgpr_desc.GetLengths()[I1],
379  vgpr_desc.GetLengths()[I2],
380  vgpr_desc.GetLengths()[I3],
381  vgpr_desc.GetLengths()[I4],
382  vgpr_desc.GetLengths()[I5],
383  vgpr_desc.GetLengths()[I6],
384  vgpr_desc.GetLengths()[I7]);
385  const auto vgpr_layout = Layout<remove_reference_t<decltype(vgpr_shape)>, decltype(vgpr_desc)>(
386  vgpr_shape, vgpr_desc);
387  // Get vector type for Vgpr
388  constexpr index_t ScalarPerVector = BlockwiseGemmXdlops::xdlops_gemm.GetRegSizePerXdlops();
389  using VgprVectorType = typename vector_type<GemmAccDataType, ScalarPerVector>::type;
390  return ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, VgprVectorType>(
391  vgpr_layout);
392 }
393 
394 } // namespace wrapper
395 } // namespace ck
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
typename remove_reference< T >::type remove_reference_t
Definition: type.hpp:292
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__host__ constexpr __device__ auto make_slice_transform(const LowLength &low_length, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition: multi_index_transform_helper.hpp:110
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
Layout wrapper that performs the tensor descriptor logic.
Definition: layout.hpp:24
__host__ constexpr __device__ const auto & layout(const Tensor< BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType > &tensor)
Get Tensor Layout.
Definition: tensor_utils.hpp:162
__device__ void blockwise_gemm_xdl(const ATensorType &a_local_tile_tensor, const BTensorType &b_local_tile_tensor, CTensorType &c_reg_tensor)
Perform blockwise gemm xdl on tensors stored in lds. Result will be stored in Vgpr register....
Definition: gemm.hpp:92
__host__ constexpr __device__ auto make_blockwise_gemm_xdl_c_local_partition(CTensorType &c_local_tile_tensor)
Create local partition per thread for C tensor.
Definition: gemm.hpp:177
__host__ constexpr __device__ auto make_blockwise_gemm_xdl_c_vgpr()
Create local partition per thread for C tensor.
Definition: gemm.hpp:335