/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp File Reference#
device_normalization_fwd_splitk_impl.hpp File Reference
#include <iostream>
#include <sstream>
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
Go to the source code of this file.
Namespaces | |
ck | |
ck::tensor_operation | |
ck::tensor_operation::device | |
Functions | |
template<typename GridwiseWelford , typename XDataType , typename WorkspaceMeanVarDataType , typename ComputeDataType , typename XGridDesc_M_K , typename MeanVarGridDesc_M_KBlock > | |
__global__ void | ck::kernel_normalizationSplitK1st (const XGridDesc_M_K x_grid_desc_m_k, const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock, index_t num_k_block_tile_iteration, const XDataType *const __restrict__ p_x_global, WorkspaceMeanVarDataType *const __restrict__ p_welford_mean, WorkspaceMeanVarDataType *const __restrict__ p_welford_variance, int32_t *const __restrict__ p_welford_count) |
template<typename GridwiseWelfordNormalization , typename WorkspaceMeanVarDataType , typename XDataType , typename GammaDataType , typename BetaDataType , typename YDataType , typename SaveMeanInvStdDataType , typename ComputeDataType , typename YElementwiseOperation , typename MeanVarGridDesc_M_KBlock , typename CountGridDesc_M_KBlock , typename XYGammaBetaGridDesc_M_K , typename SaveMeanInvStdGridDesc_M > | |
__global__ void | ck::kernel_normalizationSplitK2nd (const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock, const CountGridDesc_M_KBlock count_grid_desc_m_kblock, const XYGammaBetaGridDesc_M_K x_grid_desc_m_k, const XYGammaBetaGridDesc_M_K gamma_grid_desc_m_k, const XYGammaBetaGridDesc_M_K beta_grid_desc_m_k, const XYGammaBetaGridDesc_M_K y_grid_desc_m_k, const SaveMeanInvStdGridDesc_M save_mean_grid_desc_m, const SaveMeanInvStdGridDesc_M save_inv_std_grid_desc_m, index_t num_k_mean_var_count_iteration, index_t num_k_block_tile_iteration, index_t k_grid_size, ComputeDataType epsilon, const WorkspaceMeanVarDataType *const p_mean_global, const WorkspaceMeanVarDataType *const p_variance_global, const int32_t *const p_welford_count_global, const XDataType *const __restrict__ p_x_global, const GammaDataType *const __restrict__ p_gamma_global, const BetaDataType *const __restrict__ p_beta_global, YDataType *const __restrict__ p_y_global, SaveMeanInvStdDataType *const __restrict__ p_save_mean_global, SaveMeanInvStdDataType *const __restrict__ p_save_inv_std_global, const YElementwiseOperation y_elementwise_op) |