17 template <
typename XDataType,
18 typename GammaDataType,
19 typename BetaDataType,
21 typename SaveMeanInvStdDataType,
22 typename ComputeDataType,
23 typename YElementwiseOperation,
24 typename GridDesc_M_K,
39 index_t SaveMeanInvStdDstVectorSize,
43 static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
44 (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
45 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
47 static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) ||
48 (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
49 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
51 static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
52 "Invalid thread slice sizes and/or save mean and inverse std vector sizes "
53 "configuration, please check!");
55 static_assert(XSrcVectorSize == YDstVectorSize);
56 static_assert(XSrcVectorSize == GammaSrcVectorSize);
57 static_assert(XSrcVectorSize == BetaSrcVectorSize);
110 __device__
static void Run(
const GridDesc_M_K& x_grid_desc_m_k,
111 const GridDesc_M_K& gamma_grid_desc_m_k,
112 const GridDesc_M_K& beta_grid_desc_m_k,
113 const GridDesc_M_K& y_grid_desc_m_k,
114 const GridDesc_M& save_mean_grid_desc_m,
115 const GridDesc_M& save_inv_std_grid_desc_m,
116 index_t num_k_block_tile_iteration,
117 ComputeDataType epsilon,
118 const XDataType*
const __restrict__ p_x_global,
119 const GammaDataType*
const __restrict__ p_gamma_global,
120 const BetaDataType*
const __restrict__ p_beta_global,
121 YDataType*
const __restrict__ p_y_global,
122 SaveMeanInvStdDataType*
const __restrict__ p_save_mean_global,
123 SaveMeanInvStdDataType*
const __restrict__ p_save_inv_std_global,
124 const YElementwiseOperation y_elementwise_op)
127 __shared__ ComputeDataType p_reduce_work_buffer[BlockSize];
129 auto reduce_work_buf =
130 make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
132 auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
133 p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
135 auto save_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
136 p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize());
138 auto save_inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
139 p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize());
145 MThreadSliceSize * XSrcVectorSize,
154 MThreadSliceSize * GammaSrcVectorSize,
159 auto& beta_thread_buf = gamma_thread_buf;
165 MThreadSliceSize * YDstVectorSize,
170 auto& x_square_thread_buf = y_thread_buf;
175 mean_square_thread_buf;
177 var_thread_buf = mean_square_thread_buf;
179 inv_std_thread_buf = mean_square_thread_buf;
184 const auto thread_cluster_idx =
187 const auto thread_m_cluster_id = thread_cluster_idx[
I0];
188 const auto thread_k_cluster_id = thread_cluster_idx[
I1];
202 thread_m_cluster_id * MThreadSliceSize,
203 thread_k_cluster_id * XSrcVectorSize));
205 auto threadwise_gamma_load =
218 thread_m_cluster_id * MThreadSliceSize,
219 thread_k_cluster_id * GammaSrcVectorSize));
221 auto threadwise_beta_load =
234 thread_m_cluster_id * MThreadSliceSize,
235 thread_k_cluster_id * BetaSrcVectorSize));
237 auto threadwise_y_store =
242 YElementwiseOperation,
252 thread_m_cluster_id * MThreadSliceSize,
253 thread_k_cluster_id * YDstVectorSize),
256 auto threadwise_mean_store =
258 SaveMeanInvStdDataType,
265 SaveMeanInvStdDstVectorSize,
269 save_mean_grid_desc_m,
271 thread_m_cluster_id * MThreadSliceSize),
274 auto threadwise_inv_std_store =
276 SaveMeanInvStdDataType,
283 SaveMeanInvStdDstVectorSize,
287 save_inv_std_grid_desc_m,
289 thread_m_cluster_id * MThreadSliceSize),
293 constexpr
auto thread_copy_bwd_step_m_k =
296 const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
297 p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
299 const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
300 p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize());
302 const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
303 p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize());
307 ComputeDataType reduce_length = type_convert<ComputeDataType>(
308 x_grid_desc_m_k.GetTransforms()[
I2].GetUpperLengths()[
I0]);
311 mean_thread_buf(I) = reduce::Add::template GetIdentityValue<ComputeDataType>();
312 mean_square_thread_buf(I) = reduce::Add::template GetIdentityValue<ComputeDataType>();
316 if constexpr(SweepOnce)
319 threadwise_x_load.Run(x_grid_desc_m_k,
325 threadwise_gamma_load.
Run(gamma_grid_desc_m_k,
326 gamma_global_val_buf,
329 gamma_thread_buf(i));
333 constexpr
auto offset_m_k =
346 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
348 thread_copy_fwd_step_m_k);
357 mean_thread_buf(I) = mean_thread_buf(I) / reduce_length;
362 mean_square_thread_buf(I) = mean_square_thread_buf(I) / reduce_length;
366 mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I));
368 inv_std_thread_buf(I) = type_convert<ComputeDataType>(1.0f) /
369 ck::math::sqrt(var_thread_buf(I) + epsilon);
373 if(thread_k_cluster_id == 0)
375 if(p_save_mean_global !=
nullptr)
380 save_mean_grid_desc_m,
381 save_mean_global_val_buf);
383 if(p_save_inv_std_global !=
nullptr)
388 save_inv_std_grid_desc_m,
389 save_inv_std_global_val_buf);
397 constexpr
auto offset_m_k =
403 inv_std_thread_buf(iM);
414 threadwise_beta_load.
Run(beta_grid_desc_m_k,
422 thread_copy_fwd_step_m_k);
428 constexpr
auto offset_m_k =
447 threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k,
448 thread_copy_fwd_step_m_k);
453 for(
index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
456 threadwise_x_load.Run(x_grid_desc_m_k,
461 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
465 constexpr
auto offset_m_k =
483 mean_thread_buf(I) = mean_thread_buf(I) / reduce_length;
488 mean_square_thread_buf(I) = mean_square_thread_buf(I) / reduce_length;
492 mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I));
494 inv_std_thread_buf(I) = 1 / ck::math::sqrt(var_thread_buf(I) + epsilon);
497 if(thread_k_cluster_id == 0)
499 if(p_save_mean_global !=
nullptr)
504 save_mean_grid_desc_m,
505 save_mean_global_val_buf);
507 if(p_save_inv_std_global !=
nullptr)
512 save_inv_std_grid_desc_m,
513 save_inv_std_global_val_buf);
517 auto thread_copy_tail_m_k =
520 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
523 threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k);
525 for(
index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
528 threadwise_x_load.Run(x_grid_desc_m_k,
533 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
537 threadwise_gamma_load.
Run(gamma_grid_desc_m_k,
538 gamma_global_val_buf,
541 gamma_thread_buf(i));
544 thread_copy_fwd_step_m_k);
550 constexpr
auto offset_m_k =
556 inv_std_thread_buf(iM);
567 threadwise_beta_load.
Run(beta_grid_desc_m_k,
573 thread_copy_fwd_step_m_k);
579 constexpr
auto offset_m_k =
596 threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k,
597 thread_copy_fwd_step_m_k);
600 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k);
602 2 * thread_copy_bwd_step_m_k);
604 2 * thread_copy_bwd_step_m_k);
605 threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k,
606 2 * thread_copy_bwd_step_m_k);
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
Definition: gridwise_normalization_naive_variance.hpp:42
static constexpr auto ThreadBufferNumber
Definition: gridwise_normalization_naive_variance.hpp:108
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}, Number< XSrcVectorSize >{}))) ThreadReduceSrcDesc_M_K
Definition: gridwise_normalization_naive_variance.hpp:81
static constexpr auto I0
Definition: gridwise_normalization_naive_variance.hpp:100
static constexpr auto thread_buffer_desc_m_k
Definition: gridwise_normalization_naive_variance.hpp:73
static __device__ void Run(const GridDesc_M_K &x_grid_desc_m_k, const GridDesc_M_K &gamma_grid_desc_m_k, const GridDesc_M_K &beta_grid_desc_m_k, const GridDesc_M_K &y_grid_desc_m_k, const GridDesc_M &save_mean_grid_desc_m, const GridDesc_M &save_inv_std_grid_desc_m, index_t num_k_block_tile_iteration, ComputeDataType epsilon, const XDataType *const __restrict__ p_x_global, const GammaDataType *const __restrict__ p_gamma_global, const BetaDataType *const __restrict__ p_beta_global, YDataType *const __restrict__ p_y_global, SaveMeanInvStdDataType *const __restrict__ p_save_mean_global, SaveMeanInvStdDataType *const __restrict__ p_save_inv_std_global, const YElementwiseOperation y_elementwise_op)
Definition: gridwise_normalization_naive_variance.hpp:110
static constexpr auto I1
Definition: gridwise_normalization_naive_variance.hpp:101
static constexpr index_t K_BlockTileSize
Definition: gridwise_normalization_naive_variance.hpp:105
Sequence< MThreadSliceSize, XSrcVectorSize > ThreadBufferLengths_M_K
Definition: gridwise_normalization_naive_variance.hpp:72
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition: gridwise_normalization_naive_variance.hpp:61
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition: gridwise_normalization_naive_variance.hpp:64
Sequence< MThreadSliceSize > ThreadBufferLengths_M
Definition: gridwise_normalization_naive_variance.hpp:76
static constexpr bool reorder_thread_cluster
Definition: gridwise_normalization_naive_variance.hpp:59
static constexpr auto thread_cluster_desc
Definition: gridwise_normalization_naive_variance.hpp:69
static constexpr index_t M_BlockTileSize
Definition: gridwise_normalization_naive_variance.hpp:104
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition: gridwise_normalization_naive_variance.hpp:83
tensor_operation::element_wise::PassThrough PassThroughOp
Definition: gridwise_normalization_naive_variance.hpp:98
static constexpr auto thread_buffer_desc_m
Definition: gridwise_normalization_naive_variance.hpp:77
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition: gridwise_normalization_naive_variance.hpp:67
static constexpr index_t K_BlockTileStepSize
Definition: gridwise_normalization_naive_variance.hpp:106
static constexpr auto I2
Definition: gridwise_normalization_naive_variance.hpp:102
Definition: reduction_functions_blockwise.hpp:28
static __device__ void Reduce(BufferType &work_buffer, AccDataType &in_out_value)
Definition: reduction_functions_blockwise.hpp:44
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
Definition: reduction_functions_threadwise.hpp:23
static __device__ void Reduce(const SrcBufferType &src_buf, DstBufferType &dst_buf)
Definition: reduction_functions_threadwise.hpp:36
Definition: threadwise_tensor_slice_transfer.hpp:39
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:66
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:276
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer.hpp:389
Definition: functional.hpp:100
Definition: integral_constant.hpp:20
Definition: reduction_operator.hpp:37
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:334