/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp Source File
threadwise_gemm_dlops_v3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #ifndef CK_THREADWISE_GEMM_DLOPS_V3_HPP
5 #define CK_THREADWISE_GEMM_DLOPS_V3_HPP
6 
7 #include "common_header.hpp"
8 #include "math.hpp"
9 
10 namespace ck {
11 
12 // C[M, N] += transpose(A[K, M]) * B[K, N]
13 // Element of matrix can be vectorized data
14 // Assume:
15 // 1. AThreadDesc_E1_K_E2, BThreadDesc_E1_N_Ho_Wo_E2, CThreadDesc_K_N_Ho_Wo are known at
16 // compile-time
17 // 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
18 template <typename FloatA,
19  typename FloatB,
20  typename FloatC,
21  typename AThreadDesc_E1_K_E2,
22  typename BThreadDesc_E1_N_Ho_Wo_E2,
23  typename CThreadDesc_K_N_Ho_Wo,
24  typename enable_if<AThreadDesc_E1_K_E2::IsKnownAtCompileTime() &&
25  BThreadDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() &&
26  CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(),
27  bool>::type = false>
29 {
30 
31  template <typename ABuffer,
32  typename AOriginIdx,
33  typename BBuffer,
34  typename BOriginIdx,
35  typename CBuffer,
36  typename COriginIdx>
37  __device__ static void Run(const ABuffer& a_buf,
38  AOriginIdx,
39  const BBuffer& b_buf,
40  BOriginIdx,
41  CBuffer& c_buf,
42  COriginIdx)
43  {
44 
45  static_assert(AThreadDesc_E1_K_E2::IsKnownAtCompileTime() &&
46  BThreadDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() &&
47  CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(),
48  "wrong! Desc should be known at compile-time");
49 
53  "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
54 
55  static_assert(
59  "wrong! inconsistent type");
60 
61  constexpr auto I0 = Number<0>{};
62  constexpr auto I1 = Number<1>{};
63  constexpr auto I2 = Number<2>{};
64  constexpr auto I3 = Number<3>{};
65 
66  constexpr auto E1 = AThreadDesc_E1_K_E2{}.GetLength(I0);
67  constexpr auto K = AThreadDesc_E1_K_E2{}.GetLength(I1);
68  constexpr auto E2 = AThreadDesc_E1_K_E2{}.GetLength(I2);
69 
70  constexpr auto Ho = BThreadDesc_E1_N_Ho_Wo_E2{}.GetLength(I2);
71  constexpr auto Wo = BThreadDesc_E1_N_Ho_Wo_E2{}.GetLength(I3);
72 
73  constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
74  constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
75  constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
76 
77  if constexpr((Ho % 2 == 0) && (Wo % 2 == 0))
78  {
79  constexpr auto SubHW = 2;
80 
81  static_for<0, K, 1>{}([&](auto k) {
82  static_for<0, Ho, SubHW>{}([&](auto h) {
83  static_for<0, Wo, SubHW>{}([&](auto w) {
84  static_for<0, E1, 1>{}([&](auto e1) {
85  static_for<0, E2, 1>{}([&](auto e2) {
86  constexpr index_t a_offset = AThreadDesc_E1_K_E2{}.CalculateOffset(
87  a_origin_idx + make_tuple(e1, k, e2));
88 
89  constexpr index_t b0_offset =
90  BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
91  b_origin_idx + make_tuple(e1, 0, h, w, e2));
92 
93  constexpr index_t b1_offset =
94  BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
95  b_origin_idx + make_tuple(e1, 0, h, w + 1, e2));
96 
97  constexpr index_t b2_offset =
98  BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
99  b_origin_idx + make_tuple(e1, 0, h + 1, w, e2));
100 
101  constexpr index_t b3_offset =
102  BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
103  b_origin_idx + make_tuple(e1, 0, h + 1, w + 1, e2));
104 
105  constexpr index_t c0_offset =
106  CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx +
107  make_tuple(k, 0, h, w));
108 
109  constexpr index_t c1_offset =
110  CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
111  c_origin_idx + make_tuple(k, 0, h, w + 1));
112 
113  constexpr index_t c2_offset =
114  CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
115  c_origin_idx + make_tuple(k, 0, h + 1, w));
116 
117  constexpr index_t c3_offset =
118  CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
119  c_origin_idx + make_tuple(k, 0, h + 1, w + 1));
120 
122  b_buf[Number<b0_offset>{}],
123  b_buf[Number<b1_offset>{}],
124  b_buf[Number<b2_offset>{}],
125  b_buf[Number<b3_offset>{}],
126  c_buf(Number<c0_offset>{}),
127  c_buf(Number<c1_offset>{}),
128  c_buf(Number<c2_offset>{}),
129  c_buf(Number<c3_offset>{}));
130  });
131  });
132  });
133  });
134  });
135  }
136  else
137  {
138 
139  static_for<0, K, 1>{}([&](auto k) {
140  static_for<0, Ho, 1>{}([&](auto h) {
141  static_for<0, Wo, 1>{}([&](auto w) {
142  static_for<0, E1, 1>{}([&](auto e1) {
143  static_for<0, E2, 1>{}([&](auto e2) {
144  constexpr index_t a_offset = AThreadDesc_E1_K_E2{}.CalculateOffset(
145  a_origin_idx + make_tuple(e1, k, e2));
146 
147  constexpr index_t b_offset =
148  BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
149  b_origin_idx + make_tuple(e1, 0, h, w, e2));
150 
151  constexpr index_t c_offset =
152  CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx +
153  make_tuple(k, 0, h, w));
154 
155  inner_product<FloatA, FloatB, FloatC>(a_buf[Number<a_offset>{}],
156  b_buf[Number<b_offset>{}],
157  c_buf(Number<c_offset>{}));
158  });
159  });
160  });
161  });
162  });
163  }
164  }
165 };
166 
167 } // namespace ck
168 #endif
Definition: ck.hpp:267
__device__ void amd_assembly_outer_product_1x4(float a, float b0, float b1, float b2, float b3, float &c0, float &c1, float &c2, float &c3)
Definition: amd_inline_asm.hpp:106
__host__ constexpr __device__ auto to_multi_index(const T &x)
Definition: array_multi_index.hpp:28
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:298
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Definition: threadwise_gemm_dlops_v3.hpp:29
static __device__ void Run(const ABuffer &a_buf, AOriginIdx, const BBuffer &b_buf, BOriginIdx, CBuffer &c_buf, COriginIdx)
Definition: threadwise_gemm_dlops_v3.hpp:37
Definition: integral_constant.hpp:20
Definition: is_known_at_compile_time.hpp:14
Definition: type.hpp:177
Definition: functional2.hpp:33