20 typename ComputeTypeA,
21 typename ComputeTypeB,
23 typename AWmmaTileDesc,
24 typename BWmmaTileDesc,
25 index_t ABlockTransferSrcScalarPerVector,
26 index_t BBlockTransferSrcScalarPerVector,
42 typename ComputeTypeA,
43 typename ComputeTypeB,
45 typename AWmmaTileDesc,
46 typename BWmmaTileDesc,
47 index_t ABlockTransferSrcScalarPerVector,
48 index_t BBlockTransferSrcScalarPerVector,
66 ABlockTransferSrcScalarPerVector,
67 BBlockTransferSrcScalarPerVector,
84 ABlockTransferSrcScalarPerVector,
85 BBlockTransferSrcScalarPerVector,
103 ABlockTransferSrcScalarPerVector,
104 BBlockTransferSrcScalarPerVector,
122 using Base::wmma_gemm;
125 using Base::CalculateCThreadOriginDataIndex;
127 GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
128 using Base::GetCThreadBuffer;
130 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
132 using Base::a_block_desc_k0_m0_m1_m2_k1;
133 using Base::b_block_desc_k0_n0_n1_n2_k1;
143 return num_loop > PrefetchStages;
260 template <
typename ABlockBuffer,
261 typename AThreadBuffer,
262 typename BBlockBuffer,
263 typename BThreadBuffer,
264 typename BScaleStruct>
265 __device__
inline void LocalLoad(ABlockBuffer& a_block_buf,
266 AThreadBuffer& a_thread_buf,
267 BBlockBuffer& b_block_buf,
268 BThreadBuffer& b_thread_buf,
269 BScaleStruct& b_scale_struct)
const
274 a_block_desc_k0_m0_m1_m2_k1,
282 if constexpr(ck::is_same_v<BScaleStruct, Empty>)
286 b_block_desc_k0_n0_n1_n2_k1,
298 b_block_desc_k0_n0_n1_n2_k1,
301 b_scale_struct.b_scale_thread_bufs(
302 I0)[
Number<n0 * BScaleStruct::num_scale_k_block +
303 k0 / BScaleStruct::num_scale_krepeat>{}],
312 template <
bool HasMainLoop,
316 typename ABlockTransfer,
317 typename AGridBuffer,
318 typename ABlockBuffer,
319 typename ABlockTransferStep,
322 typename BBlockTransfer,
323 typename BGridBuffer,
324 typename BBlockBuffer,
325 typename BBlockTransferStep,
326 typename CThreadBuffer,
327 typename BScaleStruct>
328 __device__
void Run(
const AGridDesc& a_grid_desc,
329 const ABlockDesc& a_block_desc,
330 ABlockTransfer& a_blockwise_copy,
331 const AGridBuffer& a_grid_buf,
332 ABlockBuffer& a_block_buf,
333 const ABlockTransferStep& a_block_copy_step,
334 const BGridDesc& b_grid_desc,
335 const BBlockDesc& b_block_desc,
336 BBlockTransfer& b_blockwise_copy,
337 const BGridBuffer& b_grid_buf,
338 BBlockBuffer& b_block_buf,
339 const BBlockTransferStep& b_block_copy_step,
340 CThreadBuffer& c_thread_buf,
342 BScaleStruct& b_scale_struct,
344 index_t num_loop_per_scale)
const
346 __builtin_amdgcn_sched_barrier(0);
347 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
348 a_thread_desc_.GetElementSpaceSize());
349 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
350 b_thread_desc_.GetElementSpaceSize());
353 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
354 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
356 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
357 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
359 b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
362 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
363 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
366 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
367 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
369 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
370 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
373 c_thread_buf.Clear();
378 LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
380 __builtin_amdgcn_sched_barrier(0);
383 if constexpr(HasMainLoop)
390 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
391 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
393 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
394 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
396 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
397 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
399 b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
404 vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
405 vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
407 static_for<0, KPack / A_KRow, 1>{}([&](
auto ik) {
408 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
409 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
417 static_for<0, KPack / B_KRow, 1>{}([&](
auto ik) {
418 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
419 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
428 using wmma_input_type_a =
429 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
430 using wmma_input_type_b =
431 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
434 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
436 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
437 b_thread_vec.template AsType<wmma_input_type_b>(),
445 LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
448 __builtin_amdgcn_sched_barrier(0);
451 }
while(i < (num_loop - 1));
459 vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
460 vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
462 static_for<0, KPack / A_KRow, 1>{}([&](
auto ik) {
463 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
467 static_for<0, KPack / B_KRow, 1>{}([&](
auto ik) {
468 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
473 using wmma_input_type_a =
474 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
475 using wmma_input_type_b =
476 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
479 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
481 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
482 b_thread_vec.template AsType<wmma_input_type_b>(),
494 using Base::a_thread_copy_;
495 using Base::a_thread_desc_;
496 using Base::b_thread_copy_;
497 using Base::b_thread_desc_;
498 using Base::c_thread_desc_;
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:95
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:35
Definition: blockwise_gemm_pipeline_wmmaops.hpp:26
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:141
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:328
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:146
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::LocalLoad __device__ void LocalLoad(ABlockBuffer &a_block_buf, AThreadBuffer &a_thread_buf, BBlockBuffer &b_block_buf, BThreadBuffer &b_thread_buf, BScaleStruct &b_scale_struct) const
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:265
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::HotLoopScheduler static constexpr __device__ auto HotLoopScheduler()
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:152
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:36
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10