4 #ifndef CK_AMD_WMMA_HPP 
    5 #define CK_AMD_WMMA_HPP 
   12 #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \ 
   13     defined(__gfx1103__) || defined(__gfx11_generic__) 
   17 #if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__) 
   24 template <index_t MPerWave, index_t NPerWave>
 
   30     template <
class FloatC>
 
   37 #if defined(__gfx11__) 
   38         reg_c.template AsType<float8_t>()(
Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
 
   39             reg_a, reg_b, reg_c.template AsType<float8_t>()[
Number<0>{}]);
 
   49 template <index_t MPerWave, index_t NPerWave>
 
   55     template <
class FloatC>
 
   58 #if defined(__gfx11__) 
   59         reg_c.template AsType<float8_t>()(
Number<0>{}) =
 
   60             __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(
 
   61                 reg_a, reg_b, reg_c.template AsType<float8_t>()[
Number<0>{}]);
 
   71 template <index_t MPerWave, index_t NPerWave, index_t Opsel>
 
   74 template <index_t Opsel>
 
   77     template <
class FloatC>
 
   83 #if defined(__gfx11__) 
   84         reg_c.template AsType<half16_t>()(
Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(
 
   85             reg_a, reg_b, reg_c.template AsType<half16_t>()[
Number<0>{}], Opsel);
 
   95 template <index_t MPerWave, index_t NPerWave, index_t Opsel>
 
   98 template <index_t Opsel>
 
  101     template <
class FloatC>
 
  107 #if defined(__gfx11__) 
  108         reg_c.template AsType<bhalf16_t>()(
Number<0>{}) =
 
  109             __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32(
 
  110                 reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[
Number<0>{}], Opsel);
 
  120 template <index_t MPerWave, index_t NPerWave, 
bool neg_a, 
bool neg_b, 
bool clamp>
 
  123 template <
bool neg_a, 
bool neg_b, 
bool clamp>
 
  126     template <
class FloatC>
 
  129 #if defined(__gfx11__) 
  130         reg_c.template AsType<int32x8_t>()(
Number<0>{}) =
 
  131             __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
 
  133                 bit_cast<int32x4_t>(reg_a),
 
  135                 bit_cast<int32x4_t>(reg_b),
 
  136                 reg_c.template AsType<int32x8_t>()[
Number<0>{}],
 
  148 template <index_t MPerWave, index_t NPerWave>
 
  154     template <
class FloatC>
 
  157 #if defined(__gfx11__) 
  158         reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64(
 
  159             reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}]);
 
  169 template <index_t MPerWave, index_t NPerWave>
 
  175     template <
class FloatC>
 
  178 #if defined(__gfx11__) 
  179         reg_c.template AsType<float4_t>()(
Number<0>{}) =
 
  180             __builtin_amdgcn_wmma_f32_16x16x16_bf16_w64(
 
  181                 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}]);
 
  191 template <index_t MPerWave, index_t NPerWave, index_t Opsel>
 
  194 template <index_t Opsel>
 
  197     template <
class FloatC>
 
  203 #if defined(__gfx11__) 
  204         reg_c.template AsType<half8_t>()(
Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64(
 
  205             reg_a, reg_b, reg_c.template AsType<half8_t>()[
Number<0>{}], Opsel);
 
  215 template <index_t MPerWave, index_t NPerWave, index_t Opsel>
 
  218 template <index_t Opsel>
 
  221     template <
class FloatC>
 
  227 #if defined(__gfx11__) 
  228         reg_c.template AsType<bhalf8_t>()(
Number<0>{}) =
 
  229             __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64(
 
  230                 reg_a, reg_b, reg_c.template AsType<bhalf8_t>()[
Number<0>{}], Opsel);
 
  240 template <index_t MPerWave, index_t NPerWave, 
bool neg_a, 
bool neg_b, 
bool clamp>
 
  243 template <
bool neg_a, 
bool neg_b, 
bool clamp>
 
  246     template <
class FloatC>
 
  249 #if defined(__gfx11__) 
  250         reg_c.template AsType<int32x4_t>()(
Number<0>{}) =
 
  251             __builtin_amdgcn_wmma_i32_16x16x16_iu8_w64(
 
  253                 bit_cast<int32x4_t>(reg_a),
 
  255                 bit_cast<int32x4_t>(reg_b),
 
  256                 reg_c.template AsType<int32x4_t>()[
Number<0>{}],
 
  270 template <index_t MPerWave, index_t NPerWave>
 
  276     template <
class FloatC>
 
  283 #if defined(__gfx12__) 
  284         reg_c.template AsType<float8_t>()(
Number<0>{}) =
 
  285             __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(
 
  286                 reg_a, reg_b, reg_c.template AsType<float8_t>()[
Number<0>{}]);
 
  296 template <index_t MPerWave, index_t NPerWave>
 
  302     template <
class FloatC>
 
  305 #if defined(__gfx12__) 
  306         reg_c.template AsType<float8_t>()(
Number<0>{}) =
 
  307             __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(
 
  308                 reg_a, reg_b, reg_c.template AsType<float8_t>()[
Number<0>{}]);
 
  318 template <index_t MPerWave, index_t NPerWave, 
bool neg_a, 
bool neg_b, 
bool clamp>
 
  321 template <
bool neg_a, 
bool neg_b, 
bool clamp>
 
  324     template <
class FloatC>
 
  327 #if defined(__gfx12__) 
  328         reg_c.template AsType<int32x8_t>()(
Number<0>{}) =
 
  329             __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
 
  331                 bit_cast<int32x2_t>(reg_a),
 
  333                 bit_cast<int32x2_t>(reg_b),
 
  334                 reg_c.template AsType<int32x8_t>()[
Number<0>{}],
 
  345 template <index_t MPerWave, index_t NPerWave>
 
  351     template <
class FloatC>
 
  352     __device__ 
static void Run(
const f8x8_t& reg_a, 
const f8x8_t& reg_b, FloatC& reg_c)
 
  354 #if defined(__gfx12__) 
  355         reg_c.template AsType<float8_t>()(
Number<0>{}) =
 
  356             __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12(
 
  357                 bit_cast<int32x2_t>(reg_a),
 
  358                 bit_cast<int32x2_t>(reg_b),
 
  359                 reg_c.template AsType<float8_t>()[
Number<0>{}]);
 
  369 template <index_t MPerWave, index_t NPerWave>
 
  375     template <
class FloatC>
 
  376     __device__ 
static void Run(
const f8x8_t& reg_a, 
const bf8x8_t& reg_b, FloatC& reg_c)
 
  378 #if defined(__gfx12__) 
  379         reg_c.template AsType<float8_t>()(
Number<0>{}) =
 
  380             __builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12(
 
  381                 bit_cast<int32x2_t>(reg_a),
 
  382                 bit_cast<int32x2_t>(reg_b),
 
  383                 reg_c.template AsType<float8_t>()[
Number<0>{}]);
 
  393 template <index_t MPerWave, index_t NPerWave>
 
  399     template <
class FloatC>
 
  400     __device__ 
static void Run(
const bf8x8_t& reg_a, 
const f8x8_t& reg_b, FloatC& reg_c)
 
  402 #if defined(__gfx12__) 
  403         reg_c.template AsType<float8_t>()(
Number<0>{}) =
 
  404             __builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12(
 
  405                 bit_cast<int32x2_t>(reg_a),
 
  406                 bit_cast<int32x2_t>(reg_b),
 
  407                 reg_c.template AsType<float8_t>()[
Number<0>{}]);
 
  417 template <index_t MPerWave, index_t NPerWave>
 
  423     template <
class FloatC>
 
  426 #if defined(__gfx12__) 
  427         reg_c.template AsType<float8_t>()(
Number<0>{}) =
 
  428             __builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12(
 
  429                 bit_cast<int32x2_t>(reg_a),
 
  430                 bit_cast<int32x2_t>(reg_b),
 
  431                 reg_c.template AsType<float8_t>()[
Number<0>{}]);
 
__host__ constexpr __device__ T clamp(const T &x, const T &lowerbound, const T &upperbound)
Definition: math.hpp:148
 
bf8_t bf8x8_t
Definition: vector_type.hpp:238
 
typename vector_type< bhalf_t, 8 >::type bhalf8_t
Definition: dtype_vector.hpp:2162
 
typename vector_type< int8_t, 8 >::type int8x8_t
Definition: dtype_vector.hpp:2178
 
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
 
typename vector_type< int8_t, 16 >::type int8x16_t
Definition: dtype_vector.hpp:2179
 
typename vector_type< bhalf_t, 16 >::type bhalf16_t
Definition: dtype_vector.hpp:2163
 
typename vector_type< half_t, 16 >::type half16_t
Definition: dtype_vector.hpp:2156
 
typename vector_type< half_t, 8 >::type half8_t
Definition: dtype_vector.hpp:2155
 
Definition: integral_constant.hpp:20
 
static __device__ void Run(const bhalf16_t ®_a, const bhalf16_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:102
 
Definition: amd_wmma.hpp:96
 
static __device__ void Run(const bhalf16_t ®_a, const bhalf16_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:222
 
Definition: amd_wmma.hpp:216
 
static __device__ void Run(const half16_t ®_a, const half16_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:78
 
Definition: amd_wmma.hpp:72
 
static __device__ void Run(const half16_t ®_a, const half16_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:198
 
Definition: amd_wmma.hpp:192
 
static __device__ void Run(const bhalf16_t ®_a, const bhalf16_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:56
 
static __device__ void Run(const bhalf8_t ®_a, const bhalf8_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:303
 
Definition: amd_wmma.hpp:297
 
Definition: amd_wmma.hpp:50
 
static __device__ void Run(const bhalf16_t ®_a, const bhalf16_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:176
 
Definition: amd_wmma.hpp:170
 
static __device__ void Run(const bf8x8_t ®_a, const bf8x8_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:424
 
Definition: amd_wmma.hpp:418
 
static __device__ void Run(const bf8x8_t ®_a, const f8x8_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:400
 
Definition: amd_wmma.hpp:394
 
static __device__ void Run(const half16_t ®_a, const half16_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:31
 
static __device__ void Run(const half8_t ®_a, const half8_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:277
 
Definition: amd_wmma.hpp:271
 
Definition: amd_wmma.hpp:25
 
static __device__ void Run(const half16_t ®_a, const half16_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:155
 
Definition: amd_wmma.hpp:149
 
static __device__ void Run(const f8x8_t ®_a, const bf8x8_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:376
 
Definition: amd_wmma.hpp:370
 
static __device__ void Run(const f8x8_t ®_a, const f8x8_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:352
 
Definition: amd_wmma.hpp:346
 
static __device__ void Run(const int8x16_t ®_a, const int8x16_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:127
 
static __device__ void Run(const int8x8_t ®_a, const int8x8_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:325
 
Definition: amd_wmma.hpp:319
 
Definition: amd_wmma.hpp:121
 
static __device__ void Run(const int8x16_t ®_a, const int8x16_t ®_b, FloatC ®_c)
Definition: amd_wmma.hpp:247
 
Definition: amd_wmma.hpp:241