10 #define ENABLE_ASM_MARKER 1 
   12 #define ASM_MARKER(marker)               \ 
   13     __builtin_amdgcn_sched_barrier(0);   \ 
   14     asm volatile("; [POYENC] " #marker); \
 
   15     __builtin_amdgcn_sched_barrier(0);
 
   17 #define ASM_MARKER(marker) 
   20 #define ADD_SBARRIER_FOR_PHASE0 1 
   21 #if !defined(CK_TILE_DISABLE_PACKED_FP32) 
   22 #define CK_TILE_DISABLE_PACKED_FP32 0 
   28 #define ENABLE_DEBUG_STMTS 1 
   29 #if ENABLE_DEBUG_STMTS 
   31     if(get_block_1d_id() == 0 && get_warp_id() == WARP_ID && get_lane_id() == LANE_ID) 
   33 #define DEBUG_STMTS if constexpr(false) 
   38 template <
typename PipelineProblem, 
bool kIsMasking>
 
   41 template <
typename PipelineProblem>
 
   44     template <ck_tile::index_t WaveGroup, ck_tile::index_t Phase>
 
   50         if constexpr(WaveGroup == 0)
 
   52             if constexpr(Phase == 0)
 
   55                     __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
   56                     __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); 
 
   57                     __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); 
 
   60             else if constexpr(Phase == 1)
 
   62                 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); 
 
   63                 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); 
 
   65             else if constexpr(Phase == 2)
 
   67 #if !CK_TILE_DISABLE_PACKED_FP32 
   68                 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); 
 
   71                     __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
   72                     __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); 
 
   75             else if constexpr(Phase == 3)
 
   77                 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); 
 
   78                 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); 
 
   83             if constexpr(Phase == 0)
 
   85                 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); 
 
   86                 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); 
 
   88             else if constexpr(Phase == 1)
 
   91                     __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
   92                     __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); 
 
   93                     __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); 
 
   96             else if constexpr(Phase == 2)
 
   98                 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); 
 
   99                 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); 
 
  101             else if constexpr(Phase == 3)
 
  103 #if !CK_TILE_DISABLE_PACKED_FP32 
  104                 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); 
 
  107                     __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  108                     __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); 
 
  115 template <
typename PipelineProblem>
 
  118     template <ck_tile::index_t WaveGroup, ck_tile::index_t Phase>
 
  124         if constexpr(WaveGroup == 0)
 
  126             if constexpr(Phase == 0)
 
  129                     __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  130                     __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); 
 
  131                     __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); 
 
  134             else if constexpr(Phase == 1)
 
  136                 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); 
 
  137                 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); 
 
  139             else if constexpr(Phase == 2)
 
  141 #if !CK_TILE_DISABLE_PACKED_FP32 
  142                 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); 
 
  145                     __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  146                     __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); 
 
  149             else if constexpr(Phase == 3)
 
  151                 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); 
 
  152                 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); 
 
  157             if constexpr(Phase == 0)
 
  159                 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); 
 
  160                 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); 
 
  162             else if constexpr(Phase == 1)
 
  165                     __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  166                     __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); 
 
  167                     __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); 
 
  170             else if constexpr(Phase == 2)
 
  172                 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); 
 
  173                 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); 
 
  175             else if constexpr(Phase == 3)
 
  177 #if !CK_TILE_DISABLE_PACKED_FP32 
  178                 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); 
 
  181                     __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  182                     __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); 
 
  192 #if CK_TILE_DISABLE_PACKED_FP32 
  196     asm volatile(
"v_fma_f32 %[result], %[a], %[b], %[c]" 
  197                  : [result] 
"=v"(result)
 
  198                  : [
a] 
"v"(
a), [b] 
"s"(b), [c] 
"v"(c));
 
  206     asm volatile(
"v_add_f32_e32 %[result], %[lhs], %[rhs]" 
  207                  : [result] 
"=v"(result)
 
  208                  : [lhs] 
"v"(lhs), [rhs] 
"v"(rhs));
 
  215     asm volatile(
"v_mul_f32_e32 %[result], %[lhs], %[rhs]" 
  216                  : [result] 
"=v"(result)
 
  217                  : [lhs] 
"v"(lhs), [rhs] 
"v"(rhs));
 
  224     asm volatile(
"v_cvt_pk_f16_f32 %[result], %[a], %[b]" 
  225                  : [result] 
"=v"(result)
 
  226                  : [
a] 
"v"(
a), [b] 
"v"(b));
 
  233     asm volatile(
"v_cvt_pk_bf16_f32 %[result], %[a], %[b]" 
  234                  : [result] 
"=v"(result)
 
  235                  : [
a] 
"v"(
a), [b] 
"v"(b));
 
  242     asm volatile(
"v_pk_mul_f32 %[result], %[lhs], %[rhs]" 
  243                  : [result] 
"=v"(result)
 
  244                  : [lhs] 
"v"(lhs), [rhs] 
"v"(rhs));
 
  249 template <
typename Problem_, 
typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
 
  265     static_assert(std::is_same_v<SaccDataType, SMPLComputeDataType>,
 
  266                   "we will the same dist tensor 'sp_compute' for both gemm0 & softmax");
 
  280     static_assert(kSubQKHeaddim <= 256, 
"hdim bigger than 256 is not suitable for this pipeline!");
 
  282     static constexpr 
bool kIsGroupMode = Problem::kIsGroupMode;
 
  283     static constexpr 
bool kPadSeqLenQ  = Problem::kPadSeqLenQ;
 
  284     static constexpr 
bool kPadSeqLenK  = Problem::kPadSeqLenK;
 
  285     static constexpr 
bool kPadHeadDimQ = Problem::kPadHeadDimQ;
 
  286     static constexpr 
bool kPadHeadDimV = Problem::kPadHeadDimV;
 
  287     static constexpr 
bool kStoreLSE    = Problem::kStoreLSE;
 
  292         kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
 
  294         kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
 
  296         kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
 
  299         kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
 
  302         if constexpr(Problem::kBlockPerCu != -1)
 
  303             return Problem::kBlockPerCu;
 
  314                             Policy::template GetSmemSize<Problem>() +
 
  319     template <ck_tile::index_t MPerBlock, ck_tile::index_t NPerBlock>
 
  323         constexpr 
auto lds_block_desc =
 
  329         return lds_block_desc;
 
  333     template <ck_tile::index_t MPerBlock>
 
  340         return lds_block_desc;
 
  343     template <
typename DataType, 
typename Descriptor>
 
  349             make_tensor_view<address_space_enum::lds>(
reinterpret_cast<DataType*
>(base), desc);
 
  354     template <u
int16_t Vmcnt, u
int8_t Lgkmcnt, u
int8_t Expcnt = 7>
 
  360         __builtin_amdgcn_s_waitcnt((((0b110000 & Vmcnt) << (14 - 4)) | (0b1111 & Vmcnt)) |
 
  361                                    ((0b111 & Expcnt) << 4) | ((0b1111 & Lgkmcnt) << 8));
 
  364     template <u
int16_t Vmcnt>
 
  367         s_waitcnt<Vmcnt, 15>();
 
  370     template <u
int8_t Lgkmcnt>
 
  373         s_waitcnt<63, Lgkmcnt>();
 
  376     template <
typename QDramBlockWindowTmp,
 
  377               typename KDramBlockWindowTmp,
 
  378               typename VDramBlockWindowTmp,
 
  379               typename LSEDramBlockWindowTmp,
 
  380               typename QElementFunction,
 
  381               typename KElementFunction,
 
  382               typename VElementFunction,
 
  383               typename LSEElementFunction,
 
  384               typename SAccElementFunction,
 
  385               typename PComputeElementFunction,
 
  386               typename OAccElementFunction>
 
  388                                    const QElementFunction& q_element_func,
 
  389                                    const KDramBlockWindowTmp& k_dram_block_window_tmp, 
 
  390                                    [[maybe_unused]] 
const KElementFunction& k_element_func,
 
  391                                    const VDramBlockWindowTmp& v_dram_block_window_tmp, 
 
  392                                    [[maybe_unused]] 
const VElementFunction& v_element_func,
 
  393                                    LSEDramBlockWindowTmp& lse_dram_window_tmp, 
 
  394                                    const LSEElementFunction& lse_element_func,
 
  395                                    [[maybe_unused]] 
const SAccElementFunction& s_acc_element_func,
 
  396                                    const PComputeElementFunction& p_compute_element_func,
 
  397                                    const OAccElementFunction& o_acc_element_func,
 
  400                                    void* smem_ptr)
 const 
  410         static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[
number<0>{}] &&
 
  411                           kN0 == KDramBlockWindowTmp{}.get_window_lengths()[
number<0>{}] &&
 
  412                           kK0 == KDramBlockWindowTmp{}.get_window_lengths()[
number<1>{}] &&
 
  413                           kK1 == VDramBlockWindowTmp{}.get_window_lengths()[
number<0>{}] &&
 
  414                           kN1 == VDramBlockWindowTmp{}.get_window_lengths()[
number<1>{}],
 
  417         static_assert(
sizeof(
SaccDataType) * kM0 * kN0 <= GetSmemSize());
 
  418         auto s_lds = make_tensor_view<address_space_enum::lds>(
 
  419             reinterpret_cast<SaccDataType*
>(
static_cast<char*
>(smem_ptr)),
 
  420             MakeSimpleLdsDesc<kM0, kN0>());
 
  421         [[maybe_unused]] 
auto s_lds_window =
 
  424         auto p_lds = make_tensor_view<address_space_enum::lds>(
 
  425             reinterpret_cast<PDataType*
>(
static_cast<char*
>(smem_ptr) +
 
  426                                          Policy::template GetSmemSize<Problem>()),
 
  427             MakeSimpleLdsDesc<kM0, kN0>());
 
  428         [[maybe_unused]] 
auto p_lds_window =
 
  431         auto o_lds = make_tensor_view<address_space_enum::lds>(
 
  432             reinterpret_cast<PDataType*
>(
static_cast<char*
>(smem_ptr)),
 
  433             MakeSimpleLdsDesc<kM0, kN1>());
 
  434         [[maybe_unused]] 
auto o_lds_window =
 
  437         auto m_lds = make_tensor_view<address_space_enum::lds>(
 
  439                                                    Policy::template GetSmemSize<Problem>()),
 
  440             MakeSimpleLdsDesc1D<kM0>());
 
  441         [[maybe_unused]] 
auto m_lds_window =
 
  444         const index_t warp_group_id = get_warp_id() / 4;
 
  447         constexpr 
auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
 
  448         constexpr 
auto gemm_1 = Policy::template GetPVBlockGemm<Problem>();
 
  451             q_dram_block_window_tmp, Policy::template MakeQRegTileDistribution<Problem>());
 
  454         const auto f_max = [](
auto e0, 
auto e1) { 
return max(e0, e1); };
 
  455         const auto f_sum = [](
auto e0, 
auto e1) { 
return e0 + e1; };
 
  459                 return make_lds_tile_window<KDataType>(
 
  460                     smem_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf));
 
  466                 return make_lds_tile_window<KDataType>(
 
  467                     smem_ptr, Policy::template MakeVLdsStoreBlockDescriptor<Problem>(i_buf));
 
  472                                      make_lds_tile_window<KDataType>(
 
  474                                          Policy::template MakeKLdsLoadBlockDescriptor<Problem>()),
 
  475                                      Policy::template MakeKRegTileDistribution<Problem>())),
 
  480                                      make_lds_tile_window<VDataType>(
 
  482                                          Policy::template MakeVLdsLoadBlockDescriptor<Problem>()),
 
  483                                      Policy::template MakeVRegTileDistribution<Problem>())),
 
  487         decltype(make_static_distributed_tensor<QDataType>(
 
  488             Policy::template MakeQRegTileDistribution<Problem>())) q_tile;
 
  499         union sp_compute_type
 
  503             decltype(gemm_0.MakeCBlockTile()) sp_compute;
 
  504             decltype(make_static_distributed_tensor<PDataType>(
 
  505                 Policy::template MakePRegTileDistribution<Problem>())) p;
 
  509         decltype(gemm_1.MakeCBlockTile()) o_acc;
 
  510         constexpr 
index_t fmha_alu_D_reg_cnt = 6; 
 
  512         static_assert(fmha_alu_D_reg_cnt <= o_acc.thread_buf_.size());
 
  514         decltype(block_tile_reduce<SMPLComputeDataType>(
 
  521                 make_lds_tile_window<KDataType>(
 
  522                     static_cast<char*
>(smem_ptr) + (idx)*Policy::template GetSmemSizeKV<Problem>(),
 
  523                     Policy::template MakeKLdsLoadBlockDescriptor<Problem>()),
 
  524                 Policy::template MakeKRegTileDistribution<Problem>());
 
  528             v_lds_window_load(idx) =
 
  530                                      static_cast<char*
>(smem_ptr) +
 
  531                                          (idx + 2) * Policy::template GetSmemSizeKV<Problem>(),
 
  532                                      Policy::template MakeVLdsLoadBlockDescriptor<Problem>()),
 
  533                                  Policy::template MakeVRegTileDistribution<Problem>());
 
  537             auto origin_q      = 
load_tile(q_dram_window);
 
  540             q_tile = transformed_q;
 
  544         set_tile(m, bit_cast<float>(0xff7fffff)); 
 
  547         const auto q_origin = q_dram_window.get_window_origin();
 
  548         const auto [seqlen_k_start, seqlen_k_end] =
 
  552         index_t kv_token_start    = seqlen_k_start;
 
  555         if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
 
  557             if(num_total_loop <= 0)
 
  559                 if constexpr(kStoreLSE)
 
  562                         make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
 
  577                              k_dram_block_window_tmp.get_window_lengths(),
 
  579                              Policy::template MakeKDramTileDistribution<Problem>());
 
  580         k_dram_window.init_raw();
 
  584                              v_dram_block_window_tmp.get_window_lengths(),
 
  586                              Policy::template MakeVDramTileDistribution<Problem>());
 
  587         v_dram_window.init_raw();
 
  591         constexpr 
index_t k0_loops = kQKHeaddim / kK0;
 
  592         constexpr 
index_t k1_loops = kN0 / kK1;
 
  593         static_assert(1 == k0_loops);
 
  594         static_assert(1 == k1_loops);
 
  595         static_assert(kN0 == kK1);
 
  597         constexpr 
index_t NumWarpGroups = Problem::kBlockSize / Policy::NumThreadPerWarpGroup;
 
  598         static_assert(NumWarpGroups == 2);
 
  600         [[maybe_unused]] 
auto print_dist_tensor = [&](
const auto& dist_tensor, 
const char* name) {
 
  601             printf(
"[POYENC] %s (size=%d): %5.2f",
 
  603                    decltype(dist_tensor.thread_buf_)::size(),
 
  604                    ck_tile::type_convert<float>(dist_tensor.thread_buf_[0]));
 
  605             static_for<1, decltype(dist_tensor.thread_buf_)::size(), 1>{}([&](
auto i) {
 
  606                 printf(
", %5.2f", ck_tile::type_convert<float>(dist_tensor.thread_buf_[i]));
 
  611         [[maybe_unused]] 
auto print_lds = [&](
auto lds_tile_window, 
const char* name) {
 
  612             const auto num_rows = lds_tile_window.get_window_lengths().at(
number<0>{});
 
  613             const auto num_cols = lds_tile_window.get_window_lengths().at(
number<1>{});
 
  615             auto desc = lds_tile_window.get_bottom_tensor_view().desc_;
 
  616             auto data = lds_tile_window.get_bottom_tensor_view().buf_.p_data_;
 
  618             if constexpr(
true || num_rows < num_cols)
 
  620                 for(
int row = 0; row < num_rows; ++row)
 
  623                     printf(
"[DEVICE] %s[%3d] = %5.2f",
 
  626                            ck_tile::type_convert<float>(data[
offset]));
 
  627                     for(
int col = 1; col < num_cols; ++col)
 
  631                         printf(
"%5.2f", ck_tile::type_convert<float>(data[
offset]));
 
  638                 for(
int col = 0; col < num_cols; ++col)
 
  641                     printf(
"[DEVICE] %s[%3d] = %5.2f",
 
  644                            ck_tile::type_convert<float>(data[
offset]));
 
  645                     for(
int row = 1; row < num_rows; ++row)
 
  649                         printf(
"%5.2f", ck_tile::type_convert<float>(data[
offset]));
 
  656         [[maybe_unused]] 
auto print_lds_1d = [&](
auto lds_tile_window, 
const char* name) {
 
  657             const auto num_elems = lds_tile_window.get_window_lengths().at(
number<0>{});
 
  659             auto desc = lds_tile_window.get_bottom_tensor_view().desc_;
 
  660             auto data = lds_tile_window.get_bottom_tensor_view().buf_.p_data_;
 
  663             printf(
"[DEVICE] %s = %5.2f", name, ck_tile::type_convert<float>(data[
offset]));
 
  664             for(
int e = 1; e < num_elems; ++e)
 
  668                 printf(
"%5.2f", ck_tile::type_convert<float>(data[
offset]));
 
  675         constexpr 
int K_mem_su_ld_insts = k_dram_window.get_num_of_access();
 
  676         constexpr 
int V_mem_su_ld_insts = v_dram_window.get_num_of_access();
 
  678         auto K_mem_load = [&](
auto k_lds_write_idx) {
 
  686         auto K_lds_load = [&](
auto k_lds_read_idx) {
 
  687             kv_tile.k_tile = 
load_tile(k_lds_window_load(k_lds_read_idx));
 
  690         auto V_mem_load = [&](
auto v_lds_write_idx) {
 
  697         auto V_lds_load = [&](
auto v_lds_read_idx) {
 
  706         auto fmha_alu0 = [&](
auto sp_reg_idx) {
 
  708             static_assert(m.thread_buf_.size() == 1,
 
  709                           "assuming that each thread holds 1 rowmax value");
 
  710             auto m_latest = block_tile_reduce<SMPLComputeDataType>(
 
  711                 sp(sp_reg_idx).sp_compute, 
sequence<1>{}, f_max, m.thread_buf_[0]);
 
  712 #if defined(__gfx950__) 
  715                 __builtin_amdgcn_permlane32_swap(bit_cast<int32_t>(m_latest.thread_buf_[0]),
 
  716                                                  bit_cast<int32_t>(m_latest.thread_buf_[0]),
 
  720             m_latest.thread_buf_[0] = f_max(bit_cast<SMPLComputeDataType>(swapped_regs.x),
 
  721                                             bit_cast<SMPLComputeDataType>(swapped_regs.y));
 
  727             constexpr 
auto p_spans =
 
  728                 std::decay_t<decltype(sp(sp_reg_idx).sp_compute)>::get_distributed_spans();
 
  731                     constexpr 
auto i_j_idx        = 
make_tuple(idx0, idx1);
 
  733                         sp(sp_reg_idx).sp_compute(i_j_idx), scale_s, -scale_s * m(i_j_idx));
 
  739         auto fmha_alu1 = [&](
auto sp_reg_idx) {
 
  740             constexpr 
auto p_spans =
 
  741                 std::decay_t<decltype(sp(sp_reg_idx).sp_compute)>::get_distributed_spans();
 
  744                     constexpr 
auto i_j_idx = 
make_tuple(idx0, idx1);
 
  745                     sp(sp_reg_idx).sp_compute(i_j_idx) =
 
  750             auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
 
  751                 sp(sp_reg_idx).sp_compute,
 
  755             static_assert(rowsum_p.thread_buf_.size() == 1,
 
  756                           "assuming that each thread holds 1 rowsum value");
 
  757 #if defined(__gfx950__) 
  760                 __builtin_amdgcn_permlane32_swap(bit_cast<int32_t>(rowsum_p.thread_buf_[0]),
 
  761                                                  bit_cast<int32_t>(rowsum_p.thread_buf_[0]),
 
  764             rowsum_p.thread_buf_[0] = f_sum(bit_cast<SMPLComputeDataType>(swapped_regs.x),
 
  765                                             bit_cast<SMPLComputeDataType>(swapped_regs.y));
 
  775             constexpr 
auto o_spans = decltype(o_acc)::get_distributed_spans();
 
  778                 const auto tmp       = 
ck_tile::exp2(scale_s * (m_old[i_idx] - m[i_idx]));
 
  792             static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 2 == 0);
 
  793             static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 2>{}([&](
auto idx) {
 
  794                 float x = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx]);
 
  795                 float y = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1]);
 
  796                 if constexpr(std::is_same_v<PDataType, fp16_t>)
 
  799                     sp(sp_reg_idx).p.thread_buf_[idx]     = casted.x;
 
  800                     sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y;
 
  805                     sp(sp_reg_idx).p.thread_buf_[idx]     = casted.x;
 
  806                     sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y;
 
  815         auto gemm = [&](
auto sp_reg_idx, 
auto gemm_idx) {
 
  816             if constexpr(gemm_idx == 0)
 
  819                 gemm_0(sp(sp_reg_idx).sp_compute,
 
  821                                       sequence<0, (k0_loops - 1) * kK0>{},
 
  824                                       sequence<0, (k0_loops - 1) * kK0>{},
 
  831                                       sequence<0, (k1_loops - 1) * kK1>{},
 
  834                                       sequence<0, (k1_loops - 1) * kK1>{},
 
  839         auto cl_calc = [&](
auto sp_reg_idx, 
auto gemm_idx) {
 
  840             if constexpr(gemm_idx == 0)
 
  843                 gemm_0(sp(sp_reg_idx).sp_compute,
 
  845                                       sequence<0, (k0_loops - 1) * kK0>{},
 
  848                                       sequence<0, (k0_loops - 1) * kK0>{},
 
  855                                       sequence<0, (k1_loops - 1) * kK1>{},
 
  858                                       sequence<0, (k1_loops - 1) * kK1>{},
 
  864         auto fmha_alu_D_upd = [&] {
 
  865             o_acc_scale = 
ck_tile::exp2(scale_s * (m_old.thread_buf_[0] - m.thread_buf_[0]));
 
  868             pk_o_acc_scale.x = o_acc_scale;
 
  869             pk_o_acc_scale.y = o_acc_scale;
 
  871             static_assert((o_acc.thread_buf_.size() - fmha_alu_D_reg_cnt) % 2 == 0);
 
  872 #if CK_TILE_DISABLE_PACKED_FP32 
  873             static_assert(fmha_alu_D_reg_cnt + 2 <= o_acc.thread_buf_.size());
 
  875                 [&](
auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; });
 
  878             constexpr 
auto issued_D_reg_cnt =
 
  879 #if CK_TILE_DISABLE_PACKED_FP32 
  880                 fmha_alu_D_reg_cnt + 2
 
  888             static_for<issued_D_reg_cnt, o_acc.thread_buf_.size(), 2>{}([&](
auto idx) {
 
  890                 input.x = o_acc.thread_buf_[idx];
 
  891                 input.y = o_acc.thread_buf_[idx + 1];
 
  895                 o_acc.thread_buf_[idx]     = output.x;
 
  896                 o_acc.thread_buf_[idx + 1] = output.y;
 
  900         auto fmha_mask = [&](
auto sp_reg_idx) {
 
  901             if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
 
  903                 bool need_perpixel_check = mask.IsEdgeTile(
 
  905                 if(need_perpixel_check)
 
  912                                     const auto col = kv_token_start + tile_idx.at(
number<1>{});
 
  913                                     return mask.IsOutOfBound(row, col);
 
  919         auto cl_load = [&](
auto load_type, 
auto mem_wr_idx, 
auto lds_rd_idx) {
 
  920             if constexpr(load_type == 0)
 
  922                 V_mem_load(mem_wr_idx);
 
  923                 K_lds_load(lds_rd_idx);
 
  927                 K_mem_load(mem_wr_idx);
 
  928                 V_lds_load(lds_rd_idx);
 
  932         auto core_loop = [&](
auto cl_p) {
 
  941             auto iteration = [&](
auto pi) {
 
  942                 auto xdl_SP_p01_reg_idx = 
number<1>{} - pi;
 
  943                 auto xdl_SP_p23_reg_idx = pi;
 
  946                 auto V_w0_lds_wr_idx = pi;
 
  947                 auto K_w0_lds_rd_idx = pi;
 
  948                 auto V_w0_lds_rd_idx = pi;
 
  953                 auto V_w4_lds_rd_idx = pi;
 
  957                 if constexpr(cl_p == 0)
 
  959 #if ADD_SBARRIER_FOR_PHASE0 
  960                     __builtin_amdgcn_sched_barrier(0);
 
  961                     __builtin_amdgcn_s_barrier();
 
  963                     __builtin_amdgcn_sched_barrier(0);
 
  965                     if constexpr(pi == 0)
 
  973                     s_waitcnt_lgkmcnt<0>();
 
  974                     __builtin_amdgcn_sched_barrier(0);
 
  975                     cl_calc(xdl_SP_p01_reg_idx, gemm0);
 
  976                     fmha_alu1(xdl_SP_p23_reg_idx);
 
  979                     __builtin_amdgcn_sched_barrier(0);
 
  982                     s_waitcnt_vmcnt<K_mem_su_ld_insts + V_mem_su_ld_insts>();
 
  983                     __builtin_amdgcn_sched_barrier(0);
 
  984                     __builtin_amdgcn_s_barrier();
 
  985                     __builtin_amdgcn_sched_barrier(0);
 
  986                     cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx);
 
  988                     fmha_mask(xdl_SP_p01_reg_idx);
 
  990                     __builtin_amdgcn_sched_barrier(0);
 
  993                     s_waitcnt_lgkmcnt<0>();
 
  994                     __builtin_amdgcn_sched_barrier(0);
 
  995                     __builtin_amdgcn_s_barrier();
 
  996                     __builtin_amdgcn_sched_barrier(0);
 
  997                     asm volatile(
"s_nop 0");
 
  998                     __builtin_amdgcn_sched_barrier(0);
 
  999                     cl_calc(xdl_SP_p23_reg_idx, gemm1);
 
 1002                     __builtin_amdgcn_sched_barrier(0);
 
 1005                     __builtin_amdgcn_sched_barrier(0);
 
 1008                     s_waitcnt_vmcnt<K_mem_su_ld_insts + V_mem_su_ld_insts>();
 
 1009                     __builtin_amdgcn_sched_barrier(0);
 
 1010                     __builtin_amdgcn_s_barrier();
 
 1011                     __builtin_amdgcn_sched_barrier(0);
 
 1012                     cl_load(memV, V_w0_lds_wr_idx, K_w0_lds_rd_idx);
 
 1015                     kv_token_start += kN0;
 
 1016                     if(num_total_loop <= ++i_total_loops)
 
 1023 #if ADD_SBARRIER_FOR_PHASE0 
 1024                     __builtin_amdgcn_sched_barrier(0);
 
 1025                     __builtin_amdgcn_s_barrier();
 
 1027                     __builtin_amdgcn_sched_barrier(0);
 
 1029                     if constexpr(pi == 0)
 
 1037                     cl_load(memV, V_w4_lds_wr_idx, K_w4_lds_rd_idx);
 
 1040                     __builtin_amdgcn_sched_barrier(0);
 
 1043                     s_waitcnt<K_mem_su_ld_insts + V_mem_su_ld_insts, 0>();
 
 1044                     __builtin_amdgcn_sched_barrier(0);
 
 1045                     __builtin_amdgcn_s_barrier();
 
 1046                     __builtin_amdgcn_sched_barrier(0);
 
 1047                     asm volatile(
"s_nop 1");
 
 1048                     __builtin_amdgcn_sched_barrier(0);
 
 1049                     cl_calc(xdl_SP_p01_reg_idx, gemm0);
 
 1050                     fmha_alu1(xdl_SP_p23_reg_idx);
 
 1053                     __builtin_amdgcn_sched_barrier(0);
 
 1056                     __builtin_amdgcn_s_barrier();
 
 1057                     __builtin_amdgcn_sched_barrier(0);
 
 1058                     cl_load(memK, K_w4_lds_wr_idx, V_w4_lds_rd_idx);
 
 1060                     fmha_mask(xdl_SP_p01_reg_idx);
 
 1062                     kv_token_start += kN0;
 
 1063                     if(num_total_loop <= ++i_total_loops)
 
 1068                     __builtin_amdgcn_sched_barrier(0);
 
 1071                     s_waitcnt<K_mem_su_ld_insts + V_mem_su_ld_insts, 0>();
 
 1072                     __builtin_amdgcn_sched_barrier(0);
 
 1073                     __builtin_amdgcn_s_barrier();
 
 1074                     __builtin_amdgcn_sched_barrier(0);
 
 1075                     asm volatile(
"s_nop 1");
 
 1076                     __builtin_amdgcn_sched_barrier(0);
 
 1077                     cl_calc(xdl_SP_p23_reg_idx, gemm1);
 
 1080                     __builtin_amdgcn_sched_barrier(0);
 
 1088         auto fmha_post_process = [&](
auto d) {
 
 1090             auto V_lds_rd_idx = ps_pi;
 
 1092             if(1 < num_total_loop)
 
 1094                 s_waitcnt_vmcnt<K_mem_su_ld_insts>();
 
 1098                 s_waitcnt_vmcnt<0>();
 
 1100             __builtin_amdgcn_s_barrier();
 
 1102             V_lds_load(V_lds_rd_idx);
 
 1105             s_waitcnt_lgkmcnt<0>();
 
 1107             auto xdl_SP_p23_reg_idx = ps_pi;
 
 1117             s_waitcnt_vmcnt<0>();
 
 1118             __builtin_amdgcn_s_barrier();
 
 1122             s_waitcnt_lgkmcnt<0>();
 
 1123             __builtin_amdgcn_s_barrier();
 
 1126             if(1 < num_total_loop)
 
 1140             kv_token_start += kN0;
 
 1142             if(num_total_loop <= i_total_loops)
 
 1144                 goto label_main_loops_exit;
 
 1147             if(2 < num_total_loop)
 
 1151                 s_waitcnt_vmcnt<K_mem_su_ld_insts + V_mem_su_ld_insts>();
 
 1152                 __builtin_amdgcn_s_barrier();
 
 1158         if(1 < num_total_loop)
 
 1160             if(warp_group_id == 0)
 
 1165                 __builtin_amdgcn_s_setprio(0);
 
 1166                 __builtin_amdgcn_s_barrier();
 
 1170             if(warp_group_id != 0)
 
 1172                 __builtin_amdgcn_s_setprio(1);
 
 1173                 __builtin_amdgcn_s_barrier();
 
 1178     label_main_loops_exit:
 
 1179         if(num_total_loop % 2)
 
 1183         if(!(num_total_loop % 2))
 
 1189         if constexpr(kStoreLSE)
 
 1191             auto lse = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
 
 1193             constexpr 
auto lse_spans = decltype(lse)::get_distributed_spans();
 
 1196                 lse(i_idx)           = m[i_idx] / 
C_LOG2E + 
log(l[i_idx]);
 
 1203         constexpr 
auto o_spans = decltype(o_acc)::get_distributed_spans();
 
 1207             const auto tmp       = [&]() {
 
 1208                 if constexpr(FmhaMask::IsMasking)
 
 1210                     return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
 
 1213                     return 1 / l[i_idx];
 
 1216                 constexpr 
auto i_j_idx = 
make_tuple(idx0, idx1);
 
 1217                 o_acc(i_j_idx) *= tmp;
 
 1226     template <
typename QDramBlockWindowTmp,
 
 1227               typename KDramBlockWindowTmp,
 
 1228               typename VDramBlockWindowTmp,
 
 1229               typename LSEDramBlockWindowTmp>
 
 1231                                    const KDramBlockWindowTmp& k_dram_block_window_tmp, 
 
 1232                                    const VDramBlockWindowTmp& v_dram_block_window_tmp, 
 
 1233                                    LSEDramBlockWindowTmp& lse_dram_block_window_tmp,   
 
 1236                                    void* smem_ptr)
 const 
 1240         return operator()(q_dram_block_window_tmp,
 
 1242                           k_dram_block_window_tmp,
 
 1244                           v_dram_block_window_tmp,
 
 1246                           lse_dram_block_window_tmp,
 
#define ASM_MARKER(marker)
Definition: block_fmha_fwd_v3_pipeline.hpp:12
 
#define CK_TILE_DEVICE
Definition: config.hpp:41
 
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
 
CK_TILE_DEVICE bf16x2_t cvt_pk_bf16_f32(float a, float b)
Definition: block_fmha_fwd_v3_pipeline.hpp:230
 
CK_TILE_DEVICE float fma_impl_vsv(float a, float b, float c)
Definition: block_fmha_fwd_v3_pipeline.hpp:190
 
CK_TILE_DEVICE float mul_impl_vv(float lhs, float rhs)
Definition: block_fmha_fwd_v3_pipeline.hpp:212
 
CK_TILE_DEVICE float add_impl_vv(float lhs, float rhs)
Definition: block_fmha_fwd_v3_pipeline.hpp:203
 
CK_TILE_DEVICE fp16x2_t cvt_pk_fp16_f32(float a, float b)
Definition: block_fmha_fwd_v3_pipeline.hpp:221
 
CK_TILE_DEVICE fp32x2_t pk_mul_f32(fp32x2_t lhs, fp32x2_t rhs)
Definition: block_fmha_fwd_v3_pipeline.hpp:239
 
Definition: cluster_descriptor.hpp:13
 
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition: bfloat16.hpp:428
 
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:274
 
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
 
constexpr CK_TILE_DEVICE auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition: slice_tile.hpp:23
 
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition: tile_elementwise.hpp:40
 
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:95
 
bfloat16_t bf16x2_t
Definition: pk_fp4.hpp:24
 
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={}, bool_constant< CrossWarp >={})
Definition: block_reduce.hpp:21
 
tuple_array< T, N > statically_indexed_array
Definition: statically_indexed_array.hpp:16
 
float fp32x2_t
Definition: pk_fp4.hpp:22
 
int32_t index_t
Definition: integer.hpp:9
 
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
 
_Float16 fp16x2_t
Definition: half.hpp:385
 
CK_TILE_DEVICE auto load_tile_transpose(const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window)
transpose loads tile from a tensor and returns the resulting tensor with a new (transposed) tile dist...
Definition: load_tile_transpose.hpp:403
 
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
 
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
 
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
 
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition: static_distributed_tensor.hpp:175
 
constexpr CK_TILE_DEVICE auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition: tile_window_linear.hpp:993
 
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
 
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
 
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_ &&lds_tile, const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition: load_tile.hpp:133
 
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
 
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition: tile_elementwise.hpp:177
 
int32_t int32x2_t
Definition: vector_type.hpp:154
 
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
 
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
 
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition: bfloat16.hpp:425
 
constexpr bool is_same_v
Definition: type.hpp:283
 
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1517
 
Definition: block_fmha_fwd_v3_pipeline.hpp:251
 
ck_tile::remove_cvref_t< typename Problem::PDataType > PDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:260
 
ck_tile::remove_cvref_t< typename Problem::SMPLComputeDataType > SMPLComputeDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:258
 
ck_tile::remove_cvref_t< typename Problem::KDataType > KDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:255
 
ck_tile::remove_cvref_t< typename Problem::ODataType > ODataType
Definition: block_fmha_fwd_v3_pipeline.hpp:262
 
ck_tile::remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition: block_fmha_fwd_v3_pipeline.hpp:268
 
static constexpr CK_TILE_DEVICE void s_waitcnt_lgkmcnt()
Definition: block_fmha_fwd_v3_pipeline.hpp:371
 
ck_tile::remove_cvref_t< typename Problem::OaccDataType > OaccDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:261
 
static constexpr CK_TILE_DEVICE auto MakeSimpleLdsDesc1D()
Definition: block_fmha_fwd_v3_pipeline.hpp:334
 
static constexpr CK_TILE_DEVICE void s_waitcnt()
Definition: block_fmha_fwd_v3_pipeline.hpp:355
 
ck_tile::remove_cvref_t< Problem_ > Problem
Definition: block_fmha_fwd_v3_pipeline.hpp:252
 
static constexpr CK_TILE_DEVICE void s_waitcnt_vmcnt()
Definition: block_fmha_fwd_v3_pipeline.hpp:365
 
static constexpr CK_TILE_DEVICE auto make_lds_tile_window(void *base, const Descriptor &desc)
Definition: block_fmha_fwd_v3_pipeline.hpp:344
 
ck_tile::remove_cvref_t< typename Problem::QDataType > QDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:254
 
ck_tile::remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:259
 
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: block_fmha_fwd_v3_pipeline.hpp:310
 
ck_tile::remove_cvref_t< typename Problem::SaccDataType > SaccDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:257
 
ck_tile::remove_cvref_t< Policy_ > Policy
Definition: block_fmha_fwd_v3_pipeline.hpp:253
 
ck_tile::remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition: block_fmha_fwd_v3_pipeline.hpp:263
 
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const QElementFunction &q_element_func, const KDramBlockWindowTmp &k_dram_block_window_tmp, [[maybe_unused]] const KElementFunction &k_element_func, const VDramBlockWindowTmp &v_dram_block_window_tmp, [[maybe_unused]] const VElementFunction &v_element_func, LSEDramBlockWindowTmp &lse_dram_window_tmp, const LSEElementFunction &lse_element_func, [[maybe_unused]] const SAccElementFunction &s_acc_element_func, const PComputeElementFunction &p_compute_element_func, const OAccElementFunction &o_acc_element_func, FmhaMask mask, float scale_s, void *smem_ptr) const
Definition: block_fmha_fwd_v3_pipeline.hpp:387
 
ck_tile::remove_cvref_t< typename Problem::VDataType > VDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:256
 
static constexpr CK_TILE_DEVICE auto MakeSimpleLdsDesc()
Definition: block_fmha_fwd_v3_pipeline.hpp:320
 
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_block_window_tmp, FmhaMask mask, float scale_s, void *smem_ptr) const
Definition: block_fmha_fwd_v3_pipeline.hpp:1230
 
static constexpr CK_TILE_DEVICE void schedule(ck_tile::number< WaveGroup >, ck_tile::number< Phase >)
Definition: block_fmha_fwd_v3_pipeline.hpp:119
 
static constexpr CK_TILE_DEVICE void schedule(ck_tile::number< WaveGroup >, ck_tile::number< Phase >)
Definition: block_fmha_fwd_v3_pipeline.hpp:45
 
Definition: block_fmha_fwd_v3_pipeline.hpp:39
 
Definition: integral_constant.hpp:13
 
Definition: functional.hpp:86
 
Definition: numeric.hpp:18
 
Definition: coordinate_transform.hpp:1392
 
Definition: sequence.hpp:49
 
Definition: functional.hpp:43
 
Definition: tensor_view.hpp:41
 
#define C_LOG2E
Definition: math.hpp:469