47 template <DppInstr instr>
54 static constexpr
index_t lanegroup_size = 8;
57 static constexpr
index_t m_per_lanegroup = 8;
58 static constexpr
index_t n_per_lanegroup = 8;
59 static constexpr
index_t m_per_thread = 8;
60 static constexpr
index_t n_per_thread = 1;
62 static constexpr
bool share_a =
true;
65 template <index_t MPerDpp, index_t NPerDpp,
class ADataType,
class BDataType,
class CDataType>
66 __device__
void run(
const ADataType&
a,
const BDataType& b, CDataType& reg_c)
const
84 static constexpr
index_t lanegroup_size = 8;
87 static constexpr
index_t m_per_lanegroup = 8;
88 static constexpr
index_t n_per_lanegroup = 8;
89 static constexpr
index_t m_per_thread = 8;
90 static constexpr
index_t n_per_thread = 1;
92 static constexpr
bool share_a =
true;
95 template <index_t MPerDpp, index_t NPerDpp,
class ADataType,
class BDataType,
class CDataType>
96 __device__
void run(
const ADataType&
a,
const BDataType& b, CDataType& reg_c)
const
122 static constexpr
bool share_a =
true;
125 template <index_t MPerDpp, index_t NPerDpp,
class ADataType,
class BDataType,
class CDataType>
126 __device__
void run(
const ADataType&
a,
const BDataType& b, CDataType& reg_c)
const
152 static constexpr
bool share_a =
true;
155 template <index_t MPerDpp, index_t NPerDpp,
class ADataType,
class BDataType,
class CDataType>
156 __device__
void run(
const ADataType&
a,
const BDataType& b, CDataType& reg_c)
const
182 static constexpr
bool share_a =
true;
185 template <index_t MPerDpp, index_t NPerDpp,
class ADataType,
class BDataType,
class CDataType>
186 __device__
void run(
const ADataType&
a,
const BDataType& b, CDataType& reg_c)
const
212 static constexpr
bool share_a =
true;
215 template <index_t MPerDpp, index_t NPerDpp,
class ADataType,
class BDataType,
class CDataType>
216 __device__
void run(
const ADataType&
a,
const BDataType& b, CDataType& reg_c)
const
242 static constexpr
bool share_a =
true;
245 template <index_t MPerDpp, index_t NPerDpp,
class ADataType,
class BDataType,
class CDataType>
246 __device__
void run(
const ADataType&
a,
const BDataType& b, CDataType& reg_c)
const
272 static constexpr
bool share_a =
true;
275 template <index_t MPerDpp, index_t NPerDpp,
class ADataType,
class BDataType,
class CDataType>
276 __device__
void run(
const ADataType&
a,
const BDataType& b, CDataType& reg_c)
const
302 static constexpr
bool share_a =
true;
305 template <index_t MPerDpp, index_t NPerDpp,
class ADataType,
class BDataType,
class CDataType>
306 __device__
void run(
const ADataType&
a,
const BDataType& b, CDataType& reg_c)
const
320 template <
typename BaseType, index_t MPerDpp, index_t NPerDpp>
323 template <
typename BaseType_, index_t MPerDpp_, index_t NPerDpp_>
327 constexpr
auto GetDpp<half_t, 8, 32>()
333 constexpr
auto GetDpp<half_t, 8, 16>()
339 constexpr
auto GetDpp<half_t, 16, 16>()
345 constexpr
auto GetDpp<half_t, 32, 8>()
351 constexpr
auto GetDpp<half_t, 1, 32>()
357 constexpr
auto GetDpp<half_t, 2, 32>()
363 constexpr
auto GetDpp<half_t, 2, 16>()
369 constexpr
auto GetDpp<half_t, 4, 16>()
375 constexpr
auto GetDpp<half_t, 4, 32>()
392 constexpr
index_t num_dpp_c_elems =
394 static_assert(num_wave_c_elems % num_dpp_c_elems == 0);
395 static_assert(num_dpp_per_wave == num_wave_c_elems / num_dpp_c_elems);
424 template <
typename BaseType, index_t MPerDpp, index_t NPerDpp, index_t KPack>
439 static_assert(KPack %
dpp_instr.k_per_dpp == 0,
"KPack must be divisible by k_per_dpp.");
444 return MPerDpp * NPerDpp /
dpp_instr.wave_size;
447 template <
class ADataType,
class BDataType,
class CDataType>
449 Run(
const ADataType& p_a_wave,
const BDataType& p_b_wave, CDataType& p_c_thread)
const
454 "base BaseType must be double, float, half, bfloat16, and int8_t!");
457 dpp_instr.template run<MPerDpp, NPerDpp>(p_a_wave[k], p_b_wave[k], p_c_thread);
489 const auto dpp_idx = lanegroup_idx_1d_to_dpp_idx_2d_adaptor.CalculateBottomIndex(
492 const auto m_dpp_idx = dpp_idx[
I0];
493 const auto n_dpp_idx = dpp_idx[
I1];
501 const auto wave_row = laneId /
dpp_instr.n_per_wave;
516 const auto m_dpp_op_idx = dpp_op_idx[
I0];
517 const auto n_dpp_op_idx = dpp_op_idx[
I1];
522 return CIndex{m_offset, n_offset};
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
_Float16 half_t
Definition: data_type.hpp:30
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
DppInstr
Definition: dpp_gemm.hpp:13
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1249
Definition: dpp_gemm.hpp:426
__host__ static __device__ auto CalculateBThreadOriginDataIndex_K_N()
Definition: dpp_gemm.hpp:506
__device__ void Run(const ADataType &p_a_wave, const BDataType &p_b_wave, CDataType &p_c_thread) const
Definition: dpp_gemm.hpp:449
static constexpr auto dpp_instr
Definition: dpp_gemm.hpp:527
__host__ static constexpr __device__ auto GetCMNThreadBlkLengths()
Definition: dpp_gemm.hpp:532
__host__ constexpr __device__ DppGemm()
Definition: dpp_gemm.hpp:437
static constexpr auto I3
Definition: dpp_gemm.hpp:430
static constexpr auto I1
Definition: dpp_gemm.hpp:428
static __device__ auto GetWaveId()
Definition: dpp_gemm.hpp:466
static constexpr auto I5
Definition: dpp_gemm.hpp:432
static constexpr __device__ index_t GetRegSizePerDpp()
Definition: dpp_gemm.hpp:442
static __device__ auto GetLaneGroupIdInWave()
Definition: dpp_gemm.hpp:473
static __device__ CIndex GetBeginOfThreadBlk()
Definition: dpp_gemm.hpp:512
static constexpr auto I4
Definition: dpp_gemm.hpp:431
static constexpr auto I2
Definition: dpp_gemm.hpp:429
static __device__ auto GetLaneIdInLaneGroup()
Definition: dpp_gemm.hpp:468
static constexpr auto K1PerDpp
Definition: dpp_gemm.hpp:530
static constexpr auto dpp
Definition: dpp_gemm.hpp:525
__host__ static __device__ auto CalculateAThreadOriginDataIndex_K_M()
Definition: dpp_gemm.hpp:498
static constexpr auto I0
Definition: dpp_gemm.hpp:427
static __device__ auto GetDppOpIdx()
Definition: dpp_gemm.hpp:478
static __device__ auto GetLaneIdInWave()
Definition: dpp_gemm.hpp:461
static constexpr auto K0PerDpp
Definition: dpp_gemm.hpp:529
Definition: dpp_gemm.hpp:322
static constexpr index_t GetK1PerDpp()
Definition: dpp_gemm.hpp:421
static constexpr auto selected_dpp
Definition: dpp_gemm.hpp:380
static constexpr auto GetDpp()
__host__ constexpr __device__ DppSelector()
Definition: dpp_gemm.hpp:382
Definition: sequence.hpp:43
Definition: amd_gemm_dpp.hpp:37
__device__ void Run(const AVecDataType &a_vec, const BVecDataType &b_vec, CVecDataType &c_vec)
Definition: amd_gemm_dpp.hpp:43
__device__ void run(const ADataType &a, const BDataType &b, CDataType ®_c) const
Definition: dpp_gemm.hpp:156
half_t BaseType
Definition: dpp_gemm.hpp:153
__device__ void run(const ADataType &a, const BDataType &b, CDataType ®_c) const
Definition: dpp_gemm.hpp:246
half_t BaseType
Definition: dpp_gemm.hpp:243
half_t BaseType
Definition: dpp_gemm.hpp:303
__device__ void run(const ADataType &a, const BDataType &b, CDataType ®_c) const
Definition: dpp_gemm.hpp:306
__device__ void run(const ADataType &a, const BDataType &b, CDataType ®_c) const
Definition: dpp_gemm.hpp:276
half_t BaseType
Definition: dpp_gemm.hpp:273
half_t BaseType
Definition: dpp_gemm.hpp:63
__device__ void run(const ADataType &a, const BDataType &b, CDataType ®_c) const
Definition: dpp_gemm.hpp:66
__device__ void run(const ADataType &a, const BDataType &b, CDataType ®_c) const
Definition: dpp_gemm.hpp:216
half_t BaseType
Definition: dpp_gemm.hpp:213
__device__ void run(const ADataType &a, const BDataType &b, CDataType ®_c) const
Definition: dpp_gemm.hpp:186
half_t BaseType
Definition: dpp_gemm.hpp:183
__device__ void run(const ADataType &a, const BDataType &b, CDataType ®_c) const
Definition: dpp_gemm.hpp:126
half_t BaseType
Definition: dpp_gemm.hpp:123
half_t BaseType
Definition: dpp_gemm.hpp:93
__device__ void run(const ADataType &a, const BDataType &b, CDataType ®_c) const
Definition: dpp_gemm.hpp:96
Definition: dpp_gemm.hpp:48
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33