15 template <
typename MeanVarDataType,
 
   17           typename GammaDataType,
 
   18           typename BetaDataType,
 
   20           typename SaveMeanInvStdDataType,
 
   21           typename ComputeDataType,
 
   22           typename YElementwiseOperation,
 
   23           typename MeanVarGridDesc_M_KBlock,
 
   24           typename CountGridDesc_M_KBlock,
 
   25           typename XYGammaBetaGridDesc_M_K,
 
   26           typename SaveMeanInvStdGridDesc_M,
 
   40           index_t SaveMeanInvStdDstVectorSize>
 
   43     static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
 
   44                       (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
 
   45                   "Invalid thread slice sizes and/or vector sizes configuration, please check!");
 
   47     static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) ||
 
   48                       (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
 
   49                   "Invalid thread slice sizes and/or vector sizes configuration, please check!");
 
   51     static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
 
   52                   "Invalid thread slice sizes and/or save mean and inverse std vector sizes " 
   53                   "configuration, please check!");
 
   55     static_assert(XSrcVectorSize == YDstVectorSize);
 
   56     static_assert(XSrcVectorSize == GammaSrcVectorSize);
 
   57     static_assert(XSrcVectorSize == BetaSrcVectorSize);
 
  107     __device__ 
static void Run(
const MeanVarGridDesc_M_KBlock& mean_var_grid_desc_m_kblock,
 
  108                                const CountGridDesc_M_KBlock& count_grid_desc_m_kblock,
 
  109                                const XYGammaBetaGridDesc_M_K& x_grid_desc_m_k,
 
  110                                const XYGammaBetaGridDesc_M_K& gamma_grid_desc_m_k,
 
  111                                const XYGammaBetaGridDesc_M_K& beta_grid_desc_m_k,
 
  112                                const XYGammaBetaGridDesc_M_K& y_grid_desc_m_k,
 
  113                                const SaveMeanInvStdGridDesc_M& save_mean_grid_desc_m,
 
  114                                const SaveMeanInvStdGridDesc_M& save_inv_std_grid_desc_m,
 
  115                                index_t num_k_mean_var_count_iteration,
 
  116                                index_t num_k_block_tile_iteration,
 
  118                                ComputeDataType epsilon,
 
  119                                const MeanVarDataType* 
const p_mean_global,
 
  120                                const MeanVarDataType* 
const p_variance_global,
 
  121                                const int32_t* 
const p_welford_count_global,
 
  122                                const XDataType* 
const __restrict__ p_x_global,
 
  123                                const GammaDataType* 
const __restrict__ p_gamma_global,
 
  124                                const BetaDataType* 
const __restrict__ p_beta_global,
 
  125                                YDataType* 
const __restrict__ p_y_global,
 
  126                                SaveMeanInvStdDataType* 
const __restrict__ p_save_mean_global,
 
  127                                SaveMeanInvStdDataType* 
const __restrict__ p_save_inv_std_global,
 
  128                                const YElementwiseOperation y_elementwise_op)
 
  133         const index_t block_m_cluster_id = block_global_id / k_grid_size;
 
  134         const index_t block_k_cluster_id = block_global_id % k_grid_size;
 
  135         const auto thread_cluster_idx =
 
  138         const auto thread_m_cluster_id = thread_cluster_idx[
I0];
 
  139         const auto thread_k_cluster_id = thread_cluster_idx[
I1];
 
  142         const auto mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  143             p_mean_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize());
 
  145         const auto var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  146             p_variance_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize());
 
  148         const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  149             p_welford_count_global, count_grid_desc_m_kblock.GetElementSpaceSize());
 
  151         const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  152             p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
 
  154         const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  155             p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize());
 
  157         const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  158             p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize());
 
  160         auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  161             p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
 
  163         auto save_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  164             p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize());
 
  166         auto save_inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  167             p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize());
 
  175             in_welford_count_thread_buf;
 
  181             welford_count_thread_buf;
 
  182         auto& inv_std_thread_buf = var_thread_buf;
 
  188                                     MThreadSliceSize * XSrcVectorSize,
 
  197                                     MThreadSliceSize * GammaSrcVectorSize,
 
  202         auto& beta_thread_buf = gamma_thread_buf;
 
  203         auto& y_thread_buf    = x_thread_buf;
 
  206         auto threadwise_mean_var_load_m_kblock =
 
  209                                              MeanVarGridDesc_M_KBlock,
 
  217                 mean_var_grid_desc_m_kblock,
 
  219                                      thread_m_cluster_id * MThreadSliceSize,
 
  220                                  thread_k_cluster_id));
 
  222         auto threadwise_count_load_m_kblock =
 
  225                                              CountGridDesc_M_KBlock,
 
  233                 count_grid_desc_m_kblock,
 
  235                                      thread_m_cluster_id * MThreadSliceSize,
 
  236                                  thread_k_cluster_id));
 
  240                                                                   XYGammaBetaGridDesc_M_K,
 
  250                                  thread_m_cluster_id * MThreadSliceSize,
 
  252                                  thread_k_cluster_id * XSrcVectorSize));
 
  254         auto threadwise_gamma_load =
 
  257                                              XYGammaBetaGridDesc_M_K,
 
  267                                      thread_m_cluster_id * MThreadSliceSize,
 
  269                                      thread_k_cluster_id * GammaSrcVectorSize));
 
  271         auto threadwise_beta_load =
 
  274                                              XYGammaBetaGridDesc_M_K,
 
  284                                      thread_m_cluster_id * MThreadSliceSize,
 
  286                                      thread_k_cluster_id * BetaSrcVectorSize));
 
  288         auto threadwise_y_store =
 
  292                                                XYGammaBetaGridDesc_M_K,
 
  293                                                YElementwiseOperation,
 
  303                                      thread_m_cluster_id * MThreadSliceSize,
 
  305                                      thread_k_cluster_id * YDstVectorSize),
 
  308         auto threadwise_mean_store =
 
  310                                                SaveMeanInvStdDataType,
 
  312                                                SaveMeanInvStdGridDesc_M,
 
  317                                                SaveMeanInvStdDstVectorSize, 
 
  321                 save_mean_grid_desc_m,
 
  323                                  thread_m_cluster_id * MThreadSliceSize),
 
  326         auto threadwise_inv_std_store =
 
  328                                                SaveMeanInvStdDataType,
 
  330                                                SaveMeanInvStdGridDesc_M,
 
  335                                                SaveMeanInvStdDstVectorSize, 
 
  339                 save_inv_std_grid_desc_m,
 
  341                                  thread_m_cluster_id * MThreadSliceSize),
 
  345         constexpr 
auto mean_var_count_thread_copy_step_I0_k =
 
  349             mean_thread_buf(I)          = type_convert<ComputeDataType>(0.0f);
 
  350             var_thread_buf(I)           = type_convert<ComputeDataType>(0.0f);
 
  351             welford_count_thread_buf(I) = 0;
 
  354         for(
index_t k = 0; k < num_k_mean_var_count_iteration; ++k)
 
  356             threadwise_mean_var_load_m_kblock.Run(mean_var_grid_desc_m_kblock,
 
  362             threadwise_mean_var_load_m_kblock.Run(mean_var_grid_desc_m_kblock,
 
  368             threadwise_count_load_m_kblock.
Run(count_grid_desc_m_kblock,
 
  369                                                welford_count_global_val_buf,
 
  372                                                in_welford_count_thread_buf);
 
  376                                    in_welford_count_thread_buf,
 
  379                                    welford_count_thread_buf);
 
  381             threadwise_mean_var_load_m_kblock.MoveSrcSliceWindow(
 
  382                 mean_var_grid_desc_m_kblock, mean_var_count_thread_copy_step_I0_k);
 
  384                                                               mean_var_count_thread_copy_step_I0_k);
 
  392                 mean_thread_buf(I), var_thread_buf(I), welford_count_thread_buf(I));
 
  394             inv_std_thread_buf(I) =
 
  395                 type_convert<ComputeDataType>(1.0f) / ck::math::sqrt(var_thread_buf(I) + epsilon);
 
  399         if(block_k_cluster_id == 0 && thread_k_cluster_id == 0)
 
  401             if(p_save_mean_global != 
nullptr)
 
  406                                           save_mean_grid_desc_m,
 
  407                                           save_mean_global_val_buf);
 
  409             if(p_save_inv_std_global != 
nullptr)
 
  414                                              save_inv_std_grid_desc_m,
 
  415                                              save_inv_std_global_val_buf);
 
  422         for(
index_t k = 0; k < num_k_block_tile_iteration; ++k)
 
  425                 threadwise_x_load.
Run(x_grid_desc_m_k,
 
  434                 threadwise_gamma_load.
Run(gamma_grid_desc_m_k,
 
  435                                           gamma_global_val_buf,
 
  438                                           gamma_thread_buf(i));
 
  441                                                          thread_copy_fwd_step_m_k);
 
  447                         constexpr 
auto offset_m_k =
 
  453                             inv_std_thread_buf(iM);
 
  464                 threadwise_beta_load.
Run(beta_grid_desc_m_k,
 
  470                                                         thread_copy_fwd_step_m_k);
 
  476                         constexpr 
auto offset_m_k =
 
  493                 threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_fwd_step_m_k);
 
ck_tile::element_wise::PassThrough PassThrough
Definition: grouped_convolution_utils.hpp:47
 
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
 
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
 
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
int32_t index_t
Definition: ck.hpp:299
 
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
 
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
 
signed int int32_t
Definition: stdint.h:123
 
static __device__ void Run(T &mean_value, T &var_value, CountDataType &count)
Definition: blockwise_welford.hpp:51
 
Definition: gridwise_normalization_splitk_2nd.hpp:42
 
static constexpr auto thread_buffer_desc_m
Definition: gridwise_normalization_splitk_2nd.hpp:80
 
static constexpr auto I0
Definition: gridwise_normalization_splitk_2nd.hpp:61
 
BlockwiseWelford< ComputeDataType, BlockSize, ThreadClusterLengths_M_K, ThreadClusterArrangeOrder > BlockwiseWelford
Definition: gridwise_normalization_splitk_2nd.hpp:97
 
static constexpr bool reorder_thread_cluster
Definition: gridwise_normalization_splitk_2nd.hpp:59
 
static constexpr auto I1
Definition: gridwise_normalization_splitk_2nd.hpp:62
 
static constexpr index_t K_BlockTileSize
Definition: gridwise_normalization_splitk_2nd.hpp:102
 
static constexpr auto ThreadBufferNumber
Definition: gridwise_normalization_splitk_2nd.hpp:105
 
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition: gridwise_normalization_splitk_2nd.hpp:67
 
Sequence< MThreadSliceSize > ThreadBufferLengths_M
Definition: gridwise_normalization_splitk_2nd.hpp:79
 
Sequence< MThreadSliceSize, 1 > ThreadBufferLengths_M_1
Definition: gridwise_normalization_splitk_2nd.hpp:83
 
static constexpr auto thread_cluster_desc
Definition: gridwise_normalization_splitk_2nd.hpp:72
 
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition: gridwise_normalization_splitk_2nd.hpp:64
 
static constexpr auto thread_buffer_desc_m_k
Definition: gridwise_normalization_splitk_2nd.hpp:76
 
static __device__ void Run(const MeanVarGridDesc_M_KBlock &mean_var_grid_desc_m_kblock, const CountGridDesc_M_KBlock &count_grid_desc_m_kblock, const XYGammaBetaGridDesc_M_K &x_grid_desc_m_k, const XYGammaBetaGridDesc_M_K &gamma_grid_desc_m_k, const XYGammaBetaGridDesc_M_K &beta_grid_desc_m_k, const XYGammaBetaGridDesc_M_K &y_grid_desc_m_k, const SaveMeanInvStdGridDesc_M &save_mean_grid_desc_m, const SaveMeanInvStdGridDesc_M &save_inv_std_grid_desc_m, index_t num_k_mean_var_count_iteration, index_t num_k_block_tile_iteration, index_t k_grid_size, ComputeDataType epsilon, const MeanVarDataType *const p_mean_global, const MeanVarDataType *const p_variance_global, const int32_t *const p_welford_count_global, const XDataType *const __restrict__ p_x_global, const GammaDataType *const __restrict__ p_gamma_global, const BetaDataType *const __restrict__ p_beta_global, YDataType *const __restrict__ p_y_global, SaveMeanInvStdDataType *const __restrict__ p_save_mean_global, SaveMeanInvStdDataType *const __restrict__ p_save_inv_std_global, const YElementwiseOperation y_elementwise_op)
Definition: gridwise_normalization_splitk_2nd.hpp:107
 
Sequence< MThreadSliceSize, XSrcVectorSize > ThreadBufferLengths_M_K
Definition: gridwise_normalization_splitk_2nd.hpp:75
 
static constexpr index_t M_BlockTileSize
Definition: gridwise_normalization_splitk_2nd.hpp:101
 
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadWelfordDstDesc_M
Definition: gridwise_normalization_splitk_2nd.hpp:89
 
static constexpr index_t K_BlockTileStepSize
Definition: gridwise_normalization_splitk_2nd.hpp:103
 
decltype(thread_buffer_desc_m_1) ThreadWelfordSrcDesc_M_1
Definition: gridwise_normalization_splitk_2nd.hpp:87
 
tensor_operation::element_wise::PassThrough PassThroughOp
Definition: gridwise_normalization_splitk_2nd.hpp:99
 
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition: gridwise_normalization_splitk_2nd.hpp:70
 
static constexpr auto thread_buffer_desc_m_1
Definition: gridwise_normalization_splitk_2nd.hpp:84
 
Definition: sequence.hpp:43
 
Definition: static_buffer.hpp:16
 
Definition: threadwise_tensor_slice_transfer.hpp:39
 
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:66
 
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
 
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:276
 
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer.hpp:389
 
Definition: threadwise_welford.hpp:83
 
static __device__ void Run(const SrcMeanBufferType &src_mean_buf, const SrcVarBufferType &src_var_buf, const SrcCountBufferType &src_count_buf, DstMeanBufferType &dst_mean_buf, DstVarBufferType &dst_var_buf, DstCountBufferType &dst_count_buf)
Definition: threadwise_welford.hpp:110
 
Definition: functional.hpp:100
 
Definition: integral_constant.hpp:20
 
Definition: functional2.hpp:33
 
Definition: unary_element_wise_operation.hpp:340