40 template <
typename BlockGemm,
bool IsFwd = true,
typename RandValDramBlockWindowTmp>
45 (void)randval_dram_block_window_tmp;
46 (void)seqlen_qk_start;
57 unsigned long long seed,
61 bool is_store_randval_)
71 template <
typename BlockGemm,
bool IsFwd = true,
typename RandValDramBlockWindowTmp>
76 constexpr
auto config =
77 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
79 constexpr
bool IsWG32 = WG::kM == 32;
80 constexpr
index_t MWarp = config.template at<1>();
81 constexpr
index_t NWarp = config.template at<2>();
83 constexpr
index_t kMPerBlock = BlockGemmShape::kM;
84 constexpr
index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
85 constexpr
index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
86 constexpr
index_t kNPerStep = NWarp * WG::kN;
88 const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
89 auto randval_dram_window = [&]() {
93 randval_dram_block_window_tmp.get_bottom_tensor_view(),
95 {block_origin.at(number<0>{}), seqlen_qk_start});
100 randval_dram_block_window_tmp.get_bottom_tensor_view(),
102 {seqlen_qk_start, block_origin.at(number<1>{})});
106 return randval_dram_window;
109 template <
typename BlockGemm>
112 constexpr
auto config =
113 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
115 constexpr
bool IsWG32 = WG::kM == 32;
116 constexpr
index_t MWarp = config.template at<1>();
117 constexpr
index_t NWarp = config.template at<2>();
119 constexpr
index_t kMPerBlock = BlockGemmShape::kM;
120 constexpr
index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
121 constexpr
index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
122 constexpr
index_t kNPerStep = NWarp * WG::kN;
124 constexpr
index_t kN0 = kNPerStep / kN1;
133 randval_lds_block_desc_0,
140 return randval_lds_block_desc;
143 template <
typename BlockGemm>
146 constexpr
auto config =
147 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
149 constexpr
bool IsWG32 = WG::kM == 32;
150 constexpr
index_t MWarp = config.template at<1>();
151 constexpr
index_t NWarp = config.template at<2>();
153 constexpr
index_t kMPerBlock = BlockGemmShape::kM;
154 constexpr
index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
155 constexpr
index_t NIterPerWarp = 1;
168 constexpr
auto randval_block_inner_part_dstr_encoding =
170 typename WG::BDataType,
171 typename WG::CDataType,
176 IsWG32>::CWarpDstrEncoding{};
178 constexpr
auto randval_block_part_dstr_encode =
180 randval_block_inner_part_dstr_encoding);
185 template <
typename BlockGemm>
188 constexpr
auto config =
189 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
191 constexpr
bool IsWG32 = WG::kM == 32;
192 constexpr
index_t MWarp = config.template at<1>();
193 constexpr
index_t NWarp = config.template at<2>();
195 constexpr
index_t kMPerBlock = BlockGemmShape::kM;
196 constexpr
index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
197 constexpr
index_t NIterPerWarp = 1;
207 constexpr
auto randval_block_part_dstr_encode =
209 typename WG::CWarpDstrEncoding{});
214 template <
typename BlockGemm,
215 typename PComputeDataType,
216 typename RandValOutputDataType,
217 typename PComputeWindow,
218 typename RandValDramWindow>
221 PComputeWindow& p_compute,
222 RandValDramWindow& randval_dram_window)
const
224 constexpr
auto config =
225 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
227 constexpr
bool IsWG32 = WG::kM == 32;
228 constexpr
index_t MWarp = config.template at<1>();
229 constexpr
index_t NWarp = config.template at<2>();
231 constexpr
index_t kMPerBlock = BlockGemmShape::kM;
232 constexpr
index_t kNPerBlock = BlockGemmShape::kN;
233 constexpr
index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
234 constexpr
index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
235 constexpr
index_t kNPerStep = NWarp * WG::kN;
238 auto randval_lds = make_tensor_view<address_space_enum::lds>(
239 reinterpret_cast<uint8_t*
>(randval_ptr), MakeRandValLdsBlockDescriptor<BlockGemm>());
242 randval_lds, MakeRandValLdsBlockDescriptor<BlockGemm>().get_lengths(), {0, 0});
245 auto randval_dist_generated =
246 make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
248 const auto randval_lds_read_window =
250 randval_lds_window.get_window_lengths(),
251 randval_lds_window.get_window_origin(),
252 MakeRandValLdsShuffleTileDistribution<BlockGemm>());
254 const index_t start_m0_idx = randval_dram_window.get_window_origin().at(
number<0>{});
255 const index_t iMWarp = get_warp_id() / NWarp;
256 const index_t iNWarp = get_warp_id() % NWarp;
258 auto generate_randval = [&](
auto i_m0,
auto i_n0) {
260 uint8_t random_uint8_t[randval_dist_generated.kThreadElementSpaceSize];
261 const index_t wg_m0 = (start_m0_idx / WG::kM) + (i_m0 * MWarp + iMWarp) * MIterPerWarp;
262 const index_t wg_n0 = (start_n0_idx / WG::kN) + (i_n0 * NWarp + iNWarp);
267 const unsigned long long ph_subsequence =
268 bit_cast<unsigned long long>(make_uint2(wg_m0, wg_n0));
269 const index_t ph_offset = get_lane_id();
271 static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
278 const unsigned long long ph_subsequence =
279 bit_cast<unsigned long long>(make_uint2(wg_m0 / 2, wg_n0 / 2));
280 const index_t subtile_m0 = wg_m0 % 2;
283 const index_t ph_offset = (get_lane_id() & 15) +
284 (((get_lane_id() >> 4) & 1) << 5) +
287 if constexpr(MIterPerWarp == 1)
289 static_assert(randval_dist_generated.kThreadElementSpaceSize == 8);
291 random_uint8_t, ph_subsequence, subtile_m0 * 2 + 0, subtile_m0 * 2 + 1);
295 static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
301 const index_t subtile_n0 = (get_lane_id() >> 4) & 1;
302 const index_t ph_offset = (get_lane_id() & 47) + ((wg_n0 % 2) << 4);
304 if constexpr(MIterPerWarp == 1)
306 static_assert(randval_dist_generated.kThreadElementSpaceSize == 4);
308 random_uint8_t, ph_subsequence, subtile_m0 * 2 + subtile_n0);
312 static_assert(randval_dist_generated.kThreadElementSpaceSize == 8);
314 random_uint8_t, ph_subsequence, 0 * 2 + subtile_n0, 1 * 2 + subtile_n0);
319 constexpr
auto randval_dist_generated_spans =
320 decltype(randval_dist_generated)::get_distributed_spans();
321 int i_random_idx = 0;
325 randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
329 store_tile(randval_lds_window, randval_dist_generated);
331 const auto randval =
load_tile(randval_lds_read_window);
338 static_for<0, kMPerBlock / kMPerStep, 1>{}([&](
auto i_m0) {
339 static_for<0, kNPerBlock / kNPerStep, 1>{}([&](
auto i_n0) {
340 const auto randval = generate_randval(i_m0, i_n0);
342 const auto randval_store = cast_tile<RandValOutputDataType>(randval);
343 store_tile(randval_dram_window, randval_store);
350 static_for<0, kMPerBlock / kMPerStep, 1>{}([&](
auto i_m0) {
351 static_for<0, kNPerBlock / kNPerStep, 1>{}([&](
auto i_n0) {
352 const auto randval = generate_randval(i_m0, i_n0);
354 constexpr
auto randval_spans = decltype(randval)::get_distributed_spans();
357 constexpr
auto p_idx0 =
359 idx0.
impl_.template at<0>()>{};
360 constexpr
auto p_idx1 =
362 idx1.
impl_.template at<1>(),
363 idx1.impl_.template at<2>()>{};
366 p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
367 ? p_compute[p_idx] * rp_undrop
368 : PComputeDataType(0);
384 template <
bool IsDropout_,
bool IsWG32_,
bool IsStoreRandval_>
387 template <
bool IsWG32_,
bool IsStoreRandval_>
390 static constexpr
bool IsDropout =
false;
391 static constexpr
bool IsStoreRandval = IsStoreRandval_;
393 template <
typename BlockGemm,
bool IsFwd = false,
typename RandValDramBlockWindowTmp>
398 (void)randval_dram_block_window_tmp;
399 (void)seqlen_qk_start;
405 template <
bool IsWG32_,
bool IsStoreRandval_>
408 static constexpr
bool IsDropout =
true;
409 static constexpr
bool IsStoreRandval = IsStoreRandval_;
414 unsigned long long seed,
415 unsigned long long offset,
421 rp_undrop(rp_undrop_),
422 p_undrop_in_uint8_t(p_undrop_in_uint8_t_)
426 template <
typename BlockGemm,
bool IsFwd = false,
typename RandValDramBlockWindowTmp>
431 constexpr
auto config =
432 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
434 constexpr
bool IsWG32 = WG::kM == 32;
435 constexpr
index_t MWarp = config.template at<1>();
436 constexpr
index_t NWarp = config.template at<2>();
438 constexpr
index_t kMPerBlock = BlockGemmShape::kM;
439 constexpr
index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
440 constexpr
index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
441 constexpr
index_t kNPerStep = NWarp * WG::kN;
443 const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
444 auto randval_dram_window = [&]() {
448 randval_dram_block_window_tmp.get_bottom_tensor_view(),
450 {block_origin.at(number<0>{}), seqlen_qk_start});
455 randval_dram_block_window_tmp.get_bottom_tensor_view(),
457 {seqlen_qk_start, block_origin.at(number<1>{})});
461 return randval_dram_window;
464 template <
typename BlockGemm>
467 constexpr
auto config =
468 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
470 constexpr
bool IsWG32 = WG::kM == 32;
471 constexpr
index_t MWarp = config.template at<1>();
472 constexpr
index_t NWarp = config.template at<2>();
474 constexpr
index_t kMPerBlock = BlockGemmShape::kM;
475 constexpr
index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
476 constexpr
index_t NIterPerWarp = 1;
486 constexpr
auto randval_block_inner_part_dstr_encoding =
488 typename WG::BDataType,
489 typename WG::CDataType,
494 IsWG32>::CWarpDstrEncoding{};
497 typename WG::CWarpDstrEncoding>);
499 constexpr
auto randval_block_part_dstr_encode =
501 randval_block_inner_part_dstr_encoding);
506 template <
typename BlockGemm,
507 typename RandValOutputDataType,
508 typename PComputeWindow,
509 typename RandValDramWindow>
512 PComputeWindow& p_compute,
513 RandValDramWindow& randval_dram_window)
const
515 constexpr
auto config =
516 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
518 constexpr
bool IsWG32 = WG::kM == 32;
519 constexpr
index_t MWarp = config.template at<1>();
520 constexpr
index_t NWarp = config.template at<2>();
522 constexpr
index_t kMPerBlock = BlockGemmShape::kM;
523 constexpr
index_t kNPerBlock = BlockGemmShape::kN;
524 constexpr
index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
525 constexpr
index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
526 constexpr
index_t kNPerStep = NWarp * WG::kN;
529 auto randval_dist_generated =
530 make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
532 const index_t iMWarp = get_warp_id() / NWarp;
533 const index_t iNWarp = get_warp_id() % NWarp;
535 auto generate_randval = [&](
auto i_m0,
auto i_n0) {
537 uint8_t random_uint8_t[randval_dist_generated.kThreadElementSpaceSize];
538 const index_t wg_m0 = (start_m0_idx / WG::kM) + (i_m0 * MWarp + iMWarp) * MIterPerWarp;
539 const index_t wg_n0 = (start_n0_idx / WG::kN) + (i_n0 * NWarp + iNWarp);
544 const unsigned long long ph_subsequence =
545 bit_cast<unsigned long long>(make_uint2(wg_m0, wg_n0));
546 const index_t ph_offset = get_lane_id();
548 static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
555 const unsigned long long ph_subsequence =
556 bit_cast<unsigned long long>(make_uint2(wg_m0 / 2, wg_n0 / 2));
557 const index_t subtile_m0 = wg_m0 % 2;
560 const index_t ph_offset = (get_lane_id() & 15) +
561 (((get_lane_id() >> 4) & 1) << 5) +
564 if constexpr(MIterPerWarp == 1)
566 static_assert(randval_dist_generated.kThreadElementSpaceSize == 8);
568 random_uint8_t, ph_subsequence, subtile_m0 * 2 + 0, subtile_m0 * 2 + 1);
572 static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
578 const index_t subtile_n0 = (get_lane_id() >> 4) & 1;
579 const index_t ph_offset = (get_lane_id() & 47) + ((wg_n0 % 2) << 4);
581 if constexpr(MIterPerWarp == 1)
583 static_assert(randval_dist_generated.kThreadElementSpaceSize == 4);
585 random_uint8_t, ph_subsequence, subtile_m0 * 2 + subtile_n0);
589 static_assert(randval_dist_generated.kThreadElementSpaceSize == 8);
591 random_uint8_t, ph_subsequence, 0 * 2 + subtile_n0, 1 * 2 + subtile_n0);
596 constexpr
auto randval_dist_generated_spans =
597 decltype(randval_dist_generated)::get_distributed_spans();
598 int i_random_idx = 0;
602 randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
605 return randval_dist_generated;
608 static_for<0, kNPerBlock / kNPerStep, 1>{}([&](
auto i_n0) {
609 static_for<0, kMPerBlock / kMPerStep, 1>{}([&](
auto i_m0) {
610 const auto randval = generate_randval(i_m0, i_n0);
613 constexpr
auto randval_spans = decltype(randval)::get_distributed_spans();
617 constexpr
auto p_idx0 =
619 idx0.
impl_.template at<0>(),
620 idx0.impl_.template at<1>(),
621 idx0.impl_.template at<2>()>{};
624 p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
630 if constexpr(IsStoreRandval)
632 const auto randval_store = cast_tile<RandValOutputDataType>(randval);
633 store_tile(randval_dram_window, randval_store);
637 if constexpr(IsStoreRandval)
642 if constexpr(IsStoreRandval)
CK_TILE_DEVICE void block_sync_lds()
Definition: arch.hpp:192
Definition: philox_rand.hpp:12
CK_TILE_HOST_DEVICE void get_random_4x8(uint8_t *out, const unsigned long long subsequence, const index_t idx) const
Definition: philox_rand.hpp:75
CK_TILE_HOST_DEVICE void get_random_8x8(uint8_t *out, const unsigned long long subsequence, const index_t idx0, const index_t idx1) const
Definition: philox_rand.hpp:56
CK_TILE_HOST_DEVICE void get_random_16x8(uint8_t *out, const unsigned long long subsequence) const
Definition: philox_rand.hpp:42
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
constexpr index_t philox_per_tile
Definition: block_dropout.hpp:35
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:457
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:2834
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition: warp_gemm_dispatcher.hpp:184
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition: amd_wave_read_first_lane.hpp:100
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
constexpr bool is_same_v
Definition: type.hpp:283
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
unsigned char uint8_t
Definition: stdint.h:124
static constexpr CK_TILE_HOST_DEVICE auto MakeRandvalDramWindow(RandValDramBlockWindowTmp &randval_dram_block_window_tmp, index_t seqlen_qk_start)
Definition: block_dropout.hpp:395
CK_TILE_HOST_DEVICE BlockDropoutBwd(index_t i_batch, index_t i_head, index_t nheads, unsigned long long seed, unsigned long long offset, float rp_undrop_, uint8_t p_undrop_in_uint8_t_)
Definition: block_dropout.hpp:411
const unsigned long long ph_seed
Definition: block_dropout.hpp:648
static constexpr CK_TILE_HOST_DEVICE auto MakeRandValTileDistribution()
Definition: block_dropout.hpp:465
static constexpr CK_TILE_HOST_DEVICE auto MakeRandvalDramWindow(RandValDramBlockWindowTmp &randval_dram_block_window_tmp, index_t seqlen_qk_start)
Definition: block_dropout.hpp:428
const uint8_t p_undrop_in_uint8_t
Definition: block_dropout.hpp:651
const unsigned long long ph_head_offset
Definition: block_dropout.hpp:649
CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx, const index_t start_n0_idx, PComputeWindow &p_compute, RandValDramWindow &randval_dram_window) const
Definition: block_dropout.hpp:510
const float rp_undrop
Definition: block_dropout.hpp:650
Definition: block_dropout.hpp:385
Definition: block_dropout.hpp:53
const uint8_t p_undrop_in_uint8_t
Definition: block_dropout.hpp:378
CK_TILE_HOST_DEVICE BlockDropout(index_t i_batch, index_t i_head, index_t nheads, unsigned long long seed, unsigned long long offset, float rp_undrop_, uint8_t p_undrop_in_uint8_t_, bool is_store_randval_)
Definition: block_dropout.hpp:54
const float rp_undrop
Definition: block_dropout.hpp:377
const unsigned long long ph_head_offset
Definition: block_dropout.hpp:376
const bool is_store_randval
Definition: block_dropout.hpp:379
static constexpr CK_TILE_HOST_DEVICE auto MakeRandvalDramWindow(RandValDramBlockWindowTmp &randval_dram_block_window_tmp, index_t seqlen_qk_start)
Definition: block_dropout.hpp:73
CK_TILE_HOST_DEVICE void Run(void *randval_ptr, const index_t start_n0_idx, PComputeWindow &p_compute, RandValDramWindow &randval_dram_window) const
Definition: block_dropout.hpp:219
const unsigned long long ph_seed
Definition: block_dropout.hpp:375
static constexpr CK_TILE_HOST_DEVICE auto MakeRandValTileDistribution()
Definition: block_dropout.hpp:144
static constexpr CK_TILE_HOST_DEVICE auto MakeRandValLdsShuffleTileDistribution()
Definition: block_dropout.hpp:186
static constexpr CK_TILE_HOST_DEVICE auto MakeRandValLdsBlockDescriptor()
Definition: block_dropout.hpp:110
Definition: block_dropout.hpp:39
static constexpr CK_TILE_HOST_DEVICE auto MakeRandvalDramWindow(RandValDramBlockWindowTmp &randval_dram_block_window_tmp, index_t seqlen_qk_start)
Definition: block_dropout.hpp:42
Definition: integral_constant.hpp:13
Definition: coordinate_transform.hpp:1392
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: tile_distribution.hpp:42
static constexpr auto impl_
Definition: tile_distribution.hpp:45
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192