17 template <
typename XDataType,
 
   18           typename GammaDataType,
 
   19           typename BetaDataType,
 
   21           typename SaveMeanInvStdDataType,
 
   22           typename ComputeDataType,
 
   23           typename YElementwiseOperation,
 
   24           typename GridDesc_M_K,
 
   39           index_t SaveMeanInvStdDstVectorSize,
 
   43     static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
 
   44                       (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
 
   45                   "Invalid thread slice sizes and/or vector sizes configuration, please check!");
 
   47     static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) ||
 
   48                       (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
 
   49                   "Invalid thread slice sizes and/or vector sizes configuration, please check!");
 
   51     static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
 
   52                   "Invalid thread slice sizes and/or save mean and inverse std vector sizes " 
   53                   "configuration, please check!");
 
   55     static_assert(XSrcVectorSize == YDstVectorSize);
 
   56     static_assert(XSrcVectorSize == GammaSrcVectorSize);
 
   57     static_assert(XSrcVectorSize == BetaSrcVectorSize);
 
  110     __device__ 
static void Run(
const GridDesc_M_K& x_grid_desc_m_k,
 
  111                                const GridDesc_M_K& gamma_grid_desc_m_k,
 
  112                                const GridDesc_M_K& beta_grid_desc_m_k,
 
  113                                const GridDesc_M_K& y_grid_desc_m_k,
 
  114                                const GridDesc_M& save_mean_grid_desc_m,
 
  115                                const GridDesc_M& save_inv_std_grid_desc_m,
 
  116                                index_t num_k_block_tile_iteration,
 
  117                                ComputeDataType epsilon,
 
  118                                const XDataType* 
const __restrict__ p_x_global,
 
  119                                const GammaDataType* 
const __restrict__ p_gamma_global,
 
  120                                const BetaDataType* 
const __restrict__ p_beta_global,
 
  121                                YDataType* 
const __restrict__ p_y_global,
 
  122                                SaveMeanInvStdDataType* 
const __restrict__ p_save_mean_global,
 
  123                                SaveMeanInvStdDataType* 
const __restrict__ p_save_inv_std_global,
 
  124                                const YElementwiseOperation y_elementwise_op)
 
  127         __shared__ ComputeDataType p_reduce_work_buffer[BlockSize];
 
  129         auto reduce_work_buf =
 
  130             make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
 
  132         auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  133             p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
 
  135         auto save_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  136             p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize());
 
  138         auto save_inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  139             p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize());
 
  145                                     MThreadSliceSize * XSrcVectorSize,
 
  154                                     MThreadSliceSize * GammaSrcVectorSize,
 
  159         auto& beta_thread_buf = gamma_thread_buf;
 
  165                                     MThreadSliceSize * YDstVectorSize,
 
  170         auto& x_square_thread_buf = y_thread_buf;
 
  175             mean_square_thread_buf;
 
  177             var_thread_buf = mean_square_thread_buf;
 
  179             inv_std_thread_buf = mean_square_thread_buf;
 
  184         const auto thread_cluster_idx =
 
  187         const auto thread_m_cluster_id = thread_cluster_idx[
I0];
 
  188         const auto thread_k_cluster_id = thread_cluster_idx[
I1];
 
  202                                  thread_m_cluster_id * MThreadSliceSize,
 
  203                              thread_k_cluster_id * XSrcVectorSize));
 
  205         auto threadwise_gamma_load =
 
  218                                      thread_m_cluster_id * MThreadSliceSize,
 
  219                                  thread_k_cluster_id * GammaSrcVectorSize));
 
  221         auto threadwise_beta_load =
 
  234                                      thread_m_cluster_id * MThreadSliceSize,
 
  235                                  thread_k_cluster_id * BetaSrcVectorSize));
 
  237         auto threadwise_y_store =
 
  242                                                YElementwiseOperation,
 
  252                                      thread_m_cluster_id * MThreadSliceSize,
 
  253                                  thread_k_cluster_id * YDstVectorSize),
 
  256         auto threadwise_mean_store =
 
  258                                                SaveMeanInvStdDataType,
 
  265                                                SaveMeanInvStdDstVectorSize, 
 
  269                 save_mean_grid_desc_m,
 
  271                                  thread_m_cluster_id * MThreadSliceSize),
 
  274         auto threadwise_inv_std_store =
 
  276                                                SaveMeanInvStdDataType,
 
  283                                                SaveMeanInvStdDstVectorSize, 
 
  287                 save_inv_std_grid_desc_m,
 
  289                                  thread_m_cluster_id * MThreadSliceSize),
 
  293         constexpr 
auto thread_copy_bwd_step_m_k =
 
  296         const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  297             p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
 
  299         const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  300             p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize());
 
  302         const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  303             p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize());
 
  307         ComputeDataType reduce_length = type_convert<ComputeDataType>(
 
  308             x_grid_desc_m_k.GetTransforms()[
I2].GetUpperLengths()[
I0]);
 
  311             mean_thread_buf(I)        = reduce::Add::template GetIdentityValue<ComputeDataType>();
 
  312             mean_square_thread_buf(I) = reduce::Add::template GetIdentityValue<ComputeDataType>();
 
  316         if constexpr(SweepOnce)
 
  319                 threadwise_x_load.Run(x_grid_desc_m_k,
 
  325                 threadwise_gamma_load.
Run(gamma_grid_desc_m_k,
 
  326                                           gamma_global_val_buf,
 
  329                                           gamma_thread_buf(i));
 
  333                         constexpr 
auto offset_m_k =
 
  346                     threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
 
  348                                                              thread_copy_fwd_step_m_k);
 
  357                 mean_thread_buf(I) = mean_thread_buf(I) / reduce_length;
 
  362                 mean_square_thread_buf(I) = mean_square_thread_buf(I) / reduce_length;
 
  366                     mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I));
 
  368                 inv_std_thread_buf(I) = type_convert<ComputeDataType>(1.0f) /
 
  369                                         ck::math::sqrt(var_thread_buf(I) + epsilon);
 
  373             if(thread_k_cluster_id == 0)
 
  375                 if(p_save_mean_global != 
nullptr)
 
  380                                               save_mean_grid_desc_m,
 
  381                                               save_mean_global_val_buf);
 
  383                 if(p_save_inv_std_global != 
nullptr)
 
  388                                                  save_inv_std_grid_desc_m,
 
  389                                                  save_inv_std_global_val_buf);
 
  397                         constexpr 
auto offset_m_k =
 
  403                             inv_std_thread_buf(iM);
 
  414                 threadwise_beta_load.
Run(beta_grid_desc_m_k,
 
  422                                                             thread_copy_fwd_step_m_k);
 
  428                         constexpr 
auto offset_m_k =
 
  447                     threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k,
 
  448                                                           thread_copy_fwd_step_m_k);
 
  453             for(
index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
 
  456                     threadwise_x_load.Run(x_grid_desc_m_k,
 
  461                     threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
 
  465                             constexpr 
auto offset_m_k =
 
  483                 mean_thread_buf(I) = mean_thread_buf(I) / reduce_length;
 
  488                 mean_square_thread_buf(I) = mean_square_thread_buf(I) / reduce_length;
 
  492                     mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I));
 
  494                 inv_std_thread_buf(I) = 1 / ck::math::sqrt(var_thread_buf(I) + epsilon);
 
  497             if(thread_k_cluster_id == 0)
 
  499                 if(p_save_mean_global != 
nullptr)
 
  504                                               save_mean_grid_desc_m,
 
  505                                               save_mean_global_val_buf);
 
  507                 if(p_save_inv_std_global != 
nullptr)
 
  512                                                  save_inv_std_grid_desc_m,
 
  513                                                  save_inv_std_global_val_buf);
 
  517             auto thread_copy_tail_m_k =
 
  520             threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
 
  523             threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k);
 
  525             for(
index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
 
  528                     threadwise_x_load.Run(x_grid_desc_m_k,
 
  533                     threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
 
  537                     threadwise_gamma_load.
Run(gamma_grid_desc_m_k,
 
  538                                               gamma_global_val_buf,
 
  541                                               gamma_thread_buf(i));
 
  544                                                              thread_copy_fwd_step_m_k);
 
  550                             constexpr 
auto offset_m_k =
 
  556                                 inv_std_thread_buf(iM);
 
  567                     threadwise_beta_load.
Run(beta_grid_desc_m_k,
 
  573                                                             thread_copy_fwd_step_m_k);
 
  579                             constexpr 
auto offset_m_k =
 
  596                     threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k,
 
  597                                                           thread_copy_fwd_step_m_k);
 
  600                 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k);
 
  602                                                          2 * thread_copy_bwd_step_m_k);
 
  604                                                         2 * thread_copy_bwd_step_m_k);
 
  605                 threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k,
 
  606                                                       2 * thread_copy_bwd_step_m_k);
 
ck_tile::element_wise::PassThrough PassThrough
Definition: grouped_convolution_utils.hpp:47
 
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
 
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
 
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
int32_t index_t
Definition: ck.hpp:299
 
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
 
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
 
Definition: gridwise_normalization_naive_variance.hpp:42
 
static constexpr auto ThreadBufferNumber
Definition: gridwise_normalization_naive_variance.hpp:108
 
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}, Number< XSrcVectorSize >{}))) ThreadReduceSrcDesc_M_K
Definition: gridwise_normalization_naive_variance.hpp:81
 
static constexpr auto I0
Definition: gridwise_normalization_naive_variance.hpp:100
 
static constexpr auto thread_buffer_desc_m_k
Definition: gridwise_normalization_naive_variance.hpp:73
 
static __device__ void Run(const GridDesc_M_K &x_grid_desc_m_k, const GridDesc_M_K &gamma_grid_desc_m_k, const GridDesc_M_K &beta_grid_desc_m_k, const GridDesc_M_K &y_grid_desc_m_k, const GridDesc_M &save_mean_grid_desc_m, const GridDesc_M &save_inv_std_grid_desc_m, index_t num_k_block_tile_iteration, ComputeDataType epsilon, const XDataType *const __restrict__ p_x_global, const GammaDataType *const __restrict__ p_gamma_global, const BetaDataType *const __restrict__ p_beta_global, YDataType *const __restrict__ p_y_global, SaveMeanInvStdDataType *const __restrict__ p_save_mean_global, SaveMeanInvStdDataType *const __restrict__ p_save_inv_std_global, const YElementwiseOperation y_elementwise_op)
Definition: gridwise_normalization_naive_variance.hpp:110
 
static constexpr auto I1
Definition: gridwise_normalization_naive_variance.hpp:101
 
static constexpr index_t K_BlockTileSize
Definition: gridwise_normalization_naive_variance.hpp:105
 
Sequence< MThreadSliceSize, XSrcVectorSize > ThreadBufferLengths_M_K
Definition: gridwise_normalization_naive_variance.hpp:72
 
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition: gridwise_normalization_naive_variance.hpp:61
 
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition: gridwise_normalization_naive_variance.hpp:64
 
Sequence< MThreadSliceSize > ThreadBufferLengths_M
Definition: gridwise_normalization_naive_variance.hpp:76
 
static constexpr bool reorder_thread_cluster
Definition: gridwise_normalization_naive_variance.hpp:59
 
static constexpr auto thread_cluster_desc
Definition: gridwise_normalization_naive_variance.hpp:69
 
static constexpr index_t M_BlockTileSize
Definition: gridwise_normalization_naive_variance.hpp:104
 
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition: gridwise_normalization_naive_variance.hpp:83
 
tensor_operation::element_wise::PassThrough PassThroughOp
Definition: gridwise_normalization_naive_variance.hpp:98
 
static constexpr auto thread_buffer_desc_m
Definition: gridwise_normalization_naive_variance.hpp:77
 
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition: gridwise_normalization_naive_variance.hpp:67
 
static constexpr index_t K_BlockTileStepSize
Definition: gridwise_normalization_naive_variance.hpp:106
 
static constexpr auto I2
Definition: gridwise_normalization_naive_variance.hpp:102
 
Definition: reduction_functions_blockwise.hpp:28
 
static __device__ void Reduce(BufferType &work_buffer, AccDataType &in_out_value)
Definition: reduction_functions_blockwise.hpp:44
 
Definition: sequence.hpp:43
 
Definition: static_buffer.hpp:16
 
Definition: reduction_functions_threadwise.hpp:23
 
static __device__ void Reduce(const SrcBufferType &src_buf, DstBufferType &dst_buf)
Definition: reduction_functions_threadwise.hpp:36
 
Definition: threadwise_tensor_slice_transfer.hpp:39
 
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:66
 
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
 
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:276
 
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer.hpp:389
 
Definition: functional.hpp:100
 
Definition: integral_constant.hpp:20
 
Definition: reduction_operator.hpp:37
 
Definition: functional2.hpp:33
 
Definition: unary_element_wise_operation.hpp:340