14 template <
typename Problem_,
typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
36 static_assert(
kQLoadOnce == Policy::QLoadOnce);
48 static_assert(
kSubQKHeaddim <= 256,
"hdim bigger than 256 is not suitable for this pipeline!");
55 static constexpr
auto BiasEnum = Problem::BiasEnum;
56 static constexpr
bool kStoreLSE = Problem::kStoreLSE;
63 kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
65 kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
67 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
68 return Problem::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
70 return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
74 kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
76 kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
79 if constexpr(Problem::kBlockPerCu != -1)
80 return Problem::kBlockPerCu;
109 static constexpr
const char*
name =
"qr_async";
111 using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
115 return Policy::template GetSmemSize<Problem>();
118 template <
typename QDramBlockWindowTmp,
119 typename KDramBlockWindowTmp,
120 typename VDramBlockWindowTmp,
121 typename BiasDramBlockWindowTmp,
122 typename RandValDramBlockWindowTmp,
123 typename LSEDramBlockWindowTmp,
124 typename QElementFunction,
125 typename KElementFunction,
126 typename VElementFunction,
127 typename BiasElementFunction,
128 typename LSEElementFunction,
129 typename SAccElementFunction,
130 typename PComputeElementFunction,
131 typename OAccElementFunction,
132 typename PositionEncoding,
133 typename AttentionVariantParams,
134 typename BlockIndices>
136 operator()(
const QDramBlockWindowTmp& q_dram_block_window_tmp,
137 const QElementFunction& q_element_func,
138 const KDramBlockWindowTmp& k_dram_block_window_tmp,
139 const KElementFunction& k_element_func,
140 const VDramBlockWindowTmp& v_dram_block_window_tmp,
141 const VElementFunction& v_element_func,
142 const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
143 const BiasElementFunction& bias_element_func,
144 RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
145 LSEDramBlockWindowTmp& lse_dram_window_tmp,
146 const LSEElementFunction& lse_element_func,
147 const SAccElementFunction& s_acc_element_func,
148 const PComputeElementFunction& p_compute_element_func,
149 const OAccElementFunction& o_acc_element_func,
151 PositionEncoding position_encoding,
154 const AttentionVariantParams& ,
155 const BlockIndices& ,
168 static_assert(
kM0 == QDramBlockWindowTmp{}.get_window_lengths()[
number<0>{}] &&
169 kN0 == KDramBlockWindowTmp{}.get_window_lengths()[
number<0>{}] &&
170 kK0 == KDramBlockWindowTmp{}.get_window_lengths()[
number<1>{}] &&
171 kN1 == VDramBlockWindowTmp{}.get_window_lengths()[
number<0>{}] &&
172 kK1 == VDramBlockWindowTmp{}.get_window_lengths()[
number<1>{}] &&
173 kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[
number<0>{}] &&
174 kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[
number<1>{}],
182 static_assert(2 <= k0_loops);
183 static_assert(2 <= k1_loops);
185 constexpr
bool kPreloadWholeNextIterationK =
186 Policy::template IsPreloadWholeNextIterationK<Problem>();
188 constexpr
auto NumKLdsBuffers = Policy::template GetNumKLdsBuffers<Problem>();
189 constexpr
auto NumVLdsBuffers = Policy::template GetNumVLdsBuffers<Problem>();
190 constexpr
auto NumPrefetchV = Policy::template GetNumPrefetchV<Problem>();
192 static_assert(NumKLdsBuffers >= 2);
194 auto q_dram_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
195 q_dram_block_window_tmp.get_window_lengths(),
196 q_dram_block_window_tmp.get_window_origin(),
197 Policy::template MakeQRegTileDistribution<Problem>());
199 const auto q_origin = q_dram_window.get_window_origin();
200 const auto [seqlen_k_start, seqlen_k_end] =
203 auto k_dram_block_window =
205 k_dram_block_window_tmp.get_window_lengths(),
206 {seqlen_k_start, 0});
210 k_dram_block_window.get_window_lengths(),
211 k_dram_block_window.get_window_origin(),
212 Policy::template MakeKDramTileDistribution<Problem>());
214 using k_tile_type = decltype(
load_tile(k_dram_window));
216 auto k_tiles = [&]() {
217 if constexpr(kPreloadWholeNextIterationK)
228 __builtin_amdgcn_sched_barrier(0);
232 auto k_lds = make_tensor_view<address_space_enum::lds>(
233 k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
235 k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
237 using k_lds_window_type =
249 v_dram_block_window_tmp.get_window_lengths(),
251 Policy::template MakeVDramTileDistribution<Problem>());
253 auto v_lds = make_tensor_view<address_space_enum::lds>(
254 reinterpret_cast<VDataType*
>(
static_cast<char*
>(smem_ptr) +
255 Policy::template GetExclusiveKLdsBytes<Problem>()),
256 Policy::template MakeVLdsBlockDescriptor<Problem>());
258 v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
260 using v_tile_type = decltype(
load_tile(v_dram_window));
264 using v_lds_window_type =
275 constexpr
auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
276 constexpr
auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
278 using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
279 auto s_acc = SaccBlockTileType{};
282 const auto f_max = [](
auto e0,
auto e1) {
return max(e0, e1); };
283 const auto f_sum = [](
auto e0,
auto e1) {
return e0 + e1; };
286 using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_acc));
288 using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
291 using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
294 auto o_acc = OaccBlockTileType{};
295 auto m = MLBlockTileType{};
296 auto l = MLBlockTileType{};
307 if(num_total_loop <= 0)
312 make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
325 const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
326 auto bias_dram_window =
328 bias_dram_block_window_tmp.get_window_lengths(),
329 {bias_origin.at(number<0>{}), seqlen_k_start},
330 Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
332 auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
333 randval_dram_block_window_tmp, seqlen_k_start);
341 if constexpr(kPreloadWholeNextIterationK)
343 if(i_total_loops == 0)
345 if(num_total_loop > 1)
347 static_for<0, k0_loops - 1, 1>{}([&](
auto i_k0) {
349 k_lds_windows[number<i_k0 % NumKLdsBuffers>{}],
352 k_tiles[number<i_k0 + 1>{}] =
load_tile(k_dram_window);
353 if constexpr(i_k0 < k0_loops - 2)
356 if constexpr(i_k0 == 0)
363 sequence<0, i_k0 *
kK0>{},
364 sequence<
kM0, (i_k0 + 1) *
kK0>{}),
365 k_lds_windows[number<i_k0 % NumKLdsBuffers>{}]);
369 k_lds_windows[
number<(k0_loops - 1) % NumKLdsBuffers>{}],
379 static_for<0, k0_loops, 1>{}([&](
auto i_k0) {
380 k_tiles[number<i_k0>{}] =
load_tile(k_dram_window);
382 if constexpr(i_k0 < k0_loops - 1)
392 sequence<0, (k0_loops - 1) *
kK0>{},
393 sequence<kM0, k0_loops * kK0>{}),
394 k_lds_windows[
number<(k0_loops - 1) % NumKLdsBuffers>{}]);
398 static_for<0, k0_loops - 1, 1>{}([&](
auto i_k0) {
400 k_lds_windows[number<i_k0 % NumKLdsBuffers>{}],
403 k_tiles[number<i_k0 + 1>{}] =
load_tile(k_dram_window);
404 if constexpr(i_k0 < k0_loops - 2)
407 if constexpr(i_k0 == 0)
414 sequence<0, i_k0 *
kK0>{},
415 sequence<
kM0, (i_k0 + 1) *
kK0>{}),
416 k_lds_windows[number<i_k0 % NumKLdsBuffers>{}]);
420 k_lds_windows[
number<(k0_loops - 1) % NumKLdsBuffers>{}],
430 sequence<0, (k0_loops - 1) *
kK0>{},
431 sequence<kM0, k0_loops * kK0>{}),
432 k_lds_windows[
number<(k0_loops - 1) % NumKLdsBuffers>{}]);
439 if(i_total_loops < num_total_loop - 1)
464 if constexpr(1 < k0_loops - 1)
469 get_slice_tile(q_tile, sequence<0, kK0>{}, sequence<kM0, 2 * kK0>{}),
473 static_for<2, k0_loops, 1>{}([&](
auto i_k0) {
474 store_tile(k_lds_windows[number<i_k0 % NumKLdsBuffers>{}],
475 k_tiles[number<i_k0>{}]);
477 k_tiles[number<i_k0>{}] =
load_tile(k_dram_window);
478 if constexpr(i_k0 < k0_loops - 1)
484 sequence<0, i_k0 * kK0>{},
485 sequence<
kM0, (i_k0 + 1) *
kK0>{}),
486 k_lds_windows[number<i_k0 % NumKLdsBuffers>{}]);
506 static_for<1, k0_loops, 1>{}([&](
auto i_k0) {
508 k_lds_windows[number<i_k0 % NumKLdsBuffers>{}],
514 sequence<0, i_k0 * kK0>{},
515 sequence<
kM0, (i_k0 + 1) *
kK0>{}),
516 k_lds_windows[number<i_k0 % NumKLdsBuffers>{}]);
523 static_for<0, k0_loops - 1, 1>{}([&](
auto i_k0) {
524 store_tile(k_lds_windows[number<i_k0 % NumKLdsBuffers>{}],
526 if constexpr(i_k0 == 0)
529 if constexpr(i_k0 < k0_loops - 1)
531 if constexpr(i_k0 < k0_loops - 2)
538 sequence<0, i_k0 * kK0>{},
539 sequence<
kM0, (i_k0 + 1) *
kK0>{}),
540 k_lds_windows[number<i_k0 % NumKLdsBuffers>{}]);
553 sequence<0, (k0_loops - 1) *
kK0>{},
554 sequence<kM0, k0_loops * kK0>{}),
555 k_lds_windows[
number<(k0_loops - 1) % NumKLdsBuffers>{}]);
558 __builtin_amdgcn_sched_barrier(0);
560 const auto bias_tile =
load_tile(bias_dram_window);
562 static_for<1, NumPrefetchV, 1>{}([&](
auto i_buf) {
563 v_tiles[i_buf] =
load_tile(v_dram_window);
573 [&](
auto& x,
const auto& y) {
574 #if !CK_TILE_FMHA_FWD_FAST_EXP2
575 x += type_convert<SaccDataType>(bias_element_func(y));
577 x += log2e_v<SaccDataType> *
578 type_convert<SaccDataType>(bias_element_func(y));
586 const auto k_origin = k_dram_block_window.get_window_origin();
587 constexpr
auto s_spans = decltype(s_acc)::get_distributed_spans();
592 s_acc.get_tile_distribution(),
make_tuple(idx0, idx1));
594 const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
595 const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
596 constexpr
auto i_j_idx =
make_tuple(idx0, idx1);
598 s_acc(i_j_idx) *= scale_s;
599 position_encoding.update(s_acc(i_j_idx), row, col);
606 #if !CK_TILE_FMHA_FWD_FAST_EXP2
613 const auto k_origin = k_dram_block_window.get_window_origin();
614 bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
615 k_origin.at(number<0>{}),
618 if(need_perpixel_check)
622 const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
623 const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
624 return mask.IsOutOfBound(row, col);
629 const auto s = cast_tile<SMPLComputeDataType>(s_acc);
630 auto m_local = block_tile_reduce<SMPLComputeDataType>(
637 const auto m_old = m;
639 [](
auto& e0,
auto e1,
auto e2) { e0 =
max(e1, e2); }, m, m_old, m_local);
641 auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
642 s.get_tile_distribution());
651 ? type_convert<SMPLComputeDataType>(0.f)
660 constexpr
auto p_spans = decltype(p_compute)::get_distributed_spans();
663 #if CK_TILE_FMHA_FWD_FAST_EXP2
664 auto row_max = scale_s * get_validated_m(m[i_idx]);
667 constexpr
auto i_j_idx =
make_tuple(idx0, idx1);
668 #if CK_TILE_FMHA_FWD_FAST_EXP2
672 p_compute(i_j_idx) =
exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
676 p_compute(i_j_idx) =
exp2(scale_s * s[i_j_idx] - row_max);
679 p_compute(i_j_idx) =
exp(s[i_j_idx] - get_validated_m(m[i_idx]));
684 auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
689 constexpr
auto o_spans = decltype(o_acc)::get_distributed_spans();
692 #if CK_TILE_FMHA_FWD_FAST_EXP2
693 const auto tmp = [&]() {
697 return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
701 auto row_max = scale_s * get_validated_m(m[i_idx]);
702 return exp2(scale_s * m_old[i_idx] - row_max);
706 const auto tmp =
exp(m_old[i_idx] - get_validated_m(m[i_idx]));
708 l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
710 constexpr
auto i_j_idx =
make_tuple(idx0, idx1);
714 o_acc(i_j_idx) *= tmp;
721 reinterpret_cast<char*
>(smem_ptr) + Policy::template GetSmemSizeK<Problem>();
722 dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
723 smem_ptr, seqlen_k_start + i_total_loops *
kN0, p_compute, randval_dram_window);
726 __builtin_amdgcn_sched_barrier(0x7f);
728 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
730 auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
731 Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
744 __builtin_amdgcn_sched_barrier(0);
749 if constexpr(!kPreloadWholeNextIterationK)
751 if(i_total_loops < num_total_loop - 1)
758 __builtin_amdgcn_sched_barrier(0);
762 if constexpr(k1_loops > 1)
764 if constexpr(NumPrefetchV == 1)
766 static_for<0, k1_loops - 1, 1>{}([&](
auto i_k1) {
772 p, sequence<0, i_k1 * kK1>{}, sequence<
kM0, (i_k1 + 1) *
kK1>{}),
773 v_lds_windows[number<i_k1 % NumVLdsBuffers>{}]);
778 auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
779 Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
795 static_for<0, k1_loops - 1, 1>{}([&](
auto i_k1) {
796 if constexpr(i_k1 < k1_loops - NumPrefetchV)
797 v_tiles[number<i_k1 % NumPrefetchV>{}] =
load_tile(v_dram_window);
802 p, sequence<0, i_k1 * kK1>{}, sequence<
kM0, (i_k1 + 1) *
kK1>{}),
803 v_lds_windows[number<i_k1 % NumVLdsBuffers>{}]);
808 auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
809 Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
811 v_tiles[
number<(i_k1 + 1) % NumPrefetchV>{}]);
818 v_lds_windows[
number<(i_k1 + 1) % NumVLdsBuffers>{}],
820 v_tiles[
number<(i_k1 + 1) % NumPrefetchV>{}]));
823 if constexpr(i_k1 < k1_loops - NumPrefetchV)
834 v_lds_windows[
number<(k1_loops - 1) % NumVLdsBuffers>{}]);
836 if constexpr(Policy::template IsFirstKLdsBufferOverlapLastVLdsBuffer<Problem>())
838 __builtin_amdgcn_sched_barrier(0);
839 __builtin_amdgcn_s_barrier();
842 }
while(++i_total_loops < num_total_loop);
847 auto lse = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
849 constexpr
auto lse_spans = decltype(lse)::get_distributed_spans();
850 sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](
auto idx0) {
852 #if CK_TILE_FMHA_FWD_FAST_EXP2
856 lse(i_idx) = m_[i_idx] /
C_LOG2E +
log(l_[i_idx]);
860 lse(i_idx) = m_[i_idx] * scale_s /
C_LOG2E +
log(l_[i_idx]);
863 lse(i_idx) = m_[i_idx] +
log(l_[i_idx]);
871 constexpr
auto o_spans = decltype(o_acc)::get_distributed_spans();
875 const auto tmp = [&]() {
876 if constexpr(FmhaMask::IsMasking)
878 return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
884 constexpr
auto i_j_idx =
make_tuple(idx0, idx1);
885 o_acc(i_j_idx) *= tmp;
894 template <
typename QDramBlockWindowTmp,
895 typename KDramBlockWindowTmp,
896 typename VDramBlockWindowTmp,
897 typename BiasDramBlockWindowTmp,
898 typename RandValDramBlockWindowTmp,
899 typename LSEDramBlockWindowTmp,
900 typename PositionEncoding,
901 typename AttentionVariantParams,
902 typename BlockIndices>
904 operator()(
const QDramBlockWindowTmp& q_dram_block_window_tmp,
905 const KDramBlockWindowTmp& k_dram_block_window_tmp,
906 const VDramBlockWindowTmp& v_dram_block_window_tmp,
907 const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
908 RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
909 LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
911 PositionEncoding position_encoding,
914 const AttentionVariantParams& variant_params,
915 const BlockIndices& block_indices,
919 return operator()(q_dram_block_window_tmp,
921 k_dram_block_window_tmp,
923 v_dram_block_window_tmp,
925 bias_dram_block_window_tmp,
927 randval_dram_block_window_tmp,
928 lse_dram_block_window_tmp,
CK_TILE_DEVICE void block_sync_lds()
Definition: arch.hpp:190
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition: bfloat16.hpp:432
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
constexpr CK_TILE_DEVICE auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition: slice_tile.hpp:23
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition: tile_elementwise.hpp:40
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:95
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={}, bool_constant< CrossWarp >={})
Definition: block_reduce.hpp:21
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition: tile_elementwise.hpp:23
tuple_array< T, N > statically_indexed_array
Definition: statically_indexed_array.hpp:16
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constant< v > number
Definition: integral_constant.hpp:37
CK_TILE_DEVICE void shuffle_tile(OutTensor &out, const InTensor &in)
Definition: shuffle_tile.hpp:154
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:423
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition: static_distributed_tensor.hpp:175
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, DistributedIndices distributed_indices)
Definition: static_distributed_tensor.hpp:159
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition: tile_elementwise.hpp:177
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition: bfloat16.hpp:429
constexpr bool is_same_v
Definition: type.hpp:283
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:16
remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:26
static constexpr index_t kM0
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:40
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:38
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const QElementFunction &q_element_func, const KDramBlockWindowTmp &k_dram_block_window_tmp, const KElementFunction &k_element_func, const VDramBlockWindowTmp &v_dram_block_window_tmp, const VElementFunction &v_element_func, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, const BiasElementFunction &bias_element_func, RandValDramBlockWindowTmp &randval_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_window_tmp, const LSEElementFunction &lse_element_func, const SAccElementFunction &s_acc_element_func, const PComputeElementFunction &p_compute_element_func, const OAccElementFunction &o_acc_element_func, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &, const AttentionVariantParams &, const BlockIndices &, void *smem_ptr, DropoutType &dropout) const
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:136
remove_cvref_t< typename BlockFmhaShape::VLayout > VLayout
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:34
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:50
static constexpr index_t kN1
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:43
remove_cvref_t< typename Problem::QDataType > QDataType
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:19
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:54
static constexpr index_t kAlignmentO
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:73
remove_cvref_t< typename Problem::AttentionVariant > AttentionVariant
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:31
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:53
static constexpr index_t kAlignmentQ
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:62
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:52
static constexpr index_t kN0
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:41
static constexpr index_t kAlignmentBias
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:75
static constexpr bool kQLoadOnce
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:35
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:113
remove_cvref_t< typename Problem::ODataType > ODataType
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:29
static constexpr index_t kK0
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:42
remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:33
static constexpr index_t kK1
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:44
remove_cvref_t< typename Problem::PDataType > PDataType
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:27
remove_cvref_t< typename Problem::VDataType > VDataType
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:21
remove_cvref_t< typename Problem::BiasDataType > BiasDataType
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:24
static constexpr index_t kAlignmentV
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:66
static constexpr index_t kQKHeaddim
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:45
static constexpr bool kHasDropout
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:57
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:56
remove_cvref_t< typename Problem::SaccDataType > SaccDataType
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:22
remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:30
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, RandValDramBlockWindowTmp &randval_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_block_window_tmp, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, void *smem_ptr, DropoutType &dropout) const
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:904
remove_cvref_t< typename Problem::RandValOutputDataType > RandValOutputDataType
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:25
remove_cvref_t< Policy_ > Policy
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:18
static constexpr bool kHasLogitsSoftCap
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:58
static constexpr auto BiasEnum
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:55
static constexpr const char * name
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:109
remove_cvref_t< Problem_ > Problem
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:17
remove_cvref_t< typename Problem::KDataType > KDataType
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:20
static constexpr index_t kAlignmentK
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:64
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:51
static constexpr index_t kSubQKHeaddim
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:46
remove_cvref_t< typename Problem::OaccDataType > OaccDataType
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:28
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:78
remove_cvref_t< typename Problem::SMPLComputeDataType > SMPLComputeDataType
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:23
std::conditional_t< kHasDropout, BlockDropout, NullBlockDropout > DropoutType
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:111
Definition: integral_constant.hpp:13
Definition: functional.hpp:86
Definition: numeric.hpp:18
static constexpr CK_TILE_HOST_DEVICE T infinity()
Definition: numeric.hpp:38
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: tensor_layout.hpp:17
#define C_LOG2E
Definition: math.hpp:469