20 typename ComputeDataType,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
43 typename ComputeDataType,
47 typename AMmaTileDesc,
48 typename BMmaTileDesc,
49 index_t ABlockTransferSrcScalarPerVector,
50 index_t BBlockTransferSrcScalarPerVector,
71 ABlockTransferSrcScalarPerVector,
72 BBlockTransferSrcScalarPerVector,
90 ABlockTransferSrcScalarPerVector,
91 BBlockTransferSrcScalarPerVector,
111 ABlockTransferSrcScalarPerVector,
112 BBlockTransferSrcScalarPerVector,
124 using Base::xdlops_gemm;
127 using Base::CalculateCThreadOriginDataIndex;
128 using Base::CalculateCThreadOriginDataIndex8D;
129 using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
130 using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
131 using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
132 using Base::GetCThreadBuffer;
133 using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
134 using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
135 using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
136 using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
137 using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
139 using Base::a_block_desc_m0_m1_m2_k;
140 using Base::b_block_desc_n0_n1_n2_k;
142 using Base::AMmaKStride;
143 using Base::BMmaKStride;
151 return num_loop > PrefetchStages;
162 #if !defined(__gfx11__) && !defined(__gfx12__)
165 constexpr
auto num_ds_read_inst_a =
166 HotLoopInstList::A_LDS_Read_Width *
sizeof(ADataType) == 16
167 ? HotLoopInstList::A_LDS_Read_Inst_Num
168 : HotLoopInstList::A_LDS_Read_Inst_Num / 2;
169 constexpr
auto num_ds_read_inst_b =
170 HotLoopInstList::B_LDS_Read_Width *
sizeof(BDataType) == 16
171 ? HotLoopInstList::B_LDS_Read_Inst_Num
172 : HotLoopInstList::B_LDS_Read_Inst_Num / 2;
174 constexpr
auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
175 constexpr
auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
177 constexpr
auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
178 constexpr
auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
180 constexpr
auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
181 constexpr
auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
183 constexpr
auto ds_read_a_issue_cycle =
184 HotLoopInstList::A_LDS_Read_Width *
sizeof(ADataType) == 16 ? 8 : 4;
185 constexpr
auto ds_read_b_issue_cycle =
186 HotLoopInstList::B_LDS_Read_Width *
sizeof(BDataType) == 16 ? 8 : 4;
187 constexpr
auto ds_read_a_mfma_rate =
188 (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
189 constexpr
auto ds_read_b_mfma_rate =
190 (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
192 constexpr
auto num_dsread_a_mfma =
193 (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
194 constexpr
auto num_dsread_b_mfma =
195 (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
203 constexpr
auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
204 constexpr
auto num_mfma_per_issue =
205 num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
206 constexpr
auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
207 constexpr
auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
213 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0);
214 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
216 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
217 __builtin_amdgcn_sched_group_barrier(
218 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0);
224 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0);
225 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
227 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
228 __builtin_amdgcn_sched_group_barrier(
229 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0);
234 if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
237 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0);
241 __builtin_amdgcn_sched_group_barrier(0x100,
242 num_ds_read_inst_a - (num_dsread_a_mfma - 1) *
246 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
250 if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
253 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0);
257 __builtin_amdgcn_sched_group_barrier(0x100,
258 num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
262 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
267 template <
bool HasMainLoop,
271 typename ABlockTransfer,
272 typename AGridBuffer,
273 typename ABlockBuffer,
274 typename ABlockTransferStep,
277 typename BBlockTransfer,
278 typename BGridBuffer,
279 typename BBlockBuffer,
280 typename BBlockTransferStep,
281 typename CThreadBuffer>
282 __device__
void Run(
const AGridDesc& a_grid_desc,
283 const ABlockDesc& a_block_desc,
284 ABlockTransfer& a_blockwise_copy,
285 const AGridBuffer& a_grid_buf,
286 ABlockBuffer& a_block_buf,
287 const ABlockTransferStep& a_block_copy_step,
288 const BGridDesc& b_grid_desc,
289 const BBlockDesc& b_block_desc,
290 BBlockTransfer& b_blockwise_copy,
291 const BGridBuffer& b_grid_buf,
292 BBlockBuffer& b_block_buf,
293 const BBlockTransferStep& b_block_copy_step,
294 CThreadBuffer& c_thread_buf,
297 __builtin_amdgcn_sched_barrier(0);
298 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
299 a_thread_desc_.GetElementSpaceSize());
300 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
301 b_thread_desc_.GetElementSpaceSize());
304 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
305 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
307 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
308 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
311 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
312 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
315 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
316 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
318 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
319 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
322 c_thread_buf.Clear();
328 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
336 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
345 __builtin_amdgcn_sched_barrier(0);
348 if constexpr(HasMainLoop)
355 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
356 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
358 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
359 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
361 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
362 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
371 a_thread_vec.template AsType<ComputeDataType>()(ik) =
372 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
374 b_thread_vec.template AsType<ComputeDataType>()(ik) =
375 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
379 using mfma_input_type =
381 xdlops_gemm.K1PerXdlops>::type;
384 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
387 a_thread_vec.template AsType<mfma_input_type>(),
388 b_thread_vec.template AsType<mfma_input_type>(),
398 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
406 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
416 __builtin_amdgcn_sched_barrier(0);
419 }
while(i < (num_loop - 1));
431 a_thread_vec.template AsType<ComputeDataType>()(ik) =
432 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
434 b_thread_vec.template AsType<ComputeDataType>()(ik) =
435 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
439 using mfma_input_type =
440 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
443 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
445 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
446 b_thread_vec.template AsType<mfma_input_type>(),
458 using Base::a_thread_copy_;
459 using Base::a_thread_desc_;
460 using Base::b_thread_copy_;
461 using Base::b_thread_desc_;
462 using Base::c_thread_desc_;
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
Definition: blockwise_gemm_pipeline_xdlops.hpp:34
ck::BlockwiseGemmXdlops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v3.hpp:154
ck::BlockwiseGemmXdlops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v3.hpp:282
ck::BlockwiseGemmXdlops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::HotLoopScheduler static constexpr __device__ auto HotLoopScheduler()
Definition: blockwise_gemm_pipeline_xdlops_v3.hpp:160
ck::BlockwiseGemmXdlops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v3.hpp:149
Definition: blockwise_gemm_pipeline_xdlops_v3.hpp:37
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10