/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/block/block_dropout.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/block/block_dropout.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/block/block_dropout.hpp Source File
block_dropout.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 
9 namespace ck_tile {
10 
11 // BlockDropoutBwd and BlockDropout (fwd) support two warp gemm tile sizes: 32x32 (MFMA only) and
12 // 16x16 (MFMA and WMMA). Even if fwd and bwd use different tile sizes, generated random
13 // numbers will be the same, they are also the same for MFMA (on CDNA), WMMA (on RDNA), or host
14 // (for verification, see ck_tile/host/reference/reference_batched_dropout_randval.hpp).
15 //
16 // The (row, col) coordinate of the current 32x32 tile in the P matrix determines a subsequence of
17 // random numbers (ph_subsequence).
18 // The (batch, head, 0..63) coordinate determines an offset in the subsequence (ph_head_offset and
19 // ph_offset).
20 // This means that subsequences are non-overlapping, reproducible and independent of mask or window.
21 //
22 // There are 3 modes (all produce the same results):
23 // * For 32x32 MFMA tile each of 64 lanes generates 4 * 32 bits or 16 bytes, so one warp generates
24 // the entire 32x32 tile (64 * 16 = 32 * 32).
25 // * For 16x16 MFMA tile one warp generates 1/4 of the 32x32 tile ((16 * 16) / (64 * 16) = 1/4), 4
26 // warps generate the same 64 * 16 random bytes and each uses its own quarter. If kMPerBlock >
27 // MWarp * WG::kM one warp can generate two 16x16 tiles (MIterPerWarp = 2) so fewer instructions
28 // are needed for generating a 32x32 tile.
29 // * For 16x16 WMMA tile one warp generates 1/2 of the 32x32 tile ((16 * 16) / (32 * 16) = 1/2), 2
30 // warps generate the same 64 * 16 random bytes and each uses its own half. If kMPerBlock > MWarp *
31 // WG::kM one warp can generate two 16x16 tiles.
32 
33 namespace detail {
34 // The number of Philox 4x32 results required to fill 32x32 tile of 8-bit values
35 constexpr index_t philox_per_tile = 64;
36 } // namespace detail
37 
39 {
40  template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp>
41  CK_TILE_HOST_DEVICE static constexpr auto
42  MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
43  index_t seqlen_qk_start)
44  {
45  (void)randval_dram_block_window_tmp;
46  (void)seqlen_qk_start;
47 
49  }
50 };
51 
53 {
55  index_t i_head,
56  index_t nheads,
57  unsigned long long seed,
58  unsigned long long offset,
59  float rp_undrop_,
60  uint8_t p_undrop_in_uint8_t_,
61  bool is_store_randval_)
63  ph_head_offset(amd_wave_read_first_lane(offset + (i_batch * nheads + i_head) *
64  detail::philox_per_tile)),
65  rp_undrop(rp_undrop_),
66  p_undrop_in_uint8_t(p_undrop_in_uint8_t_),
67  is_store_randval(is_store_randval_)
68  {
69  }
70 
71  template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp>
72  CK_TILE_HOST_DEVICE static constexpr auto
73  MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
74  index_t seqlen_qk_start)
75  {
76  constexpr auto config =
77  BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
78  using WG = remove_cvref_t<decltype(config.template at<0>())>;
79  constexpr bool IsWG32 = WG::kM == 32;
80  constexpr index_t MWarp = config.template at<1>();
81  constexpr index_t NWarp = config.template at<2>();
83  constexpr index_t kMPerBlock = BlockGemmShape::kM;
84  constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
85  constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
86  constexpr index_t kNPerStep = NWarp * WG::kN;
87 
88  const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
89  auto randval_dram_window = [&]() {
90  if constexpr(IsFwd)
91  {
92  return make_tile_window(
93  randval_dram_block_window_tmp.get_bottom_tensor_view(),
95  {block_origin.at(number<0>{}), seqlen_qk_start}); // M/N
96  }
97  else
98  {
99  return make_tile_window(
100  randval_dram_block_window_tmp.get_bottom_tensor_view(),
102  {seqlen_qk_start, block_origin.at(number<1>{})}); // M/N
103  }
104  }();
105 
106  return randval_dram_window;
107  }
108 
109  template <typename BlockGemm>
111  {
112  constexpr auto config =
113  BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
114  using WG = remove_cvref_t<decltype(config.template at<0>())>;
115  constexpr bool IsWG32 = WG::kM == 32;
116  constexpr index_t MWarp = config.template at<1>();
117  constexpr index_t NWarp = config.template at<2>();
119  constexpr index_t kMPerBlock = BlockGemmShape::kM;
120  constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
121  constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
122  constexpr index_t kNPerStep = NWarp * WG::kN;
123  constexpr index_t kN1 = 8;
124  constexpr index_t kN0 = kNPerStep / kN1;
125 
126  constexpr auto randval_lds_block_desc_0 = make_naive_tensor_descriptor(
128  ck_tile::make_tuple(number<(kMPerStep + 1) * kN1>{}, number<kN1>{}, number<1>{}),
129  number<kN1>{},
130  number<1>{});
131 
132  constexpr auto randval_lds_block_desc = transform_tensor_descriptor(
133  randval_lds_block_desc_0,
139 
140  return randval_lds_block_desc;
141  }
142 
143  template <typename BlockGemm>
145  {
146  constexpr auto config =
147  BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
148  using WG = remove_cvref_t<decltype(config.template at<0>())>;
149  constexpr bool IsWG32 = WG::kM == 32;
150  constexpr index_t MWarp = config.template at<1>();
151  constexpr index_t NWarp = config.template at<2>();
153  constexpr index_t kMPerBlock = BlockGemmShape::kM;
154  constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
155  constexpr index_t NIterPerWarp = 1;
156 
157  // The tile distribution is different from the one in MakeRandValLdsShuffleTileDistribution,
158  // because it can combine 2 (MIterPerWarp) 16x16 subtiles for generating them at once
159  constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
160  sequence<>,
165  sequence<1, 0>>{};
166 
167  // Use Bwd WarpGemm to ensure that Fwd's random values ​​are consistent with Bwd.
168  constexpr auto randval_block_inner_part_dstr_encoding =
169  typename WarpGemmDispatcher<typename WG::ADataType,
170  typename WG::BDataType,
171  typename WG::CDataType,
172  WG::kM,
173  WG::kN,
174  WG::kK,
175  false,
176  IsWG32>::CWarpDstrEncoding{};
177 
178  constexpr auto randval_block_part_dstr_encode =
179  detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
180  randval_block_inner_part_dstr_encoding);
181 
182  return make_static_tile_distribution(randval_block_part_dstr_encode);
183  }
184 
185  template <typename BlockGemm>
187  {
188  constexpr auto config =
189  BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
190  using WG = remove_cvref_t<decltype(config.template at<0>())>;
191  constexpr bool IsWG32 = WG::kM == 32;
192  constexpr index_t MWarp = config.template at<1>();
193  constexpr index_t NWarp = config.template at<2>();
195  constexpr index_t kMPerBlock = BlockGemmShape::kM;
196  constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
197  constexpr index_t NIterPerWarp = 1;
198 
199  constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
200  sequence<>,
205  sequence<0, 0>>{};
206 
207  constexpr auto randval_block_part_dstr_encode =
208  detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
209  typename WG::CWarpDstrEncoding{});
210 
211  return make_static_tile_distribution(randval_block_part_dstr_encode);
212  }
213 
214  template <typename BlockGemm,
215  typename PComputeDataType,
216  typename RandValOutputDataType,
217  typename PComputeWindow,
218  typename RandValDramWindow>
219  CK_TILE_HOST_DEVICE void Run(void* randval_ptr,
220  const index_t start_n0_idx,
221  PComputeWindow& p_compute,
222  RandValDramWindow& randval_dram_window) const
223  {
224  constexpr auto config =
225  BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
226  using WG = remove_cvref_t<decltype(config.template at<0>())>;
227  constexpr bool IsWG32 = WG::kM == 32;
228  constexpr index_t MWarp = config.template at<1>();
229  constexpr index_t NWarp = config.template at<2>();
231  constexpr index_t kMPerBlock = BlockGemmShape::kM;
232  constexpr index_t kNPerBlock = BlockGemmShape::kN;
233  constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
234  constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
235  constexpr index_t kNPerStep = NWarp * WG::kN;
236 
237  // randval tile in LDS
238  auto randval_lds = make_tensor_view<address_space_enum::lds>(
239  reinterpret_cast<uint8_t*>(randval_ptr), MakeRandValLdsBlockDescriptor<BlockGemm>());
240 
241  auto randval_lds_window = make_tile_window(
242  randval_lds, MakeRandValLdsBlockDescriptor<BlockGemm>().get_lengths(), {0, 0});
243 
244  // register distribute
245  auto randval_dist_generated =
246  make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
247 
248  const auto randval_lds_read_window =
249  make_tile_window(randval_lds_window.get_bottom_tensor_view(),
250  randval_lds_window.get_window_lengths(),
251  randval_lds_window.get_window_origin(),
252  MakeRandValLdsShuffleTileDistribution<BlockGemm>());
253 
254  const index_t start_m0_idx = randval_dram_window.get_window_origin().at(number<0>{});
255  const index_t iMWarp = get_warp_id() / NWarp;
256  const index_t iNWarp = get_warp_id() % NWarp;
257 
258  auto generate_randval = [&](auto i_m0, auto i_n0) {
259  // Generate random numbers
260  uint8_t random_uint8_t[randval_dist_generated.kThreadElementSpaceSize];
261  const index_t wg_m0 = (start_m0_idx / WG::kM) + (i_m0 * MWarp + iMWarp) * MIterPerWarp;
262  const index_t wg_n0 = (start_n0_idx / WG::kN) + (i_n0 * NWarp + iNWarp);
263  if constexpr(IsWG32)
264  {
265  // Generate the whole 32x32 tile at once (each tile consists of random numbers taken
266  // from a separate subsequence of Philox)
267  const unsigned long long ph_subsequence =
268  bit_cast<unsigned long long>(make_uint2(wg_m0, wg_n0));
269  const index_t ph_offset = get_lane_id();
270  const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset);
271  static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
272  ph.get_random_16x8(random_uint8_t, ph_subsequence);
273  }
274  else
275  {
276  // Generate one or two 16x16 subtiles of the 32x32 tile (depending on whether
277  // MIterPerWarp is equal to 1 or 2)
278  const unsigned long long ph_subsequence =
279  bit_cast<unsigned long long>(make_uint2(wg_m0 / 2, wg_n0 / 2));
280  const index_t subtile_m0 = wg_m0 % 2;
281  if constexpr(get_warp_size() == 32)
282  {
283  const index_t ph_offset = (get_lane_id() & 15) +
284  (((get_lane_id() >> 4) & 1) << 5) +
285  ((wg_n0 % 2) << 4);
286  const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset);
287  if constexpr(MIterPerWarp == 1)
288  {
289  static_assert(randval_dist_generated.kThreadElementSpaceSize == 8);
290  ph.get_random_8x8(
291  random_uint8_t, ph_subsequence, subtile_m0 * 2 + 0, subtile_m0 * 2 + 1);
292  }
293  else
294  {
295  static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
296  ph.get_random_16x8(random_uint8_t, ph_subsequence);
297  }
298  }
299  else
300  {
301  const index_t subtile_n0 = (get_lane_id() >> 4) & 1;
302  const index_t ph_offset = (get_lane_id() & 47) + ((wg_n0 % 2) << 4);
303  const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset);
304  if constexpr(MIterPerWarp == 1)
305  {
306  static_assert(randval_dist_generated.kThreadElementSpaceSize == 4);
307  ph.get_random_4x8(
308  random_uint8_t, ph_subsequence, subtile_m0 * 2 + subtile_n0);
309  }
310  else
311  {
312  static_assert(randval_dist_generated.kThreadElementSpaceSize == 8);
313  ph.get_random_8x8(
314  random_uint8_t, ph_subsequence, 0 * 2 + subtile_n0, 1 * 2 + subtile_n0);
315  }
316  }
317  }
318 
319  constexpr auto randval_dist_generated_spans =
320  decltype(randval_dist_generated)::get_distributed_spans();
321  int i_random_idx = 0;
322  sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) {
323  sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) {
324  constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1);
325  randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
326  });
327  });
328  // Transpose randval using LDS
329  store_tile(randval_lds_window, randval_dist_generated);
330  block_sync_lds();
331  const auto randval = load_tile(randval_lds_read_window);
332  block_sync_lds();
333  return randval;
334  };
335 
336  if(is_store_randval)
337  {
338  static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
339  static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
340  const auto randval = generate_randval(i_m0, i_n0);
341  // save to Global
342  const auto randval_store = cast_tile<RandValOutputDataType>(randval);
343  store_tile(randval_dram_window, randval_store);
344  move_tile_window(randval_dram_window, {0, kNPerStep});
345  });
346  move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock});
347  });
348  move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock});
349  }
350  static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
351  static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
352  const auto randval = generate_randval(i_m0, i_n0);
353  // Drop values of P based on the generated probabilities
354  constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
355  sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
356  sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
357  constexpr auto p_idx0 =
358  tile_distributed_index<i_m0 * MIterPerWarp +
359  idx0.impl_.template at<0>()>{};
360  constexpr auto p_idx1 =
362  idx1.impl_.template at<1>(),
363  idx1.impl_.template at<2>()>{};
364  constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
365  constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
366  p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
367  ? p_compute[p_idx] * rp_undrop
368  : PComputeDataType(0);
369  });
370  });
371  });
372  });
373  }
374 
375  const unsigned long long ph_seed;
376  const unsigned long long ph_head_offset;
377  const float rp_undrop;
379  const bool is_store_randval;
380 };
381 
382 // TODO: IsWG32_ is not needed as template parameter and can be removed. IsDropout_ == false can be
383 // replaced with NullBlockDropout. This requires changes in xformers and other libs.
384 template <bool IsDropout_, bool IsWG32_, bool IsStoreRandval_>
386 
387 template <bool IsWG32_, bool IsStoreRandval_>
388 struct BlockDropoutBwd<false, IsWG32_, IsStoreRandval_>
389 {
390  static constexpr bool IsDropout = false;
391  static constexpr bool IsStoreRandval = IsStoreRandval_;
392 
393  template <typename BlockGemm, bool IsFwd = false, typename RandValDramBlockWindowTmp>
394  CK_TILE_HOST_DEVICE static constexpr auto
395  MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
396  index_t seqlen_qk_start)
397  {
398  (void)randval_dram_block_window_tmp;
399  (void)seqlen_qk_start;
400 
402  }
403 };
404 
405 template <bool IsWG32_, bool IsStoreRandval_>
406 struct BlockDropoutBwd<true, IsWG32_, IsStoreRandval_>
407 {
408  static constexpr bool IsDropout = true;
409  static constexpr bool IsStoreRandval = IsStoreRandval_;
410 
412  index_t i_head,
413  index_t nheads,
414  unsigned long long seed,
415  unsigned long long offset,
416  float rp_undrop_,
417  uint8_t p_undrop_in_uint8_t_)
418  : ph_seed(amd_wave_read_first_lane(seed)),
419  ph_head_offset(amd_wave_read_first_lane(offset + (i_batch * nheads + i_head) *
420  detail::philox_per_tile)),
421  rp_undrop(rp_undrop_),
422  p_undrop_in_uint8_t(p_undrop_in_uint8_t_)
423  {
424  }
425 
426  template <typename BlockGemm, bool IsFwd = false, typename RandValDramBlockWindowTmp>
427  CK_TILE_HOST_DEVICE static constexpr auto
428  MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
429  index_t seqlen_qk_start)
430  {
431  constexpr auto config =
432  BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
433  using WG = remove_cvref_t<decltype(config.template at<0>())>;
434  constexpr bool IsWG32 = WG::kM == 32;
435  constexpr index_t MWarp = config.template at<1>();
436  constexpr index_t NWarp = config.template at<2>();
438  constexpr index_t kMPerBlock = BlockGemmShape::kM;
439  constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
440  constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
441  constexpr index_t kNPerStep = NWarp * WG::kN;
442 
443  const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
444  auto randval_dram_window = [&]() {
445  if constexpr(IsFwd)
446  {
447  return make_tile_window(
448  randval_dram_block_window_tmp.get_bottom_tensor_view(),
450  {block_origin.at(number<0>{}), seqlen_qk_start}); // M/N
451  }
452  else
453  {
454  return make_tile_window(
455  randval_dram_block_window_tmp.get_bottom_tensor_view(),
457  {seqlen_qk_start, block_origin.at(number<1>{})}); // M/N
458  }
459  }();
460 
461  return randval_dram_window;
462  }
463 
464  template <typename BlockGemm>
466  {
467  constexpr auto config =
468  BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
469  using WG = remove_cvref_t<decltype(config.template at<0>())>;
470  constexpr bool IsWG32 = WG::kM == 32;
471  constexpr index_t MWarp = config.template at<1>();
472  constexpr index_t NWarp = config.template at<2>();
474  constexpr index_t kMPerBlock = BlockGemmShape::kM;
475  constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
476  constexpr index_t NIterPerWarp = 1;
477 
478  constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
479  sequence<>,
484  sequence<1, 0>>{};
485 
486  constexpr auto randval_block_inner_part_dstr_encoding =
487  typename WarpGemmDispatcher<typename WG::ADataType,
488  typename WG::BDataType,
489  typename WG::CDataType,
490  WG::kM,
491  WG::kN,
492  WG::kK,
493  false,
494  IsWG32>::CWarpDstrEncoding{};
495  static_assert(
496  std::is_same_v<remove_cvref_t<decltype(randval_block_inner_part_dstr_encoding)>,
497  typename WG::CWarpDstrEncoding>);
498 
499  constexpr auto randval_block_part_dstr_encode =
500  detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
501  randval_block_inner_part_dstr_encoding);
502 
503  return make_static_tile_distribution(randval_block_part_dstr_encode);
504  }
505 
506  template <typename BlockGemm,
507  typename RandValOutputDataType,
508  typename PComputeWindow,
509  typename RandValDramWindow>
510  CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx,
511  const index_t start_n0_idx,
512  PComputeWindow& p_compute,
513  RandValDramWindow& randval_dram_window) const
514  {
515  constexpr auto config =
516  BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
517  using WG = remove_cvref_t<decltype(config.template at<0>())>;
518  constexpr bool IsWG32 = WG::kM == 32;
519  constexpr index_t MWarp = config.template at<1>();
520  constexpr index_t NWarp = config.template at<2>();
522  constexpr index_t kMPerBlock = BlockGemmShape::kM;
523  constexpr index_t kNPerBlock = BlockGemmShape::kN;
524  constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
525  constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
526  constexpr index_t kNPerStep = NWarp * WG::kN;
527 
528  // register distribute
529  auto randval_dist_generated =
530  make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
531 
532  const index_t iMWarp = get_warp_id() / NWarp;
533  const index_t iNWarp = get_warp_id() % NWarp;
534 
535  auto generate_randval = [&](auto i_m0, auto i_n0) {
536  // Generate random numbers
537  uint8_t random_uint8_t[randval_dist_generated.kThreadElementSpaceSize];
538  const index_t wg_m0 = (start_m0_idx / WG::kM) + (i_m0 * MWarp + iMWarp) * MIterPerWarp;
539  const index_t wg_n0 = (start_n0_idx / WG::kN) + (i_n0 * NWarp + iNWarp);
540  if constexpr(IsWG32)
541  {
542  // Generate the whole 32x32 tile at once (each tile consists of random numbers
543  // taken from a separate subsequence of Philox)
544  const unsigned long long ph_subsequence =
545  bit_cast<unsigned long long>(make_uint2(wg_m0, wg_n0));
546  const index_t ph_offset = get_lane_id();
547  const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset);
548  static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
549  ph.get_random_16x8(random_uint8_t, ph_subsequence);
550  }
551  else
552  {
553  // Generate one or two 16x16 subtiles of the 32x32 tile (depending on whether
554  // MIterPerWarp is equal to 1 or 2)
555  const unsigned long long ph_subsequence =
556  bit_cast<unsigned long long>(make_uint2(wg_m0 / 2, wg_n0 / 2));
557  const index_t subtile_m0 = wg_m0 % 2;
558  if constexpr(get_warp_size() == 32)
559  {
560  const index_t ph_offset = (get_lane_id() & 15) +
561  (((get_lane_id() >> 4) & 1) << 5) +
562  ((wg_n0 % 2) << 4);
563  const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset);
564  if constexpr(MIterPerWarp == 1)
565  {
566  static_assert(randval_dist_generated.kThreadElementSpaceSize == 8);
567  ph.get_random_8x8(
568  random_uint8_t, ph_subsequence, subtile_m0 * 2 + 0, subtile_m0 * 2 + 1);
569  }
570  else
571  {
572  static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
573  ph.get_random_16x8(random_uint8_t, ph_subsequence);
574  }
575  }
576  else
577  {
578  const index_t subtile_n0 = (get_lane_id() >> 4) & 1;
579  const index_t ph_offset = (get_lane_id() & 47) + ((wg_n0 % 2) << 4);
580  const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset);
581  if constexpr(MIterPerWarp == 1)
582  {
583  static_assert(randval_dist_generated.kThreadElementSpaceSize == 4);
584  ph.get_random_4x8(
585  random_uint8_t, ph_subsequence, subtile_m0 * 2 + subtile_n0);
586  }
587  else
588  {
589  static_assert(randval_dist_generated.kThreadElementSpaceSize == 8);
590  ph.get_random_8x8(
591  random_uint8_t, ph_subsequence, 0 * 2 + subtile_n0, 1 * 2 + subtile_n0);
592  }
593  }
594  }
595 
596  constexpr auto randval_dist_generated_spans =
597  decltype(randval_dist_generated)::get_distributed_spans();
598  int i_random_idx = 0;
599  sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) {
600  sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) {
601  constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1);
602  randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
603  });
604  });
605  return randval_dist_generated;
606  };
607 
608  static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
609  static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
610  const auto randval = generate_randval(i_m0, i_n0);
611  // Drop values of P based on the generated probabilities, negative sign is used to
612  // distinguish such values ​​later in bwd pipeline.
613  constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
614  sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
615  sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
616  constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
617  constexpr auto p_idx0 =
618  tile_distributed_index<i_m0 * MIterPerWarp +
619  idx0.impl_.template at<0>(),
620  idx0.impl_.template at<1>(),
621  idx0.impl_.template at<2>()>{};
622  constexpr auto p_idx1 = tile_distributed_index<i_n0>{};
623  constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
624  p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
625  ? p_compute[p_idx]
626  : -p_compute[p_idx];
627  });
628  });
629  // save to Global
630  if constexpr(IsStoreRandval)
631  {
632  const auto randval_store = cast_tile<RandValOutputDataType>(randval);
633  store_tile(randval_dram_window, randval_store);
634  move_tile_window(randval_dram_window, {kMPerStep, 0});
635  }
636  });
637  if constexpr(IsStoreRandval)
638  {
639  move_tile_window(randval_dram_window, {-kMPerBlock, kNPerStep});
640  }
641  });
642  if constexpr(IsStoreRandval)
643  {
644  move_tile_window(randval_dram_window, {kMPerBlock, -kNPerBlock});
645  }
646  }
647 
648  const unsigned long long ph_seed;
649  const unsigned long long ph_head_offset;
650  const float rp_undrop;
652 };
653 
654 } // namespace ck_tile
CK_TILE_DEVICE void block_sync_lds()
Definition: arch.hpp:192
Definition: philox_rand.hpp:12
CK_TILE_HOST_DEVICE void get_random_4x8(uint8_t *out, const unsigned long long subsequence, const index_t idx) const
Definition: philox_rand.hpp:75
CK_TILE_HOST_DEVICE void get_random_8x8(uint8_t *out, const unsigned long long subsequence, const index_t idx0, const index_t idx1) const
Definition: philox_rand.hpp:56
CK_TILE_HOST_DEVICE void get_random_16x8(uint8_t *out, const unsigned long long subsequence) const
Definition: philox_rand.hpp:42
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
constexpr index_t philox_per_tile
Definition: block_dropout.hpp:35
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:457
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:2834
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition: warp_gemm_dispatcher.hpp:184
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition: amd_wave_read_first_lane.hpp:100
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
constexpr bool is_same_v
Definition: type.hpp:283
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
unsigned char uint8_t
Definition: stdint.h:124
static constexpr CK_TILE_HOST_DEVICE auto MakeRandvalDramWindow(RandValDramBlockWindowTmp &randval_dram_block_window_tmp, index_t seqlen_qk_start)
Definition: block_dropout.hpp:395
CK_TILE_HOST_DEVICE BlockDropoutBwd(index_t i_batch, index_t i_head, index_t nheads, unsigned long long seed, unsigned long long offset, float rp_undrop_, uint8_t p_undrop_in_uint8_t_)
Definition: block_dropout.hpp:411
const unsigned long long ph_seed
Definition: block_dropout.hpp:648
static constexpr CK_TILE_HOST_DEVICE auto MakeRandValTileDistribution()
Definition: block_dropout.hpp:465
static constexpr CK_TILE_HOST_DEVICE auto MakeRandvalDramWindow(RandValDramBlockWindowTmp &randval_dram_block_window_tmp, index_t seqlen_qk_start)
Definition: block_dropout.hpp:428
const uint8_t p_undrop_in_uint8_t
Definition: block_dropout.hpp:651
const unsigned long long ph_head_offset
Definition: block_dropout.hpp:649
CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx, const index_t start_n0_idx, PComputeWindow &p_compute, RandValDramWindow &randval_dram_window) const
Definition: block_dropout.hpp:510
const float rp_undrop
Definition: block_dropout.hpp:650
Definition: block_dropout.hpp:385
Definition: block_dropout.hpp:53
const uint8_t p_undrop_in_uint8_t
Definition: block_dropout.hpp:378
CK_TILE_HOST_DEVICE BlockDropout(index_t i_batch, index_t i_head, index_t nheads, unsigned long long seed, unsigned long long offset, float rp_undrop_, uint8_t p_undrop_in_uint8_t_, bool is_store_randval_)
Definition: block_dropout.hpp:54
const float rp_undrop
Definition: block_dropout.hpp:377
const unsigned long long ph_head_offset
Definition: block_dropout.hpp:376
const bool is_store_randval
Definition: block_dropout.hpp:379
static constexpr CK_TILE_HOST_DEVICE auto MakeRandvalDramWindow(RandValDramBlockWindowTmp &randval_dram_block_window_tmp, index_t seqlen_qk_start)
Definition: block_dropout.hpp:73
CK_TILE_HOST_DEVICE void Run(void *randval_ptr, const index_t start_n0_idx, PComputeWindow &p_compute, RandValDramWindow &randval_dram_window) const
Definition: block_dropout.hpp:219
const unsigned long long ph_seed
Definition: block_dropout.hpp:375
static constexpr CK_TILE_HOST_DEVICE auto MakeRandValTileDistribution()
Definition: block_dropout.hpp:144
static constexpr CK_TILE_HOST_DEVICE auto MakeRandValLdsShuffleTileDistribution()
Definition: block_dropout.hpp:186
static constexpr CK_TILE_HOST_DEVICE auto MakeRandValLdsBlockDescriptor()
Definition: block_dropout.hpp:110
Definition: block_dropout.hpp:39
static constexpr CK_TILE_HOST_DEVICE auto MakeRandvalDramWindow(RandValDramBlockWindowTmp &randval_dram_block_window_tmp, index_t seqlen_qk_start)
Definition: block_dropout.hpp:42
Definition: integral_constant.hpp:13
Definition: coordinate_transform.hpp:1392
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: tile_distribution.hpp:42
static constexpr auto impl_
Definition: tile_distribution.hpp:45
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192