/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp Source File
block_norm_reduce.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 
9 namespace ck_tile {
10 
11 template <typename Problem_, typename Policy_ = void>
13 {
15  using XDataType = typename Problem::XDataType;
16  using ComputeDataType = typename Problem::ComputeDataType;
17  static constexpr bool kFastFDiv = Problem::kFastFDiv;
18  static constexpr bool kWelford = Problem::kWelford;
19 
21 
22  // [CAUSION] - max_count_ is to deal with the padding problem
23  // max_count_ is depend on caller, eg: naive and splitN norm_reduce will have different
24  // calculation of max_count_
25  // -> use block_welford_calculate_max_count to compute
26  template <typename XDistributedTensor_,
27  typename MeanDistributedTensor_,
28  typename VarDistributedTensor_>
29  CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor,
30  MeanDistributedTensor_& mean_tensor,
31  VarDistributedTensor_& var_tensor,
32  int& cur_count_, // -> prefer init as zero
33  const int& max_count_)
34  {
35  constexpr auto I0 = number<0>{};
36  constexpr auto I1 = number<1>{};
37 
38  constexpr auto spans = XDistributedTensor_::get_distributed_spans();
39 
40  sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) {
41  if(cur_count_ < max_count_)
42  {
43  ++cur_count_;
44  sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) {
45  constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1);
46  constexpr auto out_dstr_idx = make_tuple(dstr_idx_i0);
47 
48  auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
49  if(kWelford)
50  {
51  welford_update(mean_tensor(out_dstr_idx),
52  var_tensor(out_dstr_idx),
53  x,
54  cur_count_,
56  }
57  else
58  {
59  mean_tensor(out_dstr_idx) += x;
60  var_tensor(out_dstr_idx) += x * x;
61  }
62  });
63  }
64  });
65  }
66 
67  template <typename XDistributedTensor_>
69  {
70  static_assert(std::is_same_v<XDataType, typename XDistributedTensor_::DataType>, "wrong!");
71 
72  constexpr auto reduce_dims = sequence<1>{};
73 
74  constexpr auto dstr =
76  XDistributedTensor_::get_tile_distribution()
77  .get_static_tile_distribution_encoding(),
78  reduce_dims));
79 
80  auto tensor = make_static_distributed_tensor<ComputeDataType>(dstr);
81 
82  return tensor;
83  }
84 
85  template <typename XDistributedTensor_>
86  CK_TILE_DEVICE auto
87  operator()(const XDistributedTensor_& x_tensor, int& cur_count_, const int& max_count_)
88  {
89  auto mean_tensor = MakeMeanVarBlockTile<XDistributedTensor_>();
90  auto var_tensor = MakeMeanVarBlockTile<XDistributedTensor_>();
91  clear_tile(mean_tensor);
92  clear_tile(var_tensor);
93 
94  (*this)(x_tensor, mean_tensor, var_tensor, cur_count_, max_count_);
95 
96  return ck_tile::make_tuple(mean_tensor, var_tensor);
97  }
98 };
99 
100 template <typename Problem_, typename Policy_ = void>
102 {
104  static constexpr bool kFastFDiv = Problem::kFastFDiv;
105  static constexpr bool kWelford = Problem::kWelford;
106 
107  template <typename MeanDistributedTensor_, typename VarDistributedTensor_>
108  CK_TILE_DEVICE void
109  operator()(MeanDistributedTensor_& mean_tensor, VarDistributedTensor_& var_tensor, int& count)
110  {
111  using Dstr = typename MeanDistributedTensor_::StaticTileDistribution;
112  using DstrEncode = typename Dstr::DstrEncode;
113  using DstrEncodeDetail = typename DstrEncode::detail;
114 
115  static_assert(std::is_same_v<Dstr, typename VarDistributedTensor_::StaticTileDistribution>,
116  "wrong!");
117 
118  constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
119  constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
120 
121  constexpr index_t idim_p_lane = NDimP - 1;
122 
123  // const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
124  // const auto rs_idx =
125  // mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
126 
127  constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
128  static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size());
129 
130  const int original_count = count;
131 
132  // loop over thread data
133  static_for<0, thread_buf_size, 1>{}([&](auto i) {
134  auto v_local_mean = mean_tensor.get_thread_buffer()[i];
135  auto v_local_var = var_tensor.get_thread_buffer()[i];
136  auto v_local_count = original_count;
137 
138  // cross-lane reduce for replication
139  // only reduce on R dimension correspond to lane
140  // (lane id maps to this R dimension)
141  static_for<0, NDimR, 1>{}([&](auto idim_r) {
142  // FIXME: nasty to use does_p_own_r_
143  if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
144  {
145  constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
146 
147  constexpr index_t lid_over_rid_derivative =
148  DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
149 
150  static_assert(is_power_of_two_integer(r_length),
151  "wrong! only support power of 2 reduction");
152 
153  constexpr index_t nstage = integer_log2_floor(r_length);
154 
155  // reduction sweep forward
156  static_for<0, nstage, 1>{}([&](auto istage) {
157  // xor
158  index_t src_lane =
159  (__lane_id()) ^
160  (number<lid_over_rid_derivative << istage.value>{}.value);
161 
162  // pull data from remote lane
163  const auto v_remote_mean = warp_shuffle(v_local_mean, src_lane);
164  const auto v_remote_var = warp_shuffle(v_local_var, src_lane);
165  if(kWelford)
166  {
167  const auto v_remote_count = warp_shuffle(v_local_count, src_lane);
168 
169  // norm_reduce merge
170  welford_merge(v_local_mean,
171  v_local_var,
172  v_local_count,
173  v_remote_mean,
174  v_remote_var,
175  v_remote_count,
177  }
178  else
179  {
180  v_local_mean += v_remote_mean;
181  v_local_var += v_remote_var;
182  }
183  });
184  }
185  });
186 
187  mean_tensor.get_thread_buffer()(i) = v_local_mean;
188  var_tensor.get_thread_buffer()(i) = v_local_var;
189  if(kWelford)
190  {
191  count = v_local_count;
192  }
193  });
194  }
195 };
196 
197 template <typename Problem_, typename Policy_ = void>
199 {
201  using BlockShape = typename Problem::BlockShape;
202  static constexpr bool kFastFDiv = Problem::kFastFDiv;
203  static constexpr bool kWelford = Problem::kWelford;
204  using smem_dtype = std::conditional_t<kWelford, fp32x4_t, fp32x2_t>;
205 
206  template <typename MeanDistributedTensor_>
208  {
209  constexpr index_t num_reduce_warps = [&]() {
210  using Dstr = typename MeanDistributedTensor_::StaticTileDistribution;
211  using DstrEncode = typename Dstr::DstrEncode;
212  using DstrEncodeDetail = typename DstrEncode::detail;
213 
214  constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
215 
216  constexpr index_t idim_p_warp = 0;
217 
218  index_t len_ = 1;
219  static_for<0, NDimR, 1>{}([&](auto idim_r) {
220  if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r])
221  {
222  constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
223  len_ *= r_length;
224  }
225  });
226  return len_;
227  }();
228  return num_reduce_warps;
229  }
230 
231  // return in byte
232  template <typename MeanDistributedTensor_>
234  {
235  // constexpr auto num_reduce_warps = GetReduceWarps<MeanDistributedTensor_>();
236 
237  // data need to exchange is very small, we just pack mean+var+count -> 4dword
238  constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
239 
240  // we need to store all data from every wave into smem
241  // e.g. 2x2 reduce along N
242  // -------------> reduce N
243  // | w0 | w1 | ___> | w01 |
244  // | w2 | w3 | | w23 |
245  //
246  // -> store data from every wave into LDS
247  //
248  //
249  // -------------> reduce N
250  // | w0 | w1 | w2 | w3 | -----> | w0123 |
251  //
252  // -> also store data from every wave into LDS
253  constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
254  return num_warps * 4 * thread_buf_size * sizeof(float);
255  }
256 
257  template <typename MeanDistributedTensor_, typename VarDistributedTensor_>
258  CK_TILE_DEVICE void operator()(MeanDistributedTensor_& mean_tensor,
259  VarDistributedTensor_& var_tensor,
260  int& count,
261  void* smem)
262  {
263  using DataType = typename MeanDistributedTensor_::DataType;
264  using Dstr = typename MeanDistributedTensor_::StaticTileDistribution;
265  // using DstrEncode = typename Dstr::DstrEncode;
266  // using DstrEncodeDetail = typename DstrEncode::detail;
267 
268  static_assert(std::is_same_v<Dstr, typename VarDistributedTensor_::StaticTileDistribution>,
269  "wrong!");
270 
271  constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
272  static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size());
273 
274  // Note: we always pack everything into fp32x4
275  smem_dtype* smem_ptr = reinterpret_cast<smem_dtype*>(smem);
276  const index_t lane_id = get_lane_id();
277  const index_t warp_id = get_warp_id();
278  constexpr auto num_reduce_warps = GetReduceWarps<MeanDistributedTensor_>();
279  constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
280  const index_t smem_offset = warp_id;
281 
282  // skip if nonthing to do
283  if constexpr(num_reduce_warps == 1)
284  return;
285 
286  // store into smem only for lane-0 within one warp
287  if(lane_id == 0)
288  {
289  static_for<0, thread_buf_size, 1>{}([&](auto i) {
290  smem_dtype local_scratch_;
291  local_scratch_[0] = bit_cast<float>(mean_tensor.get_thread_buffer()[i]);
292  local_scratch_[1] = bit_cast<float>(var_tensor.get_thread_buffer()[i]);
293  if(kWelford)
294  {
295  local_scratch_[2] = bit_cast<float>(count);
296  }
297  smem_ptr[smem_offset + i * num_warps] = local_scratch_;
298  });
299  }
300  block_sync_lds();
301 
302  // load from smem. here we let everythread to do compute :)
303  index_t local_warp_id = warp_id / num_reduce_warps;
304  index_t local_smem_os = local_warp_id * num_reduce_warps;
305  smem_dtype all_scratch[thread_buf_size * num_reduce_warps];
306  static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
307  static_for<0, num_reduce_warps, 1>{}([&](auto i_1) {
308  all_scratch[i_0 * num_reduce_warps + i_1] =
309  smem_ptr[i_0 * num_warps + local_smem_os + i_1];
310  });
311  });
312  block_sync_lds(); // TODO: we don't need sync here
313 
314  // const int original_count = count;
315 
316  static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
317  // TODO: use descriptor for this
318  auto v_local = all_scratch[i_0 * num_reduce_warps];
319  auto v_local_mean = bit_cast<DataType>(v_local[0]);
320  auto v_local_var = bit_cast<DataType>(v_local[1]);
321  int v_local_count = kWelford ? bit_cast<int>(v_local[2]) : 0;
322 
323  // further reduce mean/var
324  static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) {
325  constexpr auto i_1 = number<i_1_n1 + 1>{};
326  const smem_dtype v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
327  const auto v_remote_mean = bit_cast<DataType>(v_remote[0]);
328  const auto v_remote_var = bit_cast<DataType>(v_remote[1]);
329  if(kWelford)
330  {
331  const auto v_remote_count = bit_cast<int>(v_remote[2]);
332 
333  welford_merge(v_local_mean,
334  v_local_var,
335  v_local_count,
336  v_remote_mean,
337  v_remote_var,
338  v_remote_count,
340  }
341  else
342  {
343  v_local_mean += v_remote_mean;
344  v_local_var += v_remote_var;
345  }
346  });
347 
348  mean_tensor.get_thread_buffer()(i_0) = v_local_mean;
349  var_tensor.get_thread_buffer()(i_0) = v_local_var;
350  if(kWelford)
351  count = v_local_count;
352  });
353  }
354 };
355 
356 // compute the max count for a last dim reduce
357 // everything may have vector/repeat, so the max count could be uneven
358 // TODO: specify which dim to compute and proper set the problem
359 // TODO: BlockShape we reuse layernorm_fwd_shape :)
360 template <typename BlockShape>
362 {
363 #if 0
364  using S = BlockShape;
365  index_t LastloopN = row_size % S::Block_N == 0 ? S::Block_N : row_size % S::Block_N;
366  constexpr index_t NThread = S::WarpPerBlock_N * S::ThreadPerWarp_N;
367  index_t iNLane = get_thread_id() % NThread;
368  index_t iN0 = LastloopN / (S::Vector_N * S::ThreadPerWarp_N);
369  index_t iN1 = (LastloopN % (S::Vector_N * S::ThreadPerWarp_N)) / S::Vector_N;
370  index_t N2 = (LastloopN % (S::Vector_N * S::ThreadPerWarp_N)) % S::Vector_N;
371  index_t iN3 = iNLane < iN1 ? S::Vector_N : iNLane == iN1 ? N2 : 0;
372  return iN0 * S::Vector_N + iN3;
373 #endif
374  using S_ = BlockShape;
375  constexpr index_t ThreadsPerBlock_N = S_::WarpPerBlock_N * S_::ThreadPerWarp_N;
376 
377  // TODO: we always check vector size, need be evenly devidable by vector-n
378  const index_t element_per_row = row_size / S_::Vector_N;
379  index_t lane_id_n = get_thread_id() % ThreadsPerBlock_N;
380 
381  index_t cnt = 0;
382  // TODO: Repeat_N can not be too long, otherwise this is not good
383  static_for<0, S_::Repeat_N, 1>{}([&](auto) {
384  index_t _a = lane_id_n < element_per_row ? 1 : 0;
385  cnt += _a;
386  lane_id_n += ThreadsPerBlock_N;
387  });
388  return cnt * S_::Vector_N;
389 }
390 
391 // Note: this function must be called after all the computation
392 template <typename VarDistributedTensor_, bool FastFdiv_ = false>
393 CK_TILE_DEVICE constexpr void block_tile_welford_post_scale_var(VarDistributedTensor_& var_tensor,
394  int count,
396 {
397  using DataType = typename VarDistributedTensor_::DataType;
399  [&count](auto& x) {
400  if(FastFdiv_ && std::is_same_v<DataType, float>)
401  {
402  x = x * __builtin_amdgcn_rcpf(type_convert<DataType>(count));
403  }
404  else
405  {
406  x = x / type_convert<DataType>(count);
407  }
408  },
409  var_tensor);
410 }
411 } // namespace ck_tile
CK_TILE_DEVICE void block_sync_lds()
Definition: arch.hpp:190
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
constexpr CK_TILE_HOST_DEVICE auto make_reduce_tile_distribution_encoding(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition: tile_distribution_encoding.hpp:762
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE bool is_power_of_two_integer(int32_t x)
Definition: math.hpp:462
CK_TILE_DEVICE T warp_shuffle(const T &v_local, uint32_t src_lane)
Definition: utility.hpp:78
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition: tile_elementwise.hpp:23
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE void block_tile_welford_post_scale_var(VarDistributedTensor_ &var_tensor, int count, bool_constant< FastFdiv_ >={})
Definition: block_norm_reduce.hpp:393
constexpr CK_TILE_HOST_DEVICE int32_t integer_log2_floor(int32_t x)
Definition: math.hpp:455
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
CK_TILE_DEVICE void welford_update(T &mean, T &var, T x, int count, bool_constant< kFastFDiv >={})
Definition: thread_welford.hpp:11
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition: tile_elementwise.hpp:177
constexpr CK_TILE_DEVICE index_t block_tile_welford_calculate_max_count(int row_size)
Definition: block_norm_reduce.hpp:361
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
Definition: block_norm_reduce.hpp:199
typename Problem::BlockShape BlockShape
Definition: block_norm_reduce.hpp:201
remove_cvref_t< Problem_ > Problem
Definition: block_norm_reduce.hpp:200
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: block_norm_reduce.hpp:233
std::conditional_t< kWelford, fp32x4_t, fp32x2_t > smem_dtype
Definition: block_norm_reduce.hpp:204
CK_TILE_DEVICE void operator()(MeanDistributedTensor_ &mean_tensor, VarDistributedTensor_ &var_tensor, int &count, void *smem)
Definition: block_norm_reduce.hpp:258
static constexpr CK_TILE_DEVICE index_t GetReduceWarps()
Definition: block_norm_reduce.hpp:207
Definition: block_norm_reduce.hpp:13
constexpr CK_TILE_DEVICE BlockNormReduce()
Definition: block_norm_reduce.hpp:20
typename Problem::ComputeDataType ComputeDataType
Definition: block_norm_reduce.hpp:16
static CK_TILE_DEVICE auto MakeMeanVarBlockTile()
Definition: block_norm_reduce.hpp:68
remove_cvref_t< Problem_ > Problem
Definition: block_norm_reduce.hpp:14
static constexpr bool kWelford
Definition: block_norm_reduce.hpp:18
CK_TILE_DEVICE void operator()(const XDistributedTensor_ &x_tensor, MeanDistributedTensor_ &mean_tensor, VarDistributedTensor_ &var_tensor, int &cur_count_, const int &max_count_)
Definition: block_norm_reduce.hpp:29
CK_TILE_DEVICE auto operator()(const XDistributedTensor_ &x_tensor, int &cur_count_, const int &max_count_)
Definition: block_norm_reduce.hpp:87
typename Problem::XDataType XDataType
Definition: block_norm_reduce.hpp:15
static constexpr bool kFastFDiv
Definition: block_norm_reduce.hpp:17
Definition: block_norm_reduce.hpp:102
CK_TILE_DEVICE void operator()(MeanDistributedTensor_ &mean_tensor, VarDistributedTensor_ &var_tensor, int &count)
Definition: block_norm_reduce.hpp:109
static constexpr bool kWelford
Definition: block_norm_reduce.hpp:105
remove_cvref_t< Problem_ > Problem
Definition: block_norm_reduce.hpp:103
static constexpr bool kFastFDiv
Definition: block_norm_reduce.hpp:104
Definition: integral_constant.hpp:13
static constexpr value_type value
Definition: integral_constant.hpp:16
Definition: sequence.hpp:49
Definition: functional.hpp:43