10 template <
typename GridwiseReduction,
12 typename GammaDataType,
13 typename BetaDataType,
15 typename SaveMeanInvStdDataType,
16 typename ComputeDataType,
17 typename YElementwiseOperation,
18 typename GridDesc_M_K,
22 const GridDesc_M_K gamma_grid_desc_m_k,
23 const GridDesc_M_K beta_grid_desc_m_k,
24 const GridDesc_M_K y_grid_desc_m_k,
25 const GridDesc_M save_mean_grid_desc_m,
26 const GridDesc_M save_inv_std_grid_desc_m,
27 index_t num_k_block_tile_iteration,
28 ComputeDataType epsilon,
29 const XDataType*
const __restrict__ p_x_global,
30 const GammaDataType*
const __restrict__ p_gamma_global,
31 const BetaDataType*
const __restrict__ p_beta_global,
32 YDataType*
const __restrict__ p_y_global,
33 SaveMeanInvStdDataType*
const __restrict__ p_save_mean_global,
34 SaveMeanInvStdDataType*
const __restrict__ p_save_inv_std_global,
35 const YElementwiseOperation y_elementwise_op)
37 GridwiseReduction::Run(x_grid_desc_m_k,
41 save_mean_grid_desc_m,
42 save_inv_std_grid_desc_m,
43 num_k_block_tile_iteration,
50 p_save_inv_std_global,
54 template <
typename XDataType,
55 typename GammaDataType,
56 typename BetaDataType,
58 typename SaveMeanInvStdDataType,
59 typename ComputeDataType,
60 typename YElementwiseOperation,
61 typename GridDesc_M_K,
76 index_t SaveMeanInvStdDstVectorSize,
80 using GridwiseNormalizationGenericNaive =
85 SaveMeanInvStdDataType,
87 YElementwiseOperation,
103 SaveMeanInvStdDstVectorSize,
105 using GridwiseNormalizationSweepOnceNaive =
110 SaveMeanInvStdDataType,
112 YElementwiseOperation,
128 SaveMeanInvStdDstVectorSize,
130 using GridwiseNormalizationGenericWelford =
135 SaveMeanInvStdDataType,
137 YElementwiseOperation,
153 SaveMeanInvStdDstVectorSize,
155 using GridwiseNormalizationSweepOnceWelford =
160 SaveMeanInvStdDataType,
162 YElementwiseOperation,
178 SaveMeanInvStdDstVectorSize,
181 if constexpr(UseWelford)
188 SaveMeanInvStdDataType,
190 YElementwiseOperation,
198 SaveMeanInvStdDataType,
200 YElementwiseOperation,
211 SaveMeanInvStdDataType,
213 YElementwiseOperation,
221 SaveMeanInvStdDataType,
223 YElementwiseOperation,
__global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k, const GridDesc_M_K gamma_grid_desc_m_k, const GridDesc_M_K beta_grid_desc_m_k, const GridDesc_M_K y_grid_desc_m_k, const GridDesc_M save_mean_grid_desc_m, const GridDesc_M save_inv_std_grid_desc_m, index_t num_k_block_tile_iteration, ComputeDataType epsilon, const XDataType *const __restrict__ p_x_global, const GammaDataType *const __restrict__ p_gamma_global, const BetaDataType *const __restrict__ p_beta_global, YDataType *const __restrict__ p_y_global, SaveMeanInvStdDataType *const __restrict__ p_save_mean_global, SaveMeanInvStdDataType *const __restrict__ p_save_inv_std_global, const YElementwiseOperation y_elementwise_op)
Definition: gridwise_normalization_selector.hpp:21
int32_t index_t
Definition: ck.hpp:298
auto NormalizationKernelSelector(bool isSweepOnce)
Definition: gridwise_normalization_selector.hpp:78
Definition: gridwise_normalization_naive_variance.hpp:42
Definition: gridwise_normalization_welford_variance.hpp:40