20 typename ComputeDataType,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
36 struct BlockwiseGemmXdlops_pipeline_v4
43 typename ComputeDataType,
47 typename AMmaTileDesc,
48 typename BMmaTileDesc,
49 index_t ABlockTransferSrcScalarPerVector,
50 index_t BBlockTransferSrcScalarPerVector,
71 ABlockTransferSrcScalarPerVector,
72 BBlockTransferSrcScalarPerVector,
90 ABlockTransferSrcScalarPerVector,
91 BBlockTransferSrcScalarPerVector,
111 ABlockTransferSrcScalarPerVector,
112 BBlockTransferSrcScalarPerVector,
124 using Base::xdlops_gemm;
127 using Base::CalculateCThreadOriginDataIndex;
128 using Base::CalculateCThreadOriginDataIndex8D;
129 using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
130 using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
131 using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
132 using Base::GetCThreadBuffer;
133 using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
134 using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
135 using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
136 using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
137 using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
139 using Base::a_block_desc_m0_m1_m2_k;
140 using Base::b_block_desc_n0_n1_n2_k;
142 using Base::AMmaKStride;
143 using Base::BMmaKStride;
152 return num_loop > PrefetchStages;
157 if(num_loop % HotloopUnroll == 1)
171 constexpr
auto num_ds_read_inst_a =
175 constexpr
auto num_ds_read_inst_b =
181 constexpr
auto num_dswrite_per_issue_a =
183 constexpr
auto num_dsread_per_issue_a = num_ds_read_inst_a / num_issue_a;
186 constexpr
auto num_dswrite_per_issue_b =
188 constexpr
auto num_dsread_per_issue_b = num_ds_read_inst_b / num_issue_b;
190 constexpr
auto num_mfma_per_issue =
197 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
198 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
203 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0);
204 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
207 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
208 __builtin_amdgcn_sched_group_barrier(0x008,
209 num_mfma_per_issue - num_dsread_per_issue_a -
210 num_dswrite_per_issue_a,
218 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
219 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
224 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0);
225 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
228 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
229 __builtin_amdgcn_sched_group_barrier(0x008,
230 num_mfma_per_issue - num_dsread_per_issue_a -
231 num_dswrite_per_issue_b,
234 __builtin_amdgcn_sched_barrier(0);
237 template <
bool HasMainLoop,
241 typename ABlockTransfer,
242 typename AGridBuffer,
243 typename ABlockBuffer,
244 typename ABlockTransferStep,
247 typename BBlockTransfer,
248 typename BGridBuffer,
249 typename BBlockBuffer,
250 typename BBlockTransferStep,
251 typename CThreadBuffer>
252 __device__
void Run(
const AGridDesc& a_grid_desc,
253 const ABlockDesc& a_block_desc,
254 ABlockTransfer& a_blockwise_copy,
255 const AGridBuffer& a_grid_buf,
256 ABlockBuffer& a_block_buf,
257 const ABlockTransferStep& a_block_copy_step,
258 const BGridDesc& b_grid_desc,
259 const BBlockDesc& b_block_desc,
260 BBlockTransfer& b_blockwise_copy,
261 const BGridBuffer& b_grid_buf,
262 BBlockBuffer& b_block_buf,
263 const BBlockTransferStep& b_block_copy_step,
264 CThreadBuffer& c_thread_buf,
267 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
269 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
276 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
277 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
279 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
280 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
283 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(
I0));
284 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(
I0));
308 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
309 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
311 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
312 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
315 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(
I1));
316 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(
I1));
319 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
320 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
322 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
323 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
326 c_thread_buf.Clear();
329 if constexpr(HasMainLoop)
335 auto LoopFunc = [&](
auto lds_read_buf,
336 auto lds_read_reg_buf,
345 a_block_buf.At(lds_read_buf),
348 a_thread_bufs(lds_read_reg_buf));
353 b_block_buf.At(lds_read_buf),
356 b_thread_bufs(lds_read_reg_buf));
360 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf));
361 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf));
363 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
364 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
366 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
367 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
376 a_thread_vec.template AsType<ComputeDataType>()(ik) =
377 a_thread_bufs[mfma_reg_buf]
380 b_thread_vec.template AsType<ComputeDataType>()(ik) =
381 b_thread_bufs[mfma_reg_buf]
386 using mfma_input_type =
394 a_thread_vec.template AsType<mfma_input_type>(),
395 b_thread_vec.template AsType<mfma_input_type>(),
408 }
while(i < (num_loop - PrefetchStages));
411 auto ReadWriteCompFunc = [&](
auto lds_read_buf,
412 auto lds_read_reg_buf,
421 a_block_buf.At(lds_read_buf),
424 a_thread_bufs(lds_read_reg_buf));
429 b_block_buf.At(lds_read_buf),
432 b_thread_bufs(lds_read_reg_buf));
436 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf));
437 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf));
446 a_thread_vec.template AsType<ComputeDataType>()(ik) =
449 b_thread_vec.template AsType<ComputeDataType>()(ik) =
454 using mfma_input_type =
460 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
461 b_thread_vec.template AsType<mfma_input_type>(),
470 auto ReadCompFunc = [&](
auto lds_read_buf,
auto lds_read_reg_buf,
auto mfma_reg_buf) {
477 a_block_buf.At(lds_read_buf),
480 a_thread_bufs(lds_read_reg_buf));
485 b_block_buf.At(lds_read_buf),
488 b_thread_bufs(lds_read_reg_buf));
499 a_thread_vec.template AsType<ComputeDataType>()(ik) =
502 b_thread_vec.template AsType<ComputeDataType>()(ik) =
507 using mfma_input_type =
513 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
514 b_thread_vec.template AsType<mfma_input_type>(),
523 auto CompFunc = [&](
auto mfma_reg_buf) {
531 a_thread_vec.template AsType<ComputeDataType>()(ik) =
534 b_thread_vec.template AsType<ComputeDataType>()(ik) =
539 using mfma_input_type =
545 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
546 b_thread_vec.template AsType<mfma_input_type>(),
567 using Base::a_thread_copy_;
568 using Base::a_thread_desc_;
569 using Base::b_thread_copy_;
570 using Base::b_thread_desc_;
571 using Base::c_thread_desc_;
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
Definition: blockwise_gemm_pipeline_xdlops.hpp:34
static constexpr index_t B_LDS_Write_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:46
static constexpr index_t A_LDS_Read_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:49
static constexpr index_t A_LDS_Read_Width
Definition: blkgemmpipe_scheduler.hpp:82
static constexpr index_t B_LDS_Read_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:51
static constexpr index_t A_LDS_Write_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:44
static constexpr index_t C_MFMA_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:54
static constexpr index_t A_Buffer_Load_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:39
static constexpr index_t B_Buffer_Load_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:41
static constexpr index_t B_LDS_Read_Width
Definition: blkgemmpipe_scheduler.hpp:83
ck::BlockwiseGemmXdlops_pipeline_v4< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v4.hpp:150
ck::BlockwiseGemmXdlops_pipeline_v4< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::HotLoopScheduler static constexpr __device__ void HotLoopScheduler()
Definition: blockwise_gemm_pipeline_xdlops_v4.hpp:167
ck::BlockwiseGemmXdlops_pipeline_v4< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v4.hpp:155
ck::BlockwiseGemmXdlops_pipeline_v4< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v4.hpp:252
Definition: blockwise_gemm_pipeline_xdlops.hpp:103
static constexpr auto I1
Definition: blockwise_gemm_pipeline_xdlops.hpp:105
static constexpr auto b_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops.hpp:963
static constexpr __device__ auto HotLoopScheduler()
Definition: blockwise_gemm_pipeline_xdlops.hpp:375
static constexpr auto c_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops.hpp:969
BThreadCopy b_thread_copy_
Definition: blockwise_gemm_pipeline_xdlops.hpp:993
static constexpr auto I0
Definition: blockwise_gemm_pipeline_xdlops.hpp:104
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition: blockwise_gemm_pipeline_xdlops.hpp:455
AThreadCopy a_thread_copy_
Definition: blockwise_gemm_pipeline_xdlops.hpp:992
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition: blockwise_gemm_pipeline_xdlops.hpp:456
static constexpr auto a_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops.hpp:957
static constexpr auto xdlops_gemm
Definition: blockwise_gemm_pipeline_xdlops.hpp:122
__device__ void Run(const SrcDesc &, const SrcRefToOriginDisplacement &, const SrcBuffer &src_buf, const DstDesc &, const DstOriginIdx &, DstBuffer &dst_buf) const
Definition: threadwise_tensor_slice_transfer.hpp:1293
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10