20 typename ComputeTypeA,
21 typename ComputeTypeB,
23 typename AWmmaTileDesc,
24 typename BWmmaTileDesc,
25 index_t ABlockTransferSrcScalarPerVector,
26 index_t BBlockTransferSrcScalarPerVector,
42 typename ComputeTypeA,
43 typename ComputeTypeB,
45 typename AWmmaTileDesc,
46 typename BWmmaTileDesc,
47 index_t ABlockTransferSrcScalarPerVector,
48 index_t BBlockTransferSrcScalarPerVector,
66 ABlockTransferSrcScalarPerVector,
67 BBlockTransferSrcScalarPerVector,
84 ABlockTransferSrcScalarPerVector,
85 BBlockTransferSrcScalarPerVector,
104 ABlockTransferSrcScalarPerVector,
105 BBlockTransferSrcScalarPerVector,
123 using Base::wmma_gemm;
125 using Base::CalculateCThreadOriginDataIndex;
127 GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
128 using Base::GetCThreadBuffer;
130 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
132 using Base::a_block_desc_k0_m0_m1_m2_k1;
133 using Base::b_block_desc_k0_n0_n1_n2_k1;
149 template <
bool HasMainLoop,
153 typename ABlockTransfer,
154 typename AGridBuffer,
155 typename ABlockBuffer,
156 typename ABlockTransferStep,
159 typename BBlockTransfer,
160 typename BGridBuffer,
161 typename BBlockBuffer,
162 typename BBlockTransferStep,
163 typename CThreadBuffer,
164 typename BScaleStruct>
165 __device__
void Run(
const AGridDesc& a_grid_desc,
166 const ABlockDesc& a_block_desc,
167 ABlockTransfer& a_blockwise_copy,
168 const AGridBuffer& a_grid_buf,
169 ABlockBuffer& a_block_buf,
170 const ABlockTransferStep& a_block_copy_step,
171 const BGridDesc& b_grid_desc,
172 const BBlockDesc& b_block_desc,
173 BBlockTransfer& b_blockwise_copy,
174 const BGridBuffer& b_grid_buf,
175 BBlockBuffer& b_block_buf,
176 const BBlockTransferStep& b_block_copy_step,
177 CThreadBuffer& c_thread_buf,
179 BScaleStruct& b_scale_struct,
181 index_t num_loop_per_scale)
const
183 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
184 a_thread_desc_.GetElementSpaceSize());
185 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
186 b_thread_desc_.GetElementSpaceSize());
189 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
190 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
192 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
193 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
195 b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
198 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
199 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
202 c_thread_buf.Clear();
204 auto blockwise_gemm_func = [&]() {
208 a_block_desc_k0_m0_m1_m2_k1,
219 b_block_desc_k0_n0_n1_n2_k1,
231 b_block_desc_k0_n0_n1_n2_k1,
234 b_scale_struct.b_scale_thread_bufs(
235 I0)[
Number<n0 * BScaleStruct::num_scale_k_block +
236 k0 / BScaleStruct::num_scale_krepeat>{}],
245 vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
246 vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
248 static_for<0, KPack / A_KRow, 1>{}([&](
auto ik) {
249 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
253 static_for<0, KPack / B_KRow, 1>{}([&](
auto ik) {
254 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
259 using wmma_input_type_a =
260 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
261 using wmma_input_type_b =
262 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
265 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
267 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
268 b_thread_vec.template AsType<wmma_input_type_b>(),
276 if constexpr(HasMainLoop)
281 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
282 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
284 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
285 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
288 blockwise_gemm_func();
291 b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
292 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
293 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
296 }
while(i < (num_loop - 1));
303 blockwise_gemm_func();
308 using Base::a_thread_copy_;
309 using Base::a_thread_desc_;
310 using Base::b_thread_copy_;
311 using Base::b_thread_desc_;
312 using Base::c_thread_desc_;
318 typename ComputeTypeA,
319 typename ComputeTypeB,
320 typename AccDataType,
321 typename AWmmaTileDesc,
322 typename BWmmaTileDesc,
323 index_t ABlockTransferSrcScalarPerVector,
324 index_t BBlockTransferSrcScalarPerVector,
342 ABlockTransferSrcScalarPerVector,
343 BBlockTransferSrcScalarPerVector,
360 ABlockTransferSrcScalarPerVector,
361 BBlockTransferSrcScalarPerVector,
380 ABlockTransferSrcScalarPerVector,
381 BBlockTransferSrcScalarPerVector,
400 using Base::wmma_gemm;
402 using Base::CalculateCThreadOriginDataIndex;
404 GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
405 using Base::GetCThreadBuffer;
407 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
409 using Base::a_block_desc_k0_m0_m1_m2_k1;
410 using Base::b_block_desc_k0_n0_n1_n2_k1;
429 template <
bool HasMainLoop,
433 typename ABlockTransfer,
434 typename AGridBuffer,
435 typename ABlockBuffer,
436 typename ABlockTransferStep,
439 typename BBlockTransfer,
440 typename BGridBuffer,
441 typename BBlockBuffer,
442 typename BBlockTransferStep,
443 typename CThreadBuffer,
444 typename BScaleStruct>
445 __device__
void Run(
const AGridDesc& a_grid_desc,
446 const ABlockDesc& a_block_desc,
447 ABlockTransfer& a_blockwise_copy,
448 const AGridBuffer& a_grid_buf,
449 ABlockBuffer& a_block_buf,
450 const ABlockTransferStep& a_block_copy_step,
451 const BGridDesc& b_grid_desc,
452 const BBlockDesc& b_block_desc,
453 BBlockTransfer& b_blockwise_copy,
454 const BGridBuffer& b_grid_buf,
455 BBlockBuffer& b_block_buf,
456 const BBlockTransferStep& b_block_copy_step,
457 CThreadBuffer& c_thread_buf,
459 BScaleStruct& b_scale_struct,
461 index_t num_loop_per_scale)
const
463 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
464 a_thread_desc_.GetElementSpaceSize());
465 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
466 b_thread_desc_.GetElementSpaceSize());
469 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
470 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
472 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
473 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
475 b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
478 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
479 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
482 c_thread_buf.Clear();
484 auto blockwise_gemm_func = [&]() {
489 a_block_desc_k0_m0_m1_m2_k1,
505 b_block_desc_k0_n0_n1_n2_k1,
522 b_block_desc_k0_n0_n1_n2_k1,
530 b_scale_struct.b_scale_thread_bufs(I0)[
Number<
531 n0 * BScaleStruct::num_scale_k_block +
532 (k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}],
540 __builtin_amdgcn_sched_barrier(0);
547 if constexpr(k0_offset != 0 || KRepeat == 1)
549 __builtin_amdgcn_s_barrier();
550 __builtin_amdgcn_sched_barrier(0);
555 vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
556 vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
558 static_for<0, KPack / A_KRow, 1>{}([&](
auto ik) {
559 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
560 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
568 static_for<0, KPack / B_KRow, 1>{}([&](
auto ik) {
569 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
570 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
579 using wmma_input_type_a =
580 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
581 using wmma_input_type_b =
582 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
585 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
593 if constexpr(k0_offset + k0_inner == KRepeat - 1 && m0 == MRepeat - 1 &&
596 __builtin_amdgcn_sched_barrier(0);
598 __builtin_amdgcn_sched_barrier(0);
600 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
601 b_thread_vec.template AsType<wmma_input_type_b>(),
603 if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0)
605 __builtin_amdgcn_sched_barrier(0);
606 __builtin_amdgcn_s_setprio(1);
607 __builtin_amdgcn_sched_barrier(0);
612 __builtin_amdgcn_sched_barrier(0);
613 __builtin_amdgcn_s_setprio(0);
614 __builtin_amdgcn_sched_barrier(0);
619 if constexpr(HasMainLoop)
624 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
625 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
627 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
628 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
631 blockwise_gemm_func();
633 b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
634 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
635 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
638 }
while(i < (num_loop - 1));
645 blockwise_gemm_func();
650 static constexpr
auto a_thread_desc_ =
653 Number<KRepeatPerCluster>{},
659 Number<KPack / A_KRow * MRepeat>{},
664 static constexpr
auto b_thread_desc_ =
667 Number<KRepeatPerCluster>{},
673 Number<KPack / B_KRow * NRepeat>{},
681 decltype(a_block_desc_k0_m0_m1_m2_k1),
682 decltype(a_thread_desc_),
683 Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
692 decltype(b_block_desc_k0_n0_n1_n2_k1),
693 decltype(b_thread_desc_),
694 Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
700 AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
701 BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
702 using Base::c_thread_desc_;
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:208
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:95
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:35
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::BlockHasHotloop static bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:421
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::BlockLoopTailNum static TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:423
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:445
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:165
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::BlockHasHotloop static bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:141
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::BlockLoopTailNum static TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:143
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:36
Definition: sequence.hpp:43
ck::ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeTypeA, decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_thread_desc_), Sequence< KPack/A_K1/A_KRow, 1, 1, 1, 1, A_K1 >, Sequence< 0, 1, 2, 3, 4, 5 >, 5, A_K1, A_K1 >
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10