/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/inner_product_dpp8.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/inner_product_dpp8.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/inner_product_dpp8.hpp Source File
inner_product_dpp8.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "amd_gemm_dpp.hpp"
7 #include "data_type.hpp"
8 #include "type_convert.hpp"
9 
10 namespace ck {
11 
12 namespace dpp8 {
13 
15 constexpr index_t lane_group_size = 8;
16 
17 template <int SrcLaneIdx>
18 __device__ void inline_v_dot2c_dpp8_instr(const half2_t& a, const half2_t& b, float& c);
19 
20 // clang-format off
21 template <>
22 __device__ void inline_v_dot2c_dpp8_instr<0>(const half2_t& a, const half2_t& b, float& c){
23  asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[0, 0, 0, 0, 0, 0, 0, 0]" : "=v"(c) : "v"(a), "v"(b), "0"(c));
24 }
25 template <>
26 __device__ void inline_v_dot2c_dpp8_instr<1>(const half2_t& a, const half2_t& b, float& c){
27  asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[1, 1, 1, 1, 1, 1, 1, 1]" : "=v"(c) : "v"(a), "v"(b), "0"(c));
28 }
29 template <>
30 __device__ void inline_v_dot2c_dpp8_instr<2>(const half2_t& a, const half2_t& b, float& c){
31  asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[2, 2, 2, 2, 2, 2, 2, 2]" : "=v"(c) : "v"(a), "v"(b), "0"(c));
32 }
33 template <>
34 __device__ void inline_v_dot2c_dpp8_instr<3>(const half2_t& a, const half2_t& b, float& c){
35  asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[3, 3, 3, 3, 3, 3, 3, 3]" : "=v"(c) : "v"(a), "v"(b), "0"(c));
36 }
37 template <>
38 __device__ void inline_v_dot2c_dpp8_instr<4>(const half2_t& a, const half2_t& b, float& c){
39  asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[4, 4, 4, 4, 4, 4, 4, 4]" : "=v"(c) : "v"(a), "v"(b), "0"(c));
40 }
41 template <>
42 __device__ void inline_v_dot2c_dpp8_instr<5>(const half2_t& a, const half2_t& b, float& c){
43  asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[5, 5, 5, 5, 5, 5, 5, 5]" : "=v"(c) : "v"(a), "v"(b), "0"(c));
44 }
45 template <>
46 __device__ void inline_v_dot2c_dpp8_instr<6>(const half2_t& a, const half2_t& b, float& c){
47  asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[6, 6, 6, 6, 6, 6, 6, 6]" : "=v"(c) : "v"(a), "v"(b), "0"(c));
48 }
49 template <>
50 __device__ void inline_v_dot2c_dpp8_instr<7>(const half2_t& a, const half2_t& b, float& c){
51  asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[7, 7, 7, 7, 7, 7, 7, 7]" : "=v"(c) : "v"(a), "v"(b), "0"(c));
52 }
53 // clang-format on
54 
58 template <int SrcLaneIdx, bool ShareA>
59 __device__ void inline_v_dot2c_dpp8(const half2_t& a, const half2_t& b, float& c)
60 {
61  static_assert(SrcLaneIdx >= 0 && SrcLaneIdx < dpp8::lane_group_size,
62  "DPP8 src broadcast lane out of range <0, 7>.");
63  if constexpr(ShareA)
64  {
65  inline_v_dot2c_dpp8_instr<SrcLaneIdx>(a, b, c);
66  }
67  else
68  {
69  inline_v_dot2c_dpp8_instr<SrcLaneIdx>(b, a, c);
70  }
71 }
72 
77 constexpr std::array<int, dpp8::lane_group_size> IntrinsicMaskDpp8 = {
78  0, // 0, 0, 0, 0, 0, 0, 0, 0
79  2396745, // 1, 1, 1, 1, 1, 1, 1, 1
80  4793490, // 2, 2, 2, 2, 2, 2, 2, 2
81  7190235, // 3, 3, 3, 3, 3, 3, 3, 3
82  9586980, // 4, 4, 4, 4, 4, 4, 4, 4
83  11983725, // 5, 5, 5, 5, 5, 5, 5, 5
84  14380470, // 6, 6, 6, 6, 6, 6, 6, 6
85  16777215, // 7, 7, 7, 7, 7, 7, 7, 7
86 };
87 
91 template <int SrcLaneIdx>
93 {
94  static_assert(SrcLaneIdx >= 0 && SrcLaneIdx < dpp8::lane_group_size,
95  "DPP8 src broadcast lane out of range <0, 7>.");
96  return IntrinsicMaskDpp8[SrcLaneIdx];
97 }
98 
99 template <int SrcLaneIdx>
100 __device__ void intrinsic_fdot2_impl(const half2_t& a, const half2_t& b, float& c)
101 {
102  constexpr int sel_mask = get_dpp_sel_mask_broadcast<SrcLaneIdx>();
103  const half2_t val_from_other_lane =
104  bit_cast<half2_t>(__builtin_amdgcn_mov_dpp8(bit_cast<int>(a), sel_mask));
105  c = __builtin_amdgcn_fdot2(val_from_other_lane, b, c, false);
106 }
107 
111 template <int SrcLaneIdx, bool ShareA>
112 __device__ void intrinsic_fdot2(const half2_t& a, const half2_t& b, float& c)
113 {
114  if constexpr(ShareA)
115  {
116  intrinsic_fdot2_impl<SrcLaneIdx>(a, b, c);
117  }
118  else
119  {
120  intrinsic_fdot2_impl<SrcLaneIdx>(b, a, c);
121  }
122 }
123 
134 template <typename TA, typename TB, typename TC, int SrcLaneIdx, bool ShareA>
135 __device__ void inner_product_dpp(const TA& a, const TB& b, TC& c)
136 {
137 #if CK_USE_AMD_V_DOT_DPP8_INLINE_ASM
138  inline_v_dot2c_dpp8<SrcLaneIdx, ShareA>(a, b, c);
139 #else
140  intrinsic_fdot2<SrcLaneIdx, ShareA>(a, b, c);
141 #endif
142 }
143 
144 } // namespace dpp8
145 
146 } // namespace ck
__device__ void inline_v_dot2c_dpp8_instr< 5 >(const half2_t &a, const half2_t &b, float &c)
Definition: inner_product_dpp8.hpp:42
__device__ void inline_v_dot2c_dpp8_instr< 4 >(const half2_t &a, const half2_t &b, float &c)
Definition: inner_product_dpp8.hpp:38
__device__ void inner_product_dpp(const TA &a, const TB &b, TC &c)
Definition: inner_product_dpp8.hpp:135
__device__ void inline_v_dot2c_dpp8(const half2_t &a, const half2_t &b, float &c)
Definition: inner_product_dpp8.hpp:59
__device__ void inline_v_dot2c_dpp8_instr(const half2_t &a, const half2_t &b, float &c)
__device__ void inline_v_dot2c_dpp8_instr< 3 >(const half2_t &a, const half2_t &b, float &c)
Definition: inner_product_dpp8.hpp:34
__device__ void inline_v_dot2c_dpp8_instr< 2 >(const half2_t &a, const half2_t &b, float &c)
Definition: inner_product_dpp8.hpp:30
constexpr std::array< int, dpp8::lane_group_size > IntrinsicMaskDpp8
Definition: inner_product_dpp8.hpp:77
__device__ void intrinsic_fdot2(const half2_t &a, const half2_t &b, float &c)
Definition: inner_product_dpp8.hpp:112
__device__ void intrinsic_fdot2_impl(const half2_t &a, const half2_t &b, float &c)
Definition: inner_product_dpp8.hpp:100
__device__ void inline_v_dot2c_dpp8_instr< 7 >(const half2_t &a, const half2_t &b, float &c)
Definition: inner_product_dpp8.hpp:50
__device__ void inline_v_dot2c_dpp8_instr< 6 >(const half2_t &a, const half2_t &b, float &c)
Definition: inner_product_dpp8.hpp:46
constexpr int get_dpp_sel_mask_broadcast()
Definition: inner_product_dpp8.hpp:92
__device__ void inline_v_dot2c_dpp8_instr< 0 >(const half2_t &a, const half2_t &b, float &c)
Definition: inner_product_dpp8.hpp:22
constexpr index_t lane_group_size
Number of lanes that can share data using DPP8 modifiers.
Definition: inner_product_dpp8.hpp:15
__device__ void inline_v_dot2c_dpp8_instr< 1 >(const half2_t &a, const half2_t &b, float &c)
Definition: inner_product_dpp8.hpp:26
Definition: ck.hpp:267
typename vector_type< half_t, 2 >::type half2_t
Definition: dtype_vector.hpp:2139
int32_t index_t
Definition: ck.hpp:298
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1249