4 #ifndef CK_THREADWISE_GEMM_DLOPS_V3_HPP
5 #define CK_THREADWISE_GEMM_DLOPS_V3_HPP
18 template <
typename FloatA,
21 typename AThreadDesc_E1_K_E2,
22 typename BThreadDesc_E1_N_Ho_Wo_E2,
23 typename CThreadDesc_K_N_Ho_Wo,
24 typename enable_if<AThreadDesc_E1_K_E2::IsKnownAtCompileTime() &&
25 BThreadDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() &&
26 CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(),
31 template <
typename ABuffer,
37 __device__
static void Run(
const ABuffer& a_buf,
45 static_assert(AThreadDesc_E1_K_E2::IsKnownAtCompileTime() &&
46 BThreadDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() &&
47 CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(),
48 "wrong! Desc should be known at compile-time");
53 "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
59 "wrong! inconsistent type");
66 constexpr
auto E1 = AThreadDesc_E1_K_E2{}.GetLength(I0);
67 constexpr
auto K = AThreadDesc_E1_K_E2{}.GetLength(I1);
68 constexpr
auto E2 = AThreadDesc_E1_K_E2{}.GetLength(I2);
70 constexpr
auto Ho = BThreadDesc_E1_N_Ho_Wo_E2{}.GetLength(I2);
71 constexpr
auto Wo = BThreadDesc_E1_N_Ho_Wo_E2{}.GetLength(I3);
77 if constexpr((Ho % 2 == 0) && (Wo % 2 == 0))
79 constexpr
auto SubHW = 2;
86 constexpr
index_t a_offset = AThreadDesc_E1_K_E2{}.CalculateOffset(
90 BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
94 BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
95 b_origin_idx +
make_tuple(e1, 0, h, w + 1, e2));
98 BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
99 b_origin_idx +
make_tuple(e1, 0, h + 1, w, e2));
102 BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
103 b_origin_idx +
make_tuple(e1, 0, h + 1, w + 1, e2));
106 CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx +
110 CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
114 CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
118 CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
119 c_origin_idx +
make_tuple(k, 0, h + 1, w + 1));
144 constexpr
index_t a_offset = AThreadDesc_E1_K_E2{}.CalculateOffset(
148 BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
152 CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx +
__device__ void amd_assembly_outer_product_1x4(float a, float b0, float b1, float b2, float b3, float &c0, float &c1, float &c2, float &c3)
Definition: amd_inline_asm.hpp:106
__host__ constexpr __device__ auto to_multi_index(const T &x)
Definition: array_multi_index.hpp:28
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:298
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Definition: threadwise_gemm_dlops_v3.hpp:29
static __device__ void Run(const ABuffer &a_buf, AOriginIdx, const BBuffer &b_buf, BOriginIdx, CBuffer &c_buf, COriginIdx)
Definition: threadwise_gemm_dlops_v3.hpp:37
Definition: integral_constant.hpp:20
Definition: is_known_at_compile_time.hpp:14
Definition: functional2.hpp:33