15 template <
typename MeanVarDataType,
17 typename GammaDataType,
18 typename BetaDataType,
20 typename SaveMeanInvStdDataType,
21 typename ComputeDataType,
22 typename YElementwiseOperation,
23 typename MeanVarGridDesc_M_KBlock,
24 typename CountGridDesc_M_KBlock,
25 typename XYGammaBetaGridDesc_M_K,
26 typename SaveMeanInvStdGridDesc_M,
40 index_t SaveMeanInvStdDstVectorSize>
43 static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
44 (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
45 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
47 static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) ||
48 (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
49 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
51 static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
52 "Invalid thread slice sizes and/or save mean and inverse std vector sizes "
53 "configuration, please check!");
55 static_assert(XSrcVectorSize == YDstVectorSize);
56 static_assert(XSrcVectorSize == GammaSrcVectorSize);
57 static_assert(XSrcVectorSize == BetaSrcVectorSize);
107 __device__
static void Run(
const MeanVarGridDesc_M_KBlock& mean_var_grid_desc_m_kblock,
108 const CountGridDesc_M_KBlock& count_grid_desc_m_kblock,
109 const XYGammaBetaGridDesc_M_K& x_grid_desc_m_k,
110 const XYGammaBetaGridDesc_M_K& gamma_grid_desc_m_k,
111 const XYGammaBetaGridDesc_M_K& beta_grid_desc_m_k,
112 const XYGammaBetaGridDesc_M_K& y_grid_desc_m_k,
113 const SaveMeanInvStdGridDesc_M& save_mean_grid_desc_m,
114 const SaveMeanInvStdGridDesc_M& save_inv_std_grid_desc_m,
115 index_t num_k_mean_var_count_iteration,
116 index_t num_k_block_tile_iteration,
118 ComputeDataType epsilon,
119 const MeanVarDataType*
const p_mean_global,
120 const MeanVarDataType*
const p_variance_global,
121 const int32_t*
const p_welford_count_global,
122 const XDataType*
const __restrict__ p_x_global,
123 const GammaDataType*
const __restrict__ p_gamma_global,
124 const BetaDataType*
const __restrict__ p_beta_global,
125 YDataType*
const __restrict__ p_y_global,
126 SaveMeanInvStdDataType*
const __restrict__ p_save_mean_global,
127 SaveMeanInvStdDataType*
const __restrict__ p_save_inv_std_global,
128 const YElementwiseOperation y_elementwise_op)
133 const index_t block_m_cluster_id = block_global_id / k_grid_size;
134 const index_t block_k_cluster_id = block_global_id % k_grid_size;
135 const auto thread_cluster_idx =
138 const auto thread_m_cluster_id = thread_cluster_idx[
I0];
139 const auto thread_k_cluster_id = thread_cluster_idx[
I1];
142 const auto mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
143 p_mean_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize());
145 const auto var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
146 p_variance_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize());
148 const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
149 p_welford_count_global, count_grid_desc_m_kblock.GetElementSpaceSize());
151 const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
152 p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
154 const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
155 p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize());
157 const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
158 p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize());
160 auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
161 p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
163 auto save_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
164 p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize());
166 auto save_inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
167 p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize());
175 in_welford_count_thread_buf;
181 welford_count_thread_buf;
182 auto& inv_std_thread_buf = var_thread_buf;
188 MThreadSliceSize * XSrcVectorSize,
197 MThreadSliceSize * GammaSrcVectorSize,
202 auto& beta_thread_buf = gamma_thread_buf;
203 auto& y_thread_buf = x_thread_buf;
206 auto threadwise_mean_var_load_m_kblock =
209 MeanVarGridDesc_M_KBlock,
217 mean_var_grid_desc_m_kblock,
219 thread_m_cluster_id * MThreadSliceSize,
220 thread_k_cluster_id));
222 auto threadwise_count_load_m_kblock =
225 CountGridDesc_M_KBlock,
233 count_grid_desc_m_kblock,
235 thread_m_cluster_id * MThreadSliceSize,
236 thread_k_cluster_id));
240 XYGammaBetaGridDesc_M_K,
250 thread_m_cluster_id * MThreadSliceSize,
252 thread_k_cluster_id * XSrcVectorSize));
254 auto threadwise_gamma_load =
257 XYGammaBetaGridDesc_M_K,
267 thread_m_cluster_id * MThreadSliceSize,
269 thread_k_cluster_id * GammaSrcVectorSize));
271 auto threadwise_beta_load =
274 XYGammaBetaGridDesc_M_K,
284 thread_m_cluster_id * MThreadSliceSize,
286 thread_k_cluster_id * BetaSrcVectorSize));
288 auto threadwise_y_store =
292 XYGammaBetaGridDesc_M_K,
293 YElementwiseOperation,
303 thread_m_cluster_id * MThreadSliceSize,
305 thread_k_cluster_id * YDstVectorSize),
308 auto threadwise_mean_store =
310 SaveMeanInvStdDataType,
312 SaveMeanInvStdGridDesc_M,
317 SaveMeanInvStdDstVectorSize,
321 save_mean_grid_desc_m,
323 thread_m_cluster_id * MThreadSliceSize),
326 auto threadwise_inv_std_store =
328 SaveMeanInvStdDataType,
330 SaveMeanInvStdGridDesc_M,
335 SaveMeanInvStdDstVectorSize,
339 save_inv_std_grid_desc_m,
341 thread_m_cluster_id * MThreadSliceSize),
345 constexpr
auto mean_var_count_thread_copy_step_I0_k =
349 mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
350 var_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
351 welford_count_thread_buf(I) = 0;
354 for(
index_t k = 0; k < num_k_mean_var_count_iteration; ++k)
356 threadwise_mean_var_load_m_kblock.Run(mean_var_grid_desc_m_kblock,
362 threadwise_mean_var_load_m_kblock.Run(mean_var_grid_desc_m_kblock,
368 threadwise_count_load_m_kblock.
Run(count_grid_desc_m_kblock,
369 welford_count_global_val_buf,
372 in_welford_count_thread_buf);
376 in_welford_count_thread_buf,
379 welford_count_thread_buf);
381 threadwise_mean_var_load_m_kblock.MoveSrcSliceWindow(
382 mean_var_grid_desc_m_kblock, mean_var_count_thread_copy_step_I0_k);
384 mean_var_count_thread_copy_step_I0_k);
392 mean_thread_buf(I), var_thread_buf(I), welford_count_thread_buf(I));
394 inv_std_thread_buf(I) =
395 type_convert<ComputeDataType>(1.0f) / ck::math::sqrt(var_thread_buf(I) + epsilon);
399 if(block_k_cluster_id == 0 && thread_k_cluster_id == 0)
401 if(p_save_mean_global !=
nullptr)
406 save_mean_grid_desc_m,
407 save_mean_global_val_buf);
409 if(p_save_inv_std_global !=
nullptr)
414 save_inv_std_grid_desc_m,
415 save_inv_std_global_val_buf);
422 for(
index_t k = 0; k < num_k_block_tile_iteration; ++k)
425 threadwise_x_load.
Run(x_grid_desc_m_k,
434 threadwise_gamma_load.
Run(gamma_grid_desc_m_k,
435 gamma_global_val_buf,
438 gamma_thread_buf(i));
441 thread_copy_fwd_step_m_k);
447 constexpr
auto offset_m_k =
453 inv_std_thread_buf(iM);
464 threadwise_beta_load.
Run(beta_grid_desc_m_k,
470 thread_copy_fwd_step_m_k);
476 constexpr
auto offset_m_k =
493 threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_fwd_step_m_k);
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
signed int int32_t
Definition: stdint.h:123
static __device__ void Run(T &mean_value, T &var_value, CountDataType &count)
Definition: blockwise_welford.hpp:51
Definition: gridwise_normalization_splitk_2nd.hpp:42
static constexpr auto thread_buffer_desc_m
Definition: gridwise_normalization_splitk_2nd.hpp:80
static constexpr auto I0
Definition: gridwise_normalization_splitk_2nd.hpp:61
BlockwiseWelford< ComputeDataType, BlockSize, ThreadClusterLengths_M_K, ThreadClusterArrangeOrder > BlockwiseWelford
Definition: gridwise_normalization_splitk_2nd.hpp:97
static constexpr bool reorder_thread_cluster
Definition: gridwise_normalization_splitk_2nd.hpp:59
static constexpr auto I1
Definition: gridwise_normalization_splitk_2nd.hpp:62
static constexpr index_t K_BlockTileSize
Definition: gridwise_normalization_splitk_2nd.hpp:102
static constexpr auto ThreadBufferNumber
Definition: gridwise_normalization_splitk_2nd.hpp:105
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition: gridwise_normalization_splitk_2nd.hpp:67
Sequence< MThreadSliceSize > ThreadBufferLengths_M
Definition: gridwise_normalization_splitk_2nd.hpp:79
Sequence< MThreadSliceSize, 1 > ThreadBufferLengths_M_1
Definition: gridwise_normalization_splitk_2nd.hpp:83
static constexpr auto thread_cluster_desc
Definition: gridwise_normalization_splitk_2nd.hpp:72
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition: gridwise_normalization_splitk_2nd.hpp:64
static constexpr auto thread_buffer_desc_m_k
Definition: gridwise_normalization_splitk_2nd.hpp:76
static __device__ void Run(const MeanVarGridDesc_M_KBlock &mean_var_grid_desc_m_kblock, const CountGridDesc_M_KBlock &count_grid_desc_m_kblock, const XYGammaBetaGridDesc_M_K &x_grid_desc_m_k, const XYGammaBetaGridDesc_M_K &gamma_grid_desc_m_k, const XYGammaBetaGridDesc_M_K &beta_grid_desc_m_k, const XYGammaBetaGridDesc_M_K &y_grid_desc_m_k, const SaveMeanInvStdGridDesc_M &save_mean_grid_desc_m, const SaveMeanInvStdGridDesc_M &save_inv_std_grid_desc_m, index_t num_k_mean_var_count_iteration, index_t num_k_block_tile_iteration, index_t k_grid_size, ComputeDataType epsilon, const MeanVarDataType *const p_mean_global, const MeanVarDataType *const p_variance_global, const int32_t *const p_welford_count_global, const XDataType *const __restrict__ p_x_global, const GammaDataType *const __restrict__ p_gamma_global, const BetaDataType *const __restrict__ p_beta_global, YDataType *const __restrict__ p_y_global, SaveMeanInvStdDataType *const __restrict__ p_save_mean_global, SaveMeanInvStdDataType *const __restrict__ p_save_inv_std_global, const YElementwiseOperation y_elementwise_op)
Definition: gridwise_normalization_splitk_2nd.hpp:107
Sequence< MThreadSliceSize, XSrcVectorSize > ThreadBufferLengths_M_K
Definition: gridwise_normalization_splitk_2nd.hpp:75
static constexpr index_t M_BlockTileSize
Definition: gridwise_normalization_splitk_2nd.hpp:101
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadWelfordDstDesc_M
Definition: gridwise_normalization_splitk_2nd.hpp:89
static constexpr index_t K_BlockTileStepSize
Definition: gridwise_normalization_splitk_2nd.hpp:103
decltype(thread_buffer_desc_m_1) ThreadWelfordSrcDesc_M_1
Definition: gridwise_normalization_splitk_2nd.hpp:87
tensor_operation::element_wise::PassThrough PassThroughOp
Definition: gridwise_normalization_splitk_2nd.hpp:99
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition: gridwise_normalization_splitk_2nd.hpp:70
static constexpr auto thread_buffer_desc_m_1
Definition: gridwise_normalization_splitk_2nd.hpp:84
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
Definition: threadwise_tensor_slice_transfer.hpp:39
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:66
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:276
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer.hpp:389
Definition: threadwise_welford.hpp:83
static __device__ void Run(const SrcMeanBufferType &src_mean_buf, const SrcVarBufferType &src_var_buf, const SrcCountBufferType &src_count_buf, DstMeanBufferType &dst_mean_buf, DstVarBufferType &dst_var_buf, DstCountBufferType &dst_count_buf)
Definition: threadwise_welford.hpp:110
Definition: functional.hpp:100
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:334