4 #ifndef CK_AMD_WMMA_HPP
5 #define CK_AMD_WMMA_HPP
12 #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
13 defined(__gfx1103__) || defined(__gfx1150__) || defined(__gfx1151__) || \
14 defined(__gfx1152__) || defined(__gfx1153__) || defined(__gfx11_generic__)
18 #if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
25 template <index_t MPerWave, index_t NPerWave>
31 template <
class FloatC>
38 #if defined(__gfx11__)
39 reg_c.template AsType<float8_t>()(
Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
40 reg_a, reg_b, reg_c.template AsType<float8_t>()[
Number<0>{}]);
50 template <index_t MPerWave, index_t NPerWave>
56 template <
class FloatC>
59 #if defined(__gfx11__)
60 reg_c.template AsType<float8_t>()(
Number<0>{}) =
61 __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(
62 reg_a, reg_b, reg_c.template AsType<float8_t>()[
Number<0>{}]);
72 template <index_t MPerWave, index_t NPerWave, index_t Opsel>
75 template <index_t Opsel>
78 template <
class FloatC>
84 #if defined(__gfx11__)
85 reg_c.template AsType<half16_t>()(
Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(
86 reg_a, reg_b, reg_c.template AsType<half16_t>()[
Number<0>{}], Opsel);
96 template <index_t MPerWave, index_t NPerWave, index_t Opsel>
99 template <index_t Opsel>
102 template <
class FloatC>
108 #if defined(__gfx11__)
109 reg_c.template AsType<bhalf16_t>()(
Number<0>{}) =
110 __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32(
111 reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[
Number<0>{}], Opsel);
121 template <index_t MPerWave, index_t NPerWave,
bool neg_a,
bool neg_b,
bool clamp>
124 template <
bool neg_a,
bool neg_b,
bool clamp>
127 template <
class FloatC>
130 #if defined(__gfx11__)
131 reg_c.template AsType<int32x8_t>()(
Number<0>{}) =
132 __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
134 bit_cast<int32x4_t>(reg_a),
136 bit_cast<int32x4_t>(reg_b),
137 reg_c.template AsType<int32x8_t>()[
Number<0>{}],
149 template <index_t MPerWave, index_t NPerWave>
155 template <
class FloatC>
158 #if defined(__gfx11__)
159 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64(
160 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}]);
170 template <index_t MPerWave, index_t NPerWave>
176 template <
class FloatC>
179 #if defined(__gfx11__)
180 reg_c.template AsType<float4_t>()(
Number<0>{}) =
181 __builtin_amdgcn_wmma_f32_16x16x16_bf16_w64(
182 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}]);
192 template <index_t MPerWave, index_t NPerWave, index_t Opsel>
195 template <index_t Opsel>
198 template <
class FloatC>
204 #if defined(__gfx11__)
205 reg_c.template AsType<half8_t>()(
Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64(
206 reg_a, reg_b, reg_c.template AsType<half8_t>()[
Number<0>{}], Opsel);
216 template <index_t MPerWave, index_t NPerWave, index_t Opsel>
219 template <index_t Opsel>
222 template <
class FloatC>
228 #if defined(__gfx11__)
229 reg_c.template AsType<bhalf8_t>()(
Number<0>{}) =
230 __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64(
231 reg_a, reg_b, reg_c.template AsType<bhalf8_t>()[
Number<0>{}], Opsel);
241 template <index_t MPerWave, index_t NPerWave,
bool neg_a,
bool neg_b,
bool clamp>
244 template <
bool neg_a,
bool neg_b,
bool clamp>
247 template <
class FloatC>
250 #if defined(__gfx11__)
251 reg_c.template AsType<int32x4_t>()(
Number<0>{}) =
252 __builtin_amdgcn_wmma_i32_16x16x16_iu8_w64(
254 bit_cast<int32x4_t>(reg_a),
256 bit_cast<int32x4_t>(reg_b),
257 reg_c.template AsType<int32x4_t>()[
Number<0>{}],
271 template <index_t MPerWave, index_t NPerWave>
277 template <
class FloatC>
284 #if defined(__gfx12__)
285 reg_c.template AsType<float8_t>()(
Number<0>{}) =
286 __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(
287 reg_a, reg_b, reg_c.template AsType<float8_t>()[
Number<0>{}]);
297 template <index_t MPerWave, index_t NPerWave>
303 template <
class FloatC>
306 #if defined(__gfx12__)
307 reg_c.template AsType<float8_t>()(
Number<0>{}) =
308 __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(
309 reg_a, reg_b, reg_c.template AsType<float8_t>()[
Number<0>{}]);
319 template <index_t MPerWave, index_t NPerWave,
bool neg_a,
bool neg_b,
bool clamp>
322 template <
bool neg_a,
bool neg_b,
bool clamp>
325 template <
class FloatC>
328 #if defined(__gfx12__)
329 reg_c.template AsType<int32x8_t>()(
Number<0>{}) =
330 __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
332 bit_cast<int32x2_t>(reg_a),
334 bit_cast<int32x2_t>(reg_b),
335 reg_c.template AsType<int32x8_t>()[
Number<0>{}],
346 template <index_t MPerWave, index_t NPerWave>
352 template <
class FloatC>
353 __device__
static void Run(
const f8x8_t& reg_a,
const f8x8_t& reg_b, FloatC& reg_c)
355 #if defined(__gfx12__)
356 reg_c.template AsType<float8_t>()(
Number<0>{}) =
357 __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12(
358 bit_cast<int32x2_t>(reg_a),
359 bit_cast<int32x2_t>(reg_b),
360 reg_c.template AsType<float8_t>()[
Number<0>{}]);
370 template <index_t MPerWave, index_t NPerWave>
376 template <
class FloatC>
377 __device__
static void Run(
const f8x8_t& reg_a,
const bf8x8_t& reg_b, FloatC& reg_c)
379 #if defined(__gfx12__)
380 reg_c.template AsType<float8_t>()(
Number<0>{}) =
381 __builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12(
382 bit_cast<int32x2_t>(reg_a),
383 bit_cast<int32x2_t>(reg_b),
384 reg_c.template AsType<float8_t>()[
Number<0>{}]);
394 template <index_t MPerWave, index_t NPerWave>
400 template <
class FloatC>
401 __device__
static void Run(
const bf8x8_t& reg_a,
const f8x8_t& reg_b, FloatC& reg_c)
403 #if defined(__gfx12__)
404 reg_c.template AsType<float8_t>()(
Number<0>{}) =
405 __builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12(
406 bit_cast<int32x2_t>(reg_a),
407 bit_cast<int32x2_t>(reg_b),
408 reg_c.template AsType<float8_t>()[
Number<0>{}]);
418 template <index_t MPerWave, index_t NPerWave>
424 template <
class FloatC>
427 #if defined(__gfx12__)
428 reg_c.template AsType<float8_t>()(
Number<0>{}) =
429 __builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12(
430 bit_cast<int32x2_t>(reg_a),
431 bit_cast<int32x2_t>(reg_b),
432 reg_c.template AsType<float8_t>()[
Number<0>{}]);
__host__ constexpr __device__ T clamp(const T &x, const T &lowerbound, const T &upperbound)
Definition: math.hpp:148
bf8_t bf8x8_t
Definition: vector_type.hpp:239
typename vector_type< bhalf_t, 8 >::type bhalf8_t
Definition: dtype_vector.hpp:2163
typename vector_type< int8_t, 8 >::type int8x8_t
Definition: dtype_vector.hpp:2179
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
typename vector_type< int8_t, 16 >::type int8x16_t
Definition: dtype_vector.hpp:2180
typename vector_type< bhalf_t, 16 >::type bhalf16_t
Definition: dtype_vector.hpp:2164
typename vector_type< half_t, 16 >::type half16_t
Definition: dtype_vector.hpp:2157
typename vector_type< half_t, 8 >::type half8_t
Definition: dtype_vector.hpp:2156
Definition: integral_constant.hpp:20
static __device__ void Run(const bhalf16_t ®_a, const bhalf16_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:103
Definition: amd_wmma.hpp:97
static __device__ void Run(const bhalf16_t ®_a, const bhalf16_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:223
Definition: amd_wmma.hpp:217
static __device__ void Run(const half16_t ®_a, const half16_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:79
Definition: amd_wmma.hpp:73
static __device__ void Run(const half16_t ®_a, const half16_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:199
Definition: amd_wmma.hpp:193
static __device__ void Run(const bhalf16_t ®_a, const bhalf16_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:57
static __device__ void Run(const bhalf8_t ®_a, const bhalf8_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:304
Definition: amd_wmma.hpp:298
Definition: amd_wmma.hpp:51
static __device__ void Run(const bhalf16_t ®_a, const bhalf16_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:177
Definition: amd_wmma.hpp:171
static __device__ void Run(const bf8x8_t ®_a, const bf8x8_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:425
Definition: amd_wmma.hpp:419
static __device__ void Run(const bf8x8_t ®_a, const f8x8_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:401
Definition: amd_wmma.hpp:395
static __device__ void Run(const half16_t ®_a, const half16_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:32
static __device__ void Run(const half8_t ®_a, const half8_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:278
Definition: amd_wmma.hpp:272
Definition: amd_wmma.hpp:26
static __device__ void Run(const half16_t ®_a, const half16_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:156
Definition: amd_wmma.hpp:150
static __device__ void Run(const f8x8_t ®_a, const bf8x8_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:377
Definition: amd_wmma.hpp:371
static __device__ void Run(const f8x8_t ®_a, const f8x8_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:353
Definition: amd_wmma.hpp:347
static __device__ void Run(const int8x16_t ®_a, const int8x16_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:128
static __device__ void Run(const int8x8_t ®_a, const int8x8_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:326
Definition: amd_wmma.hpp:320
Definition: amd_wmma.hpp:122
static __device__ void Run(const int8x16_t ®_a, const int8x16_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:248
Definition: amd_wmma.hpp:242