12 #define ENABLE_ASM_MARKER 1
14 #define ASM_MARKER(marker) \
15 __builtin_amdgcn_sched_barrier(0); \
16 asm volatile("; [POYENC] " #marker); \
17 __builtin_amdgcn_sched_barrier(0);
19 #define ASM_MARKER(marker)
22 #define ADD_SBARRIER_FOR_PHASE0 1
23 #if !defined(CK_TILE_DISABLE_PACKED_FP32)
24 #define CK_TILE_DISABLE_PACKED_FP32 0
30 #define ENABLE_DEBUG_STMTS 1
31 #if ENABLE_DEBUG_STMTS
33 if(get_block_1d_id() == 0 && get_warp_id() == WARP_ID && get_lane_id() == LANE_ID)
35 #define DEBUG_STMTS if constexpr(false)
40 template <
typename PipelineProblem,
bool kIsMasking>
43 template <
typename PipelineProblem>
46 template <ck_tile::index_t WaveGroup, ck_tile::index_t Phase>
52 if constexpr(WaveGroup == 0)
54 if constexpr(Phase == 0)
57 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
58 __builtin_amdgcn_sched_group_barrier(0x200, 2, 0);
59 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
62 else if constexpr(Phase == 1)
64 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
65 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0);
67 else if constexpr(Phase == 2)
69 #if !CK_TILE_DISABLE_PACKED_FP32
70 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0);
73 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
74 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0);
77 else if constexpr(Phase == 3)
79 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
80 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0);
85 if constexpr(Phase == 0)
87 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
88 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0);
90 else if constexpr(Phase == 1)
93 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
94 __builtin_amdgcn_sched_group_barrier(0x200, 2, 0);
95 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
98 else if constexpr(Phase == 2)
100 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
101 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0);
103 else if constexpr(Phase == 3)
105 #if !CK_TILE_DISABLE_PACKED_FP32
106 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0);
109 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
110 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0);
117 template <
typename PipelineProblem>
120 template <ck_tile::index_t WaveGroup, ck_tile::index_t Phase>
126 if constexpr(WaveGroup == 0)
128 if constexpr(Phase == 0)
131 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
132 __builtin_amdgcn_sched_group_barrier(0x200, 2, 0);
133 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
136 else if constexpr(Phase == 1)
138 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
139 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0);
141 else if constexpr(Phase == 2)
143 #if !CK_TILE_DISABLE_PACKED_FP32
144 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0);
147 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
148 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0);
151 else if constexpr(Phase == 3)
153 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
154 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0);
159 if constexpr(Phase == 0)
161 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
162 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0);
164 else if constexpr(Phase == 1)
167 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
168 __builtin_amdgcn_sched_group_barrier(0x200, 2, 0);
169 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
172 else if constexpr(Phase == 2)
174 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
175 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0);
177 else if constexpr(Phase == 3)
179 #if !CK_TILE_DISABLE_PACKED_FP32
180 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0);
183 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
184 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0);
194 #if CK_TILE_DISABLE_PACKED_FP32
198 asm volatile(
"v_fma_f32 %[result], %[a], %[b], %[c]"
199 : [result]
"=v"(result)
200 : [
a]
"v"(
a), [b]
"s"(b), [c]
"v"(c));
208 asm volatile(
"v_add_f32_e32 %[result], %[lhs], %[rhs]"
209 : [result]
"=v"(result)
210 : [lhs]
"v"(lhs), [rhs]
"v"(rhs));
217 asm volatile(
"v_mul_f32_e32 %[result], %[lhs], %[rhs]"
218 : [result]
"=v"(result)
219 : [lhs]
"v"(lhs), [rhs]
"v"(rhs));
226 asm volatile(
"v_cvt_pk_f16_f32 %[result], %[a], %[b]"
227 : [result]
"=v"(result)
228 : [
a]
"v"(
a), [b]
"v"(b));
235 asm volatile(
"v_cvt_pk_bf16_f32 %[result], %[a], %[b]"
236 : [result]
"=v"(result)
237 : [
a]
"v"(
a), [b]
"v"(b));
244 asm volatile(
"v_pk_mul_f32 %[result], %[lhs], %[rhs]"
245 : [result]
"=v"(result)
246 : [lhs]
"v"(lhs), [rhs]
"v"(rhs));
253 template <
typename Problem_,
typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
268 static_assert(is_generic_attention_mask_v<FmhaMask>);
270 static_assert(std::is_same_v<SaccDataType, SMPLComputeDataType>,
271 "we will the same dist tensor 'sp_compute' for both gemm0 & softmax");
276 static_assert(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>);
288 static_assert(kQKHeaddim == 128 && kSubQKHeaddim == 128,
"only supports hdim=hdim_v=128");
290 static constexpr
bool kIsGroupMode = Problem::kIsGroupMode;
291 static constexpr
bool kPadSeqLenQ = Problem::kPadSeqLenQ;
292 static constexpr
bool kPadSeqLenK = Problem::kPadSeqLenK;
293 static constexpr
bool kPadHeadDimQ = Problem::kPadHeadDimQ;
294 static constexpr
bool kPadHeadDimV = Problem::kPadHeadDimV;
295 static constexpr
bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
296 static constexpr
auto BiasEnum = Problem::BiasEnum;
297 static constexpr
bool kStoreLSE = Problem::kStoreLSE;
298 static constexpr
bool kHasDropout = Problem::kHasDropout;
299 static constexpr
auto QScaleEnum = Problem::QScaleEnum;
300 static constexpr
bool kSkipMinSeqlenQ = Problem::kSkipMinSeqlenQ;
302 !kStoreLSE && !kHasDropout &&
305 "enable unsupported features");
310 kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
312 kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
314 kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
317 kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
320 if constexpr(Problem::kBlockPerCu != -1)
321 return Problem::kBlockPerCu;
332 Policy::template GetSmemSize<Problem>() +
337 template <ck_tile::index_t MPerBlock, ck_tile::index_t NPerBlock>
341 constexpr
auto lds_block_desc =
347 return lds_block_desc;
351 template <ck_tile::index_t MPerBlock>
358 return lds_block_desc;
361 template <
typename DataType,
typename Descriptor>
367 make_tensor_view<address_space_enum::lds>(
reinterpret_cast<DataType*
>(base), desc);
372 template <u
int16_t Vmcnt, u
int8_t Lgkmcnt, u
int8_t Expcnt = 7>
378 __builtin_amdgcn_s_waitcnt((((0b110000 & Vmcnt) << (14 - 4)) | (0b1111 & Vmcnt)) |
379 ((0b111 & Expcnt) << 4) | ((0b1111 & Lgkmcnt) << 8));
382 template <u
int16_t Vmcnt>
385 s_waitcnt<Vmcnt, 15>();
388 template <u
int8_t Lgkmcnt>
391 s_waitcnt<63, Lgkmcnt>();
394 template <
typename QDramBlockWindowTmp,
395 typename KDramBlockWindowTmp,
396 typename VDramBlockWindowTmp,
397 typename LSEDramBlockWindowTmp,
398 typename QElementFunction,
399 typename KElementFunction,
400 typename VElementFunction,
401 typename LSEElementFunction,
402 typename SAccElementFunction,
403 typename PComputeElementFunction,
404 typename OAccElementFunction>
406 const QElementFunction& q_element_func,
407 const KDramBlockWindowTmp& k_dram_block_window_tmp,
408 [[maybe_unused]]
const KElementFunction& k_element_func,
409 const VDramBlockWindowTmp& v_dram_block_window_tmp,
410 [[maybe_unused]]
const VElementFunction& v_element_func,
411 LSEDramBlockWindowTmp& lse_dram_window_tmp,
412 const LSEElementFunction& lse_element_func,
413 [[maybe_unused]]
const SAccElementFunction& s_acc_element_func,
414 const PComputeElementFunction& p_compute_element_func,
415 const OAccElementFunction& o_acc_element_func,
418 void* smem_ptr)
const
428 static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[
number<0>{}] &&
429 kN0 == KDramBlockWindowTmp{}.get_window_lengths()[
number<0>{}] &&
430 kK0 == KDramBlockWindowTmp{}.get_window_lengths()[
number<1>{}] &&
431 kK1 == VDramBlockWindowTmp{}.get_window_lengths()[
number<0>{}] &&
432 kN1 == VDramBlockWindowTmp{}.get_window_lengths()[
number<1>{}],
435 static_assert(
sizeof(
SaccDataType) * kM0 * kN0 <= GetSmemSize());
436 auto s_lds = make_tensor_view<address_space_enum::lds>(
437 reinterpret_cast<SaccDataType*
>(
static_cast<char*
>(smem_ptr)),
438 MakeSimpleLdsDesc<kM0, kN0>());
439 [[maybe_unused]]
auto s_lds_window =
442 auto p_lds = make_tensor_view<address_space_enum::lds>(
443 reinterpret_cast<PDataType*
>(
static_cast<char*
>(smem_ptr) +
444 Policy::template GetSmemSize<Problem>()),
445 MakeSimpleLdsDesc<kM0, kN0>());
446 [[maybe_unused]]
auto p_lds_window =
449 auto o_lds = make_tensor_view<address_space_enum::lds>(
450 reinterpret_cast<PDataType*
>(
static_cast<char*
>(smem_ptr)),
451 MakeSimpleLdsDesc<kM0, kN1>());
452 [[maybe_unused]]
auto o_lds_window =
455 auto m_lds = make_tensor_view<address_space_enum::lds>(
457 Policy::template GetSmemSize<Problem>()),
458 MakeSimpleLdsDesc1D<kM0>());
459 [[maybe_unused]]
auto m_lds_window =
462 const index_t warp_group_id = get_warp_id() / 4;
465 constexpr
auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
466 constexpr
auto gemm_1 = Policy::template GetPVBlockGemm<Problem>();
469 q_dram_block_window_tmp, Policy::template MakeQRegTileDistribution<Problem>());
472 const auto f_max = [](
auto e0,
auto e1) {
return max(e0, e1); };
473 const auto f_sum = [](
auto e0,
auto e1) {
return e0 + e1; };
477 return make_lds_tile_window<KDataType>(
478 smem_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf));
484 return make_lds_tile_window<KDataType>(
485 smem_ptr, Policy::template MakeVLdsStoreBlockDescriptor<Problem>(i_buf));
490 make_lds_tile_window<KDataType>(
492 Policy::template MakeKLdsLoadBlockDescriptor<Problem>()),
493 Policy::template MakeKRegTileDistribution<Problem>())),
498 make_lds_tile_window<VDataType>(
500 Policy::template MakeVLdsLoadBlockDescriptor<Problem>()),
501 Policy::template MakeVRegTileDistribution<Problem>())),
505 decltype(make_static_distributed_tensor<QDataType>(
506 Policy::template MakeQRegTileDistribution<Problem>())) q_tile;
517 union sp_compute_type
521 decltype(gemm_0.MakeCBlockTile()) sp_compute;
522 decltype(make_static_distributed_tensor<PDataType>(
523 Policy::template MakePRegTileDistribution<Problem>())) p;
527 decltype(gemm_1.MakeCBlockTile()) o_acc;
528 constexpr
index_t fmha_alu_D_reg_cnt = 6;
530 static_assert(fmha_alu_D_reg_cnt <= o_acc.thread_buf_.size());
532 decltype(block_tile_reduce<SMPLComputeDataType>(
539 make_lds_tile_window<KDataType>(
540 static_cast<char*
>(smem_ptr) + (idx)*Policy::template GetSmemSizeKV<Problem>(),
541 Policy::template MakeKLdsLoadBlockDescriptor<Problem>()),
542 Policy::template MakeKRegTileDistribution<Problem>());
546 v_lds_window_load(idx) =
548 static_cast<char*
>(smem_ptr) +
549 (idx + 2) * Policy::template GetSmemSizeKV<Problem>(),
550 Policy::template MakeVLdsLoadBlockDescriptor<Problem>()),
551 Policy::template MakeVRegTileDistribution<Problem>());
555 auto origin_q =
load_tile(q_dram_window);
558 q_tile = transformed_q;
562 set_tile(m, bit_cast<float>(0xff7fffff));
565 const auto q_origin = q_dram_window.get_window_origin();
566 const auto [seqlen_k_start, seqlen_k_end] =
570 index_t kv_token_start = seqlen_k_start;
573 if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
575 if(num_total_loop <= 0)
577 if constexpr(kStoreLSE)
580 make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
595 k_dram_block_window_tmp.get_window_lengths(),
597 Policy::template MakeKDramTileDistribution<Problem>());
598 k_dram_window.init_raw();
602 v_dram_block_window_tmp.get_window_lengths(),
604 Policy::template MakeVDramTileDistribution<Problem>());
605 v_dram_window.init_raw();
609 constexpr
index_t k0_loops = kQKHeaddim / kK0;
610 constexpr
index_t k1_loops = kN0 / kK1;
611 static_assert(1 == k0_loops);
612 static_assert(1 == k1_loops);
613 static_assert(kN0 == kK1);
615 constexpr
index_t NumWarpGroups = Problem::kBlockSize / Policy::NumThreadPerWarpGroup;
616 static_assert(NumWarpGroups == 2);
618 [[maybe_unused]]
auto print_dist_tensor = [&](
const auto& dist_tensor,
const char* name) {
619 printf(
"[POYENC] %s (size=%d): %5.2f",
621 decltype(dist_tensor.thread_buf_)::size(),
622 ck_tile::type_convert<float>(dist_tensor.thread_buf_[0]));
623 static_for<1, decltype(dist_tensor.thread_buf_)::size(), 1>{}([&](
auto i) {
624 printf(
", %5.2f", ck_tile::type_convert<float>(dist_tensor.thread_buf_[i]));
629 [[maybe_unused]]
auto print_lds = [&](
auto lds_tile_window,
const char* name) {
630 const auto num_rows = lds_tile_window.get_window_lengths().at(
number<0>{});
631 const auto num_cols = lds_tile_window.get_window_lengths().at(
number<1>{});
633 auto desc = lds_tile_window.get_bottom_tensor_view().desc_;
634 auto data = lds_tile_window.get_bottom_tensor_view().buf_.p_data_;
636 if constexpr(
true || num_rows < num_cols)
638 for(
int row = 0; row < num_rows; ++row)
641 printf(
"[DEVICE] %s[%3d] = %5.2f",
644 ck_tile::type_convert<float>(data[
offset]));
645 for(
int col = 1; col < num_cols; ++col)
649 printf(
"%5.2f", ck_tile::type_convert<float>(data[
offset]));
656 for(
int col = 0; col < num_cols; ++col)
659 printf(
"[DEVICE] %s[%3d] = %5.2f",
662 ck_tile::type_convert<float>(data[
offset]));
663 for(
int row = 1; row < num_rows; ++row)
667 printf(
"%5.2f", ck_tile::type_convert<float>(data[
offset]));
674 [[maybe_unused]]
auto print_lds_1d = [&](
auto lds_tile_window,
const char* name) {
675 const auto num_elems = lds_tile_window.get_window_lengths().at(
number<0>{});
677 auto desc = lds_tile_window.get_bottom_tensor_view().desc_;
678 auto data = lds_tile_window.get_bottom_tensor_view().buf_.p_data_;
681 printf(
"[DEVICE] %s = %5.2f", name, ck_tile::type_convert<float>(data[
offset]));
682 for(
int e = 1; e < num_elems; ++e)
686 printf(
"%5.2f", ck_tile::type_convert<float>(data[
offset]));
693 constexpr
int K_mem_su_ld_insts = k_dram_window.get_num_of_access();
694 constexpr
int V_mem_su_ld_insts = v_dram_window.get_num_of_access();
696 auto K_mem_load = [&](
auto k_lds_write_idx) {
704 auto K_lds_load = [&](
auto k_lds_read_idx) {
705 kv_tile.k_tile =
load_tile(k_lds_window_load(k_lds_read_idx));
708 auto V_mem_load = [&](
auto v_lds_write_idx) {
715 auto V_lds_load = [&](
auto v_lds_read_idx) {
724 auto fmha_alu0 = [&](
auto sp_reg_idx) {
726 static_assert(m.thread_buf_.size() == 1,
727 "assuming that each thread holds 1 rowmax value");
728 auto m_latest = block_tile_reduce<SMPLComputeDataType>(
729 sp(sp_reg_idx).sp_compute,
sequence<1>{}, f_max, m.thread_buf_[0]);
730 #if defined(__gfx950__)
733 __builtin_amdgcn_permlane32_swap(bit_cast<int32_t>(m_latest.thread_buf_[0]),
734 bit_cast<int32_t>(m_latest.thread_buf_[0]),
738 m_latest.thread_buf_[0] = f_max(bit_cast<SMPLComputeDataType>(swapped_regs.x),
739 bit_cast<SMPLComputeDataType>(swapped_regs.y));
745 constexpr
auto p_spans =
746 std::decay_t<decltype(sp(sp_reg_idx).sp_compute)>::get_distributed_spans();
749 constexpr
auto i_j_idx =
make_tuple(idx0, idx1);
751 sp(sp_reg_idx).sp_compute(i_j_idx), scale_s, -scale_s * m(i_j_idx));
757 auto fmha_alu1 = [&](
auto sp_reg_idx) {
758 constexpr
auto p_spans =
759 std::decay_t<decltype(sp(sp_reg_idx).sp_compute)>::get_distributed_spans();
762 constexpr
auto i_j_idx =
make_tuple(idx0, idx1);
763 sp(sp_reg_idx).sp_compute(i_j_idx) =
768 auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
769 sp(sp_reg_idx).sp_compute,
773 static_assert(rowsum_p.thread_buf_.size() == 1,
774 "assuming that each thread holds 1 rowsum value");
775 #if defined(__gfx950__)
778 __builtin_amdgcn_permlane32_swap(bit_cast<int32_t>(rowsum_p.thread_buf_[0]),
779 bit_cast<int32_t>(rowsum_p.thread_buf_[0]),
782 rowsum_p.thread_buf_[0] = f_sum(bit_cast<SMPLComputeDataType>(swapped_regs.x),
783 bit_cast<SMPLComputeDataType>(swapped_regs.y));
793 constexpr
auto o_spans = decltype(o_acc)::get_distributed_spans();
796 const auto tmp =
ck_tile::exp2(scale_s * (m_old[i_idx] - m[i_idx]));
810 static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 2 == 0);
811 static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 2>{}([&](
auto idx) {
812 float x = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx]);
813 float y = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1]);
814 if constexpr(std::is_same_v<PDataType, fp16_t>)
817 sp(sp_reg_idx).p.thread_buf_[idx] = casted.x;
818 sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y;
823 sp(sp_reg_idx).p.thread_buf_[idx] = casted.x;
824 sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y;
833 auto gemm = [&](
auto sp_reg_idx,
auto gemm_idx) {
834 if constexpr(gemm_idx == 0)
837 gemm_0(sp(sp_reg_idx).sp_compute,
839 sequence<0, (k0_loops - 1) * kK0>{},
842 sequence<0, (k0_loops - 1) * kK0>{},
849 sequence<0, (k1_loops - 1) * kK1>{},
852 sequence<0, (k1_loops - 1) * kK1>{},
857 auto cl_calc = [&](
auto sp_reg_idx,
auto gemm_idx) {
858 if constexpr(gemm_idx == 0)
861 gemm_0(sp(sp_reg_idx).sp_compute,
863 sequence<0, (k0_loops - 1) * kK0>{},
866 sequence<0, (k0_loops - 1) * kK0>{},
873 sequence<0, (k1_loops - 1) * kK1>{},
876 sequence<0, (k1_loops - 1) * kK1>{},
882 auto fmha_alu_D_upd = [&] {
883 o_acc_scale =
ck_tile::exp2(scale_s * (m_old.thread_buf_[0] - m.thread_buf_[0]));
886 pk_o_acc_scale.x = o_acc_scale;
887 pk_o_acc_scale.y = o_acc_scale;
889 static_assert((o_acc.thread_buf_.size() - fmha_alu_D_reg_cnt) % 2 == 0);
890 #if CK_TILE_DISABLE_PACKED_FP32
891 static_assert(fmha_alu_D_reg_cnt + 2 <= o_acc.thread_buf_.size());
893 [&](
auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; });
896 constexpr
auto issued_D_reg_cnt =
897 #if CK_TILE_DISABLE_PACKED_FP32
898 fmha_alu_D_reg_cnt + 2
906 static_for<issued_D_reg_cnt, o_acc.thread_buf_.size(), 2>{}([&](
auto idx) {
908 input.x = o_acc.thread_buf_[idx];
909 input.y = o_acc.thread_buf_[idx + 1];
913 o_acc.thread_buf_[idx] = output.x;
914 o_acc.thread_buf_[idx + 1] = output.y;
918 auto fmha_mask = [&](
auto sp_reg_idx) {
919 if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
921 bool need_perpixel_check = mask.IsEdgeTile(
923 if(need_perpixel_check)
930 const auto col = kv_token_start + tile_idx.at(
number<1>{});
931 return mask.IsOutOfBound(row, col);
937 auto cl_load = [&](
auto load_type,
auto mem_wr_idx,
auto lds_rd_idx) {
938 if constexpr(load_type == 0)
940 V_mem_load(mem_wr_idx);
941 K_lds_load(lds_rd_idx);
945 K_mem_load(mem_wr_idx);
946 V_lds_load(lds_rd_idx);
950 auto core_loop = [&](
auto cl_p) {
959 auto iteration = [&](
auto pi) {
960 auto xdl_SP_p01_reg_idx =
number<1>{} - pi;
961 auto xdl_SP_p23_reg_idx = pi;
964 auto V_w0_lds_wr_idx = pi;
965 auto K_w0_lds_rd_idx = pi;
966 auto V_w0_lds_rd_idx = pi;
971 auto V_w4_lds_rd_idx = pi;
975 if constexpr(cl_p == 0)
977 #if ADD_SBARRIER_FOR_PHASE0
978 __builtin_amdgcn_sched_barrier(0);
979 __builtin_amdgcn_s_barrier();
981 __builtin_amdgcn_sched_barrier(0);
983 if constexpr(pi == 0)
991 s_waitcnt_lgkmcnt<0>();
992 __builtin_amdgcn_sched_barrier(0);
993 cl_calc(xdl_SP_p01_reg_idx, gemm0);
994 fmha_alu1(xdl_SP_p23_reg_idx);
997 __builtin_amdgcn_sched_barrier(0);
1000 s_waitcnt_vmcnt<K_mem_su_ld_insts + V_mem_su_ld_insts>();
1001 __builtin_amdgcn_sched_barrier(0);
1002 __builtin_amdgcn_s_barrier();
1003 __builtin_amdgcn_sched_barrier(0);
1004 cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx);
1006 fmha_mask(xdl_SP_p01_reg_idx);
1008 __builtin_amdgcn_sched_barrier(0);
1011 s_waitcnt_lgkmcnt<0>();
1012 __builtin_amdgcn_sched_barrier(0);
1013 __builtin_amdgcn_s_barrier();
1014 __builtin_amdgcn_sched_barrier(0);
1015 asm volatile(
"s_nop 0");
1016 __builtin_amdgcn_sched_barrier(0);
1017 cl_calc(xdl_SP_p23_reg_idx, gemm1);
1020 __builtin_amdgcn_sched_barrier(0);
1023 __builtin_amdgcn_sched_barrier(0);
1026 s_waitcnt_vmcnt<K_mem_su_ld_insts + V_mem_su_ld_insts>();
1027 __builtin_amdgcn_sched_barrier(0);
1028 __builtin_amdgcn_s_barrier();
1029 __builtin_amdgcn_sched_barrier(0);
1030 cl_load(memV, V_w0_lds_wr_idx, K_w0_lds_rd_idx);
1033 kv_token_start += kN0;
1034 if(num_total_loop <= ++i_total_loops)
1041 #if ADD_SBARRIER_FOR_PHASE0
1042 __builtin_amdgcn_sched_barrier(0);
1043 __builtin_amdgcn_s_barrier();
1045 __builtin_amdgcn_sched_barrier(0);
1047 if constexpr(pi == 0)
1055 cl_load(memV, V_w4_lds_wr_idx, K_w4_lds_rd_idx);
1058 __builtin_amdgcn_sched_barrier(0);
1061 s_waitcnt<K_mem_su_ld_insts + V_mem_su_ld_insts, 0>();
1062 __builtin_amdgcn_sched_barrier(0);
1063 __builtin_amdgcn_s_barrier();
1064 __builtin_amdgcn_sched_barrier(0);
1065 asm volatile(
"s_nop 1");
1066 __builtin_amdgcn_sched_barrier(0);
1067 cl_calc(xdl_SP_p01_reg_idx, gemm0);
1068 fmha_alu1(xdl_SP_p23_reg_idx);
1071 __builtin_amdgcn_sched_barrier(0);
1074 __builtin_amdgcn_s_barrier();
1075 __builtin_amdgcn_sched_barrier(0);
1076 cl_load(memK, K_w4_lds_wr_idx, V_w4_lds_rd_idx);
1078 fmha_mask(xdl_SP_p01_reg_idx);
1080 kv_token_start += kN0;
1081 if(num_total_loop <= ++i_total_loops)
1086 __builtin_amdgcn_sched_barrier(0);
1089 s_waitcnt<K_mem_su_ld_insts + V_mem_su_ld_insts, 0>();
1090 __builtin_amdgcn_sched_barrier(0);
1091 __builtin_amdgcn_s_barrier();
1092 __builtin_amdgcn_sched_barrier(0);
1093 asm volatile(
"s_nop 1");
1094 __builtin_amdgcn_sched_barrier(0);
1095 cl_calc(xdl_SP_p23_reg_idx, gemm1);
1098 __builtin_amdgcn_sched_barrier(0);
1106 auto fmha_post_process = [&](
auto d) {
1108 auto V_lds_rd_idx = ps_pi;
1110 if(1 < num_total_loop)
1112 s_waitcnt_vmcnt<K_mem_su_ld_insts>();
1116 s_waitcnt_vmcnt<0>();
1118 __builtin_amdgcn_s_barrier();
1120 V_lds_load(V_lds_rd_idx);
1123 s_waitcnt_lgkmcnt<0>();
1125 auto xdl_SP_p23_reg_idx = ps_pi;
1135 s_waitcnt_vmcnt<0>();
1136 __builtin_amdgcn_s_barrier();
1140 s_waitcnt_lgkmcnt<0>();
1141 __builtin_amdgcn_s_barrier();
1144 if(1 < num_total_loop)
1158 kv_token_start += kN0;
1160 if(num_total_loop <= i_total_loops)
1162 goto label_main_loops_exit;
1165 if(2 < num_total_loop)
1169 s_waitcnt_vmcnt<K_mem_su_ld_insts + V_mem_su_ld_insts>();
1170 __builtin_amdgcn_s_barrier();
1176 if(1 < num_total_loop)
1178 if(warp_group_id == 0)
1183 __builtin_amdgcn_s_setprio(0);
1184 __builtin_amdgcn_s_barrier();
1188 if(warp_group_id != 0)
1190 __builtin_amdgcn_s_setprio(1);
1191 __builtin_amdgcn_s_barrier();
1196 label_main_loops_exit:
1197 if(num_total_loop % 2)
1201 if(!(num_total_loop % 2))
1207 if constexpr(kStoreLSE)
1209 auto lse = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
1211 constexpr
auto lse_spans = decltype(lse)::get_distributed_spans();
1214 lse(i_idx) = m[i_idx] /
C_LOG2E +
log(l[i_idx]);
1221 constexpr
auto o_spans = decltype(o_acc)::get_distributed_spans();
1225 const auto tmp = [&]() {
1226 if constexpr(FmhaMask::IsMasking)
1228 return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
1231 return 1 / l[i_idx];
1234 constexpr
auto i_j_idx =
make_tuple(idx0, idx1);
1235 o_acc(i_j_idx) *= tmp;
1244 template <
typename QDramBlockWindowTmp,
1245 typename KDramBlockWindowTmp,
1246 typename VDramBlockWindowTmp,
1247 typename LSEDramBlockWindowTmp>
1249 const KDramBlockWindowTmp& k_dram_block_window_tmp,
1250 const VDramBlockWindowTmp& v_dram_block_window_tmp,
1251 LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
1254 void* smem_ptr)
const
1258 return operator()(q_dram_block_window_tmp,
1260 k_dram_block_window_tmp,
1262 v_dram_block_window_tmp,
1264 lse_dram_block_window_tmp,
#define ASM_MARKER(marker)
Definition: block_fmha_fwd_v3_pipeline.hpp:14
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
CK_TILE_DEVICE bf16x2_t cvt_pk_bf16_f32(float a, float b)
Definition: block_fmha_fwd_v3_pipeline.hpp:232
CK_TILE_DEVICE float fma_impl_vsv(float a, float b, float c)
Definition: block_fmha_fwd_v3_pipeline.hpp:192
CK_TILE_DEVICE float mul_impl_vv(float lhs, float rhs)
Definition: block_fmha_fwd_v3_pipeline.hpp:214
CK_TILE_DEVICE float add_impl_vv(float lhs, float rhs)
Definition: block_fmha_fwd_v3_pipeline.hpp:205
CK_TILE_DEVICE fp16x2_t cvt_pk_fp16_f32(float a, float b)
Definition: block_fmha_fwd_v3_pipeline.hpp:223
CK_TILE_DEVICE fp32x2_t pk_mul_f32(fp32x2_t lhs, fp32x2_t rhs)
Definition: block_fmha_fwd_v3_pipeline.hpp:241
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition: bfloat16.hpp:431
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:274
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
constexpr CK_TILE_DEVICE auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition: slice_tile.hpp:23
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition: tile_elementwise.hpp:40
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:95
CK_TILE_DEVICE void async_load_tile_raw(LdsTileWindow_ &&lds_tile, const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition: load_tile.hpp:184
bfloat16_t bf16x2_t
Definition: bfloat16.hpp:433
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={}, bool_constant< CrossWarp >={})
Definition: block_reduce.hpp:21
tuple_array< T, N > statically_indexed_array
Definition: statically_indexed_array.hpp:16
float fp32x2_t
Definition: bfloat16.hpp:434
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
_Float16 fp16x2_t
Definition: half.hpp:385
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition: static_distributed_tensor.hpp:185
constexpr CK_TILE_DEVICE auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition: tile_window_linear.hpp:994
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:24
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition: tile_elementwise.hpp:177
int32_t int32x2_t
Definition: vector_type.hpp:154
CK_TILE_DEVICE auto load_tile_transpose(const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &__restrict__ tile_window)
transpose loads tile from a tensor and returns the resulting tensor with a new (transposed) tile dist...
Definition: load_tile_transpose.hpp:486
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:36
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition: bfloat16.hpp:428
constexpr bool is_same_v
Definition: type.hpp:283
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1517
Definition: block_fmha_fwd_v3_pipeline.hpp:255
ck_tile::remove_cvref_t< typename Problem::PDataType > PDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:264
ck_tile::remove_cvref_t< typename Problem::SMPLComputeDataType > SMPLComputeDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:262
ck_tile::remove_cvref_t< typename Problem::KDataType > KDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:259
ck_tile::remove_cvref_t< typename Problem::ODataType > ODataType
Definition: block_fmha_fwd_v3_pipeline.hpp:266
ck_tile::remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition: block_fmha_fwd_v3_pipeline.hpp:273
static constexpr CK_TILE_DEVICE void s_waitcnt_lgkmcnt()
Definition: block_fmha_fwd_v3_pipeline.hpp:389
ck_tile::remove_cvref_t< typename Problem::OaccDataType > OaccDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:265
static constexpr CK_TILE_DEVICE auto MakeSimpleLdsDesc1D()
Definition: block_fmha_fwd_v3_pipeline.hpp:352
static constexpr CK_TILE_DEVICE void s_waitcnt()
Definition: block_fmha_fwd_v3_pipeline.hpp:373
ck_tile::remove_cvref_t< Problem_ > Problem
Definition: block_fmha_fwd_v3_pipeline.hpp:256
static constexpr CK_TILE_DEVICE void s_waitcnt_vmcnt()
Definition: block_fmha_fwd_v3_pipeline.hpp:383
static constexpr CK_TILE_DEVICE auto make_lds_tile_window(void *base, const Descriptor &desc)
Definition: block_fmha_fwd_v3_pipeline.hpp:362
ck_tile::remove_cvref_t< typename Problem::QDataType > QDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:258
ck_tile::remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:263
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: block_fmha_fwd_v3_pipeline.hpp:328
remove_cvref_t< typename BlockFmhaShape::VLayout > VLayout
Definition: block_fmha_fwd_v3_pipeline.hpp:275
ck_tile::remove_cvref_t< typename Problem::SaccDataType > SaccDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:261
ck_tile::remove_cvref_t< Policy_ > Policy
Definition: block_fmha_fwd_v3_pipeline.hpp:257
ck_tile::remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition: block_fmha_fwd_v3_pipeline.hpp:267
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const QElementFunction &q_element_func, const KDramBlockWindowTmp &k_dram_block_window_tmp, [[maybe_unused]] const KElementFunction &k_element_func, const VDramBlockWindowTmp &v_dram_block_window_tmp, [[maybe_unused]] const VElementFunction &v_element_func, LSEDramBlockWindowTmp &lse_dram_window_tmp, const LSEElementFunction &lse_element_func, [[maybe_unused]] const SAccElementFunction &s_acc_element_func, const PComputeElementFunction &p_compute_element_func, const OAccElementFunction &o_acc_element_func, FmhaMask mask, float scale_s, void *smem_ptr) const
Definition: block_fmha_fwd_v3_pipeline.hpp:405
ck_tile::remove_cvref_t< typename Problem::VDataType > VDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:260
static constexpr CK_TILE_DEVICE auto MakeSimpleLdsDesc()
Definition: block_fmha_fwd_v3_pipeline.hpp:338
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_block_window_tmp, FmhaMask mask, float scale_s, void *smem_ptr) const
Definition: block_fmha_fwd_v3_pipeline.hpp:1248
static constexpr CK_TILE_DEVICE void schedule(ck_tile::number< WaveGroup >, ck_tile::number< Phase >)
Definition: block_fmha_fwd_v3_pipeline.hpp:121
static constexpr CK_TILE_DEVICE void schedule(ck_tile::number< WaveGroup >, ck_tile::number< Phase >)
Definition: block_fmha_fwd_v3_pipeline.hpp:47
Definition: block_fmha_fwd_v3_pipeline.hpp:41
Definition: integral_constant.hpp:13
Definition: functional.hpp:86
Definition: numeric.hpp:18
Definition: coordinate_transform.hpp:1392
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
#define C_LOG2E
Definition: math.hpp:469