/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp Source File
blockwise_gemm_dpp.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
10 
11 namespace ck {
12 
22 template <index_t BlockSize,
23  typename ABDataType,
24  typename AccDataType,
25  typename AK0MK1BlockDesc,
26  typename BK0NK1BlockDesc,
27  index_t MPerDpp,
28  index_t NPerDpp,
29  index_t MRepeat,
30  index_t NRepeat,
31  index_t KPack>
33 {
34  static constexpr auto I0 = Number<0>{};
35  static constexpr auto I1 = Number<1>{};
36  static constexpr auto I2 = Number<2>{};
37  static constexpr auto I3 = Number<3>{};
38 
40 
41  static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
42  static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
43  static constexpr index_t KPerBlock =
44  BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
45 
46  static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerDpp);
47  static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerDpp);
48  static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
49 
50  static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
51  static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0);
52  static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
53  static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
54 
56 
57  static constexpr index_t KPerThread = KPerBlock / dpp_gemm.K0PerDpp;
58 
60  AccDataType,
61  MRepeat * NRepeat,
62  dpp_gemm.GetRegSizePerDpp(),
63  true>
65 
66  __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
67 
68  __device__ static auto GetWaveIdx()
69  {
70  const index_t thread_id = ThisThreadBlock::GetThreadId();
71 
72  constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
76 
77  return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
78  }
79 
81  {
82  const auto wave_idx = GetWaveIdx();
83  const auto waveId_m = wave_idx[I0];
84  const auto dpp_a_idx = dpp_gemm.CalculateAThreadOriginDataIndex_K_M();
85  const auto dpp_a_idx_k = dpp_a_idx[I0];
86  const auto dpp_a_idx_m = dpp_a_idx[I1];
87  return make_tuple(0, waveId_m, dpp_a_idx_m, KPerThread * dpp_a_idx_k);
88  }
89 
91  {
92  const auto wave_idx = GetWaveIdx();
93  const auto waveId_n = wave_idx[I1];
94  const auto dpp_b_idx = dpp_gemm.CalculateBThreadOriginDataIndex_K_N();
95  const auto dpp_b_idx_k = dpp_b_idx[I0];
96  const auto dpp_b_idx_n = dpp_b_idx[I1];
97  return make_tuple(0, waveId_n, dpp_b_idx_n, KPerThread * dpp_b_idx_k);
98  }
99 
100  template <index_t m0, index_t n0>
102  {
103  const auto wave_idx = GetWaveIdx();
104  const auto waveId_m = wave_idx[I0];
105  const auto waveId_n = wave_idx[I1];
106 
107  const auto blk_idx = dpp_gemm.GetBeginOfThreadBlk();
108  const auto blk_m_offset = blk_idx[I0];
109  const auto blk_n_offset = blk_idx[I1];
110 
111  constexpr auto mrepeat_mwave_MPerDpp_to_m_adaptor = make_single_stage_tensor_adaptor(
112  make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerDpp))),
115 
116  constexpr auto nrepeat_nwave_NPerDpp_to_n_adaptor = make_single_stage_tensor_adaptor(
117  make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerDpp))),
120 
121  const index_t c_thread_m = mrepeat_mwave_MPerDpp_to_m_adaptor.CalculateBottomIndex(
122  make_tuple(m0, waveId_m, blk_m_offset))[I0];
123  const index_t c_thread_n = nrepeat_nwave_NPerDpp_to_n_adaptor.CalculateBottomIndex(
124  make_tuple(n0, waveId_n, blk_n_offset))[I0];
125 
126  return make_tuple(c_thread_m, c_thread_n);
127  }
128 
130  {
131  static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
132  BK0NK1BlockDesc::IsKnownAtCompileTime(),
133  "Wrong! Block descriptors should be known at the time of compilation.");
134 
135 #if defined(__HIP_DEVICE_COMPILE__)
136  // Host wave size can be different than the device one and this assert could fail for host,
137  // but it does matter only for device.
138  static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
139  "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
140 #endif
141 
142  static_assert(MPerBlock % (MPerDpp * MRepeat) == 0,
143  "Invalid parameters. MPerBlock must be divisible by MPerDpp * MRepeat.");
144  static_assert(NPerBlock % (NPerDpp * NRepeat) == 0,
145  "Invalid parameters. NPerBlock must be divisible by NPerDpp * NRepeat.");
146  }
147 
148  __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2()
149  {
150  constexpr auto c_m_n_tblk_lens = dpp_gemm.GetCMNThreadBlkLengths();
151  constexpr auto M = c_m_n_tblk_lens[I0];
152  constexpr auto N = c_m_n_tblk_lens[I1];
153 
156  }
157 
158  __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_N2()
159  {
160  constexpr auto c_m_n_tblk_lens = dpp_gemm.GetCMNThreadBlkLengths();
161  constexpr auto M = c_m_n_tblk_lens[I0];
162  constexpr auto N = c_m_n_tblk_lens[I1];
163 
166  }
167 
168  __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2()
169  {
170  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
172  Number<NRepeat>{},
173  Number<MWaves>{},
174  Number<NWaves>{},
175  Number<MPerDpp>{},
176  Number<NPerDpp>{}));
177 
178  return c_block_desc_m0_n0_m1_n1_m2_n2;
179  }
180 
181  __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_N2()
182  {
183  constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
185  Number<MRepeat>{},
186  Number<NRepeat>{},
187  Number<MWaves>{},
188  Number<NWaves>{},
189  Number<MPerDpp>{},
190  Number<NPerDpp>{}));
191  return c_block_desc_g_m0_n0_m1_n1_m2_n2;
192  }
193 
194  template <typename CGridDesc_M_N>
195  __host__ __device__ static constexpr auto
196  MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(const CGridDesc_M_N& c_grid_desc_m_n)
197  {
198  const auto M = c_grid_desc_m_n.GetLength(I0);
199  const auto N = c_grid_desc_m_n.GetLength(I1);
200 
201  const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
202  c_grid_desc_m_n,
203  make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerDpp), MWaves, MPerDpp)),
204  make_unmerge_transform(make_tuple(N / (NWaves * NPerDpp), NWaves, NPerDpp))),
207 
208  return c_grid_desc_m0_n0_m1_n1_m2_n2;
209  }
210 
211  template <typename CGridDesc_G_M_N>
212  __host__ __device__ static constexpr auto
213  MakeCGridDescriptor_G_M0_N0_M1_N1_M2_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
214  {
215  const auto G = c_grid_desc_g_m_n.GetLength(I0);
216  const auto M = c_grid_desc_g_m_n.GetLength(I1);
217  const auto N = c_grid_desc_g_m_n.GetLength(I2);
218 
219  const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
220  c_grid_desc_g_m_n,
222  make_unmerge_transform(make_tuple(M / (MWaves * MPerDpp), MWaves, MPerDpp)),
223  make_unmerge_transform(make_tuple(N / (NWaves * NPerDpp), NWaves, NPerDpp))),
226 
227  return c_grid_desc_g_m0_n0_m1_n1_m2_n2;
228  }
229 
230  __host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K()
231  {
233  AK0MK1BlockDesc{},
234  make_tuple(
240  }
241 
242  __host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K()
243  {
245  BK0NK1BlockDesc{},
246  make_tuple(
252  }
253 
256 
257  template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
258  __device__ void Run(const ABlockBuffer& a_block_buf,
259  const BBlockBuffer& b_block_buf,
260  CThreadBuffer& c_thread_buf) const
261  {
262  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ABDataType>(
263  a_thread_desc_.GetElementSpaceSize());
264  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ABDataType>(
265  b_thread_desc_.GetElementSpaceSize());
266 
267  static_for<0, MRepeat, 1>{}([&](auto m0) {
268  // read A
270  make_tuple(m0, I0, I0, I0),
271  a_block_buf,
273  make_tuple(I0, I0, I0, I0),
274  a_thread_buf);
275 
276  static_for<0, NRepeat, 1>{}([&](auto n0) {
277  // read B
279  make_tuple(n0, I0, I0, I0),
280  b_block_buf,
282  make_tuple(I0, I0, I0, I0),
283  b_thread_buf);
284 
285  static_for<0, KPerThread, KPack>{}([&](auto k) {
286  vector_type<ABDataType, KPack> a_thread_vec;
287  vector_type<ABDataType, KPack> b_thread_vec;
288 
289  static_for<0, KPack, 1>{}([&](auto i) {
290  a_thread_vec.template AsType<ABDataType>()(i) = a_thread_buf
291  [Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
292  b_thread_vec.template AsType<ABDataType>()(i) = b_thread_buf
293  [Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
294  });
295 
296  using dpp_input_type =
297  typename vector_type<ABDataType, dpp_gemm.K1PerDpp>::type;
298 
299  constexpr index_t c_offset =
300  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
301 
302  dpp_gemm.Run(a_thread_vec.template AsType<dpp_input_type>(),
303  b_thread_vec.template AsType<dpp_input_type>(),
304  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
305  });
306  });
307  });
308  }
309 
310  protected:
311  // A[M0, M1, M2, KPerThread]
312  static constexpr auto a_thread_desc_ =
314 
315  // B[N0, N1, N2, KPerThread]
316  static constexpr auto b_thread_desc_ =
318 
319  // C[M, N, NumRegDpp]
321  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, dpp_gemm.GetRegSizePerDpp()));
322 
324  ABDataType,
325  decltype(a_block_desc_m0_m1_m2_k),
326  decltype(a_thread_desc_),
329  3,
330  A_K1,
331  A_K1>;
332 
334  ABDataType,
335  decltype(b_block_desc_n0_n1_n2_k),
336  decltype(b_thread_desc_),
339  3,
340  B_K1,
341  B_K1>;
342 
345 };
346 
347 } // namespace ck
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
Definition: blockwise_gemm_dpp.hpp:33
static constexpr index_t KPerBlock
Definition: blockwise_gemm_dpp.hpp:43
static constexpr index_t NWaves
Definition: blockwise_gemm_dpp.hpp:47
static constexpr index_t B_K1
Definition: blockwise_gemm_dpp.hpp:53
BThreadCopy b_thread_copy_
Definition: blockwise_gemm_dpp.hpp:344
static constexpr index_t A_K0
Definition: blockwise_gemm_dpp.hpp:50
static __device__ auto CalculateAThreadOriginDataIndex_M0_M1_M2_K()
Definition: blockwise_gemm_dpp.hpp:80
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: blockwise_gemm_dpp.hpp:39
static constexpr auto c_thread_desc_
Definition: blockwise_gemm_dpp.hpp:320
static constexpr auto I2
Definition: blockwise_gemm_dpp.hpp:36
__host__ static constexpr __device__ auto MakeBBlockDescriptor_N0_N1_N2_K()
Definition: blockwise_gemm_dpp.hpp:242
static constexpr auto I3
Definition: blockwise_gemm_dpp.hpp:37
static __device__ auto GetWaveIdx()
Definition: blockwise_gemm_dpp.hpp:68
__host__ static constexpr __device__ auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2()
Definition: blockwise_gemm_dpp.hpp:148
static constexpr auto I1
Definition: blockwise_gemm_dpp.hpp:35
static constexpr index_t MPerBlock
Definition: blockwise_gemm_dpp.hpp:41
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: blockwise_gemm_dpp.hpp:196
__host__ static constexpr __device__ auto MakeABlockDescriptor_M0_M1_M2_K()
Definition: blockwise_gemm_dpp.hpp:230
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >)
Definition: blockwise_gemm_dpp.hpp:101
__host__ static constexpr __device__ auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition: blockwise_gemm_dpp.hpp:213
static constexpr index_t WaveSize
Definition: blockwise_gemm_dpp.hpp:48
static constexpr auto I0
Definition: blockwise_gemm_dpp.hpp:34
static constexpr auto b_thread_desc_
Definition: blockwise_gemm_dpp.hpp:316
__host__ static constexpr __device__ auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_N2()
Definition: blockwise_gemm_dpp.hpp:158
static constexpr auto b_block_desc_n0_n1_n2_k
Definition: blockwise_gemm_dpp.hpp:255
static constexpr index_t NPerBlock
Definition: blockwise_gemm_dpp.hpp:42
__host__ static constexpr __device__ auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_N2()
Definition: blockwise_gemm_dpp.hpp:181
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, AccDataType, MRepeat *NRepeat, dpp_gemm.GetRegSizePerDpp(), true > c_thread_buf_
Definition: blockwise_gemm_dpp.hpp:64
static constexpr index_t MWaves
Definition: blockwise_gemm_dpp.hpp:46
static constexpr index_t KPerThread
Definition: blockwise_gemm_dpp.hpp:57
static __device__ auto CalculateBThreadOriginDataIndex_N0_N1_N2_K()
Definition: blockwise_gemm_dpp.hpp:90
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_dpp.hpp:66
__device__ void Run(const ABlockBuffer &a_block_buf, const BBlockBuffer &b_block_buf, CThreadBuffer &c_thread_buf) const
Definition: blockwise_gemm_dpp.hpp:258
static constexpr auto a_block_desc_m0_m1_m2_k
Definition: blockwise_gemm_dpp.hpp:254
static constexpr auto dpp_gemm
Definition: blockwise_gemm_dpp.hpp:55
static constexpr auto a_thread_desc_
Definition: blockwise_gemm_dpp.hpp:312
__host__ static constexpr __device__ auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2()
Definition: blockwise_gemm_dpp.hpp:168
static constexpr index_t B_K0
Definition: blockwise_gemm_dpp.hpp:51
__host__ __device__ BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2()
Definition: blockwise_gemm_dpp.hpp:129
static constexpr index_t A_K1
Definition: blockwise_gemm_dpp.hpp:52
AThreadCopy a_thread_copy_
Definition: blockwise_gemm_dpp.hpp:343
Definition: dpp_gemm.hpp:426
Definition: sequence.hpp:43
Definition: static_buffer.hpp:75
static __device__ index_t GetThreadId()
Definition: thread_group.hpp:19
static constexpr __device__ index_t GetNumOfThread()
Definition: thread_group.hpp:15
__device__ void Run(const SrcDesc &, const SrcRefToOriginDisplacement &, const SrcBuffer &src_buf, const DstDesc &, const DstOriginIdx &, DstBuffer &dst_buf) const
Definition: threadwise_tensor_slice_transfer.hpp:1293
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10