82 template <WmmaInstr Instr, index_t WaveSize,
typename =
void>
88 template <index_t WaveSize>
98 static constexpr
index_t src_a_data_size = 2;
99 static constexpr
index_t src_b_data_size = 2;
103 static constexpr
index_t num_thread_per_subgroups = n_per_wmma;
108 static constexpr
index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
109 static constexpr
index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
112 static constexpr
index_t num_acc_vgprs_per_wave =
113 m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
114 static constexpr
index_t num_subgroups = wave_size / num_thread_per_subgroups;
116 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
117 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
119 if constexpr(wave_size == 32)
123 else if constexpr(wave_size == 64)
130 template <index_t WaveSize>
143 static constexpr
index_t num_thread_per_subgroups = n_per_wmma;
147 static constexpr
index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
148 static constexpr
index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
149 static constexpr
index_t num_acc_vgprs_per_wave =
150 m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
151 static constexpr
index_t num_subgroups = wave_size / num_thread_per_subgroups;
153 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
154 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
156 if constexpr(wave_size == 32)
160 else if constexpr(wave_size == 64)
167 template <index_t WaveSize>
180 static constexpr
index_t num_thread_per_subgroups = n_per_wmma;
184 static constexpr
index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
185 static constexpr
index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
186 static constexpr
index_t num_acc_vgprs_per_wave =
187 m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
188 static constexpr
index_t num_subgroups = wave_size / num_thread_per_subgroups;
190 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
191 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
193 if constexpr(wave_size == 32)
197 else if constexpr(wave_size == 64)
203 template <index_t WaveSize>
216 static constexpr
index_t num_thread_per_subgroups = n_per_wmma;
220 static constexpr
index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
221 static constexpr
index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
222 static constexpr
index_t num_acc_vgprs_per_wave =
223 m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
224 static constexpr
index_t num_subgroups = wave_size / num_thread_per_subgroups;
232 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
234 if constexpr(wave_size == 32)
238 else if constexpr(wave_size == 64)
245 template <index_t WaveSize>
258 static constexpr
index_t num_thread_per_subgroups = n_per_wmma;
262 static constexpr
index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
263 static constexpr
index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
264 static constexpr
index_t num_acc_vgprs_per_wave =
265 m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
266 static constexpr
index_t num_subgroups = wave_size / num_thread_per_subgroups;
276 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
278 if constexpr(wave_size == 32)
283 else if constexpr(wave_size == 64)
294 template <index_t WaveSize>
310 static constexpr
index_t num_thread_per_subgroups = n_per_wmma;
319 static constexpr
index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
320 static constexpr
index_t num_subgroups = wave_size / num_thread_per_subgroups;
322 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
323 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
325 static_assert(wave_size == 32,
"only support wave32 for gfx12 wmma");
326 if constexpr(wave_size == 32)
333 template <index_t WaveSize>
346 static constexpr
index_t num_thread_per_subgroups = n_per_wmma;
352 static constexpr
index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
353 static constexpr
index_t num_subgroups = wave_size / num_thread_per_subgroups;
355 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
356 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
358 static_assert(wave_size == 32,
"only support wave32 for gfx12 wmma");
359 if constexpr(wave_size == 32)
366 template <index_t WaveSize>
379 static constexpr
index_t num_thread_per_subgroups = n_per_wmma;
385 static constexpr
index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
386 static constexpr
index_t num_subgroups = wave_size / num_thread_per_subgroups;
396 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
398 static_assert(wave_size == 32,
"only support wave32 for gfx12 wmma");
399 if constexpr(wave_size == 32)
407 template <index_t WaveSize>
418 static constexpr
index_t num_thread_per_subgroups = n_per_wmma;
422 static constexpr
index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
423 static constexpr
index_t num_subgroups = wave_size / num_thread_per_subgroups;
425 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
426 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
428 static_assert(wave_size == 32,
"only support wave32 for gfx12 wmma");
429 if constexpr(wave_size == 32)
442 template <index_t WaveSize>
453 static constexpr
index_t num_thread_per_subgroups = n_per_wmma;
457 static constexpr
index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
458 static constexpr
index_t num_subgroups = wave_size / num_thread_per_subgroups;
460 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
461 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
463 static_assert(wave_size == 32,
"only support wave32 for gfx12 wmma");
464 if constexpr(wave_size == 32)
477 template <index_t WaveSize>
488 static constexpr
index_t num_thread_per_subgroups = n_per_wmma;
492 static constexpr
index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
493 static constexpr
index_t num_subgroups = wave_size / num_thread_per_subgroups;
495 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
496 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
498 static_assert(wave_size == 32,
"only support wave32 for gfx12 wmma");
499 if constexpr(wave_size == 32)
512 template <index_t WaveSize>
523 static constexpr
index_t num_thread_per_subgroups = n_per_wmma;
527 static constexpr
index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
528 static constexpr
index_t num_subgroups = wave_size / num_thread_per_subgroups;
530 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
531 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
533 static_assert(wave_size == 32,
"only support wave32 for gfx12 wmma");
534 if constexpr(wave_size == 32)
547 template <
typename src_type_a,
554 template <
typename src_type_a_,
555 typename src_type_b_,
562 constexpr
auto GetWmma<half_t, half_t, float, 16, 16>()
572 constexpr
auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>()
582 constexpr
auto GetWmma<half_t, half_t, half_t, 16, 16>()
588 constexpr
auto GetWmma<bhalf_t, bhalf_t, bhalf_t, 16, 16>()
594 constexpr
auto GetWmma<int8_t, int8_t, int, 16, 16>()
603 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
605 constexpr
auto GetWmma<int4_t, int4_t, int, 16, 16>()
612 constexpr
auto GetWmma<f8_t, f8_t, float, 16, 16>()
618 constexpr
auto GetWmma<f8_t, bf8_t, float, 16, 16>()
624 constexpr
auto GetWmma<bf8_t, f8_t, float, 16, 16>()
630 constexpr
auto GetWmma<bf8_t, bf8_t, float, 16, 16>()
641 static_assert(
selected_wmma.m_per_wmma == 16,
"WRONG! WMMA_M must equal to 16");
643 static_assert(
selected_wmma.m_per_wmma == 16,
"WRONG! WMMA_M must equal to 16");
645 static_assert(
selected_wmma.k_per_wmma == 16,
"WRONG! WMMA_M must equal to 16");
650 "WRONG! Invalid Number of Accumulator Register");
654 template <
typename src_type_a,
660 bool TransposeC =
false,
661 bool AssemblyBackend =
false>
676 static_assert(NPerWmma == 16 && MPerWmma == 16,
677 "Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma");
679 static_assert(KPack %
wmma_instr.k_per_wmma == 0,
"KPack should be multiple of k_per_wmma");
685 template <
typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
686 __host__ __device__
static constexpr
auto
688 const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA&
689 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
691 const auto MBlockxRepeat =
692 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(
I0);
693 const auto NBlockxRepeat =
694 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(
I3);
696 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(
I1);
698 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(
I4);
701 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
725 template <
typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
726 __host__ __device__
static constexpr
auto
728 const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA&
729 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
731 const auto MBlockxRepeat =
732 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(
I0);
733 const auto NBlockxRepeat =
734 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(
I3);
736 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(
I1);
738 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(
I4);
741 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
771 template <
class FloatA,
class FloatB,
class FloatC>
772 __device__
void Run(
const FloatA& p_a_wave,
const FloatB& p_b_wave, FloatC& p_c_thread)
const
788 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
793 "base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
794 "((f8 or bf8, f8 or bf8), float), (int8, int32) or (int4, int32)!");
796 if constexpr(!TransposeC)
798 wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave[k], p_b_wave[k], p_c_thread);
802 wmma_instr.template run<MPerWmma, NPerWmma>(p_b_wave[k], p_a_wave[k], p_c_thread);
849 return TransposeC ?
CIndex{n_offset, m_offset} :
CIndex{m_offset, n_offset};
864 __host__ __device__
static constexpr
auto
__host__ constexpr __device__ T clamp(const T &x, const T &lowerbound, const T &upperbound)
Definition: math.hpp:148
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
WmmaInstr
Definition: wmma_gemm.hpp:13
@ wmma_f32_16x16x16_bf16_gfx12
@ wmma_i32_16x16x16_iu8_gfx12
@ wmma_f32_16x16x16_bf8f8_gfx12
@ wmma_f32_16x16x16_f16_gfx12
@ wmma_f32_16x16x16_bf8bf8_gfx12
@ wmma_f32_16x16x16_f8f8_gfx12
@ wmma_bf16_16x16x16_bf16
@ wmma_f32_16x16x16_f8bf8_gfx12
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1249
Definition: sequence.hpp:43
Definition: wmma_gemm.hpp:663
static constexpr auto I0
Definition: wmma_gemm.hpp:664
static __device__ auto GetLaneId()
Definition: wmma_gemm.hpp:807
__device__ void Run(const FloatA &p_a_wave, const FloatB &p_b_wave, FloatC &p_c_thread) const
Definition: wmma_gemm.hpp:772
static constexpr __device__ index_t GetWaveSize()
Definition: wmma_gemm.hpp:769
static constexpr auto wmma
Definition: wmma_gemm.hpp:860
__host__ static constexpr __device__ auto GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()
Definition: wmma_gemm.hpp:865
__host__ static __device__ auto CalculateAThreadOriginDataIndex()
Definition: wmma_gemm.hpp:826
static __device__ auto GetSubGroupId()
Definition: wmma_gemm.hpp:809
static __device__ auto GetSwizzledLaneIdLow()
Definition: wmma_gemm.hpp:821
static constexpr auto I3
Definition: wmma_gemm.hpp:667
static constexpr auto I5
Definition: wmma_gemm.hpp:669
__host__ static __device__ auto CalculateBThreadOriginDataIndex()
Definition: wmma_gemm.hpp:835
__host__ static constexpr __device__ auto MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA &c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
Definition: wmma_gemm.hpp:727
static __device__ CIndex GetBeginOfThreadBlk()
Definition: wmma_gemm.hpp:844
static constexpr auto I4
Definition: wmma_gemm.hpp:668
static constexpr __device__ index_t GetRegSizePerWmma()
Definition: wmma_gemm.hpp:764
__host__ static constexpr __device__ auto MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA &c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
Definition: wmma_gemm.hpp:687
__host__ constexpr __device__ WmmaGemm()
Definition: wmma_gemm.hpp:674
static constexpr auto I2
Definition: wmma_gemm.hpp:666
static __device__ CIndex3D GetBeginOfThreadBlk3D()
Definition: wmma_gemm.hpp:852
static constexpr auto I1
Definition: wmma_gemm.hpp:665
static __device__ auto GetLaneIdUnderSubGroup()
Definition: wmma_gemm.hpp:817
static constexpr auto wmma_instr
Definition: wmma_gemm.hpp:862
Definition: wmma_gemm.hpp:553
static constexpr auto selected_wmma
Definition: wmma_gemm.hpp:636
__host__ constexpr __device__ WmmaSelector()
Definition: wmma_gemm.hpp:639
static constexpr auto GetWmma()
Definition: integral_constant.hpp:20
Definition: amd_wmma.hpp:96
Definition: amd_wmma.hpp:216
Definition: amd_wmma.hpp:72
Definition: amd_wmma.hpp:192
Definition: amd_wmma.hpp:297
Definition: amd_wmma.hpp:50
Definition: amd_wmma.hpp:170
Definition: amd_wmma.hpp:418
Definition: amd_wmma.hpp:394
Definition: amd_wmma.hpp:271
Definition: amd_wmma.hpp:25
Definition: amd_wmma.hpp:149
Definition: amd_wmma.hpp:370
Definition: amd_wmma.hpp:346
Definition: amd_wmma.hpp:319
Definition: amd_wmma.hpp:121
Definition: amd_wmma.hpp:241
Definition: functional2.hpp:33
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:232
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:191
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:154
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:356
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:531
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:496
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:117
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:323
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:461
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:426
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:276
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:396
Definition: wmma_gemm.hpp:84