20 typename ComputeTypeA,
21 typename ComputeTypeB,
23 typename AWmmaTileDesc,
24 typename BWmmaTileDesc,
25 index_t ABlockTransferSrcScalarPerVector,
26 index_t BBlockTransferSrcScalarPerVector,
36 bool TransposeC =
false,
37 bool BSkipLDS =
false>
45 typename ComputeTypeA,
46 typename ComputeTypeB,
48 typename AWmmaTileDesc,
49 typename BWmmaTileDesc,
50 index_t ABlockTransferSrcScalarPerVector,
51 index_t BBlockTransferSrcScalarPerVector,
71 ABlockTransferSrcScalarPerVector,
72 BBlockTransferSrcScalarPerVector,
92 ABlockTransferSrcScalarPerVector,
93 BBlockTransferSrcScalarPerVector,
113 ABlockTransferSrcScalarPerVector,
114 BBlockTransferSrcScalarPerVector,
137 using Base::wmma_gemm;
140 using Base::CalculateCThreadOriginDataIndex;
142 GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
143 using Base::GetCThreadBuffer;
145 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
147 GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs;
149 using Base::a_block_desc_k0_m0_m1_m2_k1;
150 using Base::b_block_desc_k0_n0_n1_n2_k1;
152 using typename Base::Empty;
160 return num_loop > PrefetchStages;
165 if(BlockHasHotloop(num_loop))
290 template <
typename ABlockBuffer,
291 typename AThreadBuffer,
292 typename BBlockBuffer,
293 typename BThreadBuffer,
294 typename BScaleStruct>
295 __device__
inline void LocalLoad(ABlockBuffer& a_block_buf,
296 AThreadBuffer& a_thread_buf,
297 BBlockBuffer& b_block_buf,
298 BThreadBuffer& b_thread_buf,
299 BScaleStruct& b_scale_struct)
const
303 a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
311 if constexpr(ck::is_same_v<BScaleStruct, Empty>)
314 b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
325 b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
328 b_scale_struct.scale_thread_bufs(
329 I0)[
Number<n0 * BScaleStruct::num_scale_k_block +
330 k0 / BScaleStruct::num_scale_krepeat>{}],
339 template <
bool HasMainLoop,
343 typename ABlockTransfer,
344 typename AGridBuffer,
345 typename ABlockBuffer,
346 typename ABlockTransferStep,
349 typename BBlockTransfer,
350 typename BGridBuffer,
351 typename BBlockBuffer,
352 typename BBlockTransferStep,
353 typename CThreadBuffer,
354 typename AScaleStruct,
355 typename BScaleStruct,
357 __device__
void Run(
const AGridDesc& a_grid_desc,
358 const ABlockDesc& a_block_desc,
359 ABlockTransfer& a_blockwise_copy,
360 const AGridBuffer& a_grid_buf,
361 ABlockBuffer& a_block_buf,
362 const ABlockTransferStep& a_block_copy_step,
363 const BGridDesc& b_grid_desc,
364 const BBlockDesc& b_block_desc,
365 BBlockTransfer& b_blockwise_copy,
366 const BGridBuffer& b_grid_buf,
367 BBlockBuffer& b_block_buf,
368 const BBlockTransferStep& b_block_copy_step,
369 CThreadBuffer& c_thread_buf,
371 BScaleStruct& b_scale_struct,
373 index_t num_loop_per_scale)
const
375 __builtin_amdgcn_sched_barrier(0);
377 constexpr
index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
379 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
380 a_thread_desc_.GetElementSpaceSize());
381 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
382 b_thread_desc_.GetElementSpaceSize());
385 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
386 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
388 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
389 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
392 b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
395 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
396 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
401 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
402 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
404 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
405 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
409 c_thread_buf.Clear();
414 LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
416 __builtin_amdgcn_sched_barrier(0);
419 if constexpr(HasMainLoop)
426 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
427 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
429 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
430 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
432 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
433 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
435 b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
441 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
442 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
444 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
445 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
446 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
447 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
456 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
457 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
458 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
459 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
469 using wmma_input_type_a =
470 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
471 using wmma_input_type_b =
472 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
475 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
478 a_thread_vec.template AsType<wmma_input_type_a>(),
479 b_thread_vec.template AsType<wmma_input_type_b>(),
488 LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
491 __builtin_amdgcn_sched_barrier(0);
494 }
while(i < (num_loop - 2));
502 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
503 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
507 b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
513 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
514 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
516 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
517 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
518 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
519 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
528 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
529 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
530 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
531 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
541 using wmma_input_type_a =
542 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
543 using wmma_input_type_b =
544 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
547 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
549 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
550 b_thread_vec.template AsType<wmma_input_type_b>(),
559 LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
562 __builtin_amdgcn_sched_barrier(0);
571 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
572 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
574 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
575 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
576 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
577 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
586 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
587 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
588 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
589 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
599 using wmma_input_type_a =
600 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
601 using wmma_input_type_b =
602 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
605 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
607 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
608 b_thread_vec.template AsType<wmma_input_type_b>(),
620 template <
bool HasMainLoop,
624 typename ABlockTransfer,
625 typename AGridBuffer,
626 typename ABlockBuffer,
627 typename ABlockTransferStep,
630 typename BBlockTransfer,
631 typename BGridBuffer,
632 typename BBlockBuffer,
633 typename BBlockTransferStep,
634 typename CThreadBuffer,
635 typename AScaleStruct,
636 typename BScaleStruct,
638 !ck::is_same_v<BScaleStruct, Empty>,
640 __device__
void Run(
const AGridDesc& a_grid_desc,
641 const ABlockDesc& a_block_desc,
642 ABlockTransfer& a_blockwise_copy,
643 const AGridBuffer& a_grid_buf,
644 ABlockBuffer& a_block_buf,
645 const ABlockTransferStep& a_block_copy_step,
646 const BGridDesc& b_grid_desc,
647 const BBlockDesc& b_block_desc,
648 BBlockTransfer& b_blockwise_copy,
649 const BGridBuffer& b_grid_buf,
650 BBlockBuffer& b_block_buf,
651 const BBlockTransferStep& b_block_copy_step,
652 CThreadBuffer& c_thread_buf,
653 AScaleStruct& a_scale_struct,
654 BScaleStruct& b_scale_struct,
656 index_t num_loop_per_scale)
const
658 __builtin_amdgcn_sched_barrier(0);
660 constexpr
index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
661 static constexpr
auto NumScaleKBlock =
664 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
665 a_thread_desc_.GetElementSpaceSize());
666 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
667 b_thread_desc_.GetElementSpaceSize());
669 using CScaleStruct =
typename Base::template CScale<AScaleStruct, BScaleStruct>;
670 auto c_scale_struct = CScaleStruct{};
673 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
674 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
676 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
677 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
680 a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
681 b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
683 c_scale_struct.Load(a_scale_struct, b_scale_struct);
686 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
687 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
692 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
693 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
695 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
696 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
700 c_thread_buf.Clear();
705 auto local_load_func = [&]() {
708 a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
716 b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
728 __builtin_amdgcn_sched_barrier(0);
731 if constexpr(HasMainLoop)
738 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
739 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
741 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
742 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
744 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
745 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
747 a_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
748 b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
753 c_scale_struct.Clear();
754 static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](
auto k0) {
756 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
757 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
759 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
760 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
762 kscale0 * (KRepeat / NumScaleKBlock) + k0;
763 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
764 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
773 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
774 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
776 kscale0 * (KRepeat / NumScaleKBlock) + k0;
777 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
778 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
788 using wmma_input_type_a =
789 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
790 using wmma_input_type_b =
791 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
793 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
794 b_thread_vec.template AsType<wmma_input_type_b>(),
795 c_scale_struct.c_thread_buf_per_scale
799 c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
804 c_scale_struct.Load(a_scale_struct, b_scale_struct);
810 __builtin_amdgcn_sched_barrier(0);
813 }
while(i < (num_loop - 2));
821 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
822 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
825 a_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
826 b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
831 c_scale_struct.Clear();
832 static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](
auto k0) {
834 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
835 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
837 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
838 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
840 kscale0 * (KRepeat / NumScaleKBlock) + k0;
841 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
842 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
851 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
852 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
854 kscale0 * (KRepeat / NumScaleKBlock) + k0;
855 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
856 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
866 using wmma_input_type_a =
867 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
868 using wmma_input_type_b =
869 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
872 a_thread_vec.template AsType<wmma_input_type_a>(),
873 b_thread_vec.template AsType<wmma_input_type_b>(),
874 c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
878 c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
883 c_scale_struct.Load(a_scale_struct, b_scale_struct);
889 __builtin_amdgcn_sched_barrier(0);
897 c_scale_struct.Clear();
898 static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](
auto k0) {
899 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
900 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
902 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
903 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
905 kscale0 * (KRepeat / NumScaleKBlock) + k0;
906 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
907 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
916 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
917 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
919 kscale0 * (KRepeat / NumScaleKBlock) + k0;
920 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
921 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
931 using wmma_input_type_a =
932 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
933 using wmma_input_type_b =
934 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
937 a_thread_vec.template AsType<wmma_input_type_a>(),
938 b_thread_vec.template AsType<wmma_input_type_b>(),
939 c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
943 c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
954 using Base::a_thread_copy_;
955 using Base::a_thread_desc_;
956 using Base::b_thread_copy_;
957 using Base::b_thread_desc_;
958 using Base::c_thread_desc_;
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
TailNumber
Tail number enumeration for pipeline buffering.
Definition: scheduler_enum.hpp:49
@ Even
Even number of iterations.
@ Odd
Odd number of iterations.
@ Full
Full tail iterations.
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
BlockGemmPipelineScheduler
Block GEMM pipeline scheduler enumeration.
Definition: scheduler_enum.hpp:33
@ Intrawave
Schedule within a single wavefront.
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:301
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:36
Definition: blockwise_gemm_pipeline_wmmaops.hpp:26
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, false >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:158
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, false >::LocalLoad __device__ void LocalLoad(ABlockBuffer &a_block_buf, AThreadBuffer &a_thread_buf, BBlockBuffer &b_block_buf, BThreadBuffer &b_thread_buf, BScaleStruct &b_scale_struct) const
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:295
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, false >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:163
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, false >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, AScaleStruct &, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:357
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, false >::HotLoopScheduler static constexpr __device__ auto HotLoopScheduler()
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:182
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, false >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, AScaleStruct &a_scale_struct, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:640
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:39
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:11