/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp Source File
blockwise_gemm_dlops_v2r2.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP
5 #define CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP
6 
7 #include "common_header.hpp"
8 #include "tensor_adaptor.hpp"
10 #include "threadwise_contraction_dlops.hpp"
11 
12 namespace ck {
13 
14 // C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1]
15 // A and B are visable to the whole block, C is distributed among each thread
16 // Assume:
17 // 1. A:
18 // 1. AKMBlockDesc is known at compile-time
19 // 2. ABlockBuffer is DynamicBuffer
20 // 2. B:
21 // 1. BKNBlockDesc is known at compile-time
22 // 2. BBlockBuffer is DynamicBuffer
23 // 3. C:
24 // 1. CM0M1N0N1ThreadDesc is known at compile-time
25 // 2. CThreadBuffer is StaticBuffer
26 // Also assume:
27 // M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
28 template <
29  index_t BlockSize,
30  typename FloatA,
31  typename FloatB,
32  typename FloatC,
33  typename AKMBlockDesc,
34  typename BKNBlockDesc,
35  index_t M1PerThreadM11,
36  index_t N1PerThreadN11,
37  index_t KPerThread,
38  index_t M1N1ThreadClusterM100,
39  index_t M1N1ThreadClusterN100,
40  index_t M1N1ThreadClusterM101,
41  index_t M1N1ThreadClusterN101,
42  index_t AThreadCopyScalarPerVector_M11,
43  index_t BThreadCopyScalarPerVector_N11,
44  typename enable_if<AKMBlockDesc::IsKnownAtCompileTime() && BKNBlockDesc::IsKnownAtCompileTime(),
45  bool>::type = false>
47 {
51 
52  static constexpr auto I0 = Number<0>{};
53  static constexpr auto I1 = Number<1>{};
54  static constexpr auto I2 = Number<2>{};
55  static constexpr auto I3 = Number<3>{};
56 
57  static constexpr index_t K = AKMBlockDesc{}.GetLength(I0);
58  static constexpr index_t M = AKMBlockDesc{}.GetLength(I1);
59  static constexpr index_t N = BKNBlockDesc{}.GetLength(I1);
60 
61  static constexpr index_t M100 = M1N1ThreadClusterM100;
62  static constexpr index_t N100 = M1N1ThreadClusterN100;
63 
64  static constexpr index_t M101 = M1N1ThreadClusterM101;
65  static constexpr index_t N101 = M1N1ThreadClusterN101;
66 
67  static constexpr index_t M11 = M1PerThreadM11;
68  static constexpr index_t N11 = N1PerThreadN11;
69 
70  static constexpr index_t M1 = M1N1ThreadClusterM100 * M1N1ThreadClusterM101 * M1PerThreadM11;
71  static constexpr index_t N1 = M1N1ThreadClusterN100 * M1N1ThreadClusterN101 * N1PerThreadN11;
72 
73  static constexpr index_t M0 = M / M1;
74  static constexpr index_t N0 = N / N1;
75 
76  __host__ __device__ static constexpr auto
77  MakeAKM0M1BlockDescriptor(const AKMBlockDesc& /* a_k_m_block_desc */)
78  {
79  const auto a_k_m0_m1_block_desc = transform_tensor_descriptor(
80  AKMBlockDesc{},
85 
86  return a_k_m0_m1_block_desc;
87  }
88 
89  __host__ __device__ static constexpr auto
90  MakeBKN0N1BlockDescriptor(const BKNBlockDesc& /* b_k_n_block_desc */)
91  {
92  const auto b_k_n0_n1_block_desc = transform_tensor_descriptor(
93  BKNBlockDesc{},
98 
99  return b_k_n0_n1_block_desc;
100  }
101 
102  __host__ __device__ static constexpr auto MakeCM0M100M101M11N0N100N101N11ToMNBlockAdaptor()
103  {
104  // upper: [M0, M100, M101, M11, N0, N100, N101, N11]
105  // lower: [M, N]
106  constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor =
114 
115  return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor;
116  }
117 
118  __host__ __device__ static constexpr auto
120  {
121  // upper: [M0, M100, M101, M11, N0, N100, N101, N11]
122  // lower: [M0, M1, N0, N1]
123  constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor =
133 
134  return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor;
135  }
136 
137  __host__ __device__ static constexpr auto GetCM0M1N0N1ThreadTensorLengths()
138  {
140  }
141 
142  static constexpr auto a_k_m0_m1_block_desc_ = MakeAKM0M1BlockDescriptor(AKMBlockDesc{});
143  static constexpr auto b_k_n0_n1_block_desc_ = MakeBKN0N1BlockDescriptor(BKNBlockDesc{});
144 
145  public:
147  : c_thread_origin_data_idx_{CalculateCM0M1N0N1ThreadOriginOnBlock(
149  a_thread_copy_{
150  make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])},
151  b_thread_copy_{
152  make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3])}
153  {
154  static_assert(AKMBlockDesc::IsKnownAtCompileTime() && BKNBlockDesc::IsKnownAtCompileTime(),
155  "wrong! Desc should be known at compile-time");
156 
157  static_assert(BlockSize == M101 * M100 * N101 * N100,
158  "wrong! blocksize and cluster size not consistent");
159 
160  static_assert(M % M1 == 0 && N % N1 == 0, "wrong!");
161 
162  static_assert(AKMBlockDesc{}.GetLength(I0) == BKNBlockDesc{}.GetLength(I0),
163  "wrong! K dimension not consistent");
164 
165  // TODO: remove this restriction
166  static_assert(M0 == 2 && N0 == 2, "wrong");
167  }
168 
170  {
171  // lower: [M0, M1, N0, N1]
172  // upper: [M0, M100, M101, M11, N0, N100, N101, N11]
173  constexpr auto adaptor0 = MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor();
174 
175  // lower: [M0, M100, M101, M11, N0, N100, N101, N11]
176  // upper: [Tid, M0, M11, N0, N11]
177  constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
183  make_tuple(
186 
187  constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1);
188 
189  return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0));
190  }
191 
192  __host__ __device__ static constexpr index_t GetABlockAlignment() { return M1PerThreadM11; }
193 
194  __host__ __device__ static constexpr auto GetBBlockAlignment() { return N1PerThreadN11; }
195 
196  template <typename CM0M1N0N1ThreadDesc,
197  typename ABlockBuffer,
198  typename BBlockBuffer,
199  typename CThreadBuffer>
200  __device__ void Run(const CM0M1N0N1ThreadDesc& /* c_m0_m1_n0_n1_thread_desc */,
201  const ABlockBuffer& a_block_buf,
202  const BBlockBuffer& b_block_buf,
203  CThreadBuffer& c_thread_buf) const
204  {
205  static_assert(CM0M1N0N1ThreadDesc::IsKnownAtCompileTime(),
206  "wrong! Desc should be known at compile-time");
207 
208  // TODO: remove this restriction
209  static_assert(M0 == 2 && N0 == 2 && CM0M1N0N1ThreadDesc{}.GetLength(I0) == M0 &&
210  CM0M1N0N1ThreadDesc{}.GetLength(I2) == N0,
211  "wrong");
212 
213  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
214  a_k_m0_m1_thread_desc_.GetElementSpaceSize());
215  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
216  b_k_n0_n1_thread_desc_.GetElementSpaceSize());
217 
218  constexpr auto threadwise_gemm =
219  ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1<FloatA,
220  FloatB,
221  FloatC,
222  decltype(a_k_m0_m1_thread_desc_),
223  decltype(b_k_n0_n1_thread_desc_),
224  CM0M1N0N1ThreadDesc,
228 
229  // read A_sub_0
230  a_thread_copy_.Run(a_k_m0_m1_block_desc_,
231  make_tuple(I0, I0, I0),
232  a_block_buf,
233  a_k_m0_m1_thread_desc_,
234  make_tuple(I0, I0, I0),
235  a_thread_buf);
236 
237  // read B_sub_0
238  b_thread_copy_.Run(b_k_n0_n1_block_desc_,
239  make_tuple(I0, I0, I0),
240  b_block_buf,
241  b_k_n0_n1_thread_desc_,
242  make_tuple(I0, I0, I0),
243  b_thread_buf);
244 
245  // read B_sub_1
246  b_thread_copy_.Run(b_k_n0_n1_block_desc_,
247  make_tuple(I0, I1, I0),
248  b_block_buf,
249  b_k_n0_n1_thread_desc_,
250  make_tuple(I0, I1, I0),
251  b_thread_buf);
252 
253  // read A_sub_1
254  a_thread_copy_.Run(a_k_m0_m1_block_desc_,
255  make_tuple(I0, I1, I0),
256  a_block_buf,
257  a_k_m0_m1_thread_desc_,
258  make_tuple(I0, I1, I0),
259  a_thread_buf);
260 
261  // C_sub_00 += transpose(A_sub_0) * B_sub_0
262  threadwise_gemm.Run(a_thread_buf,
263  make_tuple(I0, I0, I0),
264  b_thread_buf,
265  make_tuple(I0, I0, I0),
266  c_thread_buf,
267  make_tuple(I0, I0, I0, I0));
268 
269  // C_sub_01 += transpose(A_sub_0) * B_sub_1
270  threadwise_gemm.Run(a_thread_buf,
271  make_tuple(I0, I0, I0),
272  b_thread_buf,
273  make_tuple(I0, I1, I0),
274  c_thread_buf,
275  make_tuple(I0, I0, I1, I0));
276 
277  // loop over rest of k
279  // read A_sub_0
280  a_thread_copy_.Run(a_k_m0_m1_block_desc_,
281  make_tuple(k, I0, I0),
282  a_block_buf,
283  a_k_m0_m1_thread_desc_,
284  make_tuple(I0, I0, I0),
285  a_thread_buf);
286 
287  // C_sub_10 += transpose(A_sub_1) * B_sub_0
288  threadwise_gemm.Run(a_thread_buf,
289  make_tuple(I0, I1, I0),
290  b_thread_buf,
291  make_tuple(I0, I0, I0),
292  c_thread_buf,
293  make_tuple(I1, I0, I0, I0));
294 
295  // read B_sub_0
296  b_thread_copy_.Run(b_k_n0_n1_block_desc_,
297  make_tuple(k, I0, I0),
298  b_block_buf,
299  b_k_n0_n1_thread_desc_,
300  make_tuple(I0, I0, I0),
301  b_thread_buf);
302 
303  // C_sub_11 += transpose(A_sub_1) * B_sub_1
304  threadwise_gemm.Run(a_thread_buf,
305  make_tuple(I0, I1, I0),
306  b_thread_buf,
307  make_tuple(I0, I1, I0),
308  c_thread_buf,
309  make_tuple(I1, I0, I1, I0));
310 
311  // read B_sub_1
312  b_thread_copy_.Run(b_k_n0_n1_block_desc_,
313  make_tuple(k, I1, I0),
314  b_block_buf,
315  b_k_n0_n1_thread_desc_,
316  make_tuple(I0, I1, I0),
317  b_thread_buf);
318 
319  // read A_sub_1
320  a_thread_copy_.Run(a_k_m0_m1_block_desc_,
321  make_tuple(k, I1, I0),
322  a_block_buf,
323  a_k_m0_m1_thread_desc_,
324  make_tuple(I0, I1, I0),
325  a_thread_buf);
326 
327  // C_sub_00 += transpose(A_sub_0) * B_sub_0
328  threadwise_gemm.Run(a_thread_buf,
329  make_tuple(I0, I0, I0),
330  b_thread_buf,
331  make_tuple(I0, I0, I0),
332  c_thread_buf,
333  make_tuple(I0, I0, I0, I0));
334 
335  // C_sub_01 += transpose(A_sub_0) * B_sub_1
336  threadwise_gemm.Run(a_thread_buf,
337  make_tuple(I0, I0, I0),
338  b_thread_buf,
339  make_tuple(I0, I1, I0),
340  c_thread_buf,
341  make_tuple(I0, I0, I1, I0));
342  });
343 
344  // C_sub_10 += transpose(A_sub_1) * B_sub_0
345  threadwise_gemm.Run(a_thread_buf,
346  make_tuple(I0, I1, I0),
347  b_thread_buf,
348  make_tuple(I0, I0, I0),
349  c_thread_buf,
350  make_tuple(I1, I0, I0, I0));
351 
352  // C_sub_11 += transpose(A_sub_1) * B_sub_1
353  threadwise_gemm.Run(a_thread_buf,
354  make_tuple(I0, I1, I0),
355  b_thread_buf,
356  make_tuple(I0, I1, I0),
357  c_thread_buf,
358  make_tuple(I1, I0, I1, I0));
359  }
360 
361  private:
362  // A[K, M0, M1]
363  static constexpr auto a_k_m0_m1_thread_desc_ = make_naive_tensor_descriptor_packed(
364  make_tuple(Number<KPerThread>{}, Number<M0>{}, Number<M1PerThreadM11>{}));
365 
366  // B[K, N0, N1]
367  static constexpr auto b_k_n0_n1_thread_desc_ = make_naive_tensor_descriptor_packed(
368  make_tuple(Number<KPerThread>{}, Number<N0>{}, Number<N1PerThreadN11>{}));
369 
370  using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
371  FloatA,
372  decltype(a_k_m0_m1_block_desc_),
373  decltype(a_k_m0_m1_thread_desc_),
374  Sequence<KPerThread, 1, M1PerThreadM11>,
375  Sequence<0, 1, 2>,
376  2,
377  AThreadCopyScalarPerVector_M11,
378  1>;
379 
380  using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
381  FloatB,
382  decltype(b_k_n0_n1_block_desc_),
383  decltype(b_k_n0_n1_thread_desc_),
384  Sequence<KPerThread, 1, N1PerThreadN11>,
385  Sequence<0, 1, 2>,
386  2,
387  BThreadCopyScalarPerVector_N11,
388  1>;
389 
390  CIndex c_thread_origin_data_idx_;
391 
392  AThreadCopy a_thread_copy_;
393  BThreadCopy b_thread_copy_;
394 };
395 
396 } // namespace ck
397 #endif
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto chain_tensor_adaptors(const TensorAdaptor0 &adaptor0, const TensorAdaptor1 &adaptor1)
Definition: tensor_adaptor.hpp:245
Definition: array.hpp:14
Definition: blockwise_gemm_dlops_v2r2.hpp:47
static constexpr auto b_k_n0_n1_block_desc_
Definition: blockwise_gemm_dlops_v2r2.hpp:143
static constexpr auto I2
Definition: blockwise_gemm_dlops_v2r2.hpp:54
static __device__ CIndex CalculateCM0M1N0N1ThreadOriginOnBlock(index_t thread_id)
Definition: blockwise_gemm_dlops_v2r2.hpp:169
__host__ static constexpr __device__ index_t GetABlockAlignment()
Definition: blockwise_gemm_dlops_v2r2.hpp:192
__host__ static constexpr __device__ auto MakeBKN0N1BlockDescriptor(const BKNBlockDesc &)
Definition: blockwise_gemm_dlops_v2r2.hpp:90
static constexpr auto I0
Definition: blockwise_gemm_dlops_v2r2.hpp:52
static constexpr index_t N101
Definition: blockwise_gemm_dlops_v2r2.hpp:65
MultiIndex< 4 > CIndex
Definition: blockwise_gemm_dlops_v2r2.hpp:50
static constexpr index_t M11
Definition: blockwise_gemm_dlops_v2r2.hpp:67
__host__ static constexpr __device__ auto GetBBlockAlignment()
Definition: blockwise_gemm_dlops_v2r2.hpp:194
static constexpr index_t N11
Definition: blockwise_gemm_dlops_v2r2.hpp:68
static constexpr index_t N
Definition: blockwise_gemm_dlops_v2r2.hpp:59
__host__ static constexpr __device__ auto GetCM0M1N0N1ThreadTensorLengths()
Definition: blockwise_gemm_dlops_v2r2.hpp:137
static constexpr index_t N1
Definition: blockwise_gemm_dlops_v2r2.hpp:71
static constexpr index_t N100
Definition: blockwise_gemm_dlops_v2r2.hpp:62
static constexpr index_t M0
Definition: blockwise_gemm_dlops_v2r2.hpp:73
static constexpr auto I1
Definition: blockwise_gemm_dlops_v2r2.hpp:53
__device__ void Run(const CM0M1N0N1ThreadDesc &, const ABlockBuffer &a_block_buf, const BBlockBuffer &b_block_buf, CThreadBuffer &c_thread_buf) const
Definition: blockwise_gemm_dlops_v2r2.hpp:200
static constexpr index_t K
Definition: blockwise_gemm_dlops_v2r2.hpp:57
static constexpr index_t M100
Definition: blockwise_gemm_dlops_v2r2.hpp:61
static constexpr auto a_k_m0_m1_block_desc_
Definition: blockwise_gemm_dlops_v2r2.hpp:142
static constexpr index_t N0
Definition: blockwise_gemm_dlops_v2r2.hpp:74
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2()
Definition: blockwise_gemm_dlops_v2r2.hpp:146
__host__ static constexpr __device__ auto MakeAKM0M1BlockDescriptor(const AKMBlockDesc &)
Definition: blockwise_gemm_dlops_v2r2.hpp:77
__host__ static constexpr __device__ auto MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor()
Definition: blockwise_gemm_dlops_v2r2.hpp:119
static constexpr auto I3
Definition: blockwise_gemm_dlops_v2r2.hpp:55
static constexpr index_t M101
Definition: blockwise_gemm_dlops_v2r2.hpp:64
__host__ static constexpr __device__ auto MakeCM0M100M101M11N0N100N101N11ToMNBlockAdaptor()
Definition: blockwise_gemm_dlops_v2r2.hpp:102
static constexpr index_t M1
Definition: blockwise_gemm_dlops_v2r2.hpp:70
static constexpr index_t M
Definition: blockwise_gemm_dlops_v2r2.hpp:58
Definition: sequence.hpp:43
__device__ void Run(const SrcDesc &, const SrcRefToOriginDisplacement &, const SrcBuffer &src_buf, const DstDesc &, const DstOriginIdx &, DstBuffer &dst_buf) const
Definition: threadwise_tensor_slice_transfer.hpp:1293
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33