/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp Source File
blockwise_gemm_dlops_v3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
5 #define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
6 
7 #include "common_header.hpp"
9 
10 namespace ck {
11 
12 template <index_t BlockSize,
13  typename FloatA,
14  typename FloatB,
15  typename FloatC,
16  typename ABlockDesc_E1_K1_E2,
17  typename BBlockDesc_E1_N_Ho_Wo_E2,
18  typename CThreadDesc_K_N_Ho_Wo,
19  index_t EPerThreadLoop,
20  index_t KPerThreadLoop>
22 {
23  static constexpr auto I0 = Number<0>{};
24  static constexpr auto I1 = Number<1>{};
25  static constexpr auto I2 = Number<2>{};
26  static constexpr auto I3 = Number<3>{};
27  static constexpr auto I4 = Number<4>{};
28 
32 
33  static constexpr auto E1 = ABlockDesc_E1_K1_E2{}.GetLength(I0);
34  static constexpr auto KPerBlock = ABlockDesc_E1_K1_E2{}.GetLength(I1);
35  static constexpr auto E2 = ABlockDesc_E1_K1_E2{}.GetLength(I2);
36 
37  static constexpr auto HoPerBlock = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I2);
38  static constexpr auto WoPerBlock = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I3);
39 
40  static constexpr auto KPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I0);
41  static constexpr auto HoPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I2);
42  static constexpr auto WoPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I3);
43 
46 
47  static constexpr auto b_thread_mtx_ =
49  Number<1>{},
50  Number<HoPerThread>{},
51  Number<WoPerThread>{},
52  Number<E2>{}));
53 
55  Number<KPerThreadLoop>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
56 
58  : c_thread_origin_data_idx_{GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id())},
59  a_thread_copy_{make_tuple(0, c_thread_origin_data_idx_[I0] * KPerThread, 0)}
60  {
61  static_assert(ABlockDesc_E1_K1_E2::IsKnownAtCompileTime() &&
62  BBlockDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() &&
63  CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(),
64  "wrong! Desc should be known at compile-time");
65 
66  static_assert(
67  ABlockDesc_E1_K1_E2{}.GetLength(I0) == BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I0) &&
68  ABlockDesc_E1_K1_E2{}.GetLength(I2) == BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I4),
69  "wrong! E dimension not consistent\n");
70 
71  static_assert(E1 % EPerThreadLoop == 0, "");
72  static_assert(KPerThread % KPerThreadLoop == 0, "");
73 
74  static_assert(KPerBlock % KPerThread == 0 && HoPerBlock % HoPerThread == 0 &&
75  WoPerBlock % WoPerThread == 0,
76  "wrong! Cannot evenly divide work among\n");
77 
78  constexpr auto KThreadCluster = KPerBlock / KPerThread;
79  constexpr auto HThreadCluster = HoPerBlock / HoPerThread;
80  constexpr auto WThreadCluster = WoPerBlock / WoPerThread;
81 
82  static_assert(BlockSize == KThreadCluster * HThreadCluster * WThreadCluster,
83  "wrong! wrong blocksize\n");
84  }
85 
86  __device__ static constexpr auto GetCThreadDesc_K_N_Ho_WoLengths()
87  {
89  }
90 
91  __device__ static CIndex GetBeginOfCThreadDesc_K_N_Ho_Wo(index_t thread_id)
92  {
93  constexpr auto K0 = KPerBlock / KPerThread;
94  constexpr auto N0 = I1;
95  constexpr auto H0 = HoPerBlock / HoPerThread;
96  constexpr auto W0 = WoPerBlock / WoPerThread;
97 
98  constexpr auto c_threadid_to_k_n_h_w_thread_cluster_adaptor =
100  make_tuple(make_merge_transform(make_tuple(K0, N0, H0, W0))),
103 
104  const auto c_k_n_h_w_thread_cluster_idx =
105  c_threadid_to_k_n_h_w_thread_cluster_adaptor.CalculateBottomIndex(
106  make_multi_index(thread_id));
107 
108  return c_k_n_h_w_thread_cluster_idx;
109  }
110 
111  template <typename ABlockBuffer, typename BThreadBuffer, typename CThreadBuffer>
112  __device__ void Run(const ABlockBuffer& a_block_buf,
113  const BThreadBuffer& b_thread_buf,
114  CThreadBuffer& c_thread_buf) const
115  {
116  static_assert(
120  "wrong! inconsistent type");
121 
122  constexpr auto a_block_mtx = ABlockDesc_E1_K1_E2{};
123 
124  // thread A buffer for GEMM
125  StaticBuffer<AddressSpaceEnum::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true>
126  a_thread_buf;
127 
128  constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA,
129  FloatB,
130  FloatC,
131  decltype(a_thread_mtx_),
132  decltype(b_thread_mtx_),
133  decltype(c_thread_mtx_)>{};
134 
135  static_for<0, E1, EPerThreadLoop>{}([&](auto e_begin) {
136  static_for<0, KPerThread, KPerThreadLoop>{}([&](auto k_begin) {
137  a_thread_copy_.Run(a_block_mtx,
138  make_tuple(e_begin, k_begin, I0),
139  a_block_buf,
141  make_tuple(I0, I0, I0),
142  a_thread_buf);
143 
144  threadwise_gemm.Run(a_thread_buf,
145  make_tuple(I0, I0, I0),
146  b_thread_buf,
147  make_tuple(e_begin, I0, I0, I0, I0),
148  c_thread_buf,
149  make_tuple(k_begin, I0, I0, I0));
150  });
151  });
152  }
153 
154  template <typename ABlockSliceMoveStepIdx>
155  __device__ void MoveABlockSliceWindow(const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx)
156  {
157  a_thread_copy_.MoveSrcSliceWindow(ABlockDesc_E1_K1_E2{}, a_block_slice_move_step_idx);
158  }
159 
160  private:
161  using AThreadCopy =
163  FloatA,
164  ABlockDesc_E1_K1_E2,
165  decltype(a_thread_mtx_),
168  2,
169  E2,
170  E2>;
171 
172  CIndex c_thread_origin_data_idx_;
173 
174  AThreadCopy a_thread_copy_;
175 };
176 
177 } // namespace ck
178 #endif
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:298
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Definition: array.hpp:14
Definition: blockwise_gemm_dlops_v3.hpp:22
static constexpr auto E1
Definition: blockwise_gemm_dlops_v3.hpp:33
static constexpr auto b_thread_mtx_
Definition: blockwise_gemm_dlops_v3.hpp:47
__device__ void MoveABlockSliceWindow(const ABlockSliceMoveStepIdx &a_block_slice_move_step_idx)
Definition: blockwise_gemm_dlops_v3.hpp:155
static constexpr auto I4
Definition: blockwise_gemm_dlops_v3.hpp:27
static constexpr auto E2
Definition: blockwise_gemm_dlops_v3.hpp:35
static constexpr auto WoPerBlock
Definition: blockwise_gemm_dlops_v3.hpp:38
static constexpr auto KPerBlock
Definition: blockwise_gemm_dlops_v3.hpp:34
__device__ void Run(const ABlockBuffer &a_block_buf, const BThreadBuffer &b_thread_buf, CThreadBuffer &c_thread_buf) const
Definition: blockwise_gemm_dlops_v3.hpp:112
static constexpr __device__ auto GetCThreadDesc_K_N_Ho_WoLengths()
Definition: blockwise_gemm_dlops_v3.hpp:86
static constexpr auto I1
Definition: blockwise_gemm_dlops_v3.hpp:24
static constexpr auto HoPerBlock
Definition: blockwise_gemm_dlops_v3.hpp:37
static constexpr auto c_thread_mtx_
Definition: blockwise_gemm_dlops_v3.hpp:54
static constexpr auto a_thread_mtx_
Definition: blockwise_gemm_dlops_v3.hpp:44
static constexpr auto HoPerThread
Definition: blockwise_gemm_dlops_v3.hpp:41
static constexpr auto I0
Definition: blockwise_gemm_dlops_v3.hpp:23
static __device__ CIndex GetBeginOfCThreadDesc_K_N_Ho_Wo(index_t thread_id)
Definition: blockwise_gemm_dlops_v3.hpp:91
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3()
Definition: blockwise_gemm_dlops_v3.hpp:57
static constexpr auto KPerThread
Definition: blockwise_gemm_dlops_v3.hpp:40
static constexpr auto I3
Definition: blockwise_gemm_dlops_v3.hpp:26
MultiIndex< 4 > CIndex
Definition: blockwise_gemm_dlops_v3.hpp:31
static constexpr auto I2
Definition: blockwise_gemm_dlops_v3.hpp:25
static constexpr auto WoPerThread
Definition: blockwise_gemm_dlops_v3.hpp:42
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
Definition: threadwise_gemm_dlops_v3.hpp:29
Definition: threadwise_tensor_slice_transfer.hpp:1260
__device__ void Run(const SrcDesc &, const SrcRefToOriginDisplacement &, const SrcBuffer &src_buf, const DstDesc &, const DstOriginIdx &, DstBuffer &dst_buf) const
Definition: threadwise_tensor_slice_transfer.hpp:1293
__device__ void MoveSrcSliceWindow(const SrcDesc &, const SrcSliceMoveStepIdx &src_slice_move_step_idx)
Definition: threadwise_tensor_slice_transfer.hpp:1683
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: functional2.hpp:33