10 #define ENABLE_ASM_MARKER 1
12 #define ASM_MARKER(marker) \
13 __builtin_amdgcn_sched_barrier(0); \
14 asm volatile("; [POYENC] " #marker); \
15 __builtin_amdgcn_sched_barrier(0);
17 #define ASM_MARKER(marker)
20 #define ADD_SBARRIER_FOR_PHASE0 1
21 #if !defined(CK_TILE_DISABLE_PACKED_FP32)
22 #define CK_TILE_DISABLE_PACKED_FP32 0
28 #define ENABLE_DEBUG_STMTS 1
29 #if ENABLE_DEBUG_STMTS
31 if(get_block_1d_id() == 0 && get_warp_id() == WARP_ID && get_lane_id() == LANE_ID)
33 #define DEBUG_STMTS if constexpr(false)
38 template <
typename PipelineProblem,
bool kIsMasking>
41 template <
typename PipelineProblem>
44 template <ck_tile::index_t WaveGroup, ck_tile::index_t Phase>
50 if constexpr(WaveGroup == 0)
52 if constexpr(Phase == 0)
55 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
56 __builtin_amdgcn_sched_group_barrier(0x200, 2, 0);
57 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
60 else if constexpr(Phase == 1) {}
61 else if constexpr(Phase == 2)
63 #if !CK_TILE_DISABLE_PACKED_FP32
64 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0);
67 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
68 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0);
71 else if constexpr(Phase == 3) {}
75 if constexpr(Phase == 0) {}
76 else if constexpr(Phase == 1)
79 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
80 __builtin_amdgcn_sched_group_barrier(0x200, 2, 0);
81 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
84 else if constexpr(Phase == 2) {}
85 else if constexpr(Phase == 3)
87 #if !CK_TILE_DISABLE_PACKED_FP32
88 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0);
91 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
92 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0);
99 template <
typename PipelineProblem>
102 template <ck_tile::index_t WaveGroup, ck_tile::index_t Phase>
108 if constexpr(WaveGroup == 0)
110 if constexpr(Phase == 0)
113 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
114 __builtin_amdgcn_sched_group_barrier(0x200, 2, 0);
115 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
118 else if constexpr(Phase == 1) {}
119 else if constexpr(Phase == 2)
121 #if !CK_TILE_DISABLE_PACKED_FP32
122 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0);
125 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
126 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0);
129 else if constexpr(Phase == 3) {}
133 if constexpr(Phase == 0) {}
134 else if constexpr(Phase == 1)
137 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
138 __builtin_amdgcn_sched_group_barrier(0x200, 2, 0);
139 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0);
142 else if constexpr(Phase == 2) {}
143 else if constexpr(Phase == 3)
145 #if !CK_TILE_DISABLE_PACKED_FP32
146 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0);
149 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
150 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0);
160 #if CK_TILE_DISABLE_PACKED_FP32
164 asm volatile(
"v_fma_f32 %[result], %[a], %[b], %[c]"
165 : [result]
"=v"(result)
166 : [
a]
"v"(
a), [b]
"s"(b), [c]
"v"(c));
174 asm volatile(
"v_add_f32_e32 %[result], %[lhs], %[rhs]"
175 : [result]
"=v"(result)
176 : [lhs]
"v"(lhs), [rhs]
"v"(rhs));
183 asm volatile(
"v_cvt_pk_f16_f32 %[result], %[a], %[b]"
184 : [result]
"=v"(result)
185 : [
a]
"v"(
a), [b]
"v"(b));
192 asm volatile(
"v_cvt_pk_bf16_f32 %[result], %[a], %[b]"
193 : [result]
"=v"(result)
194 : [
a]
"v"(
a), [b]
"v"(b));
201 asm volatile(
"v_pk_mul_f32 %[result], %[lhs], %[rhs]"
202 : [result]
"=v"(result)
203 : [lhs]
"v"(lhs), [rhs]
"v"(rhs));
208 template <
typename Problem_,
typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
224 static_assert(std::is_same_v<SaccDataType, SMPLComputeDataType>,
225 "we will the same dist tensor 'sp_compute' for both gemm0 & softmax");
239 static_assert(kSubQKHeaddim <= 256,
"hdim bigger than 256 is not suitable for this pipeline!");
241 static constexpr
bool kIsGroupMode = Problem::kIsGroupMode;
242 static constexpr
bool kPadSeqLenQ = Problem::kPadSeqLenQ;
243 static constexpr
bool kPadSeqLenK = Problem::kPadSeqLenK;
244 static constexpr
bool kPadHeadDimQ = Problem::kPadHeadDimQ;
245 static constexpr
bool kPadHeadDimV = Problem::kPadHeadDimV;
246 static constexpr
bool kStoreLSE = Problem::kStoreLSE;
251 kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
253 kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
255 kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
258 kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
261 if constexpr(Problem::kBlockPerCu != -1)
262 return Problem::kBlockPerCu;
273 Policy::template GetSmemSize<Problem>() +
278 template <ck_tile::index_t MPerBlock, ck_tile::index_t NPerBlock>
282 constexpr
auto lds_block_desc =
288 return lds_block_desc;
292 template <ck_tile::index_t MPerBlock>
299 return lds_block_desc;
302 template <
typename DataType,
typename Descriptor>
308 make_tensor_view<address_space_enum::lds>(
reinterpret_cast<DataType*
>(base), desc);
313 template <u
int16_t Vmcnt, u
int8_t Lgkmcnt, u
int8_t Expcnt = 7>
319 __builtin_amdgcn_s_waitcnt((((0b110000 & Vmcnt) << (14 - 4)) | (0b1111 & Vmcnt)) |
320 ((0b111 & Expcnt) << 4) | ((0b1111 & Lgkmcnt) << 8));
323 template <u
int16_t Vmcnt>
326 s_waitcnt<Vmcnt, 15>();
329 template <u
int8_t Lgkmcnt>
332 s_waitcnt<63, Lgkmcnt>();
335 template <
typename QDramBlockWindowTmp,
336 typename KDramBlockWindowTmp,
337 typename VDramBlockWindowTmp,
338 typename LSEDramBlockWindowTmp,
339 typename QElementFunction,
340 typename KElementFunction,
341 typename VElementFunction,
342 typename LSEElementFunction,
343 typename SAccElementFunction,
344 typename PComputeElementFunction,
345 typename OAccElementFunction>
347 const QElementFunction& q_element_func,
348 const KDramBlockWindowTmp& k_dram_block_window_tmp,
349 [[maybe_unused]]
const KElementFunction& k_element_func,
350 const VDramBlockWindowTmp& v_dram_block_window_tmp,
351 [[maybe_unused]]
const VElementFunction& v_element_func,
352 LSEDramBlockWindowTmp& lse_dram_window_tmp,
353 const LSEElementFunction& lse_element_func,
354 [[maybe_unused]]
const SAccElementFunction& s_acc_element_func,
355 const PComputeElementFunction& p_compute_element_func,
356 const OAccElementFunction& o_acc_element_func,
359 void* smem_ptr)
const
369 static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[
number<0>{}] &&
370 kN0 == KDramBlockWindowTmp{}.get_window_lengths()[
number<0>{}] &&
371 kK0 == KDramBlockWindowTmp{}.get_window_lengths()[
number<1>{}] &&
372 kK1 == VDramBlockWindowTmp{}.get_window_lengths()[
number<0>{}] &&
373 kN1 == VDramBlockWindowTmp{}.get_window_lengths()[
number<1>{}],
376 static_assert(
sizeof(
SaccDataType) * kM0 * kN0 <= GetSmemSize());
377 auto s_lds = make_tensor_view<address_space_enum::lds>(
378 reinterpret_cast<SaccDataType*
>(
static_cast<char*
>(smem_ptr)),
379 MakeSimpleLdsDesc<kM0, kN0>());
380 [[maybe_unused]]
auto s_lds_window =
383 auto p_lds = make_tensor_view<address_space_enum::lds>(
384 reinterpret_cast<PDataType*
>(
static_cast<char*
>(smem_ptr) +
385 Policy::template GetSmemSize<Problem>()),
386 MakeSimpleLdsDesc<kM0, kN0>());
387 [[maybe_unused]]
auto p_lds_window =
390 auto o_lds = make_tensor_view<address_space_enum::lds>(
391 reinterpret_cast<PDataType*
>(
static_cast<char*
>(smem_ptr)),
392 MakeSimpleLdsDesc<kM0, kN1>());
393 [[maybe_unused]]
auto o_lds_window =
396 auto m_lds = make_tensor_view<address_space_enum::lds>(
398 Policy::template GetSmemSize<Problem>()),
399 MakeSimpleLdsDesc1D<kM0>());
400 [[maybe_unused]]
auto m_lds_window =
403 const index_t warp_group_id = get_warp_id() / 4;
406 constexpr
auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
407 constexpr
auto gemm_1 = Policy::template GetPVBlockGemm<Problem>();
410 q_dram_block_window_tmp, Policy::template MakeQRegTileDistribution<Problem>());
413 const auto f_max = [](
auto e0,
auto e1) {
return max(e0, e1); };
414 const auto f_sum = [](
auto e0,
auto e1) {
return e0 + e1; };
418 return make_lds_tile_window<KDataType>(
419 smem_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf));
425 return make_lds_tile_window<KDataType>(
426 smem_ptr, Policy::template MakeVLdsStoreBlockDescriptor<Problem>(i_buf));
431 make_lds_tile_window<KDataType>(
433 Policy::template MakeKLdsLoadBlockDescriptor<Problem>()),
434 Policy::template MakeKRegTileDistribution<Problem>())),
439 make_lds_tile_window<VDataType>(
441 Policy::template MakeVLdsLoadBlockDescriptor<Problem>()),
442 Policy::template MakeVRegTileDistribution<Problem>())),
446 decltype(make_static_distributed_tensor<QDataType>(
447 Policy::template MakeQRegTileDistribution<Problem>())) q_tile;
458 union sp_compute_type
462 decltype(gemm_0.MakeCBlockTile()) sp_compute;
463 decltype(make_static_distributed_tensor<PDataType>(
464 Policy::template MakePRegTileDistribution<Problem>())) p;
468 decltype(gemm_1.MakeCBlockTile()) o_acc;
469 constexpr
index_t fmha_alu_D_reg_cnt = 0;
471 static_assert(fmha_alu_D_reg_cnt <= o_acc.thread_buf_.size());
473 decltype(block_tile_reduce<SMPLComputeDataType>(
480 make_lds_tile_window<KDataType>(
481 static_cast<char*
>(smem_ptr) + (idx)*Policy::template GetSmemSizeKV<Problem>(),
482 Policy::template MakeKLdsLoadBlockDescriptor<Problem>()),
483 Policy::template MakeKRegTileDistribution<Problem>());
487 v_lds_window_load(idx) =
489 static_cast<char*
>(smem_ptr) +
490 (idx + 2) * Policy::template GetSmemSizeKV<Problem>(),
491 Policy::template MakeVLdsLoadBlockDescriptor<Problem>()),
492 Policy::template MakeVRegTileDistribution<Problem>());
496 auto origin_q =
load_tile(q_dram_window);
499 q_tile = transformed_q;
503 set_tile(m, bit_cast<float>(0xff7fffff));
506 const auto q_origin = q_dram_window.get_window_origin();
507 const auto [seqlen_k_start, seqlen_k_end] =
511 index_t kv_token_start = seqlen_k_start;
514 if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
516 if(num_total_loop <= 0)
518 if constexpr(kStoreLSE)
521 make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
536 k_dram_block_window_tmp.get_window_lengths(),
538 Policy::template MakeKDramTileDistribution<Problem>());
539 k_dram_window.init_raw();
543 v_dram_block_window_tmp.get_window_lengths(),
545 Policy::template MakeVDramTileDistribution<Problem>());
546 v_dram_window.init_raw();
550 constexpr
index_t k0_loops = kQKHeaddim / kK0;
551 constexpr
index_t k1_loops = kN0 / kK1;
552 static_assert(1 == k0_loops);
553 static_assert(1 == k1_loops);
554 static_assert(kN0 == kK1);
556 constexpr
index_t NumWarpGroups = Problem::kBlockSize / Policy::NumThreadPerWarpGroup;
557 static_assert(NumWarpGroups == 2);
559 [[maybe_unused]]
auto print_dist_tensor = [&](
const auto& dist_tensor,
const char* name) {
560 printf(
"[POYENC] %s (size=%d): %5.2f",
562 decltype(dist_tensor.thread_buf_)::size(),
563 ck_tile::type_convert<float>(dist_tensor.thread_buf_[0]));
564 static_for<1, decltype(dist_tensor.thread_buf_)::size(), 1>{}([&](
auto i) {
565 printf(
", %5.2f", ck_tile::type_convert<float>(dist_tensor.thread_buf_[i]));
570 [[maybe_unused]]
auto print_lds = [&](
auto lds_tile_window,
const char* name) {
571 const auto num_rows = lds_tile_window.get_window_lengths().at(
number<0>{});
572 const auto num_cols = lds_tile_window.get_window_lengths().at(
number<1>{});
574 auto desc = lds_tile_window.get_bottom_tensor_view().desc_;
575 auto data = lds_tile_window.get_bottom_tensor_view().buf_.p_data_;
577 if constexpr(
true || num_rows < num_cols)
579 for(
int row = 0; row < num_rows; ++row)
582 printf(
"[DEVICE] %s[%3d] = %5.2f",
585 ck_tile::type_convert<float>(data[
offset]));
586 for(
int col = 1; col < num_cols; ++col)
590 printf(
"%5.2f", ck_tile::type_convert<float>(data[
offset]));
597 for(
int col = 0; col < num_cols; ++col)
600 printf(
"[DEVICE] %s[%3d] = %5.2f",
603 ck_tile::type_convert<float>(data[
offset]));
604 for(
int row = 1; row < num_rows; ++row)
608 printf(
"%5.2f", ck_tile::type_convert<float>(data[
offset]));
615 [[maybe_unused]]
auto print_lds_1d = [&](
auto lds_tile_window,
const char* name) {
616 const auto num_elems = lds_tile_window.get_window_lengths().at(
number<0>{});
618 auto desc = lds_tile_window.get_bottom_tensor_view().desc_;
619 auto data = lds_tile_window.get_bottom_tensor_view().buf_.p_data_;
622 printf(
"[DEVICE] %s = %5.2f", name, ck_tile::type_convert<float>(data[
offset]));
623 for(
int e = 1; e < num_elems; ++e)
627 printf(
"%5.2f", ck_tile::type_convert<float>(data[
offset]));
634 static constexpr
int K_mem_su_ld_insts = 1;
635 static constexpr
int V_mem_su_ld_insts = 1;
637 auto K_mem_load = [&](
auto k_lds_write_idx) {
645 auto K_lds_load = [&](
auto k_lds_read_idx) {
646 kv_tile.k_tile =
load_tile(k_lds_window_load(k_lds_read_idx));
649 auto V_mem_load = [&](
auto v_lds_write_idx) {
651 __builtin_amdgcn_sched_barrier(0);
657 auto V_lds_load = [&](
auto v_lds_read_idx) {
666 auto fmha_alu0 = [&](
auto sp_reg_idx) {
668 static_assert(m.thread_buf_.size() == 1,
669 "assuming that each thread holds 1 rowmax value");
670 auto m_latest = block_tile_reduce<SMPLComputeDataType>(
671 sp(sp_reg_idx).sp_compute,
sequence<1>{}, f_max, m.thread_buf_[0]);
672 #if defined(__gfx950__)
675 __builtin_amdgcn_permlane32_swap(bit_cast<int32_t>(m_latest.thread_buf_[0]),
676 bit_cast<int32_t>(m_latest.thread_buf_[0]),
680 m_latest.thread_buf_[0] = f_max(bit_cast<SMPLComputeDataType>(swapped_regs.x),
681 bit_cast<SMPLComputeDataType>(swapped_regs.y));
687 constexpr
auto p_spans =
688 std::decay_t<decltype(sp(sp_reg_idx).sp_compute)>::get_distributed_spans();
691 constexpr
auto i_j_idx =
make_tuple(idx0, idx1);
693 sp(sp_reg_idx).sp_compute(i_j_idx), scale_s, -scale_s * m(i_j_idx));
699 auto fmha_alu1 = [&](
auto sp_reg_idx) {
700 constexpr
auto p_spans =
701 std::decay_t<decltype(sp(sp_reg_idx).sp_compute)>::get_distributed_spans();
704 constexpr
auto i_j_idx =
make_tuple(idx0, idx1);
705 sp(sp_reg_idx).sp_compute(i_j_idx) =
710 auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
711 sp(sp_reg_idx).sp_compute,
715 static_assert(rowsum_p.thread_buf_.size() == 1,
716 "assuming that each thread holds 1 rowsum value");
717 #if defined(__gfx950__)
720 __builtin_amdgcn_permlane32_swap(bit_cast<int32_t>(rowsum_p.thread_buf_[0]),
721 bit_cast<int32_t>(rowsum_p.thread_buf_[0]),
724 rowsum_p.thread_buf_[0] = f_sum(bit_cast<SMPLComputeDataType>(swapped_regs.x),
725 bit_cast<SMPLComputeDataType>(swapped_regs.y));
731 [&](
auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; });
734 constexpr
auto o_spans = decltype(o_acc)::get_distributed_spans();
737 const auto tmp =
ck_tile::exp2(scale_s * (m_old[i_idx] - m[i_idx]));
744 [&](
auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; });
749 static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 2 == 0);
750 static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 2>{}([&](
auto idx) {
751 float x = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx]);
752 float y = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1]);
753 if constexpr(std::is_same_v<PDataType, fp16_t>)
756 sp(sp_reg_idx).p.thread_buf_[idx] = casted.x;
757 sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y;
762 sp(sp_reg_idx).p.thread_buf_[idx] = casted.x;
763 sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y;
768 auto gemm = [&](
auto sp_reg_idx,
auto gemm_idx) {
769 if constexpr(gemm_idx == 0)
772 gemm_0(sp(sp_reg_idx).sp_compute,
774 sequence<0, (k0_loops - 1) * kK0>{},
777 sequence<0, (k0_loops - 1) * kK0>{},
784 sequence<0, (k1_loops - 1) * kK1>{},
787 sequence<0, (k1_loops - 1) * kK1>{},
792 auto cl_calc = [&](
auto sp_reg_idx,
auto gemm_idx) {
793 if constexpr(gemm_idx == 0)
796 gemm_0(sp(sp_reg_idx).sp_compute,
798 sequence<0, (k0_loops - 1) * kK0>{},
801 sequence<0, (k0_loops - 1) * kK0>{},
808 sequence<0, (k1_loops - 1) * kK1>{},
811 sequence<0, (k1_loops - 1) * kK1>{},
817 auto fmha_alu_D_upd = [&] {
818 o_acc_scale =
ck_tile::exp2(scale_s * (m_old.thread_buf_[0] - m.thread_buf_[0]));
821 pk_o_acc_scale.x = o_acc_scale;
822 pk_o_acc_scale.y = o_acc_scale;
824 static_assert((o_acc.thread_buf_.size() - fmha_alu_D_reg_cnt) % 2 == 0);
825 #if CK_TILE_DISABLE_PACKED_FP32
826 static_assert(fmha_alu_D_reg_cnt + 2 <= o_acc.thread_buf_.size());
828 [&](
auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; });
831 constexpr
auto issued_D_reg_cnt =
832 #if CK_TILE_DISABLE_PACKED_FP32
833 fmha_alu_D_reg_cnt + 2
841 static_for<issued_D_reg_cnt, o_acc.thread_buf_.size(), 2>{}([&](
auto idx) {
843 input.x = o_acc.thread_buf_[idx];
844 input.y = o_acc.thread_buf_[idx + 1];
848 o_acc.thread_buf_[idx] = output.x;
849 o_acc.thread_buf_[idx + 1] = output.y;
853 auto fmha_mask = [&](
auto sp_reg_idx) {
854 if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
856 bool need_perpixel_check = mask.IsEdgeTile(
858 if(need_perpixel_check)
865 const auto col = kv_token_start + tile_idx.at(
number<1>{});
866 return mask.IsOutOfBound(row, col);
872 auto cl_load = [&](
auto load_type,
auto mem_wr_idx,
auto lds_rd_idx) {
873 if constexpr(load_type == 0)
875 V_mem_load(mem_wr_idx);
876 K_lds_load(lds_rd_idx);
880 K_mem_load(mem_wr_idx);
881 V_lds_load(lds_rd_idx);
885 auto core_loop = [&](
auto cl_p) {
894 auto iteration = [&](
auto pi) {
895 auto xdl_SP_p01_reg_idx =
number<1>{} - pi;
896 auto xdl_SP_p23_reg_idx = pi;
899 auto V_w0_lds_wr_idx = pi;
900 auto K_w0_lds_rd_idx = pi;
901 auto V_w0_lds_rd_idx = pi;
906 auto V_w4_lds_rd_idx = pi;
910 if constexpr(cl_p == 0)
912 #if ADD_SBARRIER_FOR_PHASE0
913 __builtin_amdgcn_sched_barrier(0);
914 __builtin_amdgcn_s_barrier();
916 __builtin_amdgcn_sched_barrier(0);
918 if constexpr(pi == 0)
926 s_waitcnt_lgkmcnt<0>();
927 __builtin_amdgcn_sched_barrier(0);
928 cl_calc(xdl_SP_p01_reg_idx, gemm0);
929 fmha_alu1(xdl_SP_p23_reg_idx);
932 __builtin_amdgcn_sched_barrier(0);
935 s_waitcnt_vmcnt<K_mem_su_ld_insts + V_mem_su_ld_insts>();
936 __builtin_amdgcn_sched_barrier(0);
937 __builtin_amdgcn_s_barrier();
938 __builtin_amdgcn_sched_barrier(0);
939 cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx);
940 fmha_mask(xdl_SP_p01_reg_idx);
943 __builtin_amdgcn_sched_barrier(0);
946 s_waitcnt_lgkmcnt<0>();
947 __builtin_amdgcn_sched_barrier(0);
948 __builtin_amdgcn_s_barrier();
949 __builtin_amdgcn_sched_barrier(0);
950 cl_calc(xdl_SP_p23_reg_idx, gemm1);
953 __builtin_amdgcn_sched_barrier(0);
956 __builtin_amdgcn_sched_barrier(0);
959 s_waitcnt_vmcnt<K_mem_su_ld_insts + V_mem_su_ld_insts>();
960 __builtin_amdgcn_sched_barrier(0);
961 __builtin_amdgcn_s_barrier();
962 __builtin_amdgcn_sched_barrier(0);
963 cl_load(memV, V_w0_lds_wr_idx, K_w0_lds_rd_idx);
966 kv_token_start += kN0;
967 if(num_total_loop <= ++i_total_loops)
974 #if ADD_SBARRIER_FOR_PHASE0
975 __builtin_amdgcn_sched_barrier(0);
976 __builtin_amdgcn_s_barrier();
978 __builtin_amdgcn_sched_barrier(0);
980 if constexpr(pi == 0)
988 cl_load(memV, V_w4_lds_wr_idx, K_w4_lds_rd_idx);
991 __builtin_amdgcn_sched_barrier(0);
994 s_waitcnt<K_mem_su_ld_insts + V_mem_su_ld_insts, 0>();
995 __builtin_amdgcn_sched_barrier(0);
996 __builtin_amdgcn_s_barrier();
997 __builtin_amdgcn_sched_barrier(0);
998 cl_calc(xdl_SP_p01_reg_idx, gemm0);
999 fmha_alu1(xdl_SP_p23_reg_idx);
1002 __builtin_amdgcn_sched_barrier(0);
1005 __builtin_amdgcn_s_barrier();
1006 __builtin_amdgcn_sched_barrier(0);
1007 cl_load(memK, K_w4_lds_wr_idx, V_w4_lds_rd_idx);
1008 fmha_mask(xdl_SP_p01_reg_idx);
1011 kv_token_start += kN0;
1012 if(num_total_loop <= ++i_total_loops)
1017 __builtin_amdgcn_sched_barrier(0);
1020 s_waitcnt<K_mem_su_ld_insts + V_mem_su_ld_insts, 0>();
1021 __builtin_amdgcn_sched_barrier(0);
1022 __builtin_amdgcn_s_barrier();
1023 __builtin_amdgcn_sched_barrier(0);
1024 cl_calc(xdl_SP_p23_reg_idx, gemm1);
1027 __builtin_amdgcn_sched_barrier(0);
1035 auto fmha_post_process = [&](
auto d) {
1037 auto V_lds_rd_idx = ps_pi;
1039 s_waitcnt_vmcnt<K_mem_su_ld_insts>();
1040 __builtin_amdgcn_s_barrier();
1042 V_lds_load(V_lds_rd_idx);
1045 s_waitcnt_lgkmcnt<0>();
1047 auto xdl_SP_p23_reg_idx = ps_pi;
1057 s_waitcnt_vmcnt<0>();
1058 __builtin_amdgcn_s_barrier();
1062 s_waitcnt_lgkmcnt<0>();
1063 __builtin_amdgcn_s_barrier();
1066 if(1 < num_total_loop)
1080 kv_token_start += kN0;
1082 if(num_total_loop <= i_total_loops)
1084 goto label_main_loops_exit;
1087 if(2 < num_total_loop)
1091 s_waitcnt_vmcnt<K_mem_su_ld_insts + V_mem_su_ld_insts>();
1092 __builtin_amdgcn_s_barrier();
1098 if(1 < num_total_loop)
1100 if(warp_group_id == 0)
1105 asm volatile(
"s_setprio 0");
1106 __builtin_amdgcn_s_barrier();
1110 if(warp_group_id != 0)
1112 asm volatile(
"s_setprio 1");
1113 __builtin_amdgcn_s_barrier();
1118 label_main_loops_exit:
1119 if(num_total_loop % 2)
1123 if(!(num_total_loop % 2))
1129 if constexpr(kStoreLSE)
1131 auto lse = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
1133 constexpr
auto lse_spans = decltype(lse)::get_distributed_spans();
1136 lse(i_idx) = m[i_idx] /
C_LOG2E +
log(l[i_idx]);
1143 constexpr
auto o_spans = decltype(o_acc)::get_distributed_spans();
1147 const auto tmp = [&]() {
1148 if constexpr(FmhaMask::IsMasking)
1150 return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
1153 return 1 / l[i_idx];
1156 constexpr
auto i_j_idx =
make_tuple(idx0, idx1);
1157 o_acc(i_j_idx) *= tmp;
1166 template <
typename QDramBlockWindowTmp,
1167 typename KDramBlockWindowTmp,
1168 typename VDramBlockWindowTmp,
1169 typename LSEDramBlockWindowTmp>
1172 const KDramBlockWindowTmp& k_dram_block_window_tmp,
1173 const VDramBlockWindowTmp& v_dram_block_window_tmp,
1174 LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
1177 void* smem_ptr)
const
1181 return operator()(q_dram_block_window_tmp,
1183 k_dram_block_window_tmp,
1185 v_dram_block_window_tmp,
1187 lse_dram_block_window_tmp,
#define ASM_MARKER(marker)
Definition: block_fmha_fwd_v3_pipeline.hpp:12
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
CK_TILE_DEVICE bf16x2_t cvt_pk_bf16_f32(float a, float b)
Definition: block_fmha_fwd_v3_pipeline.hpp:189
CK_TILE_DEVICE float fma_impl_vsv(float a, float b, float c)
Definition: block_fmha_fwd_v3_pipeline.hpp:158
CK_TILE_DEVICE float add_impl_vv(float lhs, float rhs)
Definition: block_fmha_fwd_v3_pipeline.hpp:171
CK_TILE_DEVICE fp16x2_t cvt_pk_fp16_f32(float a, float b)
Definition: block_fmha_fwd_v3_pipeline.hpp:180
CK_TILE_DEVICE fp32x2_t pk_mul_f32(fp32x2_t lhs, fp32x2_t rhs)
Definition: block_fmha_fwd_v3_pipeline.hpp:198
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition: bfloat16.hpp:432
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:268
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
constexpr CK_TILE_DEVICE auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition: slice_tile.hpp:23
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition: tile_elementwise.hpp:40
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:95
bfloat16_t bf16x2_t
Definition: pk_fp4.hpp:24
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={}, bool_constant< CrossWarp >={})
Definition: block_reduce.hpp:21
tuple_array< T, N > statically_indexed_array
Definition: statically_indexed_array.hpp:16
float fp32x2_t
Definition: pk_fp4.hpp:22
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
_Float16 fp16x2_t
Definition: half.hpp:385
CK_TILE_DEVICE auto load_tile_transpose(const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window)
transpose loads tile from a tensor and returns the resulting tensor with a new (transposed) tile dist...
Definition: load_tile_transpose.hpp:403
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition: static_distributed_tensor.hpp:175
constexpr CK_TILE_DEVICE auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition: tile_window_linear.hpp:993
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_ &&lds_tile, const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition: load_tile.hpp:110
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition: tile_elementwise.hpp:177
int32_t int32x2_t
Definition: vector_type.hpp:143
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition: bfloat16.hpp:429
constexpr bool is_same_v
Definition: type.hpp:283
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1249
Definition: block_fmha_fwd_v3_pipeline.hpp:210
ck_tile::remove_cvref_t< typename Problem::PDataType > PDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:219
ck_tile::remove_cvref_t< typename Problem::SMPLComputeDataType > SMPLComputeDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:217
ck_tile::remove_cvref_t< typename Problem::KDataType > KDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:214
ck_tile::remove_cvref_t< typename Problem::ODataType > ODataType
Definition: block_fmha_fwd_v3_pipeline.hpp:221
ck_tile::remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition: block_fmha_fwd_v3_pipeline.hpp:227
static constexpr CK_TILE_DEVICE void s_waitcnt_lgkmcnt()
Definition: block_fmha_fwd_v3_pipeline.hpp:330
ck_tile::remove_cvref_t< typename Problem::OaccDataType > OaccDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:220
static constexpr CK_TILE_DEVICE auto MakeSimpleLdsDesc1D()
Definition: block_fmha_fwd_v3_pipeline.hpp:293
static constexpr CK_TILE_DEVICE void s_waitcnt()
Definition: block_fmha_fwd_v3_pipeline.hpp:314
ck_tile::remove_cvref_t< Problem_ > Problem
Definition: block_fmha_fwd_v3_pipeline.hpp:211
static constexpr CK_TILE_DEVICE void s_waitcnt_vmcnt()
Definition: block_fmha_fwd_v3_pipeline.hpp:324
static constexpr CK_TILE_DEVICE auto make_lds_tile_window(void *base, const Descriptor &desc)
Definition: block_fmha_fwd_v3_pipeline.hpp:303
ck_tile::remove_cvref_t< typename Problem::QDataType > QDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:213
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_block_window_tmp, FmhaMask mask, float scale_s, void *smem_ptr) const
Definition: block_fmha_fwd_v3_pipeline.hpp:1171
ck_tile::remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:218
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: block_fmha_fwd_v3_pipeline.hpp:269
ck_tile::remove_cvref_t< typename Problem::SaccDataType > SaccDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:216
ck_tile::remove_cvref_t< Policy_ > Policy
Definition: block_fmha_fwd_v3_pipeline.hpp:212
ck_tile::remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition: block_fmha_fwd_v3_pipeline.hpp:222
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const QElementFunction &q_element_func, const KDramBlockWindowTmp &k_dram_block_window_tmp, [[maybe_unused]] const KElementFunction &k_element_func, const VDramBlockWindowTmp &v_dram_block_window_tmp, [[maybe_unused]] const VElementFunction &v_element_func, LSEDramBlockWindowTmp &lse_dram_window_tmp, const LSEElementFunction &lse_element_func, [[maybe_unused]] const SAccElementFunction &s_acc_element_func, const PComputeElementFunction &p_compute_element_func, const OAccElementFunction &o_acc_element_func, FmhaMask mask, float scale_s, void *smem_ptr) const
Definition: block_fmha_fwd_v3_pipeline.hpp:346
ck_tile::remove_cvref_t< typename Problem::VDataType > VDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:215
static constexpr CK_TILE_DEVICE auto MakeSimpleLdsDesc()
Definition: block_fmha_fwd_v3_pipeline.hpp:279
static constexpr CK_TILE_DEVICE void schedule(ck_tile::number< WaveGroup >, ck_tile::number< Phase >)
Definition: block_fmha_fwd_v3_pipeline.hpp:103
static constexpr CK_TILE_DEVICE void schedule(ck_tile::number< WaveGroup >, ck_tile::number< Phase >)
Definition: block_fmha_fwd_v3_pipeline.hpp:45
Definition: block_fmha_fwd_v3_pipeline.hpp:39
Definition: integral_constant.hpp:13
Definition: functional.hpp:86
Definition: numeric.hpp:18
Definition: coordinate_transform.hpp:1392
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
#define C_LOG2E
Definition: math.hpp:469