31 bool AEnableLds =
true,
32 bool BEnableLds =
true,
33 bool TransposeC =
false>
52 struct BlockwiseGemmWMMA
54 static constexpr
auto I0 = Number<0>{};
55 static constexpr
auto I1 = Number<1>{};
56 static constexpr
auto I2 = Number<2>{};
57 static constexpr
auto I3 = Number<3>{};
58 static constexpr
auto I4 = Number<4>{};
59 static constexpr
auto I5 = Number<5>{};
60 static constexpr
auto WmmaK = Number<16>{};
77 WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{};
79 static constexpr
index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
80 static constexpr
index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
100 return threadid_to_wave_idx_adaptor.CalculateBottomIndex(
make_multi_index(thread_id));
106 if constexpr(AEnableLds)
109 const auto waveId_m = wave_idx[
I0];
110 const auto WMMA_a_idx =
wmma_gemm.CalculateAThreadOriginDataIndex();
123 if constexpr(BEnableLds)
126 const auto waveId_n = wave_idx[
I1];
127 const auto WMMA_b_idx =
wmma_gemm.CalculateBThreadOriginDataIndex();
138 template <index_t m0, index_t n0>
143 const auto waveId_m = wave_idx[
I0];
144 const auto waveId_n = wave_idx[
I1];
146 const auto blk_idx =
wmma_gemm.GetBeginOfThreadBlk();
158 const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex(
160 const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex(
166 template <index_t m0, index_t n0>
171 const auto waveId_m = wave_idx[
I0];
172 const auto waveId_n = wave_idx[
I1];
174 const auto blk_idx =
wmma_gemm.GetBeginOfThreadBlk3D();
177 Number<m0>{}, waveId_m, blk_idx[
I0], Number<n0>{}, waveId_n, blk_idx[
I1], blk_idx[
I2]);
185 static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
186 "wrong! Desc should be known at compile-time");
189 "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
191 static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 &&
192 NPerBlock % (NPerWMMA * NRepeat) == 0,
197 __host__ __device__
static constexpr
auto
200 constexpr
auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
201 wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
203 constexpr
auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[
I2];
212 __host__ __device__
static constexpr
auto
215 constexpr
auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
216 wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
218 constexpr
auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[
I2];
219 constexpr
auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[
I3];
224 make_tuple(Number<NRepeat>{} * MAccVgprs * AccStride,
225 Number<NRepeat>{} * MAccVgprs * AccStride,
226 Number<NRepeat>{} * MAccVgprs * AccStride,
227 MAccVgprs * AccStride,
228 MAccVgprs * AccStride,
229 MAccVgprs * AccStride,
233 template <
typename CGr
idDesc_M_N>
234 __host__ __device__
static constexpr
auto
236 const CGridDesc_M_N& c_grid_desc_m_n)
238 const auto M = c_grid_desc_m_n.GetLength(
I0);
239 const auto N = c_grid_desc_m_n.GetLength(
I1);
241 const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma =
248 make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
251 .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
252 c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
256 __host__ __device__
static constexpr
auto
259 constexpr
auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
265 Number<NPerWMMA>{}));
268 .MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(
269 c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
273 __host__ __device__
static constexpr
auto
276 constexpr
auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
282 Number<NPerWMMA>{}));
285 .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
286 c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
294 template <
typename ABlockBuffer,
typename BBlockBuffer,
typename CThreadBuffer>
295 __device__
void Run(
const ABlockBuffer& a_block_buf,
296 const BBlockBuffer& b_block_buf,
297 CThreadBuffer& c_thread_buf)
const
299 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
301 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
304 static_assert(KPack % (
A_K1 *
A_KRow) == 0,
"");
305 static_assert(KPack % (
B_K1 *
B_KRow) == 0,
"");
308 if constexpr(MRepeat < NRepeat)
310 static_for<0, KPerBlock / KPack, 1>{}(
312 static_for<0, MRepeat, 1>{}([&](
auto m0) {
322 static_for<0, NRepeat, 1>{}([&](
auto n0) {
332 vector_type<FloatA, KPack /
A_KRow> a_thread_vec;
333 vector_type<FloatB, KPack /
B_KRow> b_thread_vec;
335 static_for<0, KPack /
A_KRow, 1>{}([&](
auto i) {
336 a_thread_vec.template AsType<FloatA>()(i) =
341 static_for<0, KPack /
B_KRow, 1>{}([&](
auto i) {
342 b_thread_vec.template AsType<FloatB>()(i) =
347 using wmma_input_type_a =
348 typename vector_type<FloatA,
WmmaK /
A_KRow>::type;
349 using wmma_input_type_b =
350 typename vector_type<FloatB,
WmmaK /
B_KRow>::type;
356 a_thread_vec.template AsType<wmma_input_type_a>(),
357 b_thread_vec.template AsType<wmma_input_type_b>(),
358 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
365 static_for<0, NRepeat, 1>{}([&](
auto n0) {
366 static_for<0, MRepeat, 1>{}([&](
auto m0) {
367 static_for<0, KPerBlock / KPack, 1>{}([&](
auto k) {
386 vector_type<FloatA, KPack /
A_KRow> a_thread_vec;
387 vector_type<FloatB, KPack /
B_KRow> b_thread_vec;
389 static_for<0, KPack /
A_KRow, 1>{}([&](
auto i) {
390 a_thread_vec.template AsType<FloatA>()(i) =
395 static_for<0, KPack /
B_KRow, 1>{}([&](
auto i) {
396 b_thread_vec.template AsType<FloatB>()(i) =
401 using wmma_input_type_a =
402 typename vector_type<FloatA,
WmmaK /
A_KRow>::type;
403 using wmma_input_type_b =
404 typename vector_type<FloatB,
WmmaK /
B_KRow>::type;
410 a_thread_vec.template AsType<wmma_input_type_a>(),
411 b_thread_vec.template AsType<wmma_input_type_b>(),
412 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
421 make_tuple(Number<KPack / A_K1 / A_KRow>{}, Number<MRepeat>{},
I1,
I1,
I1, Number<A_K1>{}),
430 make_tuple(Number<KPack / B_K1 / B_KRow>{}, Number<NRepeat>{},
I1,
I1,
I1, Number<B_K1>{}),
442 template <
bool EnableLds>
443 struct AThreadCopySelector;
446 struct AThreadCopySelector<true>
449 ThreadwiseTensorSliceTransfer_v4<FloatA,
454 Sequence<0, 1, 2, 3, 4, 5>,
461 struct AThreadCopySelector<false>
463 using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow<
468 tensor_operation::element_wise::PassThrough,
470 Sequence<0, 1, 2, 3, 4, 5>,
476 template <
bool EnableLds>
477 struct BThreadCopySelector;
480 struct BThreadCopySelector<true>
483 ThreadwiseTensorSliceTransfer_v4<FloatB,
488 Sequence<0, 1, 2, 3, 4, 5>,
495 struct BThreadCopySelector<false>
497 using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow<
502 tensor_operation::element_wise::PassThrough,
504 Sequence<0, 1, 2, 3, 4, 5>,
528 bool AEnableLds =
true,
529 bool BEnableLds =
true,
530 bool TransposeC =
false>
596 return threadid_to_wave_idx_adaptor.CalculateBottomIndex(
make_multi_index(thread_id));
602 if constexpr(AEnableLds)
605 const auto waveId_m = wave_idx[
I0];
606 const auto WMMA_a_idx =
wmma_gemm.CalculateAThreadOriginDataIndex();
609 return make_tuple(0, 0, waveId_m, 0, WMMA_a_idx, 0);
619 if constexpr(BEnableLds)
622 const auto waveId_n = wave_idx[
I1];
623 const auto WMMA_b_idx =
wmma_gemm.CalculateBThreadOriginDataIndex();
626 return make_tuple(0, 0, waveId_n, 0, WMMA_b_idx, 0);
634 template <index_t m0, index_t n0>
639 const auto waveId_m = wave_idx[
I0];
640 const auto waveId_n = wave_idx[
I1];
642 const auto blk_idx =
wmma_gemm.GetBeginOfThreadBlk();
654 const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex(
656 const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex(
662 template <index_t m0, index_t n0>
667 const auto waveId_m = wave_idx[
I0];
668 const auto waveId_n = wave_idx[
I1];
670 const auto blk_idx =
wmma_gemm.GetBeginOfThreadBlk3D();
681 static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
682 "wrong! Desc should be known at compile-time");
685 "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
687 static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 &&
688 NPerBlock % (NPerWMMA * NRepeat) == 0,
693 __host__ __device__
static constexpr
auto
696 constexpr
auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
697 wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
699 constexpr
auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[
I2];
708 __host__ __device__
static constexpr
auto
711 constexpr
auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
712 wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
714 constexpr
auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[
I2];
715 constexpr
auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[
I3];
723 MAccVgprs * AccStride,
724 MAccVgprs * AccStride,
725 MAccVgprs * AccStride,
729 template <
typename CGr
idDesc_M_N>
730 __host__ __device__
static constexpr
auto
732 const CGridDesc_M_N& c_grid_desc_m_n)
734 const auto M = c_grid_desc_m_n.GetLength(
I0);
735 const auto N = c_grid_desc_m_n.GetLength(
I1);
737 const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma =
747 .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
748 c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
752 __host__ __device__
static constexpr
auto
755 constexpr
auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
764 .MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(
765 c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
769 __host__ __device__
static constexpr
auto
772 constexpr
auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
781 .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
782 c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
790 template <
typename ABlockBuffer,
typename BBlockBuffer,
typename CThreadBuffer>
791 __device__
void Run(
const ABlockBuffer& a_block_buf,
792 const BBlockBuffer& b_block_buf,
793 CThreadBuffer& c_thread_buf)
const
795 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
797 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
801 if constexpr(MRepeat < NRepeat)
829 a_thread_vec.template AsType<FloatA>()(i) =
837 b_thread_vec.template AsType<FloatB>()(i) =
853 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
854 b_thread_vec.template AsType<wmma_input_type_b>(),
864 static_for<0, KPerBlock / KPack, 1>{}([&](
auto k) {
887 b_thread_vec.template AsType<FloatB>()(i) =
895 a_thread_vec.template AsType<FloatA>()(i) =
911 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
912 b_thread_vec.template AsType<wmma_input_type_b>(),
930 Number<A_K1 * A_KRow>{},
944 Number<B_K1 * B_KRow>{},
953 template <
bool EnableLds>
986 TransposeC ? false :
true>;
989 template <
bool EnableLds>
1022 TransposeC ? true :
false>;
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: blockwise_gemm_wmma.hpp:954
Definition: blockwise_gemm_wmma.hpp:990
Definition: blockwise_gemm_wmma.hpp:550
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: blockwise_gemm_wmma.hpp:731
__host__ static constexpr __device__ auto GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
Definition: blockwise_gemm_wmma.hpp:709
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, FloatAcc, MRepeat *NRepeat, wmma_gemm.GetRegSizePerWmma(), true > c_thread_buf_
Definition: blockwise_gemm_wmma.hpp:583
static constexpr index_t NWaves
Definition: blockwise_gemm_wmma.hpp:576
static constexpr index_t A_KRow
Definition: blockwise_gemm_wmma.hpp:567
static constexpr auto b_thread_desc_
Definition: blockwise_gemm_wmma.hpp:935
static constexpr index_t B_K1
Definition: blockwise_gemm_wmma.hpp:570
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: blockwise_gemm_wmma.hpp:559
__device__ void Run(const ABlockBuffer &a_block_buf, const BBlockBuffer &b_block_buf, CThreadBuffer &c_thread_buf) const
Definition: blockwise_gemm_wmma.hpp:791
static constexpr index_t A_K1
Definition: blockwise_gemm_wmma.hpp:569
static constexpr auto I0
Definition: blockwise_gemm_wmma.hpp:551
static constexpr auto I5
Definition: blockwise_gemm_wmma.hpp:556
static constexpr index_t B_KRow
Definition: blockwise_gemm_wmma.hpp:568
__host__ static constexpr __device__ auto GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
Definition: blockwise_gemm_wmma.hpp:770
BThreadCopySelector< BEnableLds >::type b_thread_copy_
Definition: blockwise_gemm_wmma.hpp:1026
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_wmma.hpp:585
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >)
Definition: blockwise_gemm_wmma.hpp:635
static constexpr auto I1
Definition: blockwise_gemm_wmma.hpp:552
static constexpr auto I3
Definition: blockwise_gemm_wmma.hpp:554
static constexpr index_t MWaves
Definition: blockwise_gemm_wmma.hpp:575
decltype(CalculateAThreadOriginDataIndex()) Tuple6
Definition: blockwise_gemm_wmma.hpp:676
static constexpr auto a_thread_desc_
Definition: blockwise_gemm_wmma.hpp:921
static constexpr index_t WaveSize
Definition: blockwise_gemm_wmma.hpp:562
static __device__ auto CalculateAThreadOriginDataIndex()
Definition: blockwise_gemm_wmma.hpp:600
__host__ static constexpr __device__ auto GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
Definition: blockwise_gemm_wmma.hpp:753
static constexpr auto c_thread_desc_
Definition: blockwise_gemm_wmma.hpp:950
static constexpr auto WmmaK
Definition: blockwise_gemm_wmma.hpp:557
static constexpr auto I4
Definition: blockwise_gemm_wmma.hpp:555
static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1
Definition: blockwise_gemm_wmma.hpp:787
static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1
Definition: blockwise_gemm_wmma.hpp:788
static __device__ auto GetWaveIdx()
Definition: blockwise_gemm_wmma.hpp:587
__host__ static constexpr __device__ auto GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
Definition: blockwise_gemm_wmma.hpp:694
static constexpr auto wmma_gemm
Definition: blockwise_gemm_wmma.hpp:572
static __device__ auto CalculateBThreadOriginDataIndex()
Definition: blockwise_gemm_wmma.hpp:617
__host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin=CalculateAThreadOriginDataIndex(), Tuple6 b_origin=CalculateBThreadOriginDataIndex())
Definition: blockwise_gemm_wmma.hpp:677
static constexpr auto I2
Definition: blockwise_gemm_wmma.hpp:553
AThreadCopySelector< AEnableLds >::type a_thread_copy_
Definition: blockwise_gemm_wmma.hpp:1025
static __device__ auto CalculateCThreadOriginDataIndex7D(Number< m0 >, Number< n0 >)
Definition: blockwise_gemm_wmma.hpp:663
Definition: sequence.hpp:43
Definition: static_buffer.hpp:75
static __device__ index_t GetThreadId()
Definition: thread_group.hpp:19
static constexpr __device__ index_t GetNumOfThread()
Definition: thread_group.hpp:15
Definition: threadwise_tensor_slice_transfer.hpp:1877
Definition: threadwise_tensor_slice_transfer.hpp:1260
Definition: wmma_gemm.hpp:663
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:334
Definition: dtype_vector.hpp:10