/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/block/block_dropout.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/block/block_dropout.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/block/block_dropout.hpp Source File
block_dropout.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 
9 namespace ck_tile {
10 
12 {
13  template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp>
14  __host__ __device__ static constexpr auto
15  MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
16  index_t seqlen_qk_start)
17  {
18  (void)randval_dram_block_window_tmp;
19  (void)seqlen_qk_start;
20 
22  }
23 };
24 
26 {
28  index_t i_head,
29  index_t nheads,
30  unsigned long long seed,
31  unsigned long long offset,
32  float rp_undrop_,
33  uint8_t p_undrop_in_uint8_t_,
34  bool is_store_randval_)
35  : ph(seed, offset + (i_batch * nheads + i_head) * get_warp_size() + get_lane_id()),
36  rp_undrop(rp_undrop_),
37  p_undrop_in_uint8_t(p_undrop_in_uint8_t_),
38  is_store_randval(is_store_randval_)
39  {
40  }
41 
42  template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp>
43  CK_TILE_HOST_DEVICE static constexpr auto
44  MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
45  index_t seqlen_qk_start)
46  {
47  constexpr auto config =
48  BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
49  using WG = remove_cvref_t<decltype(config.template at<0>())>;
50  constexpr index_t MWarp = config.template at<1>();
51  constexpr index_t NWarp = config.template at<2>();
52  constexpr index_t kMPerStep = MWarp * WG::kM;
53  constexpr index_t kNPerStep = NWarp * WG::kN;
54 
55  const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
56  auto randval_dram_window = [&]() {
57  if constexpr(IsFwd)
58  {
59  return make_tile_window(
60  randval_dram_block_window_tmp.get_bottom_tensor_view(),
62  {block_origin.at(number<0>{}), seqlen_qk_start}); // M/N
63  }
64  else
65  {
66  return make_tile_window(
67  randval_dram_block_window_tmp.get_bottom_tensor_view(),
69  {seqlen_qk_start, block_origin.at(number<1>{})}); // M/N
70  }
71  }();
72 
73  return randval_dram_window;
74  }
75 
76  template <typename BlockGemm>
78  {
79  constexpr auto config =
80  BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
81  using WG = remove_cvref_t<decltype(config.template at<0>())>;
82  constexpr index_t MWarp = config.template at<1>();
83  constexpr index_t kMPerStep = MWarp * WG::kM;
84  constexpr index_t kNPerStep = WG::kN;
85  constexpr index_t kN1 = 8;
86  constexpr index_t kN0 = kNPerStep / kN1;
87 
88  constexpr auto randval_lds_block_desc_0 = make_naive_tensor_descriptor(
90  ck_tile::make_tuple(number<(kMPerStep + 1) * kN1>{}, number<kN1>{}, number<1>{}),
91  number<kN1>{},
92  number<1>{});
93 
94  constexpr auto randval_lds_block_desc = transform_tensor_descriptor(
95  randval_lds_block_desc_0,
101 
102  return randval_lds_block_desc;
103  }
104 
105  template <typename BlockGemm>
107  {
108  constexpr auto config =
109  BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
110  constexpr index_t MWarp = config.template at<1>();
111  constexpr index_t NWarp = config.template at<2>();
112 
113  constexpr index_t MIterPerWarp = 1;
114  constexpr index_t NIterPerWarp = 1;
115 
116  constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
117  sequence<>,
122  sequence<0, 0>>{};
123 
124  // Use Bwd WarpGemm to ensure that Fwd's random values ​​are consistent with Bwd.
125  constexpr auto randval_block_inner_part_dstr_encoding = []() {
126  if constexpr(std::is_same_v<typename BlockGemm::ADataType, half_t> &&
127  std::is_same_v<typename BlockGemm::BDataType, half_t> &&
128  std::is_same_v<typename BlockGemm::CDataType, float>)
129  {
131  }
132  else
133  {
135  }
136  }();
137 
138  constexpr auto randval_block_part_dstr_encode =
139  detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
140  randval_block_inner_part_dstr_encoding);
141 
142  return make_static_tile_distribution(randval_block_part_dstr_encode);
143  }
144 
145  template <typename BlockGemm>
147  {
148  constexpr auto config =
149  BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
150  using WG = remove_cvref_t<decltype(config.template at<0>())>;
151  constexpr index_t MWarp = config.template at<1>();
152  constexpr index_t NWarp = config.template at<2>();
153 
154  constexpr index_t MIterPerWarp = 1;
155  constexpr index_t NIterPerWarp = 1;
156 
157  constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
158  sequence<>,
163  sequence<0, 0>>{};
164 
165  constexpr auto randval_block_part_dstr_encode =
166  detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
167  typename WG::CWarpDstrEncoding{});
168 
169  return make_static_tile_distribution(randval_block_part_dstr_encode);
170  }
171 
172  template <typename BlockGemm,
173  typename PComputeDataType,
174  typename RandValOutputDataType,
175  typename PComputeWindow,
176  typename RandValDramWindow>
177  CK_TILE_HOST_DEVICE void Run(void* randval_ptr,
178  const index_t start_n0_idx,
179  PComputeWindow& p_compute,
180  RandValDramWindow& randval_dram_window) const
181  {
182  constexpr auto config =
183  BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
184  using WG = remove_cvref_t<decltype(config.template at<0>())>;
185  constexpr index_t MWarp = config.template at<1>();
186  constexpr index_t NWarp = config.template at<2>();
188  constexpr index_t kMPerBlock = BlockGemmShape::kM;
189  constexpr index_t kNPerBlock = BlockGemmShape::kN;
190  constexpr index_t kMPerStep = MWarp * WG::kM;
191  constexpr index_t kNPerStep = NWarp * WG::kN;
192 
193  // randval tile in LDS
194  auto randval_lds = make_tensor_view<address_space_enum::lds>(
195  reinterpret_cast<uint8_t*>(randval_ptr), MakeRandValLdsBlockDescriptor<BlockGemm>());
196 
197  auto randval_lds_window = make_tile_window(
198  randval_lds, MakeRandValLdsBlockDescriptor<BlockGemm>().get_lengths(), {0, 0});
199 
200  // register distribute
201  auto randval_dist_generated =
202  make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
203  static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
204 
205  auto randval_lds_read_window =
206  make_tile_window(randval_lds_window.get_bottom_tensor_view(),
207  randval_lds_window.get_window_lengths(),
208  randval_lds_window.get_window_origin(),
209  MakeRandValLdsShuffleTileDistribution<BlockGemm>());
210 
211  const int start_m0_idx = randval_dram_window.get_window_origin().at(number<0>{});
212  if(is_store_randval)
213  {
214  static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
215  static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
216  int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id();
217  int block_col_start = (start_n0_idx / WG::kN) + i_n0;
218  uint2 rowcol = make_uint2(block_row_start, block_col_start);
219 
220  // generate random number
221  uint8_t random_uint8_t[16];
222  ph.get_random_16x8(random_uint8_t,
223  reinterpret_cast<unsigned long long&>(rowcol));
224 
225  constexpr auto randval_dist_generated_spans =
226  decltype(randval_dist_generated)::get_distributed_spans();
227  int i_random_idx = 0;
228  sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) {
229  sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) {
230  constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1);
231  randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
232  });
233  });
234  // save to LDS
235  store_tile(randval_lds_window, randval_dist_generated);
236  block_sync_lds();
237  // read from LDS to register
238  auto randval = load_tile(randval_lds_read_window);
239  // save to Global
240  const auto randval_store = cast_tile<RandValOutputDataType>(randval);
241  store_tile(randval_dram_window, randval_store);
242  move_tile_window(randval_dram_window, {0, kNPerStep});
243  });
244  move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock});
245  });
246  move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock});
247  };
248  static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
249  static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
250  int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id();
251  int block_col_start = (start_n0_idx / WG::kN) + i_n0;
252  uint2 rowcol = make_uint2(block_row_start, block_col_start);
253 
254  // generate random number
255  uint8_t random_uint8_t[16];
256  ph.get_random_16x8(random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol));
257 
258  constexpr auto randval_dist_generated_spans =
259  decltype(randval_dist_generated)::get_distributed_spans();
260  int i_random_idx = 0;
261  sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) {
262  sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) {
263  constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1);
264  randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
265  });
266  });
267  // save to LDS
268  store_tile(randval_lds_window, randval_dist_generated);
269  block_sync_lds();
270  // read from LDS to register
271  auto randval = load_tile(randval_lds_read_window);
272  constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
273  sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
274  sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
275  constexpr auto p_idx0 = tile_distributed_index<i_m0>{};
276  constexpr auto p_idx1 =
277  tile_distributed_index<i_n0, idx1.impl_.at(1), idx1.impl_.at(2)>{};
278  constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
279  constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
280  p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
281  ? p_compute[p_idx] * rp_undrop
282  : PComputeDataType(0);
283  });
284  });
285  });
286  });
287  }
288 
290  const float rp_undrop;
292  const bool is_store_randval;
293 };
294 
295 template <bool IsDropout_, bool IsWG32_, bool IsStoreRandval_>
297 
298 template <bool IsWG32_, bool IsStoreRandval_>
299 struct BlockDropoutBwd<false, IsWG32_, IsStoreRandval_>
300 {
301  static constexpr bool IsDropout = false;
302  static constexpr bool IsStoreRandval = IsStoreRandval_;
303 
304  template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp>
305  __host__ __device__ static constexpr auto
306  MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
307  index_t seqlen_qk_start)
308  {
309  (void)randval_dram_block_window_tmp;
310  (void)seqlen_qk_start;
311 
313  }
314 };
315 
316 template <bool IsWG32_, bool IsStoreRandval_>
317 struct BlockDropoutBwd<true, IsWG32_, IsStoreRandval_>
318 {
319  static constexpr bool IsDropout = true;
320  // true: 32*32 warp gemm
321  // false: 16*16 warp gemm
322  static constexpr bool IsWG32 = IsWG32_;
323  static constexpr bool IsStoreRandval = IsStoreRandval_;
324 
326  index_t i_head,
327  index_t nheads,
328  unsigned long long seed,
329  unsigned long long offset,
330  float rp_undrop_,
331  uint8_t p_undrop_in_uint8_t_)
332  : ph(seed,
333  offset + (i_batch * nheads + i_head) * get_warp_size() +
334  (IsWG32 ? get_lane_id() : ((get_lane_id() & 47) + ((get_warp_id() & 1) << 4)))),
335  rp_undrop(rp_undrop_),
336  p_undrop_in_uint8_t(p_undrop_in_uint8_t_)
337  {
338  }
339 
340  template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp>
341  CK_TILE_HOST_DEVICE static constexpr auto
342  MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
343  index_t seqlen_qk_start)
344  {
345  constexpr auto config =
346  BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
348  using WG = remove_cvref_t<decltype(config.template at<0>())>;
349  constexpr index_t kMPerBlock = BlockGemmShape::kM;
350  constexpr index_t MWarp = config.template at<1>();
351  constexpr index_t NWarp = config.template at<2>();
352  constexpr bool MBwdWG16MultiIterCheck = (!IsFwd) && (!IsWG32) && (kMPerBlock > 16);
353  constexpr index_t kMPerStep = [&]() {
354  if constexpr(MBwdWG16MultiIterCheck)
355  {
356  return MWarp * WG::kM * 2;
357  }
358  else
359  {
360  return MWarp * WG::kM;
361  }
362  }();
363  constexpr index_t kNPerStep = NWarp * WG::kN;
364 
365  const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
366  auto randval_dram_window = [&]() {
367  if constexpr(IsFwd)
368  {
369  return make_tile_window(
370  randval_dram_block_window_tmp.get_bottom_tensor_view(),
372  {block_origin.at(number<0>{}), seqlen_qk_start}); // M/N
373  }
374  else
375  {
376  return make_tile_window(
377  randval_dram_block_window_tmp.get_bottom_tensor_view(),
379  {seqlen_qk_start, block_origin.at(number<1>{})}); // M/N
380  }
381  }();
382 
383  return randval_dram_window;
384  }
385 
386  template <typename BlockGemm>
388  {
389  constexpr auto config =
390  BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
391  using WG = remove_cvref_t<decltype(config.template at<0>())>;
392  constexpr index_t MWarp = config.template at<1>();
393  constexpr index_t kMPerStep = MWarp * WG::kM;
394  constexpr index_t kNPerStep = WG::kN;
395  constexpr index_t kN1 = 8;
396  constexpr index_t kN0 = kNPerStep / kN1;
397 
398  constexpr auto randval_lds_block_desc_0 = make_naive_tensor_descriptor(
400  ck_tile::make_tuple(number<(kMPerStep + 1) * kN1>{}, number<kN1>{}, number<1>{}),
401  number<kN1>{},
402  number<1>{});
403 
404  constexpr auto randval_lds_block_desc = transform_tensor_descriptor(
405  randval_lds_block_desc_0,
411 
412  return randval_lds_block_desc;
413  }
414 
415  template <typename BlockGemm, bool IsFwd = true>
417  {
418  constexpr auto config =
419  BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
421  constexpr index_t kMPerBlock = BlockGemmShape::kM;
422  constexpr index_t MWarp = config.template at<1>();
423  constexpr index_t NWarp = config.template at<2>();
424  constexpr bool MBwdWG16MultiIterCheck = (!IsFwd) && (!IsWG32) && (kMPerBlock > 16);
425 
426  constexpr index_t MIterPerWarp = [&]() {
427  if constexpr(MBwdWG16MultiIterCheck)
428  {
429  return 2;
430  }
431  else
432  {
433  return 1;
434  }
435  }();
436  constexpr index_t NIterPerWarp = 1;
437 
438  constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
439  sequence<>,
444  sequence<0, 0>>{};
445 
446  // Use Bwd WarpGemm to ensure that Fwd's random values ​​are consistent with Bwd.
447  // except headdim256.
448  constexpr auto randval_block_inner_part_dstr_encoding = []() {
449  if constexpr(std::is_same_v<typename BlockGemm::ADataType, half_t> &&
450  std::is_same_v<typename BlockGemm::BDataType, half_t> &&
451  std::is_same_v<typename BlockGemm::CDataType, float>)
452  {
453  if constexpr(IsWG32)
455  else
457  }
458  else
459  {
460  if constexpr(IsWG32)
462  else
464  }
465  }();
466 
467  constexpr auto randval_block_part_dstr_encode =
468  detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
469  randval_block_inner_part_dstr_encoding);
470 
471  return make_static_tile_distribution(randval_block_part_dstr_encode);
472  }
473 
474  template <typename BlockGemm>
476  {
477  constexpr auto config =
478  BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
479  using WG = remove_cvref_t<decltype(config.template at<0>())>;
480  constexpr index_t MWarp = config.template at<1>();
481  constexpr index_t NWarp = config.template at<2>();
482 
483  constexpr index_t MIterPerWarp = 1;
484  constexpr index_t NIterPerWarp = 1;
485 
486  constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
487  sequence<>,
492  sequence<0, 0>>{};
493 
494  constexpr auto randval_block_part_dstr_encode =
495  detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
496  typename WG::CWarpDstrEncoding{});
497 
498  return make_static_tile_distribution(randval_block_part_dstr_encode);
499  }
500 
501  template <typename BlockGemm,
502  typename PComputeDataType,
503  typename RandValOutputDataType,
504  typename PComputeWindow,
505  typename RandValDramWindow>
506  CK_TILE_HOST_DEVICE void Run(void* randval_ptr,
507  const index_t start_m0_idx,
508  const index_t start_n0_idx,
509  PComputeWindow& p_compute,
510  RandValDramWindow& randval_dram_window) const
511  {
512  constexpr auto config =
513  BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
514  using WG = remove_cvref_t<decltype(config.template at<0>())>;
515  constexpr index_t MWarp = config.template at<1>();
516  constexpr index_t NWarp = config.template at<2>();
518  constexpr index_t kMPerBlock = BlockGemmShape::kM;
519  constexpr index_t kNPerBlock = BlockGemmShape::kN;
520  constexpr index_t kMPerStep = MWarp * WG::kM;
521  constexpr index_t kNPerStep = NWarp * WG::kN;
522 
523  // randval tile in LDS
524  auto randval_lds = make_tensor_view<address_space_enum::lds>(
525  reinterpret_cast<uint8_t*>(randval_ptr), MakeRandValLdsBlockDescriptor<BlockGemm>());
526 
527  auto randval_lds_window = make_tile_window(
528  randval_lds, MakeRandValLdsBlockDescriptor<BlockGemm>().get_lengths(), {0, 0});
529 
530  // register distribute
531  auto randval_dist_generated =
532  make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
533  static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
534 
535  auto randval_lds_read_window =
536  make_tile_window(randval_lds_window.get_bottom_tensor_view(),
537  randval_lds_window.get_window_lengths(),
538  randval_lds_window.get_window_origin(),
539  MakeRandValLdsShuffleTileDistribution<BlockGemm>());
540 
541  static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
542  static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
543  int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id();
544  int block_col_start = (start_n0_idx / WG::kN) + i_n0;
545  uint2 rowcol = make_uint2(block_row_start, block_col_start);
546 
547  // generate random number
548  uint8_t random_uint8_t[16];
549  ph.get_random_16x8(random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol));
550 
551  constexpr auto randval_dist_generated_spans =
552  decltype(randval_dist_generated)::get_distributed_spans();
553  int i_random_idx = 0;
554  sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) {
555  sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) {
556  constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1);
557  randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
558  });
559  });
560  // save to LDS
561  store_tile(randval_lds_window, randval_dist_generated);
562  block_sync_lds();
563  // read from LDS to register
564  auto randval = load_tile(randval_lds_read_window);
565  constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
566  sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
567  sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
568  constexpr auto p_idx0 = tile_distributed_index<i_m0>{};
569  constexpr auto p_idx1 =
570  tile_distributed_index<i_n0, idx1.impl_.at(1), idx1.impl_.at(2)>{};
571  constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
572  constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
573  p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
574  ? p_compute[p_idx] * rp_undrop
575  : PComputeDataType(0);
576  });
577  });
578  // save to Global
579  if constexpr(IsStoreRandval)
580  {
581  const auto randval_store = cast_tile<RandValOutputDataType>(randval);
582  store_tile(randval_dram_window, randval_store);
583  move_tile_window(randval_dram_window, {0, kNPerStep});
584  }
585  });
586  if constexpr(IsStoreRandval)
587  {
588  move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock});
589  }
590  });
591  if constexpr(IsStoreRandval)
592  {
593  move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock});
594  }
595  }
596 
597  template <typename BlockGemm,
598  typename RandValOutputDataType,
599  typename PComputeWindow,
600  typename RandValDramWindow>
601  CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx,
602  const index_t start_n0_idx,
603  PComputeWindow& p_compute,
604  RandValDramWindow& randval_dram_window) const
605  {
606  constexpr auto config =
607  BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
608  using WG = remove_cvref_t<decltype(config.template at<0>())>;
609  constexpr index_t MWarp = config.template at<1>();
610  constexpr index_t NWarp = config.template at<2>();
612  constexpr index_t kMPerBlock = BlockGemmShape::kM;
613  constexpr index_t kNPerBlock = BlockGemmShape::kN;
614  constexpr bool MBwdWG16MultiIterCheck = (!IsWG32) && (kMPerBlock > 16);
615  constexpr bool MBwdWG16SingleIterCheck = (!IsWG32) && (kMPerBlock == 16);
616  constexpr index_t kMPerStep = [&]() {
617  if constexpr(MBwdWG16MultiIterCheck)
618  {
619  return MWarp * WG::kM * 2;
620  }
621  else
622  {
623  return MWarp * WG::kM;
624  }
625  }();
626  constexpr index_t kNPerStep = NWarp * WG::kN;
627 
628  // register distribute
629  auto randval = make_static_distributed_tensor<uint8_t>(
630  MakeRandValTileDistribution<BlockGemm, false>());
631  if constexpr(IsWG32)
632  static_assert(randval.kThreadElementSpaceSize == 16);
633  else
634  static_assert(randval.kThreadElementSpaceSize == 4 ||
635  randval.kThreadElementSpaceSize == 8);
636 
637  static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
638  static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
639  int block_row_start, block_col_start;
640  if constexpr(IsWG32)
641  {
642  block_row_start = (start_m0_idx / WG::kM) + i_m0;
643  block_col_start = (start_n0_idx / WG::kN) + (i_n0 * NWarp) + get_warp_id();
644  }
645  else
646  {
647  block_row_start = start_m0_idx / 32 + i_m0;
648  block_col_start = (start_n0_idx / 32) + get_warp_id() / 2 + i_n0 * 2;
649  }
650  uint2 rowcol = make_uint2(block_row_start, block_col_start);
651 
652  // generate random number
653  uint8_t* random_uint8_t_;
654  if constexpr(MBwdWG16SingleIterCheck)
655  {
656  uint8_t random_uint8_t[4];
657  // m0t0 ~m0t15/m0t32~m0t47: 0
658  // m0t16~m0t31/m0t48~m0t63: 1
659  // m1t0 ~m1t15/m1t32~m1t47: 2
660  // m1t16~m1t31/m1t48~m1t63: 3
661  const index_t start_idx =
662  ((get_lane_id() >> 4) & 1) + (((start_m0_idx >> 4) & 1) << 1);
663  ph.get_random_4x8(
664  random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol), start_idx);
665  random_uint8_t_ = random_uint8_t;
666  }
667  else if constexpr(MBwdWG16MultiIterCheck)
668  {
669  uint8_t random_uint8_t[8];
670  // t0 ~t15/t32~t47: 0
671  // t16~t31/t48~t63: 1
672  const index_t start_idx = (get_lane_id() >> 4) & 1;
673  ph.get_random_8x8(
674  random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol), start_idx);
675  random_uint8_t_ = random_uint8_t;
676  }
677  else
678  {
679  uint8_t random_uint8_t[16];
680  ph.get_random_16x8(random_uint8_t,
681  reinterpret_cast<unsigned long long&>(rowcol));
682  random_uint8_t_ = random_uint8_t;
683  }
684 
685  constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
686  int i_random_idx = 0;
687  sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
688  sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
689  constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
690  randval(r_idx) = random_uint8_t_[i_random_idx++];
691  constexpr auto p_idx0 = tile_distributed_index<i_m0 + idx0.impl_.at(0),
692  idx0.impl_.at(1),
693  idx0.impl_.at(2)>{};
694  constexpr auto p_idx1 = tile_distributed_index<i_n0>{};
695  constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
696  p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
697  ? p_compute[p_idx]
698  : -p_compute[p_idx];
699  });
700  });
701  // save to Global
702  if constexpr(IsStoreRandval)
703  {
704  const auto randval_store = cast_tile<RandValOutputDataType>(randval);
705  store_tile(randval_dram_window, randval_store);
706  move_tile_window(randval_dram_window, {kMPerStep, 0});
707  }
708  });
709  if constexpr(IsStoreRandval)
710  {
711  move_tile_window(randval_dram_window, {-kMPerBlock, kNPerStep});
712  }
713  });
714  if constexpr(IsStoreRandval)
715  {
716  move_tile_window(randval_dram_window, {kMPerBlock, -kNPerBlock});
717  }
718  }
719 
721  const float rp_undrop;
723 };
724 
725 } // namespace ck_tile
CK_TILE_DEVICE void block_sync_lds()
Definition: arch.hpp:190
Definition: philox_rand.hpp:12
CK_TILE_HOST_DEVICE void get_random_16x8(uint8_t *out, const unsigned long long subsequence) const
Definition: philox_rand.hpp:42
CK_TILE_HOST_DEVICE void get_random_8x8(uint8_t *out, const unsigned long long subsequence, const index_t start_idx) const
Definition: philox_rand.hpp:56
CK_TILE_HOST_DEVICE void get_random_4x8(uint8_t *out, const unsigned long long subsequence, const index_t start_idx) const
Definition: philox_rand.hpp:73
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:457
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
unsigned char uint8_t
Definition: stdint.h:124
__host__ static constexpr __device__ auto MakeRandvalDramWindow(RandValDramBlockWindowTmp &randval_dram_block_window_tmp, index_t seqlen_qk_start)
Definition: block_dropout.hpp:306
CK_TILE_HOST_DEVICE BlockDropoutBwd(index_t i_batch, index_t i_head, index_t nheads, unsigned long long seed, unsigned long long offset, float rp_undrop_, uint8_t p_undrop_in_uint8_t_)
Definition: block_dropout.hpp:325
static constexpr CK_TILE_HOST_DEVICE auto MakeRandValLdsBlockDescriptor()
Definition: block_dropout.hpp:387
static constexpr CK_TILE_HOST_DEVICE auto MakeRandvalDramWindow(RandValDramBlockWindowTmp &randval_dram_block_window_tmp, index_t seqlen_qk_start)
Definition: block_dropout.hpp:342
static constexpr CK_TILE_HOST_DEVICE auto MakeRandValLdsShuffleTileDistribution()
Definition: block_dropout.hpp:475
const uint8_t p_undrop_in_uint8_t
Definition: block_dropout.hpp:722
static constexpr CK_TILE_HOST_DEVICE auto MakeRandValTileDistribution()
Definition: block_dropout.hpp:416
ck_tile::philox ph
Definition: block_dropout.hpp:720
CK_TILE_HOST_DEVICE void Run(void *randval_ptr, const index_t start_m0_idx, const index_t start_n0_idx, PComputeWindow &p_compute, RandValDramWindow &randval_dram_window) const
Definition: block_dropout.hpp:506
CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx, const index_t start_n0_idx, PComputeWindow &p_compute, RandValDramWindow &randval_dram_window) const
Definition: block_dropout.hpp:601
const float rp_undrop
Definition: block_dropout.hpp:721
Definition: block_dropout.hpp:296
Definition: block_dropout.hpp:26
const uint8_t p_undrop_in_uint8_t
Definition: block_dropout.hpp:291
CK_TILE_HOST_DEVICE BlockDropout(index_t i_batch, index_t i_head, index_t nheads, unsigned long long seed, unsigned long long offset, float rp_undrop_, uint8_t p_undrop_in_uint8_t_, bool is_store_randval_)
Definition: block_dropout.hpp:27
ck_tile::philox ph
Definition: block_dropout.hpp:289
const float rp_undrop
Definition: block_dropout.hpp:290
const bool is_store_randval
Definition: block_dropout.hpp:292
static constexpr CK_TILE_HOST_DEVICE auto MakeRandvalDramWindow(RandValDramBlockWindowTmp &randval_dram_block_window_tmp, index_t seqlen_qk_start)
Definition: block_dropout.hpp:44
CK_TILE_HOST_DEVICE void Run(void *randval_ptr, const index_t start_n0_idx, PComputeWindow &p_compute, RandValDramWindow &randval_dram_window) const
Definition: block_dropout.hpp:177
static constexpr CK_TILE_HOST_DEVICE auto MakeRandValTileDistribution()
Definition: block_dropout.hpp:106
static constexpr CK_TILE_HOST_DEVICE auto MakeRandValLdsShuffleTileDistribution()
Definition: block_dropout.hpp:146
static constexpr CK_TILE_HOST_DEVICE auto MakeRandValLdsBlockDescriptor()
Definition: block_dropout.hpp:77
Definition: block_dropout.hpp:12
__host__ static constexpr __device__ auto MakeRandvalDramWindow(RandValDramBlockWindowTmp &randval_dram_block_window_tmp, index_t seqlen_qk_start)
Definition: block_dropout.hpp:15
typename WarpGemmAttribute::CWarpDstrEncoding CWarpDstrEncoding
Definition: warp_gemm_impl.hpp:30
Definition: integral_constant.hpp:13
Definition: coordinate_transform.hpp:1392
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: tile_distribution.hpp:42
static constexpr auto impl_
Definition: tile_distribution.hpp:45
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192