/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp Source File
gridwise_normalization_splitk_1st.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/math.hpp"
12 
13 namespace ck {
14 
15 template <typename XDataType,
16  typename ComputeDataType,
17  typename MeanVarDataType,
18  typename XGridDesc_M_K,
19  typename MeanVarGridDesc_M_KBlock,
20  index_t BlockSize,
21  index_t MThreadClusterSize,
22  index_t KThreadClusterSize,
23  index_t MThreadSliceSize,
24  index_t KThreadSliceSize,
25  index_t XSrcVectorDim,
26  index_t XSrcVectorSize>
28 {
29  static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
30  (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
31  "Invalid thread slice sizes and/or vector sizes configuration, please check!");
32 
33  static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0);
34 
35  static constexpr auto I0 = Number<0>{};
36  static constexpr auto I1 = Number<1>{};
37  static constexpr auto I2 = Number<2>{};
38 
40 
43 
46 
47  static constexpr auto thread_cluster_desc =
49 
53 
55  static constexpr auto thread_buffer_desc_m_1 =
57 
62 
65 
66  using BlockwiseWelford = BlockwiseWelford<ComputeDataType,
67  BlockSize,
70  false>;
71 
73 
74  static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
75  static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
76  static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize;
77 
78  static constexpr auto ThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
79 
80  __device__ static int
81  GetKPerThread(int k, int kRaw, int kGridSize, int block_k_cluster_id, int thread_k_cluster_id)
82  {
83  bool is_rightmost_block = block_k_cluster_id == kGridSize - 1;
84 
85  if(is_rightmost_block)
86  {
87  int left_kPerBlock = math::integer_divide_ceil(k, kGridSize);
88  int kRightmostBlock = kRaw - left_kPerBlock * (kGridSize - 1);
89  int kPerThread = kRightmostBlock < K_BlockTileSize
90  ? 0
91  : KThreadSliceSize * (kRightmostBlock / K_BlockTileSize);
92  int kPerBlockTail = kRightmostBlock - kPerThread * KThreadClusterSize;
93 
94  if(kPerBlockTail > 0)
95  {
97  int thread_max_len =
98  (thread_k_cluster_id + 1) * XSrcVectorSize + K_BlockTileStepSize * i;
99  int delta = thread_max_len - kPerBlockTail;
100  delta = math::clamp(thread_max_len - kPerBlockTail, 0, XSrcVectorSize);
101  kPerThread += XSrcVectorSize - delta;
102  });
103  }
104 
105  return kPerThread;
106  }
107  else
108  {
109  int kPerBlock = math::integer_divide_ceil(k, kGridSize);
110  return KThreadSliceSize * (kPerBlock / K_BlockTileSize);
111  }
112  }
113 
114  // Calculate mean and variance by welford along k dimension
115  __device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k,
116  const MeanVarGridDesc_M_KBlock& mean_var_grid_desc_m_kblock,
117  index_t num_k_block_tile_iteration,
118  const XDataType* const __restrict__ p_x_global,
119  MeanVarDataType* const p_mean_global,
120  MeanVarDataType* const p_variance_global,
121  int32_t* const p_welford_count_global)
122  {
123  auto x_thread_buf = generate_tuple(
124  [&](auto) {
126  ComputeDataType,
127  MThreadSliceSize * XSrcVectorSize,
128  true>{};
129  },
131 
133  mean_thread_buf;
135  var_thread_buf;
136 
137  const index_t thread_local_id = get_thread_local_1d_id();
138  const index_t block_global_id = get_block_1d_id();
139 
140  const index_t k_grid_size = mean_var_grid_desc_m_kblock.GetLength(I1);
141  const index_t block_m_cluster_id = block_global_id / k_grid_size;
142  const index_t block_k_cluster_id = block_global_id % k_grid_size;
143 
144  const auto thread_cluster_idx =
145  thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
146 
147  const auto thread_m_cluster_id = thread_cluster_idx[I0];
148  const auto thread_k_cluster_id = thread_cluster_idx[I1];
149 
150  const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
151 
152  auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
153  ComputeDataType,
154  XGridDesc_M_K,
155  decltype(thread_buffer_desc_m_k),
158  XSrcVectorDim,
159  XSrcVectorSize,
160  1,
161  true>(
162  x_grid_desc_m_k,
164  block_m_cluster_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
165  block_k_cluster_id * reduceSizePerBlock + thread_k_cluster_id * XSrcVectorSize));
166 
167  auto mean_var_count_store_index = make_multi_index(
168  block_m_cluster_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
169  block_k_cluster_id);
170 
171  auto threadwise_welford_mean_var_store =
172  ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
173  MeanVarDataType,
174  decltype(thread_buffer_desc_m_1),
175  MeanVarGridDesc_M_KBlock,
179  1,
180  1,
182  1,
183  true>(
184  mean_var_grid_desc_m_kblock, mean_var_count_store_index, PassThroughOp{});
185 
186  constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
187 
188  const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
189  p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
190 
191  auto mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
192  p_mean_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize());
193 
194  auto var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
195  p_variance_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize());
196 
197  auto threadwise_welford = ThreadwiseWelford();
198  int kRaw = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0];
199  threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k.GetLength(I1),
200  kRaw,
201  k_grid_size,
202  block_k_cluster_id,
203  thread_k_cluster_id);
204 
205  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
206  mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
207  var_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
208  });
209 
210  for(index_t k = 0; k < num_k_block_tile_iteration; ++k)
211  {
213  threadwise_x_load.Run(x_grid_desc_m_k,
214  x_global_val_buf,
216  make_tuple(I0, I0),
217  x_thread_buf(i));
218  threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
219  threadwise_welford.Run(x_thread_buf[i], mean_thread_buf, var_thread_buf);
220  });
221  }
222 
223  int welford_count = 0;
224  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
225  if constexpr(I > 0)
226  block_sync_lds();
227 
228  int count = threadwise_welford.cur_count_;
229  BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
230 
231  // The value of count is same for all I
232  if constexpr(I == MThreadSliceSize - 1)
233  welford_count = count;
234  });
235 
236  if(thread_k_cluster_id == 0)
237  {
238  threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
239  make_tuple(I0, I0),
240  mean_thread_buf,
241  mean_var_grid_desc_m_kblock,
242  mean_global_val_buf);
243 
244  threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
245  make_tuple(I0, I0),
246  var_thread_buf,
247  mean_var_grid_desc_m_kblock,
248  var_global_val_buf);
249 
250  if(block_m_cluster_id == 0 && thread_m_cluster_id == 0)
251  p_welford_count_global[block_k_cluster_id] = welford_count;
252  }
253  }
254 };
255 
256 } // namespace ck
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T clamp(const T &x, const T &lowerbound, const T &upperbound)
Definition: math.hpp:148
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
signed int int32_t
Definition: stdint.h:123
static __device__ void Run(T &mean_value, T &var_value, CountDataType &count)
Definition: blockwise_welford.hpp:51
Definition: gridwise_normalization_splitk_1st.hpp:28
Sequence< MThreadSliceSize, 1 > ThreadBufferLengths_M_1
Definition: gridwise_normalization_splitk_1st.hpp:54
static constexpr auto I1
Definition: gridwise_normalization_splitk_1st.hpp:36
static constexpr auto ThreadBufferNumber
Definition: gridwise_normalization_splitk_1st.hpp:78
static constexpr index_t K_BlockTileSize
Definition: gridwise_normalization_splitk_1st.hpp:75
tensor_operation::element_wise::PassThrough PassThroughOp
Definition: gridwise_normalization_splitk_1st.hpp:72
static constexpr auto thread_buffer_desc_m_1
Definition: gridwise_normalization_splitk_1st.hpp:55
static __device__ void Run(const XGridDesc_M_K &x_grid_desc_m_k, const MeanVarGridDesc_M_KBlock &mean_var_grid_desc_m_kblock, index_t num_k_block_tile_iteration, const XDataType *const __restrict__ p_x_global, MeanVarDataType *const p_mean_global, MeanVarDataType *const p_variance_global, int32_t *const p_welford_count_global)
Definition: gridwise_normalization_splitk_1st.hpp:115
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition: gridwise_normalization_splitk_1st.hpp:61
ThreadwiseWelford< ComputeDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M > ThreadwiseWelford
Definition: gridwise_normalization_splitk_1st.hpp:64
static constexpr bool reorder_thread_cluster
Definition: gridwise_normalization_splitk_1st.hpp:33
static constexpr auto I0
Definition: gridwise_normalization_splitk_1st.hpp:35
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}, Number< XSrcVectorSize >{}))) ThreadReduceSrcDesc_M_K
Definition: gridwise_normalization_splitk_1st.hpp:59
static __device__ int GetKPerThread(int k, int kRaw, int kGridSize, int block_k_cluster_id, int thread_k_cluster_id)
Definition: gridwise_normalization_splitk_1st.hpp:81
static constexpr index_t M_BlockTileSize
Definition: gridwise_normalization_splitk_1st.hpp:74
static constexpr index_t K_BlockTileStepSize
Definition: gridwise_normalization_splitk_1st.hpp:76
BlockwiseWelford< ComputeDataType, BlockSize, ThreadClusterLengths_M_K, ThreadClusterArrangeOrder, false > BlockwiseWelford
Definition: gridwise_normalization_splitk_1st.hpp:70
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition: gridwise_normalization_splitk_1st.hpp:45
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition: gridwise_normalization_splitk_1st.hpp:42
static constexpr auto I2
Definition: gridwise_normalization_splitk_1st.hpp:37
static constexpr auto thread_buffer_desc_m_k
Definition: gridwise_normalization_splitk_1st.hpp:51
Sequence< MThreadSliceSize, XSrcVectorSize > ThreadBufferLengths_M_K
Definition: gridwise_normalization_splitk_1st.hpp:50
static constexpr auto thread_cluster_desc
Definition: gridwise_normalization_splitk_1st.hpp:47
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition: gridwise_normalization_splitk_1st.hpp:39
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: functional.hpp:100
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:334