4 #ifndef CK_AMD_WMMA_HPP
5 #define CK_AMD_WMMA_HPP
12 #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
13 defined(__gfx1103__) || defined(__gfx11_generic__)
17 #if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
24 template <index_t MPerWave, index_t NPerWave>
30 template <
class FloatC>
37 #if defined(__gfx11__)
38 reg_c.template AsType<float8_t>()(
Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
39 reg_a, reg_b, reg_c.template AsType<float8_t>()[
Number<0>{}]);
49 template <index_t MPerWave, index_t NPerWave>
55 template <
class FloatC>
58 #if defined(__gfx11__)
59 reg_c.template AsType<float8_t>()(
Number<0>{}) =
60 __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(
61 reg_a, reg_b, reg_c.template AsType<float8_t>()[
Number<0>{}]);
71 template <index_t MPerWave, index_t NPerWave, index_t Opsel>
74 template <index_t Opsel>
77 template <
class FloatC>
83 #if defined(__gfx11__)
84 reg_c.template AsType<half16_t>()(
Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(
85 reg_a, reg_b, reg_c.template AsType<half16_t>()[
Number<0>{}], Opsel);
95 template <index_t MPerWave, index_t NPerWave, index_t Opsel>
98 template <index_t Opsel>
101 template <
class FloatC>
107 #if defined(__gfx11__)
108 reg_c.template AsType<bhalf16_t>()(
Number<0>{}) =
109 __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32(
110 reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[
Number<0>{}], Opsel);
120 template <index_t MPerWave, index_t NPerWave,
bool neg_a,
bool neg_b,
bool clamp>
123 template <
bool neg_a,
bool neg_b,
bool clamp>
126 template <
class FloatC>
129 #if defined(__gfx11__)
130 reg_c.template AsType<int32x8_t>()(
Number<0>{}) =
131 __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
133 bit_cast<int32x4_t>(reg_a),
135 bit_cast<int32x4_t>(reg_b),
136 reg_c.template AsType<int32x8_t>()[
Number<0>{}],
148 template <index_t MPerWave, index_t NPerWave>
154 template <
class FloatC>
157 #if defined(__gfx11__)
158 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64(
159 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}]);
169 template <index_t MPerWave, index_t NPerWave>
175 template <
class FloatC>
178 #if defined(__gfx11__)
179 reg_c.template AsType<float4_t>()(
Number<0>{}) =
180 __builtin_amdgcn_wmma_f32_16x16x16_bf16_w64(
181 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}]);
191 template <index_t MPerWave, index_t NPerWave, index_t Opsel>
194 template <index_t Opsel>
197 template <
class FloatC>
203 #if defined(__gfx11__)
204 reg_c.template AsType<half8_t>()(
Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64(
205 reg_a, reg_b, reg_c.template AsType<half8_t>()[
Number<0>{}], Opsel);
215 template <index_t MPerWave, index_t NPerWave, index_t Opsel>
218 template <index_t Opsel>
221 template <
class FloatC>
227 #if defined(__gfx11__)
228 reg_c.template AsType<bhalf8_t>()(
Number<0>{}) =
229 __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64(
230 reg_a, reg_b, reg_c.template AsType<bhalf8_t>()[
Number<0>{}], Opsel);
240 template <index_t MPerWave, index_t NPerWave,
bool neg_a,
bool neg_b,
bool clamp>
243 template <
bool neg_a,
bool neg_b,
bool clamp>
246 template <
class FloatC>
249 #if defined(__gfx11__)
250 reg_c.template AsType<int32x4_t>()(
Number<0>{}) =
251 __builtin_amdgcn_wmma_i32_16x16x16_iu8_w64(
253 bit_cast<int32x4_t>(reg_a),
255 bit_cast<int32x4_t>(reg_b),
256 reg_c.template AsType<int32x4_t>()[
Number<0>{}],
270 template <index_t MPerWave, index_t NPerWave>
276 template <
class FloatC>
283 #if defined(__gfx12__)
284 reg_c.template AsType<float8_t>()(
Number<0>{}) =
285 __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(
286 reg_a, reg_b, reg_c.template AsType<float8_t>()[
Number<0>{}]);
296 template <index_t MPerWave, index_t NPerWave>
302 template <
class FloatC>
305 #if defined(__gfx12__)
306 reg_c.template AsType<float8_t>()(
Number<0>{}) =
307 __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(
308 reg_a, reg_b, reg_c.template AsType<float8_t>()[
Number<0>{}]);
318 template <index_t MPerWave, index_t NPerWave,
bool neg_a,
bool neg_b,
bool clamp>
321 template <
bool neg_a,
bool neg_b,
bool clamp>
324 template <
class FloatC>
327 #if defined(__gfx12__)
328 reg_c.template AsType<int32x8_t>()(
Number<0>{}) =
329 __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
331 bit_cast<int32x2_t>(reg_a),
333 bit_cast<int32x2_t>(reg_b),
334 reg_c.template AsType<int32x8_t>()[
Number<0>{}],
345 template <index_t MPerWave, index_t NPerWave>
351 template <
class FloatC>
352 __device__
static void Run(
const f8x8_t& reg_a,
const f8x8_t& reg_b, FloatC& reg_c)
354 #if defined(__gfx12__)
355 reg_c.template AsType<float8_t>()(
Number<0>{}) =
356 __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12(
357 bit_cast<int32x2_t>(reg_a),
358 bit_cast<int32x2_t>(reg_b),
359 reg_c.template AsType<float8_t>()[
Number<0>{}]);
369 template <index_t MPerWave, index_t NPerWave>
375 template <
class FloatC>
376 __device__
static void Run(
const f8x8_t& reg_a,
const bf8x8_t& reg_b, FloatC& reg_c)
378 #if defined(__gfx12__)
379 reg_c.template AsType<float8_t>()(
Number<0>{}) =
380 __builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12(
381 bit_cast<int32x2_t>(reg_a),
382 bit_cast<int32x2_t>(reg_b),
383 reg_c.template AsType<float8_t>()[
Number<0>{}]);
393 template <index_t MPerWave, index_t NPerWave>
399 template <
class FloatC>
400 __device__
static void Run(
const bf8x8_t& reg_a,
const f8x8_t& reg_b, FloatC& reg_c)
402 #if defined(__gfx12__)
403 reg_c.template AsType<float8_t>()(
Number<0>{}) =
404 __builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12(
405 bit_cast<int32x2_t>(reg_a),
406 bit_cast<int32x2_t>(reg_b),
407 reg_c.template AsType<float8_t>()[
Number<0>{}]);
417 template <index_t MPerWave, index_t NPerWave>
423 template <
class FloatC>
426 #if defined(__gfx12__)
427 reg_c.template AsType<float8_t>()(
Number<0>{}) =
428 __builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12(
429 bit_cast<int32x2_t>(reg_a),
430 bit_cast<int32x2_t>(reg_b),
431 reg_c.template AsType<float8_t>()[
Number<0>{}]);
__host__ constexpr __device__ T clamp(const T &x, const T &lowerbound, const T &upperbound)
Definition: math.hpp:148
bf8_t bf8x8_t
Definition: vector_type.hpp:227
typename vector_type< bhalf_t, 8 >::type bhalf8_t
Definition: dtype_vector.hpp:2148
typename vector_type< int8_t, 8 >::type int8x8_t
Definition: dtype_vector.hpp:2164
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
typename vector_type< int8_t, 16 >::type int8x16_t
Definition: dtype_vector.hpp:2165
typename vector_type< bhalf_t, 16 >::type bhalf16_t
Definition: dtype_vector.hpp:2149
typename vector_type< half_t, 16 >::type half16_t
Definition: dtype_vector.hpp:2142
typename vector_type< half_t, 8 >::type half8_t
Definition: dtype_vector.hpp:2141
Definition: integral_constant.hpp:20
static __device__ void Run(const bhalf16_t ®_a, const bhalf16_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:102
Definition: amd_wmma.hpp:96
static __device__ void Run(const bhalf16_t ®_a, const bhalf16_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:222
Definition: amd_wmma.hpp:216
static __device__ void Run(const half16_t ®_a, const half16_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:78
Definition: amd_wmma.hpp:72
static __device__ void Run(const half16_t ®_a, const half16_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:198
Definition: amd_wmma.hpp:192
static __device__ void Run(const bhalf16_t ®_a, const bhalf16_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:56
static __device__ void Run(const bhalf8_t ®_a, const bhalf8_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:303
Definition: amd_wmma.hpp:297
Definition: amd_wmma.hpp:50
static __device__ void Run(const bhalf16_t ®_a, const bhalf16_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:176
Definition: amd_wmma.hpp:170
static __device__ void Run(const bf8x8_t ®_a, const bf8x8_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:424
Definition: amd_wmma.hpp:418
static __device__ void Run(const bf8x8_t ®_a, const f8x8_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:400
Definition: amd_wmma.hpp:394
static __device__ void Run(const half16_t ®_a, const half16_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:31
static __device__ void Run(const half8_t ®_a, const half8_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:277
Definition: amd_wmma.hpp:271
Definition: amd_wmma.hpp:25
static __device__ void Run(const half16_t ®_a, const half16_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:155
Definition: amd_wmma.hpp:149
static __device__ void Run(const f8x8_t ®_a, const bf8x8_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:376
Definition: amd_wmma.hpp:370
static __device__ void Run(const f8x8_t ®_a, const f8x8_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:352
Definition: amd_wmma.hpp:346
static __device__ void Run(const int8x16_t ®_a, const int8x16_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:127
static __device__ void Run(const int8x8_t ®_a, const int8x8_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:325
Definition: amd_wmma.hpp:319
Definition: amd_wmma.hpp:121
static __device__ void Run(const int8x16_t ®_a, const int8x16_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:247
Definition: amd_wmma.hpp:241