12 #define ENABLE_ASM_MARKER 1
14 #define ASM_MARKER(marker) \
15 __builtin_amdgcn_sched_barrier(0); \
16 asm volatile("; [POYENC] " #marker); \
17 __builtin_amdgcn_sched_barrier(0);
19 #define ASM_MARKER(marker)
22 #define ADD_SBARRIER_FOR_PHASE0 1
23 #if !defined(CK_TILE_DISABLE_PACKED_FP32)
24 #define CK_TILE_DISABLE_PACKED_FP32 0
30 #define ENABLE_DEBUG_STMTS 1
31 #if ENABLE_DEBUG_STMTS
33 if(get_block_1d_id() == 0 && get_warp_id() == WARP_ID && get_lane_id() == LANE_ID)
35 #define DEBUG_STMTS if constexpr(false)
40 template <
typename PipelineProblem,
bool kIsMasking>
43 template <
typename PipelineProblem>
46 template <ck_tile::index_t WaveGroup, ck_tile::index_t Phase>
52 if constexpr(WaveGroup == 0)
54 if constexpr(Phase == 0)
57 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
58 __builtin_amdgcn_sched_group_barrier(0x200, 2, 0);
59 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
62 else if constexpr(Phase == 1)
64 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
65 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0);
67 else if constexpr(Phase == 2)
69 #if !CK_TILE_DISABLE_PACKED_FP32
70 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0);
73 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
74 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0);
77 else if constexpr(Phase == 3)
79 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
80 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0);
85 if constexpr(Phase == 0)
87 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
88 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0);
90 else if constexpr(Phase == 1)
93 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
94 __builtin_amdgcn_sched_group_barrier(0x200, 2, 0);
95 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
98 else if constexpr(Phase == 2)
100 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
101 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0);
103 else if constexpr(Phase == 3)
105 #if !CK_TILE_DISABLE_PACKED_FP32
106 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0);
109 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
110 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0);
117 template <
typename PipelineProblem>
120 template <ck_tile::index_t WaveGroup, ck_tile::index_t Phase>
126 if constexpr(WaveGroup == 0)
128 if constexpr(Phase == 0)
131 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
132 __builtin_amdgcn_sched_group_barrier(0x200, 2, 0);
133 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
136 else if constexpr(Phase == 1)
138 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
139 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0);
141 else if constexpr(Phase == 2)
143 #if !CK_TILE_DISABLE_PACKED_FP32
144 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0);
147 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
148 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0);
151 else if constexpr(Phase == 3)
153 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
154 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0);
159 if constexpr(Phase == 0)
161 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
162 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0);
164 else if constexpr(Phase == 1)
167 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
168 __builtin_amdgcn_sched_group_barrier(0x200, 2, 0);
169 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
172 else if constexpr(Phase == 2)
174 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
175 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0);
177 else if constexpr(Phase == 3)
179 #if !CK_TILE_DISABLE_PACKED_FP32
180 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0);
183 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
184 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0);
194 #if CK_TILE_DISABLE_PACKED_FP32
198 asm volatile(
"v_fma_f32 %[result], %[a], %[b], %[c]"
199 : [result]
"=v"(result)
200 : [
a]
"v"(
a), [b]
"s"(b), [c]
"v"(c));
208 asm volatile(
"v_add_f32_e32 %[result], %[lhs], %[rhs]"
209 : [result]
"=v"(result)
210 : [lhs]
"v"(lhs), [rhs]
"v"(rhs));
217 asm volatile(
"v_mul_f32_e32 %[result], %[lhs], %[rhs]"
218 : [result]
"=v"(result)
219 : [lhs]
"v"(lhs), [rhs]
"v"(rhs));
226 asm volatile(
"v_cvt_pk_f16_f32 %[result], %[a], %[b]"
227 : [result]
"=v"(result)
228 : [
a]
"v"(
a), [b]
"v"(b));
235 asm volatile(
"v_cvt_pk_bf16_f32 %[result], %[a], %[b]"
236 : [result]
"=v"(result)
237 : [
a]
"v"(
a), [b]
"v"(b));
244 asm volatile(
"v_pk_mul_f32 %[result], %[lhs], %[rhs]"
245 : [result]
"=v"(result)
246 : [lhs]
"v"(lhs), [rhs]
"v"(rhs));
253 template <
typename Problem_,
typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
269 static_assert(is_generic_attention_mask_v<FmhaMask>);
271 static_assert(std::is_same_v<SaccDataType, SMPLComputeDataType>,
272 "we will the same dist tensor 'sp_compute' for both gemm0 & softmax");
277 static_assert(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>);
289 static_assert(kQKHeaddim == 128 && kSubQKHeaddim == 128,
"only supports hdim=hdim_v=128");
291 static constexpr
bool kIsGroupMode = Problem::kIsGroupMode;
292 static constexpr
bool kPadSeqLenQ = Problem::kPadSeqLenQ;
293 static constexpr
bool kPadSeqLenK = Problem::kPadSeqLenK;
294 static constexpr
bool kPadHeadDimQ = Problem::kPadHeadDimQ;
295 static constexpr
bool kPadHeadDimV = Problem::kPadHeadDimV;
296 static constexpr
bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
297 static constexpr
auto BiasEnum = Problem::BiasEnum;
298 static constexpr
bool kStoreLSE = Problem::kStoreLSE;
299 static constexpr
bool kHasDropout = Problem::kHasDropout;
300 static constexpr
auto QScaleEnum = Problem::QScaleEnum;
301 static constexpr
bool kSkipMinSeqlenQ = Problem::kSkipMinSeqlenQ;
305 "enable unsupported features");
310 kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
312 kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
314 kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
317 kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
320 if constexpr(Problem::kBlockPerCu != -1)
321 return Problem::kBlockPerCu;
332 Policy::template GetSmemSize<Problem>() +
337 template <ck_tile::index_t MPerBlock, ck_tile::index_t NPerBlock>
341 constexpr
auto lds_block_desc =
347 return lds_block_desc;
351 template <ck_tile::index_t MPerBlock>
358 return lds_block_desc;
361 template <
typename DataType,
typename Descriptor>
367 make_tensor_view<address_space_enum::lds>(
reinterpret_cast<DataType*
>(base), desc);
372 template <u
int16_t Vmcnt, u
int8_t Lgkmcnt, u
int8_t Expcnt = 7>
378 __builtin_amdgcn_s_waitcnt((((0b110000 & Vmcnt) << (14 - 4)) | (0b1111 & Vmcnt)) |
379 ((0b111 & Expcnt) << 4) | ((0b1111 & Lgkmcnt) << 8));
382 template <u
int16_t Vmcnt>
385 s_waitcnt<Vmcnt, 15>();
388 template <u
int8_t Lgkmcnt>
391 s_waitcnt<63, Lgkmcnt>();
394 template <
typename QDramBlockWindowTmp,
395 typename KDramBlockWindowTmp,
396 typename VDramBlockWindowTmp,
397 typename LSEDramBlockWindowTmp,
398 typename QElementFunction,
399 typename KElementFunction,
400 typename VElementFunction,
401 typename LSEElementFunction,
402 typename SAccElementFunction,
403 typename PComputeElementFunction,
404 typename OAccElementFunction,
405 typename AttentionVariantParams,
406 typename BlockIndices>
408 const QElementFunction& q_element_func,
409 const KDramBlockWindowTmp& k_dram_block_window_tmp,
410 [[maybe_unused]]
const KElementFunction& k_element_func,
411 const VDramBlockWindowTmp& v_dram_block_window_tmp,
412 [[maybe_unused]]
const VElementFunction& v_element_func,
413 LSEDramBlockWindowTmp& lse_dram_window_tmp,
414 const LSEElementFunction& lse_element_func,
415 [[maybe_unused]]
const SAccElementFunction& s_acc_element_func,
416 const PComputeElementFunction& p_compute_element_func,
417 const OAccElementFunction& o_acc_element_func,
421 const AttentionVariantParams& variant_params,
422 const BlockIndices& block_indices,
423 void* smem_ptr)
const
433 static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[
number<0>{}] &&
434 kN0 == KDramBlockWindowTmp{}.get_window_lengths()[
number<0>{}] &&
435 kK0 == KDramBlockWindowTmp{}.get_window_lengths()[
number<1>{}] &&
436 kK1 == VDramBlockWindowTmp{}.get_window_lengths()[
number<0>{}] &&
437 kN1 == VDramBlockWindowTmp{}.get_window_lengths()[
number<1>{}],
440 static_assert(
sizeof(
SaccDataType) * kM0 * kN0 <= GetSmemSize());
441 auto s_lds = make_tensor_view<address_space_enum::lds>(
442 reinterpret_cast<SaccDataType*
>(
static_cast<char*
>(smem_ptr)),
443 MakeSimpleLdsDesc<kM0, kN0>());
444 [[maybe_unused]]
auto s_lds_window =
447 auto p_lds = make_tensor_view<address_space_enum::lds>(
448 reinterpret_cast<PDataType*
>(
static_cast<char*
>(smem_ptr) +
449 Policy::template GetSmemSize<Problem>()),
450 MakeSimpleLdsDesc<kM0, kN0>());
451 [[maybe_unused]]
auto p_lds_window =
454 auto o_lds = make_tensor_view<address_space_enum::lds>(
455 reinterpret_cast<PDataType*
>(
static_cast<char*
>(smem_ptr)),
456 MakeSimpleLdsDesc<kM0, kN1>());
457 [[maybe_unused]]
auto o_lds_window =
460 auto m_lds = make_tensor_view<address_space_enum::lds>(
462 Policy::template GetSmemSize<Problem>()),
463 MakeSimpleLdsDesc1D<kM0>());
464 [[maybe_unused]]
auto m_lds_window =
467 const index_t warp_group_id = get_warp_id() / 4;
470 constexpr
auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
471 constexpr
auto gemm_1 = Policy::template GetPVBlockGemm<Problem>();
474 q_dram_block_window_tmp, Policy::template MakeQRegTileDistribution<Problem>());
477 const auto f_max = [](
auto e0,
auto e1) {
return max(e0, e1); };
478 const auto f_sum = [](
auto e0,
auto e1) {
return e0 + e1; };
482 return make_lds_tile_window<KDataType>(
483 smem_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf));
489 return make_lds_tile_window<KDataType>(
490 smem_ptr, Policy::template MakeVLdsStoreBlockDescriptor<Problem>(i_buf));
495 make_lds_tile_window<KDataType>(
497 Policy::template MakeKLdsLoadBlockDescriptor<Problem>()),
498 Policy::template MakeKRegTileDistribution<Problem>())),
503 make_lds_tile_window<VDataType>(
505 Policy::template MakeVLdsLoadBlockDescriptor<Problem>()),
506 Policy::template MakeVRegTileDistribution<Problem>())),
510 decltype(make_static_distributed_tensor<QDataType>(
511 Policy::template MakeQRegTileDistribution<Problem>())) q_tile;
522 union sp_compute_type
526 decltype(gemm_0.MakeCBlockTile()) sp_compute;
527 decltype(make_static_distributed_tensor<PDataType>(
528 Policy::template MakePRegTileDistribution<Problem>())) p;
532 decltype(gemm_1.MakeCBlockTile()) o_acc;
533 constexpr
index_t fmha_alu_D_reg_cnt = 6;
535 static_assert(fmha_alu_D_reg_cnt <= o_acc.thread_buf_.size());
537 decltype(block_tile_reduce<SMPLComputeDataType>(
544 make_lds_tile_window<KDataType>(
545 static_cast<char*
>(smem_ptr) + (idx)*Policy::template GetSmemSizeKV<Problem>(),
546 Policy::template MakeKLdsLoadBlockDescriptor<Problem>()),
547 Policy::template MakeKRegTileDistribution<Problem>());
551 v_lds_window_load(idx) =
553 static_cast<char*
>(smem_ptr) +
554 (idx + 2) * Policy::template GetSmemSizeKV<Problem>(),
555 Policy::template MakeVLdsLoadBlockDescriptor<Problem>()),
556 Policy::template MakeVRegTileDistribution<Problem>());
560 auto origin_q =
load_tile(q_dram_window);
563 q_tile = transformed_q;
567 set_tile(m, bit_cast<float>(0xff7fffff));
570 const auto q_origin = q_dram_window.get_window_origin();
571 const auto [seqlen_k_start, seqlen_k_end] =
575 index_t kv_token_start = seqlen_k_start;
578 if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
580 if(num_total_loop <= 0)
582 if constexpr(kStoreLSE)
585 make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
600 k_dram_block_window_tmp.get_window_lengths(),
602 Policy::template MakeKDramTileDistribution<Problem>());
603 k_dram_window.init_raw();
607 v_dram_block_window_tmp.get_window_lengths(),
609 Policy::template MakeVDramTileDistribution<Problem>());
610 v_dram_window.init_raw();
614 constexpr
index_t k0_loops = kQKHeaddim / kK0;
615 constexpr
index_t k1_loops = kN0 / kK1;
616 static_assert(1 == k0_loops);
617 static_assert(1 == k1_loops);
618 static_assert(kN0 == kK1);
620 constexpr
index_t NumWarpGroups = Problem::kBlockSize / Policy::NumThreadPerWarpGroup;
621 static_assert(NumWarpGroups == 2);
623 [[maybe_unused]]
auto print_dist_tensor = [&](
const auto& dist_tensor,
const char* name) {
624 printf(
"[POYENC] %s (size=%d): %5.2f",
626 decltype(dist_tensor.thread_buf_)::size(),
627 ck_tile::type_convert<float>(dist_tensor.thread_buf_[0]));
628 static_for<1, decltype(dist_tensor.thread_buf_)::size(), 1>{}([&](
auto i) {
629 printf(
", %5.2f", ck_tile::type_convert<float>(dist_tensor.thread_buf_[i]));
634 [[maybe_unused]]
auto print_lds = [&](
auto lds_tile_window,
const char* name) {
635 const auto num_rows = lds_tile_window.get_window_lengths().at(
number<0>{});
636 const auto num_cols = lds_tile_window.get_window_lengths().at(
number<1>{});
638 auto desc = lds_tile_window.get_bottom_tensor_view().desc_;
639 auto data = lds_tile_window.get_bottom_tensor_view().buf_.p_data_;
641 if constexpr(
true || num_rows < num_cols)
643 for(
int row = 0; row < num_rows; ++row)
646 printf(
"[DEVICE] %s[%3d] = %5.2f",
649 ck_tile::type_convert<float>(data[
offset]));
650 for(
int col = 1; col < num_cols; ++col)
654 printf(
"%5.2f", ck_tile::type_convert<float>(data[
offset]));
661 for(
int col = 0; col < num_cols; ++col)
664 printf(
"[DEVICE] %s[%3d] = %5.2f",
667 ck_tile::type_convert<float>(data[
offset]));
668 for(
int row = 1; row < num_rows; ++row)
672 printf(
"%5.2f", ck_tile::type_convert<float>(data[
offset]));
679 [[maybe_unused]]
auto print_lds_1d = [&](
auto lds_tile_window,
const char* name) {
680 const auto num_elems = lds_tile_window.get_window_lengths().at(
number<0>{});
682 auto desc = lds_tile_window.get_bottom_tensor_view().desc_;
683 auto data = lds_tile_window.get_bottom_tensor_view().buf_.p_data_;
686 printf(
"[DEVICE] %s = %5.2f", name, ck_tile::type_convert<float>(data[
offset]));
687 for(
int e = 1; e < num_elems; ++e)
691 printf(
"%5.2f", ck_tile::type_convert<float>(data[
offset]));
698 constexpr
int K_mem_su_ld_insts = k_dram_window.get_num_of_access();
699 constexpr
int V_mem_su_ld_insts = v_dram_window.get_num_of_access();
701 auto K_mem_load = [&](
auto k_lds_write_idx) {
709 auto K_lds_load = [&](
auto k_lds_read_idx) {
710 kv_tile.k_tile =
load_tile(k_lds_window_load(k_lds_read_idx));
713 auto V_mem_load = [&](
auto v_lds_write_idx) {
720 auto V_lds_load = [&](
auto v_lds_read_idx) {
729 auto fmha_logits_trans = [&](
auto sp_reg_idx) {
730 if constexpr(kHasLogitsSoftCap)
732 auto apply_logits_transform = [&variant, &variant_params, &block_indices](
734 logits = variant.LogitsTransform(variant_params,
735 variant.QueryTransform(variant_params, logits),
736 block_indices.batch_idx,
737 block_indices.qo_head_idx,
738 block_indices.kv_head_idx);
745 auto fmha_alu0 = [&](
auto sp_reg_idx) {
747 static_assert(m.thread_buf_.size() == 1,
748 "assuming that each thread holds 1 rowmax value");
749 auto m_latest = block_tile_reduce<SMPLComputeDataType>(
750 sp(sp_reg_idx).sp_compute,
sequence<1>{}, f_max, m.thread_buf_[0]);
751 #if defined(__gfx950__)
754 __builtin_amdgcn_permlane32_swap(bit_cast<int32_t>(m_latest.thread_buf_[0]),
755 bit_cast<int32_t>(m_latest.thread_buf_[0]),
759 m_latest.thread_buf_[0] = f_max(bit_cast<SMPLComputeDataType>(swapped_regs.x),
760 bit_cast<SMPLComputeDataType>(swapped_regs.y));
766 constexpr
auto p_spans =
767 std::decay_t<decltype(sp(sp_reg_idx).sp_compute)>::get_distributed_spans();
770 constexpr
auto i_j_idx =
make_tuple(idx0, idx1);
771 if constexpr(kHasLogitsSoftCap)
773 sp_delta(sp_reg_idx)(i_j_idx) =
774 sp(sp_reg_idx).sp_compute(i_j_idx) - m(i_j_idx);
779 sp(sp_reg_idx).sp_compute(i_j_idx), scale_s, -scale_s * m(i_j_idx));
786 auto fmha_alu1 = [&](
auto sp_reg_idx) {
787 constexpr
auto p_spans =
788 std::decay_t<decltype(sp(sp_reg_idx).sp_compute)>::get_distributed_spans();
791 constexpr
auto i_j_idx =
make_tuple(idx0, idx1);
792 sp(sp_reg_idx).sp_compute(i_j_idx) =
797 auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
798 sp(sp_reg_idx).sp_compute,
802 static_assert(rowsum_p.thread_buf_.size() == 1,
803 "assuming that each thread holds 1 rowsum value");
804 #if defined(__gfx950__)
807 __builtin_amdgcn_permlane32_swap(bit_cast<int32_t>(rowsum_p.thread_buf_[0]),
808 bit_cast<int32_t>(rowsum_p.thread_buf_[0]),
811 rowsum_p.thread_buf_[0] = f_sum(bit_cast<SMPLComputeDataType>(swapped_regs.x),
812 bit_cast<SMPLComputeDataType>(swapped_regs.y));
822 constexpr
auto o_spans = decltype(o_acc)::get_distributed_spans();
825 const auto tmp = [&] {
826 if constexpr(kHasLogitsSoftCap)
847 static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 2 == 0);
848 static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 2>{}([&](
auto idx) {
849 float x = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx]);
850 float y = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1]);
851 if constexpr(std::is_same_v<PDataType, fp16_t>)
854 sp(sp_reg_idx).p.thread_buf_[idx] = casted.x;
855 sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y;
860 sp(sp_reg_idx).p.thread_buf_[idx] = casted.x;
861 sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y;
870 auto gemm = [&](
auto sp_reg_idx,
auto gemm_idx) {
871 if constexpr(gemm_idx == 0)
874 gemm_0(sp(sp_reg_idx).sp_compute,
876 sequence<0, (k0_loops - 1) * kK0>{},
879 sequence<0, (k0_loops - 1) * kK0>{},
886 sequence<0, (k1_loops - 1) * kK1>{},
889 sequence<0, (k1_loops - 1) * kK1>{},
894 auto cl_calc = [&](
auto sp_reg_idx,
auto gemm_idx) {
895 if constexpr(gemm_idx == 0)
898 gemm_0(sp(sp_reg_idx).sp_compute,
900 sequence<0, (k0_loops - 1) * kK0>{},
903 sequence<0, (k0_loops - 1) * kK0>{},
910 sequence<0, (k1_loops - 1) * kK1>{},
913 sequence<0, (k1_loops - 1) * kK1>{},
919 auto fmha_alu_D_upd = [&] {
921 if constexpr(kHasLogitsSoftCap)
923 return ck_tile::exp2(m_old.thread_buf_[0] - m.thread_buf_[0]);
927 return ck_tile::exp2(scale_s * (m_old.thread_buf_[0] - m.thread_buf_[0]));
932 pk_o_acc_scale.x = o_acc_scale;
933 pk_o_acc_scale.y = o_acc_scale;
935 static_assert((o_acc.thread_buf_.size() - fmha_alu_D_reg_cnt) % 2 == 0);
936 #if CK_TILE_DISABLE_PACKED_FP32
937 static_assert(fmha_alu_D_reg_cnt + 2 <= o_acc.thread_buf_.size());
939 [&](
auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; });
942 constexpr
auto issued_D_reg_cnt =
943 #if CK_TILE_DISABLE_PACKED_FP32
944 fmha_alu_D_reg_cnt + 2
952 static_for<issued_D_reg_cnt, o_acc.thread_buf_.size(), 2>{}([&](
auto idx) {
954 input.x = o_acc.thread_buf_[idx];
955 input.y = o_acc.thread_buf_[idx + 1];
959 o_acc.thread_buf_[idx] = output.x;
960 o_acc.thread_buf_[idx + 1] = output.y;
964 auto fmha_mask = [&](
auto sp_reg_idx) {
965 if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
967 bool need_perpixel_check = mask.IsEdgeTile(
969 if(need_perpixel_check)
976 const auto col = kv_token_start + tile_idx.at(
number<1>{});
977 return !variant.LogitsMask(variant_params,
978 block_indices.batch_idx,
981 block_indices.qo_head_idx,
982 block_indices.kv_head_idx);
988 auto cl_load = [&](
auto load_type,
auto mem_wr_idx,
auto lds_rd_idx) {
989 if constexpr(load_type == 0)
991 V_mem_load(mem_wr_idx);
992 K_lds_load(lds_rd_idx);
996 K_mem_load(mem_wr_idx);
997 V_lds_load(lds_rd_idx);
1001 auto core_loop = [&](
auto cl_p) {
1010 auto iteration = [&](
auto pi) {
1011 auto xdl_SP_p01_reg_idx =
number<1>{} - pi;
1012 auto xdl_SP_p23_reg_idx = pi;
1014 auto K_w0_lds_wr_idx =
number<1>{} - pi;
1015 auto V_w0_lds_wr_idx = pi;
1016 auto K_w0_lds_rd_idx = pi;
1017 auto V_w0_lds_rd_idx = pi;
1019 auto K_w4_lds_wr_idx =
number<1>{} - pi;
1020 auto V_w4_lds_wr_idx =
number<1>{} - pi;
1021 auto K_w4_lds_rd_idx =
number<1>{} - pi;
1022 auto V_w4_lds_rd_idx = pi;
1026 if constexpr(cl_p == 0)
1028 #if ADD_SBARRIER_FOR_PHASE0
1029 __builtin_amdgcn_sched_barrier(0);
1030 __builtin_amdgcn_s_barrier();
1032 __builtin_amdgcn_sched_barrier(0);
1034 if constexpr(pi == 0)
1042 s_waitcnt_lgkmcnt<0>();
1043 __builtin_amdgcn_sched_barrier(0);
1044 cl_calc(xdl_SP_p01_reg_idx, gemm0);
1045 fmha_alu1(xdl_SP_p23_reg_idx);
1046 fmha_logits_trans(xdl_SP_p01_reg_idx);
1049 __builtin_amdgcn_sched_barrier(0);
1052 s_waitcnt_vmcnt<K_mem_su_ld_insts + V_mem_su_ld_insts>();
1053 __builtin_amdgcn_sched_barrier(0);
1054 __builtin_amdgcn_s_barrier();
1055 __builtin_amdgcn_sched_barrier(0);
1056 cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx);
1058 fmha_mask(xdl_SP_p01_reg_idx);
1060 __builtin_amdgcn_sched_barrier(0);
1063 s_waitcnt_lgkmcnt<0>();
1064 __builtin_amdgcn_sched_barrier(0);
1065 __builtin_amdgcn_s_barrier();
1066 __builtin_amdgcn_sched_barrier(0);
1067 asm volatile(
"s_nop 0");
1068 __builtin_amdgcn_sched_barrier(0);
1069 cl_calc(xdl_SP_p23_reg_idx, gemm1);
1072 __builtin_amdgcn_sched_barrier(0);
1075 __builtin_amdgcn_sched_barrier(0);
1078 s_waitcnt_vmcnt<K_mem_su_ld_insts + V_mem_su_ld_insts>();
1079 __builtin_amdgcn_sched_barrier(0);
1080 __builtin_amdgcn_s_barrier();
1081 __builtin_amdgcn_sched_barrier(0);
1082 cl_load(memV, V_w0_lds_wr_idx, K_w0_lds_rd_idx);
1085 kv_token_start += kN0;
1086 if(num_total_loop <= ++i_total_loops)
1093 #if ADD_SBARRIER_FOR_PHASE0
1094 __builtin_amdgcn_sched_barrier(0);
1095 __builtin_amdgcn_s_barrier();
1097 __builtin_amdgcn_sched_barrier(0);
1099 if constexpr(pi == 0)
1107 cl_load(memV, V_w4_lds_wr_idx, K_w4_lds_rd_idx);
1110 __builtin_amdgcn_sched_barrier(0);
1113 s_waitcnt<K_mem_su_ld_insts + V_mem_su_ld_insts, 0>();
1114 __builtin_amdgcn_sched_barrier(0);
1115 __builtin_amdgcn_s_barrier();
1116 __builtin_amdgcn_sched_barrier(0);
1117 asm volatile(
"s_nop 1");
1118 __builtin_amdgcn_sched_barrier(0);
1119 cl_calc(xdl_SP_p01_reg_idx, gemm0);
1120 fmha_alu1(xdl_SP_p23_reg_idx);
1121 fmha_logits_trans(xdl_SP_p01_reg_idx);
1124 __builtin_amdgcn_sched_barrier(0);
1127 __builtin_amdgcn_s_barrier();
1128 __builtin_amdgcn_sched_barrier(0);
1129 cl_load(memK, K_w4_lds_wr_idx, V_w4_lds_rd_idx);
1131 fmha_mask(xdl_SP_p01_reg_idx);
1133 kv_token_start += kN0;
1134 if(num_total_loop <= ++i_total_loops)
1139 __builtin_amdgcn_sched_barrier(0);
1142 s_waitcnt<K_mem_su_ld_insts + V_mem_su_ld_insts, 0>();
1143 __builtin_amdgcn_sched_barrier(0);
1144 __builtin_amdgcn_s_barrier();
1145 __builtin_amdgcn_sched_barrier(0);
1146 asm volatile(
"s_nop 1");
1147 __builtin_amdgcn_sched_barrier(0);
1148 cl_calc(xdl_SP_p23_reg_idx, gemm1);
1151 __builtin_amdgcn_sched_barrier(0);
1159 auto fmha_post_process = [&](
auto d) {
1161 auto V_lds_rd_idx = ps_pi;
1163 if(1 < num_total_loop)
1165 s_waitcnt_vmcnt<K_mem_su_ld_insts>();
1169 s_waitcnt_vmcnt<0>();
1171 __builtin_amdgcn_s_barrier();
1173 V_lds_load(V_lds_rd_idx);
1176 s_waitcnt_lgkmcnt<0>();
1178 auto xdl_SP_p23_reg_idx = ps_pi;
1188 s_waitcnt_vmcnt<0>();
1189 __builtin_amdgcn_s_barrier();
1193 s_waitcnt_lgkmcnt<0>();
1194 __builtin_amdgcn_s_barrier();
1197 if(1 < num_total_loop)
1211 kv_token_start += kN0;
1213 if(num_total_loop <= i_total_loops)
1215 goto label_main_loops_exit;
1218 if(2 < num_total_loop)
1222 s_waitcnt_vmcnt<K_mem_su_ld_insts + V_mem_su_ld_insts>();
1223 __builtin_amdgcn_s_barrier();
1229 if(1 < num_total_loop)
1231 if(warp_group_id == 0)
1236 __builtin_amdgcn_s_setprio(0);
1237 __builtin_amdgcn_s_barrier();
1241 if(warp_group_id != 0)
1243 __builtin_amdgcn_s_setprio(1);
1244 __builtin_amdgcn_s_barrier();
1249 label_main_loops_exit:
1250 if(num_total_loop % 2)
1254 if(!(num_total_loop % 2))
1260 if constexpr(kStoreLSE)
1262 auto lse = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
1264 constexpr
auto lse_spans = decltype(lse)::get_distributed_spans();
1267 lse(i_idx) = m[i_idx] /
C_LOG2E +
log(l[i_idx]);
1274 constexpr
auto o_spans = decltype(o_acc)::get_distributed_spans();
1278 const auto tmp = [&]() {
1279 if constexpr(FmhaMask::IsMasking)
1281 return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
1284 return 1 / l[i_idx];
1287 constexpr
auto i_j_idx =
make_tuple(idx0, idx1);
1288 o_acc(i_j_idx) *= tmp;
1297 template <
typename QDramBlockWindowTmp,
1298 typename KDramBlockWindowTmp,
1299 typename VDramBlockWindowTmp,
1300 typename LSEDramBlockWindowTmp,
1301 typename AttentionVariantParams,
1302 typename BlockIndices>
1304 const KDramBlockWindowTmp& k_dram_block_window_tmp,
1305 const VDramBlockWindowTmp& v_dram_block_window_tmp,
1306 LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
1310 const AttentionVariantParams& variant_params,
1311 const BlockIndices& block_indices,
1312 void* smem_ptr)
const
1316 return operator()(q_dram_block_window_tmp,
1318 k_dram_block_window_tmp,
1320 v_dram_block_window_tmp,
1322 lse_dram_block_window_tmp,
#define ASM_MARKER(marker)
Definition: block_fmha_fwd_v3_pipeline.hpp:14
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
CK_TILE_DEVICE bf16x2_t cvt_pk_bf16_f32(float a, float b)
Definition: block_fmha_fwd_v3_pipeline.hpp:232
CK_TILE_DEVICE float fma_impl_vsv(float a, float b, float c)
Definition: block_fmha_fwd_v3_pipeline.hpp:192
CK_TILE_DEVICE float mul_impl_vv(float lhs, float rhs)
Definition: block_fmha_fwd_v3_pipeline.hpp:214
CK_TILE_DEVICE float add_impl_vv(float lhs, float rhs)
Definition: block_fmha_fwd_v3_pipeline.hpp:205
CK_TILE_DEVICE fp16x2_t cvt_pk_fp16_f32(float a, float b)
Definition: block_fmha_fwd_v3_pipeline.hpp:223
CK_TILE_DEVICE fp32x2_t pk_mul_f32(fp32x2_t lhs, fp32x2_t rhs)
Definition: block_fmha_fwd_v3_pipeline.hpp:241
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition: bfloat16.hpp:431
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:274
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:145
constexpr CK_TILE_DEVICE auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition: slice_tile.hpp:23
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition: tile_elementwise.hpp:40
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:95
CK_TILE_DEVICE void async_load_tile_raw(LdsTileWindow_ &&lds_tile, const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition: load_tile.hpp:184
bfloat16_t bf16x2_t
Definition: bfloat16.hpp:433
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={}, bool_constant< CrossWarp >={})
Definition: block_reduce.hpp:21
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition: tile_elementwise.hpp:23
tuple_array< T, N > statically_indexed_array
Definition: statically_indexed_array.hpp:16
float fp32x2_t
Definition: bfloat16.hpp:434
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
_Float16 fp16x2_t
Definition: half.hpp:385
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition: static_distributed_tensor.hpp:185
constexpr CK_TILE_DEVICE auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition: tile_window_linear.hpp:994
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:24
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition: tile_elementwise.hpp:177
int32_t int32x2_t
Definition: vector_type.hpp:154
CK_TILE_DEVICE auto load_tile_transpose(const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &__restrict__ tile_window)
transpose loads tile from a tensor and returns the resulting tensor with a new (transposed) tile dist...
Definition: load_tile_transpose.hpp:486
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:157
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:36
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition: bfloat16.hpp:428
constexpr bool is_same_v
Definition: type.hpp:283
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1517
Definition: block_fmha_fwd_v3_pipeline.hpp:255
ck_tile::remove_cvref_t< typename Problem::AttentionVariant > AttentionVariant
Definition: block_fmha_fwd_v3_pipeline.hpp:267
ck_tile::remove_cvref_t< typename Problem::PDataType > PDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:264
ck_tile::remove_cvref_t< typename Problem::SMPLComputeDataType > SMPLComputeDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:262
ck_tile::remove_cvref_t< typename Problem::KDataType > KDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:259
ck_tile::remove_cvref_t< typename Problem::ODataType > ODataType
Definition: block_fmha_fwd_v3_pipeline.hpp:266
ck_tile::remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition: block_fmha_fwd_v3_pipeline.hpp:274
static constexpr CK_TILE_DEVICE void s_waitcnt_lgkmcnt()
Definition: block_fmha_fwd_v3_pipeline.hpp:389
ck_tile::remove_cvref_t< typename Problem::OaccDataType > OaccDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:265
static constexpr CK_TILE_DEVICE auto MakeSimpleLdsDesc1D()
Definition: block_fmha_fwd_v3_pipeline.hpp:352
static constexpr CK_TILE_DEVICE void s_waitcnt()
Definition: block_fmha_fwd_v3_pipeline.hpp:373
ck_tile::remove_cvref_t< Problem_ > Problem
Definition: block_fmha_fwd_v3_pipeline.hpp:256
static constexpr CK_TILE_DEVICE void s_waitcnt_vmcnt()
Definition: block_fmha_fwd_v3_pipeline.hpp:383
static constexpr CK_TILE_DEVICE auto make_lds_tile_window(void *base, const Descriptor &desc)
Definition: block_fmha_fwd_v3_pipeline.hpp:362
ck_tile::remove_cvref_t< typename Problem::QDataType > QDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:258
ck_tile::remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:263
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: block_fmha_fwd_v3_pipeline.hpp:328
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const QElementFunction &q_element_func, const KDramBlockWindowTmp &k_dram_block_window_tmp, [[maybe_unused]] const KElementFunction &k_element_func, const VDramBlockWindowTmp &v_dram_block_window_tmp, [[maybe_unused]] const VElementFunction &v_element_func, LSEDramBlockWindowTmp &lse_dram_window_tmp, const LSEElementFunction &lse_element_func, [[maybe_unused]] const SAccElementFunction &s_acc_element_func, const PComputeElementFunction &p_compute_element_func, const OAccElementFunction &o_acc_element_func, FmhaMask mask, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, void *smem_ptr) const
Definition: block_fmha_fwd_v3_pipeline.hpp:407
remove_cvref_t< typename BlockFmhaShape::VLayout > VLayout
Definition: block_fmha_fwd_v3_pipeline.hpp:276
ck_tile::remove_cvref_t< typename Problem::SaccDataType > SaccDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:261
ck_tile::remove_cvref_t< Policy_ > Policy
Definition: block_fmha_fwd_v3_pipeline.hpp:257
ck_tile::remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition: block_fmha_fwd_v3_pipeline.hpp:268
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_block_window_tmp, FmhaMask mask, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, void *smem_ptr) const
Definition: block_fmha_fwd_v3_pipeline.hpp:1303
ck_tile::remove_cvref_t< typename Problem::VDataType > VDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:260
static constexpr CK_TILE_DEVICE auto MakeSimpleLdsDesc()
Definition: block_fmha_fwd_v3_pipeline.hpp:338
static constexpr CK_TILE_DEVICE void schedule(ck_tile::number< WaveGroup >, ck_tile::number< Phase >)
Definition: block_fmha_fwd_v3_pipeline.hpp:121
static constexpr CK_TILE_DEVICE void schedule(ck_tile::number< WaveGroup >, ck_tile::number< Phase >)
Definition: block_fmha_fwd_v3_pipeline.hpp:47
Definition: block_fmha_fwd_v3_pipeline.hpp:41
Definition: integral_constant.hpp:13
Definition: functional.hpp:114
Definition: numeric.hpp:18
Definition: coordinate_transform.hpp:1392
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
#define C_LOG2E
Definition: math.hpp:462