10 #define ENABLE_ASM_MARKER 1
12 #define ASM_MARKER(marker) \
13 __builtin_amdgcn_sched_barrier(0); \
14 asm volatile("; [POYENC] " #marker); \
15 __builtin_amdgcn_sched_barrier(0);
17 #define ASM_MARKER(marker)
20 #define ADD_SBARRIER_FOR_PHASE0 1
21 #if !defined(CK_TILE_DISABLE_PACKED_FP32)
22 #define CK_TILE_DISABLE_PACKED_FP32 0
28 #define ENABLE_DEBUG_STMTS 1
29 #if ENABLE_DEBUG_STMTS
31 if(get_block_1d_id() == 0 && get_warp_id() == WARP_ID && get_lane_id() == LANE_ID)
33 #define DEBUG_STMTS if constexpr(false)
38 template <
typename PipelineProblem,
bool kIsMasking>
41 template <
typename PipelineProblem>
44 template <ck_tile::index_t WaveGroup, ck_tile::index_t Phase>
50 if constexpr(WaveGroup == 0)
52 if constexpr(Phase == 0)
55 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
56 __builtin_amdgcn_sched_group_barrier(0x200, 2, 0);
57 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
60 else if constexpr(Phase == 1)
62 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
63 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0);
65 else if constexpr(Phase == 2)
67 #if !CK_TILE_DISABLE_PACKED_FP32
68 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0);
71 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
72 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0);
75 else if constexpr(Phase == 3)
77 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
78 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0);
83 if constexpr(Phase == 0)
85 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
86 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0);
88 else if constexpr(Phase == 1)
91 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
92 __builtin_amdgcn_sched_group_barrier(0x200, 2, 0);
93 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
96 else if constexpr(Phase == 2)
98 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
99 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0);
101 else if constexpr(Phase == 3)
103 #if !CK_TILE_DISABLE_PACKED_FP32
104 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0);
107 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
108 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0);
115 template <
typename PipelineProblem>
118 template <ck_tile::index_t WaveGroup, ck_tile::index_t Phase>
124 if constexpr(WaveGroup == 0)
126 if constexpr(Phase == 0)
129 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
130 __builtin_amdgcn_sched_group_barrier(0x200, 2, 0);
131 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
134 else if constexpr(Phase == 1)
136 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
137 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0);
139 else if constexpr(Phase == 2)
141 #if !CK_TILE_DISABLE_PACKED_FP32
142 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0);
145 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
146 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0);
149 else if constexpr(Phase == 3)
151 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
152 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0);
157 if constexpr(Phase == 0)
159 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
160 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0);
162 else if constexpr(Phase == 1)
165 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
166 __builtin_amdgcn_sched_group_barrier(0x200, 2, 0);
167 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
170 else if constexpr(Phase == 2)
172 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
173 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0);
175 else if constexpr(Phase == 3)
177 #if !CK_TILE_DISABLE_PACKED_FP32
178 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0);
181 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
182 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0);
192 #if CK_TILE_DISABLE_PACKED_FP32
196 asm volatile(
"v_fma_f32 %[result], %[a], %[b], %[c]"
197 : [result]
"=v"(result)
198 : [
a]
"v"(
a), [b]
"s"(b), [c]
"v"(c));
206 asm volatile(
"v_add_f32_e32 %[result], %[lhs], %[rhs]"
207 : [result]
"=v"(result)
208 : [lhs]
"v"(lhs), [rhs]
"v"(rhs));
215 asm volatile(
"v_mul_f32_e32 %[result], %[lhs], %[rhs]"
216 : [result]
"=v"(result)
217 : [lhs]
"v"(lhs), [rhs]
"v"(rhs));
224 asm volatile(
"v_cvt_pk_f16_f32 %[result], %[a], %[b]"
225 : [result]
"=v"(result)
226 : [
a]
"v"(
a), [b]
"v"(b));
233 asm volatile(
"v_cvt_pk_bf16_f32 %[result], %[a], %[b]"
234 : [result]
"=v"(result)
235 : [
a]
"v"(
a), [b]
"v"(b));
242 asm volatile(
"v_pk_mul_f32 %[result], %[lhs], %[rhs]"
243 : [result]
"=v"(result)
244 : [lhs]
"v"(lhs), [rhs]
"v"(rhs));
249 template <
typename Problem_,
typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
265 static_assert(std::is_same_v<SaccDataType, SMPLComputeDataType>,
266 "we will the same dist tensor 'sp_compute' for both gemm0 & softmax");
280 static_assert(kSubQKHeaddim <= 256,
"hdim bigger than 256 is not suitable for this pipeline!");
282 static constexpr
bool kIsGroupMode = Problem::kIsGroupMode;
283 static constexpr
bool kPadSeqLenQ = Problem::kPadSeqLenQ;
284 static constexpr
bool kPadSeqLenK = Problem::kPadSeqLenK;
285 static constexpr
bool kPadHeadDimQ = Problem::kPadHeadDimQ;
286 static constexpr
bool kPadHeadDimV = Problem::kPadHeadDimV;
287 static constexpr
bool kStoreLSE = Problem::kStoreLSE;
292 kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
294 kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
296 kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
299 kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
302 if constexpr(Problem::kBlockPerCu != -1)
303 return Problem::kBlockPerCu;
314 Policy::template GetSmemSize<Problem>() +
319 template <ck_tile::index_t MPerBlock, ck_tile::index_t NPerBlock>
323 constexpr
auto lds_block_desc =
329 return lds_block_desc;
333 template <ck_tile::index_t MPerBlock>
340 return lds_block_desc;
343 template <
typename DataType,
typename Descriptor>
349 make_tensor_view<address_space_enum::lds>(
reinterpret_cast<DataType*
>(base), desc);
354 template <u
int16_t Vmcnt, u
int8_t Lgkmcnt, u
int8_t Expcnt = 7>
360 __builtin_amdgcn_s_waitcnt((((0b110000 & Vmcnt) << (14 - 4)) | (0b1111 & Vmcnt)) |
361 ((0b111 & Expcnt) << 4) | ((0b1111 & Lgkmcnt) << 8));
364 template <u
int16_t Vmcnt>
367 s_waitcnt<Vmcnt, 15>();
370 template <u
int8_t Lgkmcnt>
373 s_waitcnt<63, Lgkmcnt>();
376 template <
typename QDramBlockWindowTmp,
377 typename KDramBlockWindowTmp,
378 typename VDramBlockWindowTmp,
379 typename LSEDramBlockWindowTmp,
380 typename QElementFunction,
381 typename KElementFunction,
382 typename VElementFunction,
383 typename LSEElementFunction,
384 typename SAccElementFunction,
385 typename PComputeElementFunction,
386 typename OAccElementFunction>
388 const QElementFunction& q_element_func,
389 const KDramBlockWindowTmp& k_dram_block_window_tmp,
390 [[maybe_unused]]
const KElementFunction& k_element_func,
391 const VDramBlockWindowTmp& v_dram_block_window_tmp,
392 [[maybe_unused]]
const VElementFunction& v_element_func,
393 LSEDramBlockWindowTmp& lse_dram_window_tmp,
394 const LSEElementFunction& lse_element_func,
395 [[maybe_unused]]
const SAccElementFunction& s_acc_element_func,
396 const PComputeElementFunction& p_compute_element_func,
397 const OAccElementFunction& o_acc_element_func,
400 void* smem_ptr)
const
410 static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[
number<0>{}] &&
411 kN0 == KDramBlockWindowTmp{}.get_window_lengths()[
number<0>{}] &&
412 kK0 == KDramBlockWindowTmp{}.get_window_lengths()[
number<1>{}] &&
413 kK1 == VDramBlockWindowTmp{}.get_window_lengths()[
number<0>{}] &&
414 kN1 == VDramBlockWindowTmp{}.get_window_lengths()[
number<1>{}],
417 static_assert(
sizeof(
SaccDataType) * kM0 * kN0 <= GetSmemSize());
418 auto s_lds = make_tensor_view<address_space_enum::lds>(
419 reinterpret_cast<SaccDataType*
>(
static_cast<char*
>(smem_ptr)),
420 MakeSimpleLdsDesc<kM0, kN0>());
421 [[maybe_unused]]
auto s_lds_window =
424 auto p_lds = make_tensor_view<address_space_enum::lds>(
425 reinterpret_cast<PDataType*
>(
static_cast<char*
>(smem_ptr) +
426 Policy::template GetSmemSize<Problem>()),
427 MakeSimpleLdsDesc<kM0, kN0>());
428 [[maybe_unused]]
auto p_lds_window =
431 auto o_lds = make_tensor_view<address_space_enum::lds>(
432 reinterpret_cast<PDataType*
>(
static_cast<char*
>(smem_ptr)),
433 MakeSimpleLdsDesc<kM0, kN1>());
434 [[maybe_unused]]
auto o_lds_window =
437 auto m_lds = make_tensor_view<address_space_enum::lds>(
439 Policy::template GetSmemSize<Problem>()),
440 MakeSimpleLdsDesc1D<kM0>());
441 [[maybe_unused]]
auto m_lds_window =
444 const index_t warp_group_id = get_warp_id() / 4;
447 constexpr
auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
448 constexpr
auto gemm_1 = Policy::template GetPVBlockGemm<Problem>();
451 q_dram_block_window_tmp, Policy::template MakeQRegTileDistribution<Problem>());
454 const auto f_max = [](
auto e0,
auto e1) {
return max(e0, e1); };
455 const auto f_sum = [](
auto e0,
auto e1) {
return e0 + e1; };
459 return make_lds_tile_window<KDataType>(
460 smem_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf));
466 return make_lds_tile_window<KDataType>(
467 smem_ptr, Policy::template MakeVLdsStoreBlockDescriptor<Problem>(i_buf));
472 make_lds_tile_window<KDataType>(
474 Policy::template MakeKLdsLoadBlockDescriptor<Problem>()),
475 Policy::template MakeKRegTileDistribution<Problem>())),
480 make_lds_tile_window<VDataType>(
482 Policy::template MakeVLdsLoadBlockDescriptor<Problem>()),
483 Policy::template MakeVRegTileDistribution<Problem>())),
487 decltype(make_static_distributed_tensor<QDataType>(
488 Policy::template MakeQRegTileDistribution<Problem>())) q_tile;
499 union sp_compute_type
503 decltype(gemm_0.MakeCBlockTile()) sp_compute;
504 decltype(make_static_distributed_tensor<PDataType>(
505 Policy::template MakePRegTileDistribution<Problem>())) p;
509 decltype(gemm_1.MakeCBlockTile()) o_acc;
510 constexpr
index_t fmha_alu_D_reg_cnt = 6;
512 static_assert(fmha_alu_D_reg_cnt <= o_acc.thread_buf_.size());
514 decltype(block_tile_reduce<SMPLComputeDataType>(
521 make_lds_tile_window<KDataType>(
522 static_cast<char*
>(smem_ptr) + (idx)*Policy::template GetSmemSizeKV<Problem>(),
523 Policy::template MakeKLdsLoadBlockDescriptor<Problem>()),
524 Policy::template MakeKRegTileDistribution<Problem>());
528 v_lds_window_load(idx) =
530 static_cast<char*
>(smem_ptr) +
531 (idx + 2) * Policy::template GetSmemSizeKV<Problem>(),
532 Policy::template MakeVLdsLoadBlockDescriptor<Problem>()),
533 Policy::template MakeVRegTileDistribution<Problem>());
537 auto origin_q =
load_tile(q_dram_window);
540 q_tile = transformed_q;
544 set_tile(m, bit_cast<float>(0xff7fffff));
547 const auto q_origin = q_dram_window.get_window_origin();
548 const auto [seqlen_k_start, seqlen_k_end] =
552 index_t kv_token_start = seqlen_k_start;
555 if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
557 if(num_total_loop <= 0)
559 if constexpr(kStoreLSE)
562 make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
577 k_dram_block_window_tmp.get_window_lengths(),
579 Policy::template MakeKDramTileDistribution<Problem>());
580 k_dram_window.init_raw();
584 v_dram_block_window_tmp.get_window_lengths(),
586 Policy::template MakeVDramTileDistribution<Problem>());
587 v_dram_window.init_raw();
591 constexpr
index_t k0_loops = kQKHeaddim / kK0;
592 constexpr
index_t k1_loops = kN0 / kK1;
593 static_assert(1 == k0_loops);
594 static_assert(1 == k1_loops);
595 static_assert(kN0 == kK1);
597 constexpr
index_t NumWarpGroups = Problem::kBlockSize / Policy::NumThreadPerWarpGroup;
598 static_assert(NumWarpGroups == 2);
600 [[maybe_unused]]
auto print_dist_tensor = [&](
const auto& dist_tensor,
const char* name) {
601 printf(
"[POYENC] %s (size=%d): %5.2f",
603 decltype(dist_tensor.thread_buf_)::size(),
604 ck_tile::type_convert<float>(dist_tensor.thread_buf_[0]));
605 static_for<1, decltype(dist_tensor.thread_buf_)::size(), 1>{}([&](
auto i) {
606 printf(
", %5.2f", ck_tile::type_convert<float>(dist_tensor.thread_buf_[i]));
611 [[maybe_unused]]
auto print_lds = [&](
auto lds_tile_window,
const char* name) {
612 const auto num_rows = lds_tile_window.get_window_lengths().at(
number<0>{});
613 const auto num_cols = lds_tile_window.get_window_lengths().at(
number<1>{});
615 auto desc = lds_tile_window.get_bottom_tensor_view().desc_;
616 auto data = lds_tile_window.get_bottom_tensor_view().buf_.p_data_;
618 if constexpr(
true || num_rows < num_cols)
620 for(
int row = 0; row < num_rows; ++row)
623 printf(
"[DEVICE] %s[%3d] = %5.2f",
626 ck_tile::type_convert<float>(data[
offset]));
627 for(
int col = 1; col < num_cols; ++col)
631 printf(
"%5.2f", ck_tile::type_convert<float>(data[
offset]));
638 for(
int col = 0; col < num_cols; ++col)
641 printf(
"[DEVICE] %s[%3d] = %5.2f",
644 ck_tile::type_convert<float>(data[
offset]));
645 for(
int row = 1; row < num_rows; ++row)
649 printf(
"%5.2f", ck_tile::type_convert<float>(data[
offset]));
656 [[maybe_unused]]
auto print_lds_1d = [&](
auto lds_tile_window,
const char* name) {
657 const auto num_elems = lds_tile_window.get_window_lengths().at(
number<0>{});
659 auto desc = lds_tile_window.get_bottom_tensor_view().desc_;
660 auto data = lds_tile_window.get_bottom_tensor_view().buf_.p_data_;
663 printf(
"[DEVICE] %s = %5.2f", name, ck_tile::type_convert<float>(data[
offset]));
664 for(
int e = 1; e < num_elems; ++e)
668 printf(
"%5.2f", ck_tile::type_convert<float>(data[
offset]));
675 constexpr
int K_mem_su_ld_insts = k_dram_window.get_num_of_access();
676 constexpr
int V_mem_su_ld_insts = v_dram_window.get_num_of_access();
678 auto K_mem_load = [&](
auto k_lds_write_idx) {
686 auto K_lds_load = [&](
auto k_lds_read_idx) {
687 kv_tile.k_tile =
load_tile(k_lds_window_load(k_lds_read_idx));
690 auto V_mem_load = [&](
auto v_lds_write_idx) {
697 auto V_lds_load = [&](
auto v_lds_read_idx) {
706 auto fmha_alu0 = [&](
auto sp_reg_idx) {
708 static_assert(m.thread_buf_.size() == 1,
709 "assuming that each thread holds 1 rowmax value");
710 auto m_latest = block_tile_reduce<SMPLComputeDataType>(
711 sp(sp_reg_idx).sp_compute,
sequence<1>{}, f_max, m.thread_buf_[0]);
712 #if defined(__gfx950__)
715 __builtin_amdgcn_permlane32_swap(bit_cast<int32_t>(m_latest.thread_buf_[0]),
716 bit_cast<int32_t>(m_latest.thread_buf_[0]),
720 m_latest.thread_buf_[0] = f_max(bit_cast<SMPLComputeDataType>(swapped_regs.x),
721 bit_cast<SMPLComputeDataType>(swapped_regs.y));
727 constexpr
auto p_spans =
728 std::decay_t<decltype(sp(sp_reg_idx).sp_compute)>::get_distributed_spans();
731 constexpr
auto i_j_idx =
make_tuple(idx0, idx1);
733 sp(sp_reg_idx).sp_compute(i_j_idx), scale_s, -scale_s * m(i_j_idx));
739 auto fmha_alu1 = [&](
auto sp_reg_idx) {
740 constexpr
auto p_spans =
741 std::decay_t<decltype(sp(sp_reg_idx).sp_compute)>::get_distributed_spans();
744 constexpr
auto i_j_idx =
make_tuple(idx0, idx1);
745 sp(sp_reg_idx).sp_compute(i_j_idx) =
750 auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
751 sp(sp_reg_idx).sp_compute,
755 static_assert(rowsum_p.thread_buf_.size() == 1,
756 "assuming that each thread holds 1 rowsum value");
757 #if defined(__gfx950__)
760 __builtin_amdgcn_permlane32_swap(bit_cast<int32_t>(rowsum_p.thread_buf_[0]),
761 bit_cast<int32_t>(rowsum_p.thread_buf_[0]),
764 rowsum_p.thread_buf_[0] = f_sum(bit_cast<SMPLComputeDataType>(swapped_regs.x),
765 bit_cast<SMPLComputeDataType>(swapped_regs.y));
775 constexpr
auto o_spans = decltype(o_acc)::get_distributed_spans();
778 const auto tmp =
ck_tile::exp2(scale_s * (m_old[i_idx] - m[i_idx]));
792 static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 2 == 0);
793 static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 2>{}([&](
auto idx) {
794 float x = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx]);
795 float y = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1]);
796 if constexpr(std::is_same_v<PDataType, fp16_t>)
799 sp(sp_reg_idx).p.thread_buf_[idx] = casted.x;
800 sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y;
805 sp(sp_reg_idx).p.thread_buf_[idx] = casted.x;
806 sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y;
815 auto gemm = [&](
auto sp_reg_idx,
auto gemm_idx) {
816 if constexpr(gemm_idx == 0)
819 gemm_0(sp(sp_reg_idx).sp_compute,
821 sequence<0, (k0_loops - 1) * kK0>{},
824 sequence<0, (k0_loops - 1) * kK0>{},
831 sequence<0, (k1_loops - 1) * kK1>{},
834 sequence<0, (k1_loops - 1) * kK1>{},
839 auto cl_calc = [&](
auto sp_reg_idx,
auto gemm_idx) {
840 if constexpr(gemm_idx == 0)
843 gemm_0(sp(sp_reg_idx).sp_compute,
845 sequence<0, (k0_loops - 1) * kK0>{},
848 sequence<0, (k0_loops - 1) * kK0>{},
855 sequence<0, (k1_loops - 1) * kK1>{},
858 sequence<0, (k1_loops - 1) * kK1>{},
864 auto fmha_alu_D_upd = [&] {
865 o_acc_scale =
ck_tile::exp2(scale_s * (m_old.thread_buf_[0] - m.thread_buf_[0]));
868 pk_o_acc_scale.x = o_acc_scale;
869 pk_o_acc_scale.y = o_acc_scale;
871 static_assert((o_acc.thread_buf_.size() - fmha_alu_D_reg_cnt) % 2 == 0);
872 #if CK_TILE_DISABLE_PACKED_FP32
873 static_assert(fmha_alu_D_reg_cnt + 2 <= o_acc.thread_buf_.size());
875 [&](
auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; });
878 constexpr
auto issued_D_reg_cnt =
879 #if CK_TILE_DISABLE_PACKED_FP32
880 fmha_alu_D_reg_cnt + 2
888 static_for<issued_D_reg_cnt, o_acc.thread_buf_.size(), 2>{}([&](
auto idx) {
890 input.x = o_acc.thread_buf_[idx];
891 input.y = o_acc.thread_buf_[idx + 1];
895 o_acc.thread_buf_[idx] = output.x;
896 o_acc.thread_buf_[idx + 1] = output.y;
900 auto fmha_mask = [&](
auto sp_reg_idx) {
901 if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
903 bool need_perpixel_check = mask.IsEdgeTile(
905 if(need_perpixel_check)
912 const auto col = kv_token_start + tile_idx.at(
number<1>{});
913 return mask.IsOutOfBound(row, col);
919 auto cl_load = [&](
auto load_type,
auto mem_wr_idx,
auto lds_rd_idx) {
920 if constexpr(load_type == 0)
922 V_mem_load(mem_wr_idx);
923 K_lds_load(lds_rd_idx);
927 K_mem_load(mem_wr_idx);
928 V_lds_load(lds_rd_idx);
932 auto core_loop = [&](
auto cl_p) {
941 auto iteration = [&](
auto pi) {
942 auto xdl_SP_p01_reg_idx =
number<1>{} - pi;
943 auto xdl_SP_p23_reg_idx = pi;
946 auto V_w0_lds_wr_idx = pi;
947 auto K_w0_lds_rd_idx = pi;
948 auto V_w0_lds_rd_idx = pi;
953 auto V_w4_lds_rd_idx = pi;
957 if constexpr(cl_p == 0)
959 #if ADD_SBARRIER_FOR_PHASE0
960 __builtin_amdgcn_sched_barrier(0);
961 __builtin_amdgcn_s_barrier();
963 __builtin_amdgcn_sched_barrier(0);
965 if constexpr(pi == 0)
973 s_waitcnt_lgkmcnt<0>();
974 __builtin_amdgcn_sched_barrier(0);
975 cl_calc(xdl_SP_p01_reg_idx, gemm0);
976 fmha_alu1(xdl_SP_p23_reg_idx);
979 __builtin_amdgcn_sched_barrier(0);
982 s_waitcnt_vmcnt<K_mem_su_ld_insts + V_mem_su_ld_insts>();
983 __builtin_amdgcn_sched_barrier(0);
984 __builtin_amdgcn_s_barrier();
985 __builtin_amdgcn_sched_barrier(0);
986 cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx);
988 fmha_mask(xdl_SP_p01_reg_idx);
990 __builtin_amdgcn_sched_barrier(0);
993 s_waitcnt_lgkmcnt<0>();
994 __builtin_amdgcn_sched_barrier(0);
995 __builtin_amdgcn_s_barrier();
996 __builtin_amdgcn_sched_barrier(0);
997 asm volatile(
"s_nop 0");
998 __builtin_amdgcn_sched_barrier(0);
999 cl_calc(xdl_SP_p23_reg_idx, gemm1);
1002 __builtin_amdgcn_sched_barrier(0);
1005 __builtin_amdgcn_sched_barrier(0);
1008 s_waitcnt_vmcnt<K_mem_su_ld_insts + V_mem_su_ld_insts>();
1009 __builtin_amdgcn_sched_barrier(0);
1010 __builtin_amdgcn_s_barrier();
1011 __builtin_amdgcn_sched_barrier(0);
1012 cl_load(memV, V_w0_lds_wr_idx, K_w0_lds_rd_idx);
1015 kv_token_start += kN0;
1016 if(num_total_loop <= ++i_total_loops)
1023 #if ADD_SBARRIER_FOR_PHASE0
1024 __builtin_amdgcn_sched_barrier(0);
1025 __builtin_amdgcn_s_barrier();
1027 __builtin_amdgcn_sched_barrier(0);
1029 if constexpr(pi == 0)
1037 cl_load(memV, V_w4_lds_wr_idx, K_w4_lds_rd_idx);
1040 __builtin_amdgcn_sched_barrier(0);
1043 s_waitcnt<K_mem_su_ld_insts + V_mem_su_ld_insts, 0>();
1044 __builtin_amdgcn_sched_barrier(0);
1045 __builtin_amdgcn_s_barrier();
1046 __builtin_amdgcn_sched_barrier(0);
1047 asm volatile(
"s_nop 1");
1048 __builtin_amdgcn_sched_barrier(0);
1049 cl_calc(xdl_SP_p01_reg_idx, gemm0);
1050 fmha_alu1(xdl_SP_p23_reg_idx);
1053 __builtin_amdgcn_sched_barrier(0);
1056 __builtin_amdgcn_s_barrier();
1057 __builtin_amdgcn_sched_barrier(0);
1058 cl_load(memK, K_w4_lds_wr_idx, V_w4_lds_rd_idx);
1060 fmha_mask(xdl_SP_p01_reg_idx);
1062 kv_token_start += kN0;
1063 if(num_total_loop <= ++i_total_loops)
1068 __builtin_amdgcn_sched_barrier(0);
1071 s_waitcnt<K_mem_su_ld_insts + V_mem_su_ld_insts, 0>();
1072 __builtin_amdgcn_sched_barrier(0);
1073 __builtin_amdgcn_s_barrier();
1074 __builtin_amdgcn_sched_barrier(0);
1075 asm volatile(
"s_nop 1");
1076 __builtin_amdgcn_sched_barrier(0);
1077 cl_calc(xdl_SP_p23_reg_idx, gemm1);
1080 __builtin_amdgcn_sched_barrier(0);
1088 auto fmha_post_process = [&](
auto d) {
1090 auto V_lds_rd_idx = ps_pi;
1092 if(1 < num_total_loop)
1094 s_waitcnt_vmcnt<K_mem_su_ld_insts>();
1098 s_waitcnt_vmcnt<0>();
1100 __builtin_amdgcn_s_barrier();
1102 V_lds_load(V_lds_rd_idx);
1105 s_waitcnt_lgkmcnt<0>();
1107 auto xdl_SP_p23_reg_idx = ps_pi;
1117 s_waitcnt_vmcnt<0>();
1118 __builtin_amdgcn_s_barrier();
1122 s_waitcnt_lgkmcnt<0>();
1123 __builtin_amdgcn_s_barrier();
1126 if(1 < num_total_loop)
1140 kv_token_start += kN0;
1142 if(num_total_loop <= i_total_loops)
1144 goto label_main_loops_exit;
1147 if(2 < num_total_loop)
1151 s_waitcnt_vmcnt<K_mem_su_ld_insts + V_mem_su_ld_insts>();
1152 __builtin_amdgcn_s_barrier();
1158 if(1 < num_total_loop)
1160 if(warp_group_id == 0)
1165 __builtin_amdgcn_s_setprio(0);
1166 __builtin_amdgcn_s_barrier();
1170 if(warp_group_id != 0)
1172 __builtin_amdgcn_s_setprio(1);
1173 __builtin_amdgcn_s_barrier();
1178 label_main_loops_exit:
1179 if(num_total_loop % 2)
1183 if(!(num_total_loop % 2))
1189 if constexpr(kStoreLSE)
1191 auto lse = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
1193 constexpr
auto lse_spans = decltype(lse)::get_distributed_spans();
1196 lse(i_idx) = m[i_idx] /
C_LOG2E +
log(l[i_idx]);
1203 constexpr
auto o_spans = decltype(o_acc)::get_distributed_spans();
1207 const auto tmp = [&]() {
1208 if constexpr(FmhaMask::IsMasking)
1210 return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
1213 return 1 / l[i_idx];
1216 constexpr
auto i_j_idx =
make_tuple(idx0, idx1);
1217 o_acc(i_j_idx) *= tmp;
1226 template <
typename QDramBlockWindowTmp,
1227 typename KDramBlockWindowTmp,
1228 typename VDramBlockWindowTmp,
1229 typename LSEDramBlockWindowTmp>
1231 const KDramBlockWindowTmp& k_dram_block_window_tmp,
1232 const VDramBlockWindowTmp& v_dram_block_window_tmp,
1233 LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
1236 void* smem_ptr)
const
1240 return operator()(q_dram_block_window_tmp,
1242 k_dram_block_window_tmp,
1244 v_dram_block_window_tmp,
1246 lse_dram_block_window_tmp,
#define ASM_MARKER(marker)
Definition: block_fmha_fwd_v3_pipeline.hpp:12
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
CK_TILE_DEVICE bf16x2_t cvt_pk_bf16_f32(float a, float b)
Definition: block_fmha_fwd_v3_pipeline.hpp:230
CK_TILE_DEVICE float fma_impl_vsv(float a, float b, float c)
Definition: block_fmha_fwd_v3_pipeline.hpp:190
CK_TILE_DEVICE float mul_impl_vv(float lhs, float rhs)
Definition: block_fmha_fwd_v3_pipeline.hpp:212
CK_TILE_DEVICE float add_impl_vv(float lhs, float rhs)
Definition: block_fmha_fwd_v3_pipeline.hpp:203
CK_TILE_DEVICE fp16x2_t cvt_pk_fp16_f32(float a, float b)
Definition: block_fmha_fwd_v3_pipeline.hpp:221
CK_TILE_DEVICE fp32x2_t pk_mul_f32(fp32x2_t lhs, fp32x2_t rhs)
Definition: block_fmha_fwd_v3_pipeline.hpp:239
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition: bfloat16.hpp:420
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:274
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
constexpr CK_TILE_DEVICE auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition: slice_tile.hpp:23
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition: tile_elementwise.hpp:40
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:95
bfloat16_t bf16x2_t
Definition: pk_fp4.hpp:24
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={}, bool_constant< CrossWarp >={})
Definition: block_reduce.hpp:21
tuple_array< T, N > statically_indexed_array
Definition: statically_indexed_array.hpp:16
float fp32x2_t
Definition: pk_fp4.hpp:22
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
_Float16 fp16x2_t
Definition: half.hpp:385
CK_TILE_DEVICE auto load_tile_transpose(const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window)
transpose loads tile from a tensor and returns the resulting tensor with a new (transposed) tile dist...
Definition: load_tile_transpose.hpp:403
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition: static_distributed_tensor.hpp:175
constexpr CK_TILE_DEVICE auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition: tile_window_linear.hpp:993
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_ &&lds_tile, const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition: load_tile.hpp:133
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition: tile_elementwise.hpp:177
int32_t int32x2_t
Definition: vector_type.hpp:143
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition: bfloat16.hpp:417
constexpr bool is_same_v
Definition: type.hpp:283
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1249
Definition: block_fmha_fwd_v3_pipeline.hpp:251
ck_tile::remove_cvref_t< typename Problem::PDataType > PDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:260
ck_tile::remove_cvref_t< typename Problem::SMPLComputeDataType > SMPLComputeDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:258
ck_tile::remove_cvref_t< typename Problem::KDataType > KDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:255
ck_tile::remove_cvref_t< typename Problem::ODataType > ODataType
Definition: block_fmha_fwd_v3_pipeline.hpp:262
ck_tile::remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition: block_fmha_fwd_v3_pipeline.hpp:268
static constexpr CK_TILE_DEVICE void s_waitcnt_lgkmcnt()
Definition: block_fmha_fwd_v3_pipeline.hpp:371
ck_tile::remove_cvref_t< typename Problem::OaccDataType > OaccDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:261
static constexpr CK_TILE_DEVICE auto MakeSimpleLdsDesc1D()
Definition: block_fmha_fwd_v3_pipeline.hpp:334
static constexpr CK_TILE_DEVICE void s_waitcnt()
Definition: block_fmha_fwd_v3_pipeline.hpp:355
ck_tile::remove_cvref_t< Problem_ > Problem
Definition: block_fmha_fwd_v3_pipeline.hpp:252
static constexpr CK_TILE_DEVICE void s_waitcnt_vmcnt()
Definition: block_fmha_fwd_v3_pipeline.hpp:365
static constexpr CK_TILE_DEVICE auto make_lds_tile_window(void *base, const Descriptor &desc)
Definition: block_fmha_fwd_v3_pipeline.hpp:344
ck_tile::remove_cvref_t< typename Problem::QDataType > QDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:254
ck_tile::remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:259
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: block_fmha_fwd_v3_pipeline.hpp:310
ck_tile::remove_cvref_t< typename Problem::SaccDataType > SaccDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:257
ck_tile::remove_cvref_t< Policy_ > Policy
Definition: block_fmha_fwd_v3_pipeline.hpp:253
ck_tile::remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition: block_fmha_fwd_v3_pipeline.hpp:263
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const QElementFunction &q_element_func, const KDramBlockWindowTmp &k_dram_block_window_tmp, [[maybe_unused]] const KElementFunction &k_element_func, const VDramBlockWindowTmp &v_dram_block_window_tmp, [[maybe_unused]] const VElementFunction &v_element_func, LSEDramBlockWindowTmp &lse_dram_window_tmp, const LSEElementFunction &lse_element_func, [[maybe_unused]] const SAccElementFunction &s_acc_element_func, const PComputeElementFunction &p_compute_element_func, const OAccElementFunction &o_acc_element_func, FmhaMask mask, float scale_s, void *smem_ptr) const
Definition: block_fmha_fwd_v3_pipeline.hpp:387
ck_tile::remove_cvref_t< typename Problem::VDataType > VDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:256
static constexpr CK_TILE_DEVICE auto MakeSimpleLdsDesc()
Definition: block_fmha_fwd_v3_pipeline.hpp:320
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_block_window_tmp, FmhaMask mask, float scale_s, void *smem_ptr) const
Definition: block_fmha_fwd_v3_pipeline.hpp:1230
static constexpr CK_TILE_DEVICE void schedule(ck_tile::number< WaveGroup >, ck_tile::number< Phase >)
Definition: block_fmha_fwd_v3_pipeline.hpp:119
static constexpr CK_TILE_DEVICE void schedule(ck_tile::number< WaveGroup >, ck_tile::number< Phase >)
Definition: block_fmha_fwd_v3_pipeline.hpp:45
Definition: block_fmha_fwd_v3_pipeline.hpp:39
Definition: integral_constant.hpp:13
Definition: functional.hpp:86
Definition: numeric.hpp:18
Definition: coordinate_transform.hpp:1392
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
#define C_LOG2E
Definition: math.hpp:469