10 template <
typename GridwiseReduction,
 
   12           typename GammaDataType,
 
   13           typename BetaDataType,
 
   15           typename SaveMeanInvStdDataType,
 
   16           typename ComputeDataType,
 
   17           typename YElementwiseOperation,
 
   18           typename GridDesc_M_K,
 
   22                      const GridDesc_M_K gamma_grid_desc_m_k,
 
   23                      const GridDesc_M_K beta_grid_desc_m_k,
 
   24                      const GridDesc_M_K y_grid_desc_m_k,
 
   25                      const GridDesc_M save_mean_grid_desc_m,
 
   26                      const GridDesc_M save_inv_std_grid_desc_m,
 
   27                      index_t num_k_block_tile_iteration,
 
   28                      ComputeDataType epsilon,
 
   29                      const XDataType* 
const __restrict__ p_x_global,
 
   30                      const GammaDataType* 
const __restrict__ p_gamma_global,
 
   31                      const BetaDataType* 
const __restrict__ p_beta_global,
 
   32                      YDataType* 
const __restrict__ p_y_global,
 
   33                      SaveMeanInvStdDataType* 
const __restrict__ p_save_mean_global,
 
   34                      SaveMeanInvStdDataType* 
const __restrict__ p_save_inv_std_global,
 
   35                      const YElementwiseOperation y_elementwise_op)
 
   37     GridwiseReduction::Run(x_grid_desc_m_k,
 
   41                            save_mean_grid_desc_m,
 
   42                            save_inv_std_grid_desc_m,
 
   43                            num_k_block_tile_iteration,
 
   50                            p_save_inv_std_global,
 
   54 template <
typename XDataType,
 
   55           typename GammaDataType,
 
   56           typename BetaDataType,
 
   58           typename SaveMeanInvStdDataType,
 
   59           typename ComputeDataType,
 
   60           typename YElementwiseOperation,
 
   61           typename GridDesc_M_K,
 
   76           index_t SaveMeanInvStdDstVectorSize,
 
   80     using GridwiseNormalizationGenericNaive =
 
   85                                                     SaveMeanInvStdDataType,
 
   87                                                     YElementwiseOperation,
 
  103                                                     SaveMeanInvStdDstVectorSize,
 
  105     using GridwiseNormalizationSweepOnceNaive =
 
  110                                                     SaveMeanInvStdDataType,
 
  112                                                     YElementwiseOperation,
 
  128                                                     SaveMeanInvStdDstVectorSize,
 
  130     using GridwiseNormalizationGenericWelford =
 
  135                                                       SaveMeanInvStdDataType,
 
  137                                                       YElementwiseOperation,
 
  153                                                       SaveMeanInvStdDstVectorSize,
 
  155     using GridwiseNormalizationSweepOnceWelford =
 
  160                                                       SaveMeanInvStdDataType,
 
  162                                                       YElementwiseOperation,
 
  178                                                       SaveMeanInvStdDstVectorSize,
 
  181     if constexpr(UseWelford)
 
  188                                                   SaveMeanInvStdDataType,
 
  190                                                   YElementwiseOperation,
 
  198                                                   SaveMeanInvStdDataType,
 
  200                                                   YElementwiseOperation,
 
  211                                                   SaveMeanInvStdDataType,
 
  213                                                   YElementwiseOperation,
 
  221                                                   SaveMeanInvStdDataType,
 
  223                                                   YElementwiseOperation,
 
__global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k, const GridDesc_M_K gamma_grid_desc_m_k, const GridDesc_M_K beta_grid_desc_m_k, const GridDesc_M_K y_grid_desc_m_k, const GridDesc_M save_mean_grid_desc_m, const GridDesc_M save_inv_std_grid_desc_m, index_t num_k_block_tile_iteration, ComputeDataType epsilon, const XDataType *const __restrict__ p_x_global, const GammaDataType *const __restrict__ p_gamma_global, const BetaDataType *const __restrict__ p_beta_global, YDataType *const __restrict__ p_y_global, SaveMeanInvStdDataType *const __restrict__ p_save_mean_global, SaveMeanInvStdDataType *const __restrict__ p_save_inv_std_global, const YElementwiseOperation y_elementwise_op)
Definition: gridwise_normalization_selector.hpp:21
 
int32_t index_t
Definition: ck.hpp:299
 
auto NormalizationKernelSelector(bool isSweepOnce)
Definition: gridwise_normalization_selector.hpp:78
 
Definition: gridwise_normalization_naive_variance.hpp:42
 
Definition: gridwise_normalization_welford_variance.hpp:40