16 typename ComputeTypeA,
17 typename ComputeTypeB,
19 typename AWmmaTileDesc,
20 typename BWmmaTileDesc,
21 index_t ABlockTransferSrcScalarPerVector,
22 index_t BBlockTransferSrcScalarPerVector,
32 bool TransposeC =
false,
33 bool BSkipLDS =
false>
41 typename ComputeTypeA,
42 typename ComputeTypeB,
44 typename AWmmaTileDesc,
45 typename BWmmaTileDesc,
46 index_t ABlockTransferSrcScalarPerVector,
47 index_t BBlockTransferSrcScalarPerVector,
67 ABlockTransferSrcScalarPerVector,
68 BBlockTransferSrcScalarPerVector,
88 ABlockTransferSrcScalarPerVector,
89 BBlockTransferSrcScalarPerVector,
113 ABlockTransferSrcScalarPerVector,
114 BBlockTransferSrcScalarPerVector,
136 using Base::wmma_gemm;
138 using Base::CalculateCThreadOriginDataIndex;
140 GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
141 using Base::GetCThreadBuffer;
143 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
145 using Base::a_block_desc_k0_m0_m1_m2_k1;
146 using Base::b_block_desc_k0_n0_n1_n2_k1;
148 using typename Base::Empty;
162 template <
bool HasMainLoop,
166 typename ABlockTransfer,
167 typename AGridBuffer,
168 typename ABlockBuffer,
169 typename ABlockTransferStep,
172 typename BBlockTransfer,
173 typename BGridBuffer,
174 typename BBlockBuffer,
175 typename BBlockTransferStep,
176 typename CThreadBuffer,
177 typename AScaleStruct,
178 typename BScaleStruct,
180 __device__
void Run(
const AGridDesc& a_grid_desc,
181 const ABlockDesc& a_block_desc,
182 ABlockTransfer& a_blockwise_copy,
183 const AGridBuffer& a_grid_buf,
184 ABlockBuffer& a_block_buf,
185 const ABlockTransferStep& a_block_copy_step,
186 const BGridDesc& b_grid_desc,
187 const BBlockDesc& b_block_desc,
188 BBlockTransfer& b_blockwise_copy,
189 const BGridBuffer& b_grid_buf,
190 BBlockBuffer& b_block_buf,
191 const BBlockTransferStep& b_block_copy_step,
192 CThreadBuffer& c_thread_buf,
194 BScaleStruct& b_scale_struct,
196 index_t num_loop_per_scale)
const
198 constexpr
index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
200 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
201 a_thread_desc_.GetElementSpaceSize());
202 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
203 b_thread_desc_.GetElementSpaceSize());
206 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
207 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
209 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
210 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
213 b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
216 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
217 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
220 c_thread_buf.Clear();
222 auto blockwise_gemm_func = [&]() {
226 a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
232 if constexpr(m0 == I0)
237 b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
249 b_block_desc_k0_n0_n1_n2_k1,
252 b_scale_struct.scale_thread_bufs(
253 I0)[
Number<n0 * BScaleStruct::num_scale_k_block +
254 k0 / BScaleStruct::num_scale_krepeat>{}],
264 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
265 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
267 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
268 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
269 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
270 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
279 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
280 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
281 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
282 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
292 using wmma_input_type_a =
293 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
294 using wmma_input_type_b =
295 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
298 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
300 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
301 b_thread_vec.template AsType<wmma_input_type_b>(),
310 if constexpr(HasMainLoop)
315 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
316 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
318 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
319 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
322 blockwise_gemm_func();
325 b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
330 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
331 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
333 constexpr
index_t num_ds_write_inst =
334 HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num;
336 constexpr
index_t num_buffer_load_inst = HotLoopInstList::A_Buffer_Load_Inst_Num +
337 HotLoopInstList::B_Buffer_Load_Inst_Num;
339 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
343 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
344 if constexpr(m0 == I0)
347 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
352 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
358 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0);
362 }
while(i < (num_loop - 1));
369 blockwise_gemm_func();
373 template <
bool HasMainLoop,
377 typename ABlockTransfer,
378 typename AGridBuffer,
379 typename ABlockBuffer,
380 typename ABlockTransferStep,
383 typename BBlockTransfer,
384 typename BGridBuffer,
385 typename BBlockBuffer,
386 typename BBlockTransferStep,
387 typename CThreadBuffer,
388 typename AScaleStruct,
389 typename BScaleStruct,
391 !ck::is_same_v<BScaleStruct, Empty>,
393 __device__
void Run(
const AGridDesc& a_grid_desc,
394 const ABlockDesc& a_block_desc,
395 ABlockTransfer& a_blockwise_copy,
396 const AGridBuffer& a_grid_buf,
397 ABlockBuffer& a_block_buf,
398 const ABlockTransferStep& a_block_copy_step,
399 const BGridDesc& b_grid_desc,
400 const BBlockDesc& b_block_desc,
401 BBlockTransfer& b_blockwise_copy,
402 const BGridBuffer& b_grid_buf,
403 BBlockBuffer& b_block_buf,
404 const BBlockTransferStep& b_block_copy_step,
405 CThreadBuffer& c_thread_buf,
406 AScaleStruct& a_scale_struct,
407 BScaleStruct& b_scale_struct,
409 index_t num_loop_per_scale)
const
411 constexpr
index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
412 static constexpr
auto NumScaleKBlock =
415 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
416 Base::a_thread_desc_.GetElementSpaceSize());
417 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
418 Base::b_thread_desc_.GetElementSpaceSize());
420 using CScaleStruct =
typename Base::template CScale<AScaleStruct, BScaleStruct>;
421 auto c_scale_struct = CScaleStruct{};
424 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
425 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
427 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
428 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
431 a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
432 b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
434 c_scale_struct.Load(a_scale_struct, b_scale_struct);
437 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
438 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
441 c_thread_buf.Clear();
443 auto blockwise_gemm_func = [&]() {
447 Base::a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
450 Base::a_thread_desc_,
455 Base::b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
458 Base::b_thread_desc_,
467 c_scale_struct.Clear();
468 static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](
auto k0) {
469 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
470 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
473 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
474 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
476 kscale0 * (KRepeat / NumScaleKBlock) + k0;
477 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
478 a_thread_buf[
Number<Base::a_thread_desc_.CalculateOffset(
487 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
488 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
490 kscale0 * (KRepeat / NumScaleKBlock) + k0;
491 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
492 b_thread_buf[
Number<Base::b_thread_desc_.CalculateOffset(
502 using wmma_input_type_a =
503 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
504 using wmma_input_type_b =
505 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
508 a_thread_vec.template AsType<wmma_input_type_a>(),
509 b_thread_vec.template AsType<wmma_input_type_b>(),
510 c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
514 c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
521 if constexpr(HasMainLoop)
526 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
527 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
529 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
530 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
532 a_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
533 b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
536 blockwise_gemm_func();
539 c_scale_struct.Load(a_scale_struct, b_scale_struct);
541 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
542 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
545 }
while(i < (num_loop - 1));
552 blockwise_gemm_func();
568 decltype(a_block_desc_k0_m0_m1_m2_k1),
569 decltype(a_thread_desc_),
570 Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, 1, A_K1>,
579 decltype(b_block_desc_k0_n0_n1_n2_k1),
580 decltype(b_thread_desc_),
581 Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, 1, B_K1>,
587 AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
588 BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
589 using Base::c_thread_desc_;
595 typename ComputeTypeA,
596 typename ComputeTypeB,
597 typename AccDataType,
598 typename AWmmaTileDesc,
599 typename BWmmaTileDesc,
600 index_t ABlockTransferSrcScalarPerVector,
601 index_t BBlockTransferSrcScalarPerVector,
621 ABlockTransferSrcScalarPerVector,
622 BBlockTransferSrcScalarPerVector,
642 ABlockTransferSrcScalarPerVector,
643 BBlockTransferSrcScalarPerVector,
667 ABlockTransferSrcScalarPerVector,
668 BBlockTransferSrcScalarPerVector,
689 using Base::wmma_gemm;
691 using Base::CalculateCThreadOriginDataIndex;
693 GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
694 using Base::GetCThreadBuffer;
696 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
698 using Base::a_block_desc_k0_m0_m1_m2_k1;
699 using Base::b_block_desc_k0_n0_n1_n2_k1;
701 using typename Base::Empty;
718 template <
typename AScaleStruct,
typename BScaleStruct>
721 static constexpr
auto KRepeatNoScale = 1;
722 static constexpr
auto NumScaleKBlock =
724 static constexpr
auto KRepeatPerNumScaleKBlock = KRepeatPerCluster / NumScaleKBlock;
730 static constexpr
index_t KRepeatNoScale = KRepeatPerCluster;
732 static constexpr
index_t KRepeatPerNumScaleKBlock = 1;
735 template <
bool HasMainLoop,
739 typename ABlockTransfer,
740 typename AGridBuffer,
741 typename ABlockBuffer,
742 typename ABlockTransferStep,
745 typename BBlockTransfer,
746 typename BGridBuffer,
747 typename BBlockBuffer,
748 typename BBlockTransferStep,
749 typename CThreadBuffer,
750 typename AScaleStruct,
751 typename BScaleStruct,
753 __device__
void Run(
const AGridDesc& a_grid_desc,
754 const ABlockDesc& a_block_desc,
755 ABlockTransfer& a_blockwise_copy,
756 const AGridBuffer& a_grid_buf,
757 ABlockBuffer& a_block_buf,
758 const ABlockTransferStep& a_block_copy_step,
759 const BGridDesc& b_grid_desc,
760 const BBlockDesc& b_block_desc,
761 BBlockTransfer& b_blockwise_copy,
762 const BGridBuffer& b_grid_buf,
763 BBlockBuffer& b_block_buf,
764 const BBlockTransferStep& b_block_copy_step,
765 CThreadBuffer& c_thread_buf,
767 BScaleStruct& b_scale_struct,
769 index_t num_loop_per_scale)
const
771 constexpr
index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
773 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
774 a_thread_desc_.GetElementSpaceSize());
775 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
776 b_thread_desc_.GetElementSpaceSize());
779 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
780 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
782 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
783 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
786 b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
789 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
790 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
793 c_thread_buf.Clear();
795 auto blockwise_gemm_func = [&]() {
799 a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
800 make_tuple(I0, m0, k0_offset + k0_inner, I0, I0, I0, I0),
810 b_block_desc_k0_n0_n1_n2_k1,
811 make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0),
822 b_block_desc_k0_n0_n1_n2_k1,
823 make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0),
825 b_scale_struct.scale_thread_bufs(I0)[
Number<
826 n0 * BScaleStruct::num_scale_k_block +
827 (k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}],
835 __builtin_amdgcn_sched_barrier(0);
842 if constexpr(k0_offset != 0 || KRepeat == 1)
844 __builtin_amdgcn_s_barrier();
845 __builtin_amdgcn_sched_barrier(0);
851 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
852 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
854 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
855 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
856 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
857 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
866 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
867 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
868 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
869 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
879 using wmma_input_type_a =
880 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
881 using wmma_input_type_b =
882 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
885 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
893 if constexpr(k0_offset + k0_inner == KRepeat - 1 &&
894 m0 == MRepeat - 1 && n0 == NRepeat - 1)
896 __builtin_amdgcn_sched_barrier(0);
898 __builtin_amdgcn_sched_barrier(0);
901 a_thread_vec.template AsType<wmma_input_type_a>(),
902 b_thread_vec.template AsType<wmma_input_type_b>(),
904 if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0)
906 __builtin_amdgcn_sched_barrier(0);
907 __builtin_amdgcn_s_setprio(1);
908 __builtin_amdgcn_sched_barrier(0);
915 __builtin_amdgcn_sched_barrier(0);
916 __builtin_amdgcn_s_setprio(0);
917 __builtin_amdgcn_sched_barrier(0);
922 if constexpr(HasMainLoop)
927 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
928 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
930 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
931 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
934 blockwise_gemm_func();
936 b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
941 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
942 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
945 }
while(i < (num_loop - 1));
952 blockwise_gemm_func();
957 static constexpr
auto a_thread_desc_ =
960 Number<KRepeatPerCluster>{},
967 Number<KPack / A_KRow * MRepeat>{},
973 static constexpr
auto b_thread_desc_ =
976 Number<KRepeatPerCluster>{},
983 Number<KPack / B_KRow * NRepeat>{},
992 decltype(a_block_desc_k0_m0_m1_m2_k1),
993 decltype(a_thread_desc_),
994 Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, 1, A_K1>,
1003 decltype(b_block_desc_k0_n0_n1_n2_k1),
1004 decltype(b_thread_desc_),
1005 Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, 1, B_K1>,
1013 using Base::c_thread_desc_;
1019 typename ComputeTypeA,
1020 typename ComputeTypeB,
1021 typename AccDataType,
1022 typename AWmmaTileDesc,
1023 typename BWmmaTileDesc,
1024 index_t ABlockTransferSrcScalarPerVector,
1025 index_t BBlockTransferSrcScalarPerVector,
1045 ABlockTransferSrcScalarPerVector,
1046 BBlockTransferSrcScalarPerVector,
1066 ABlockTransferSrcScalarPerVector,
1067 BBlockTransferSrcScalarPerVector,
1091 ABlockTransferSrcScalarPerVector,
1092 BBlockTransferSrcScalarPerVector,
1106 using Base::WaveSize;
1113 using Base::KRepeat;
1116 using Base::wmma_gemm;
1118 using Base::CalculateCThreadOriginDataIndex;
1120 GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
1121 using Base::GetCThreadBuffer;
1123 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
1125 using Base::a_block_desc_k0_m0_m1_m2_k1;
1126 using Base::b_block_desc_k0_n0_n1_n2_k1;
1128 using typename Base::Empty;
1143 constexpr
auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
1144 constexpr
auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
1145 constexpr
auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves;
1146 constexpr
auto wmma_interleave = 2;
1150 if constexpr(MPerBlock >= 128 && NPerBlock >= 128)
1152 __builtin_amdgcn_sched_group_barrier(0x008, 2 * wmma_interleave, 0);
1156 __builtin_amdgcn_sched_group_barrier(0x008, wmma_interleave, 0);
1158 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
1164 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
1165 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0);
1166 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
1167 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
1173 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
1174 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
1178 template <
bool HasMainLoop,
1181 typename ABlockDesc,
1182 typename ABlockTransfer,
1183 typename AGridBuffer,
1184 typename ABlockBuffer,
1185 typename ABlockTransferStep,
1187 typename BBlockDesc,
1188 typename BBlockTransfer,
1189 typename BGridBuffer,
1190 typename BBlockBuffer,
1191 typename BBlockTransferStep,
1192 typename CThreadBuffer,
1193 typename AScaleStruct,
1194 typename BScaleStruct,
1196 __device__
void Run(
const AGridDesc& a_grid_desc,
1197 const ABlockDesc& a_block_desc,
1198 ABlockTransfer& a_blockwise_copy,
1199 const AGridBuffer& a_grid_buf,
1200 ABlockBuffer& a_block_buf,
1201 const ABlockTransferStep& a_block_copy_step,
1202 const BGridDesc& b_grid_desc,
1204 BBlockTransfer& b_blockwise_copy,
1205 const BGridBuffer& b_grid_buf,
1207 const BBlockTransferStep& b_block_copy_step,
1208 CThreadBuffer& c_thread_buf,
1214 __builtin_amdgcn_sched_barrier(0);
1215 constexpr
index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
1217 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
1218 a_thread_desc_.GetElementSpaceSize());
1219 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
1220 b_thread_desc_.GetElementSpaceSize());
1223 constexpr
auto b_block_origin_idx =
make_tuple(I0, I0, I0, I0, I0, I0, I0);
1226 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
1227 b_blockwise_copy.Run(b_grid_desc,
1229 b_block_desc_k0_n0_n1_n2_k1,
1233 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
1234 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
1235 __builtin_amdgcn_sched_barrier(0);
1238 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
1241 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
1242 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
1248 a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
1258 c_thread_buf.Clear();
1260 __builtin_amdgcn_sched_barrier(0);
1263 if constexpr(HasMainLoop)
1268 auto LoopFunc = [&](
auto wmma_reg_buf,
auto local_read_buf) {
1269 b_blockwise_copy.Run(b_grid_desc,
1271 b_block_desc_k0_n0_n1_n2_k1,
1273 b_thread_bufs(local_read_buf));
1275 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
1279 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, wmma_reg_buf);
1281 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
1282 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
1287 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
1288 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
1290 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
1291 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
1292 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1293 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
1302 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
1303 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
1304 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1305 b_thread_bufs[wmma_reg_buf]
1306 [
Number<b_thread_desc_.CalculateOffset(
1315 using wmma_input_type_a =
1316 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
1317 using wmma_input_type_b =
1318 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
1321 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
1324 a_thread_vec.template AsType<wmma_input_type_a>(),
1325 b_thread_vec.template AsType<wmma_input_type_b>(),
1337 a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
1347 __builtin_amdgcn_sched_barrier(0);
1354 }
while(i < (num_loop - 2));
1360 b_blockwise_copy.Run(b_grid_desc,
1362 b_block_desc_k0_n0_n1_n2_k1,
1368 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
1373 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
1374 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
1376 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
1377 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
1378 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1379 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
1388 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
1389 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
1390 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1391 b_thread_bufs[I0][
Number<b_thread_desc_.CalculateOffset(
1401 using wmma_input_type_a =
1402 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
1403 using wmma_input_type_b =
1404 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
1407 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
1409 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
1410 b_thread_vec.template AsType<wmma_input_type_b>(),
1422 a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
1431 __builtin_amdgcn_sched_barrier(0);
1436 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
1437 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
1439 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
1440 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
1441 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1442 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
1451 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
1452 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
1453 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1454 b_thread_bufs[I1][
Number<b_thread_desc_.CalculateOffset(
1463 using wmma_input_type_a =
1464 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
1465 using wmma_input_type_b =
1466 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
1469 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
1471 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
1472 b_thread_vec.template AsType<wmma_input_type_b>(),
1487 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
1488 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
1490 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
1491 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
1492 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1493 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
1502 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
1503 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
1504 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1505 b_thread_bufs[I0][
Number<b_thread_desc_.CalculateOffset(
1514 using wmma_input_type_a =
1515 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
1516 using wmma_input_type_b =
1517 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
1520 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
1522 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
1523 b_thread_vec.template AsType<wmma_input_type_b>(),
1532 template <
bool HasMainLoop,
1535 typename ABlockDesc,
1536 typename ABlockTransfer,
1537 typename AGridBuffer,
1538 typename ABlockBuffer,
1539 typename ABlockTransferStep,
1541 typename BBlockDesc,
1542 typename BBlockTransfer,
1543 typename BGridBuffer,
1544 typename BBlockBuffer,
1545 typename BBlockTransferStep,
1546 typename CThreadBuffer,
1547 typename AScaleStruct,
1548 typename BScaleStruct,
1550 !ck::is_same_v<BScaleStruct, Empty>,
1551 bool>::type =
false>
1552 __device__
void Run(
const AGridDesc& a_grid_desc,
1553 const ABlockDesc& a_block_desc,
1554 ABlockTransfer& a_blockwise_copy,
1555 const AGridBuffer& a_grid_buf,
1556 ABlockBuffer& a_block_buf,
1557 const ABlockTransferStep& a_block_copy_step,
1558 const BGridDesc& b_grid_desc,
1560 BBlockTransfer& b_blockwise_copy,
1561 const BGridBuffer& b_grid_buf,
1563 const BBlockTransferStep& b_block_copy_step,
1564 CThreadBuffer& c_thread_buf,
1565 AScaleStruct& a_scale_struct,
1566 BScaleStruct& b_scale_struct,
1568 index_t num_loop_per_scale)
const
1570 __builtin_amdgcn_sched_barrier(0);
1571 constexpr
index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
1572 static constexpr
auto NumScaleKBlock =
1575 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
1576 a_thread_desc_.GetElementSpaceSize());
1577 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
1578 b_thread_desc_.GetElementSpaceSize());
1581 constexpr
auto b_block_origin_idx =
make_tuple(I0, I0, I0, I0, I0, I0, I0);
1583 using CScaleStruct =
typename Base::template CScale<AScaleStruct, BScaleStruct>;
1584 auto c_scale_struct = CScaleStruct{};
1586 auto gemm_core_func = [&](
auto reg_buf) {
1590 c_scale_struct.Clear();
1591 static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](
auto k0) {
1592 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
1593 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
1595 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
1596 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
1598 kscale0 * (KRepeat / NumScaleKBlock) + k0;
1599 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1600 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
1609 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
1610 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
1612 kscale0 * (KRepeat / NumScaleKBlock) + k0;
1613 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1614 b_thread_bufs[reg_buf]
1615 [
Number<b_thread_desc_.CalculateOffset(
1624 using wmma_input_type_a =
1625 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
1626 using wmma_input_type_b =
1627 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
1629 a_thread_vec.template AsType<wmma_input_type_a>(),
1630 b_thread_vec.template AsType<wmma_input_type_b>(),
1631 c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
1635 c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
1641 auto a_local_prefetch_func = [&]() {
1644 a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
1655 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
1656 b_blockwise_copy.Run(b_grid_desc,
1658 b_block_desc_k0_n0_n1_n2_k1,
1662 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
1663 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
1666 a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
1667 b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
1669 __builtin_amdgcn_sched_barrier(0);
1671 c_scale_struct.Load(a_scale_struct, b_scale_struct);
1674 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
1677 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
1678 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
1682 a_local_prefetch_func();
1685 c_thread_buf.Clear();
1687 __builtin_amdgcn_sched_barrier(0);
1690 if constexpr(HasMainLoop)
1695 auto LoopFunc = [&](
auto wmma_reg_buf,
auto local_read_buf) {
1696 b_blockwise_copy.Run(b_grid_desc,
1698 b_block_desc_k0_n0_n1_n2_k1,
1700 b_thread_bufs(local_read_buf));
1702 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
1706 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, wmma_reg_buf);
1708 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
1709 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
1711 a_scale_struct.template GlobalLoad<0>(
1712 (i + 2 + wmma_reg_buf) % num_loop_per_scale == 0);
1713 b_scale_struct.template GlobalLoad<0>(
1714 (i + 2 + wmma_reg_buf) % num_loop_per_scale == 0);
1716 gemm_core_func(wmma_reg_buf);
1721 a_local_prefetch_func();
1723 c_scale_struct.Load(a_scale_struct, b_scale_struct);
1726 __builtin_amdgcn_sched_barrier(0);
1733 }
while(i < (num_loop - 2));
1739 b_blockwise_copy.Run(b_grid_desc,
1741 b_block_desc_k0_n0_n1_n2_k1,
1747 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
1749 a_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
1750 b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
1757 a_local_prefetch_func();
1759 c_scale_struct.Load(a_scale_struct, b_scale_struct);
1761 __builtin_amdgcn_sched_barrier(0);
1775 static constexpr
auto b_thread_desc_ =
1784 using Base::a_thread_copy_;
1785 using Base::a_thread_desc_;
1786 using Base::c_thread_desc_;
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:211
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
TailNumber
Tail number enumeration for pipeline buffering.
Definition: scheduler_enum.hpp:49
@ Even
Even number of iterations.
@ Odd
Odd number of iterations.
@ Full
Full tail iterations.
@ Empty
No tail iterations.
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
BlockGemmPipelineScheduler
Block GEMM pipeline scheduler enumeration.
Definition: scheduler_enum.hpp:33
@ Intrawave
Schedule within a single wavefront.
@ Interwave
Schedule across multiple wavefronts.
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:301
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:36
Definition: blockwise_gemm_pipeline_wmmaops.hpp:26
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, false >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, AScaleStruct &, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:753
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, false >::BlockLoopTailNum static TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:712
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, false >::BlockHasHotloop static bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:710
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, false >::BlockLoopTailNum static TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:156
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, false >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, AScaleStruct &, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:180
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, false >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, AScaleStruct &a_scale_struct, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:393
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, false >::BlockHasHotloop static bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:154
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, true >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, AScaleStruct &, BScaleStruct &, index_t num_loop, index_t) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:1196
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, true >::HotLoopScheduler static constexpr __device__ auto HotLoopScheduler()
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:1141
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, true >::BlockHasHotloop static bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:1134
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, true >::BlockLoopTailNum static TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:1136
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, true >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, AScaleStruct &a_scale_struct, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:1552
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:35
Definition: sequence.hpp:43
ck::ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeTypeA, decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_thread_desc_), Sequence< KPack/A_K1/A_KRow, 1, 1, 1, 1, 1, A_K1 >, Sequence< 0, 1, 2, 3, 4, 5, 6 >, 6, A_K1, A_K1 >
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:11