/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_batched_dropout_randval.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_batched_dropout_randval.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_batched_dropout_randval.hpp Source File
reference_batched_dropout_randval.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 #include <thread>
9 
10 namespace ck_tile {
11 
12 template <typename RandValOutputDataType>
13 CK_TILE_HOST void
15  index_t batch,
16  uint64_t drop_seed,
17  uint64_t drop_offset)
18 {
19  const index_t nhead = randval_b_m_n.mDesc.get_lengths()[0];
20  const index_t real_seqlen_q = randval_b_m_n.mDesc.get_lengths()[1];
21  const index_t real_seqlen_k = randval_b_m_n.mDesc.get_lengths()[2];
22 
23  static_assert(std::is_same_v<RandValOutputDataType, uint8_t>);
24 
25  // BlockDropout generates random numbers by 32x32 tiles. Even when warp gemm 16x16 is used, the
26  // order of values in the bigger 32x32 tile must be the same because fwd and bwd may use
27  // different warp gemms (16x16 or 32x32).
28  // To compute 32x32 tiles, WarpGemmMfmaF16F16F32M32N32K16SwizzleA is used. It is
29  // WarpGemmAttributeMfmaImplF16F16F32M32N32K8 with SFactor = 2 (swizzling factor).
30  // Matrix element to register mapping for WarpGemmAttributeMfmaImplF16F16F32M32N32K8:
31  // C i: (8 * floor(GPR_num / 4) % 32) + 4 * floor(lane / 32) + (GPR_num % 4)
32  // C j: (lane % 32)
33  // With SFactor = 2 it becomes:
34  // C i: (16 * floor(GPR_num / 8) % 32) + 8 * floor(lane / 32) + (GPR_num % 8)
35  // C j: (lane % 32)
36  // See ck_tile/ops/fmha/block/block_dropout.hpp for more details.
37 
38  // The number of Philox 4x32 results required to fill 32x32 tile of 8-bit values
39  constexpr index_t philox_per_tile = 64;
40  constexpr index_t warp_gemm_mn = 32;
41 
42  const index_t rows = integer_divide_ceil(real_seqlen_q, warp_gemm_mn);
43  const index_t cols = integer_divide_ceil(real_seqlen_k, warp_gemm_mn);
44 
45  auto f = [&](index_t i_h, index_t row, index_t col) {
46  uint2 rowcol = make_uint2(row, col);
47  for(index_t lane = 0; lane < philox_per_tile; lane++)
48  {
49  const uint64_t ph_head_offset = drop_offset + (batch * nhead + i_h) * philox_per_tile;
50  const index_t ph_offset = lane;
51  philox ph(drop_seed, ph_head_offset + ph_offset);
52 
53  uint8_t random_uint8_t[16];
54  ph.get_random_16x8(random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol));
55 
56  for(auto r = 0; r < 16; r++)
57  {
58  index_t i = (16 * (r / 8) % 32) + 8 * (lane / 32) + (r % 8);
59  index_t j = (lane % 32);
60  index_t m = row * warp_gemm_mn + i;
61  index_t n = col * warp_gemm_mn + j;
62 
63  if(m < real_seqlen_q && n < real_seqlen_k)
64  {
65  randval_b_m_n(i_h, m, n) = random_uint8_t[r];
66  }
67  }
68  }
69  };
70 
71  make_ParallelTensorFunctor(f, nhead, rows, cols)(std::thread::hardware_concurrency());
72 }
73 
74 } // namespace ck_tile
Definition: philox_rand.hpp:12
CK_TILE_HOST_DEVICE void get_random_16x8(uint8_t *out, const unsigned long long subsequence) const
Definition: philox_rand.hpp:42
#define CK_TILE_HOST
Definition: config.hpp:40
constexpr index_t philox_per_tile
Definition: block_dropout.hpp:35
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
CK_TILE_HOST void reference_batched_dropout_randval(HostTensor< RandValOutputDataType > &randval_b_m_n, index_t batch, uint64_t drop_seed, uint64_t drop_offset)
Definition: reference_batched_dropout_randval.hpp:14
int32_t index_t
Definition: integer.hpp:9
unsigned char uint8_t
Definition: stdint.h:124
unsigned __int64 uint64_t
Definition: stdint.h:136
const std::vector< std::size_t > & get_lengths() const
Definition: host_tensor.hpp:198
Definition: host_tensor.hpp:336
Descriptor mDesc
Definition: host_tensor.hpp:800