/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/amd_smfmac.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/amd_smfmac.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/amd_smfmac.hpp Source File
amd_smfmac.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #include "ck/ck.hpp"
5 #pragma once
6 
7 namespace ck {
8 
9 template <index_t MPerWave, index_t NPerWave>
11 
12 // for every smfmac instruction if CBSZ[1:0]=0, ABID[1:0] selects one of four 8-bit sets of sparse
13 // indices from reg_idx
14 template <>
16 {
17  template <class FloatC, index_t abid = 0>
18  __device__ static void
19  Run(const half4_t& reg_a, const half8_t& reg_b, const index_t& reg_idx, FloatC& reg_c)
20  {
21 #if defined(__gfx94__)
22  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_f16(
23  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, abid);
24 #else
25  ignore = reg_a;
26  ignore = reg_b;
27  ignore = reg_c;
28  ignore = reg_idx;
29 #endif
30  }
31 };
32 
33 template <index_t MPerWave, index_t NPerWave>
35 
36 template <>
38 {
39  template <class FloatC, index_t abid = 0>
40  __device__ static void
41  Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const index_t& reg_idx, FloatC& reg_c)
42  {
43 #if defined(__gfx94__)
44  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_bf16(
45  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, abid);
46 #else
47  ignore = reg_a;
48  ignore = reg_b;
49  ignore = reg_c;
50  ignore = reg_idx;
51 #endif
52  }
53 };
54 
55 template <index_t MPerWave, index_t NPerWave>
57 
58 template <>
60 {
61  template <class FloatC, index_t abid = 0>
62  __device__ static void
63  Run(const half4_t& reg_a, const half8_t& reg_b, const index_t& reg_idx, FloatC& reg_c)
64  {
65 #if defined(__gfx94__)
66  reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_f16(
67  reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, abid);
68 #else
69  ignore = reg_a;
70  ignore = reg_b;
71  ignore = reg_c;
72  ignore = reg_idx;
73 #endif
74  }
75 };
76 
77 template <index_t MPerWave, index_t NPerWave>
79 
80 template <>
82 {
83  template <class FloatC, index_t abid = 0>
84  __device__ static void
85  Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const index_t& reg_idx, FloatC& reg_c)
86  {
87 #if defined(__gfx94__)
88  reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_bf16(
89  reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, abid);
90 #else
91  ignore = reg_a;
92  ignore = reg_b;
93  ignore = reg_c;
94  ignore = reg_idx;
95 #endif
96  }
97 };
98 
99 } // namespace ck
Definition: ck.hpp:267
typename vector_type< bhalf_t, 4 >::type bhalf4_t
Definition: dtype_vector.hpp:2147
typename vector_type< bhalf_t, 8 >::type bhalf8_t
Definition: dtype_vector.hpp:2148
typename vector_type< half_t, 4 >::type half4_t
Definition: dtype_vector.hpp:2140
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
int32_t index_t
Definition: ck.hpp:298
typename vector_type< half_t, 8 >::type half8_t
Definition: dtype_vector.hpp:2141
Definition: integral_constant.hpp:20
static __device__ void Run(const bhalf4_t &reg_a, const bhalf8_t &reg_b, const index_t &reg_idx, FloatC &reg_c)
Definition: amd_smfmac.hpp:41
Definition: amd_smfmac.hpp:34
static __device__ void Run(const half4_t &reg_a, const half8_t &reg_b, const index_t &reg_idx, FloatC &reg_c)
Definition: amd_smfmac.hpp:19
Definition: amd_smfmac.hpp:10
static __device__ void Run(const bhalf4_t &reg_a, const bhalf8_t &reg_b, const index_t &reg_idx, FloatC &reg_c)
Definition: amd_smfmac.hpp:85
Definition: amd_smfmac.hpp:78
static __device__ void Run(const half4_t &reg_a, const half8_t &reg_b, const index_t &reg_idx, FloatC &reg_c)
Definition: amd_smfmac.hpp:63
Definition: amd_smfmac.hpp:56