11 template <
typename Problem_,
typename Policy_ =
void>
17 static constexpr
bool kFastFDiv = Problem::kFastFDiv;
18 static constexpr
bool kWelford = Problem::kWelford;
26 template <
typename XDistributedTensor_,
27 typename MeanDistributedTensor_,
28 typename VarDistributedTensor_>
30 MeanDistributedTensor_& mean_tensor,
31 VarDistributedTensor_& var_tensor,
33 const int& max_count_)
38 constexpr
auto spans = XDistributedTensor_::get_distributed_spans();
41 if(cur_count_ < max_count_)
45 constexpr
auto in_dstr_idx =
make_tuple(dstr_idx_i0, dstr_idx_i1);
46 constexpr
auto out_dstr_idx =
make_tuple(dstr_idx_i0);
48 auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
52 var_tensor(out_dstr_idx),
59 mean_tensor(out_dstr_idx) += x;
60 var_tensor(out_dstr_idx) += x * x;
67 template <
typename XDistributedTensor_>
70 static_assert(std::is_same_v<XDataType, typename XDistributedTensor_::DataType>,
"wrong!");
76 XDistributedTensor_::get_tile_distribution()
77 .get_static_tile_distribution_encoding(),
80 auto tensor = make_static_distributed_tensor<ComputeDataType>(dstr);
85 template <
typename XDistributedTensor_>
87 operator()(
const XDistributedTensor_& x_tensor,
int& cur_count_,
const int& max_count_)
89 auto mean_tensor = MakeMeanVarBlockTile<XDistributedTensor_>();
90 auto var_tensor = MakeMeanVarBlockTile<XDistributedTensor_>();
94 (*this)(x_tensor, mean_tensor, var_tensor, cur_count_, max_count_);
100 template <
typename Problem_,
typename Policy_ =
void>
105 static constexpr
bool kWelford = Problem::kWelford;
107 template <
typename MeanDistributedTensor_,
typename VarDistributedTensor_>
109 operator()(MeanDistributedTensor_& mean_tensor, VarDistributedTensor_& var_tensor,
int& count)
111 using Dstr =
typename MeanDistributedTensor_::StaticTileDistribution;
112 using DstrEncode =
typename Dstr::DstrEncode;
113 using DstrEncodeDetail =
typename DstrEncode::detail;
115 static_assert(std::is_same_v<Dstr, typename VarDistributedTensor_::StaticTileDistribution>,
118 constexpr
index_t NDimP = Dstr::get_num_of_dimension_p();
119 constexpr
index_t NDimR = Dstr::get_num_of_dimension_r();
121 constexpr
index_t idim_p_lane = NDimP - 1;
127 constexpr
index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
128 static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size());
130 const int original_count = count;
134 auto v_local_mean = mean_tensor.get_thread_buffer()[i];
135 auto v_local_var = var_tensor.get_thread_buffer()[i];
136 auto v_local_count = original_count;
143 if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
145 constexpr
index_t r_length = DstrEncode::rs_lengths_[idim_r];
147 constexpr
index_t lid_over_rid_derivative =
148 DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
151 "wrong! only support power of 2 reduction");
160 (
number<lid_over_rid_derivative << istage.
value>{}.value);
163 const auto v_remote_mean =
warp_shuffle(v_local_mean, src_lane);
164 const auto v_remote_var =
warp_shuffle(v_local_var, src_lane);
167 const auto v_remote_count =
warp_shuffle(v_local_count, src_lane);
170 welford_merge(v_local_mean,
180 v_local_mean += v_remote_mean;
181 v_local_var += v_remote_var;
187 mean_tensor.get_thread_buffer()(i) = v_local_mean;
188 var_tensor.get_thread_buffer()(i) = v_local_var;
191 count = v_local_count;
197 template <
typename Problem_,
typename Policy_ =
void>
202 static constexpr
bool kFastFDiv = Problem::kFastFDiv;
203 static constexpr
bool kWelford = Problem::kWelford;
204 using smem_dtype = std::conditional_t<kWelford, fp32x4_t, fp32x2_t>;
206 template <
typename MeanDistributedTensor_>
209 constexpr
index_t num_reduce_warps = [&]() {
210 using Dstr =
typename MeanDistributedTensor_::StaticTileDistribution;
211 using DstrEncode =
typename Dstr::DstrEncode;
212 using DstrEncodeDetail =
typename DstrEncode::detail;
214 constexpr
index_t NDimR = Dstr::get_num_of_dimension_r();
216 constexpr
index_t idim_p_warp = 0;
220 if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r])
222 constexpr
index_t r_length = DstrEncode::rs_lengths_[idim_r];
228 return num_reduce_warps;
232 template <
typename MeanDistributedTensor_>
238 constexpr
index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
254 return num_warps * 4 * thread_buf_size *
sizeof(float);
257 template <
typename MeanDistributedTensor_,
typename VarDistributedTensor_>
259 VarDistributedTensor_& var_tensor,
263 using DataType =
typename MeanDistributedTensor_::DataType;
264 using Dstr =
typename MeanDistributedTensor_::StaticTileDistribution;
268 static_assert(std::is_same_v<Dstr, typename VarDistributedTensor_::StaticTileDistribution>,
271 constexpr
index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
272 static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size());
276 const index_t lane_id = get_lane_id();
277 const index_t warp_id = get_warp_id();
278 constexpr
auto num_reduce_warps = GetReduceWarps<MeanDistributedTensor_>();
280 const index_t smem_offset = warp_id;
283 if constexpr(num_reduce_warps == 1)
291 local_scratch_[0] = bit_cast<float>(mean_tensor.get_thread_buffer()[i]);
292 local_scratch_[1] = bit_cast<float>(var_tensor.get_thread_buffer()[i]);
295 local_scratch_[2] = bit_cast<float>(count);
297 smem_ptr[smem_offset + i * num_warps] = local_scratch_;
303 index_t local_warp_id = warp_id / num_reduce_warps;
304 index_t local_smem_os = local_warp_id * num_reduce_warps;
305 smem_dtype all_scratch[thread_buf_size * num_reduce_warps];
308 all_scratch[i_0 * num_reduce_warps + i_1] =
309 smem_ptr[i_0 * num_warps + local_smem_os + i_1];
318 auto v_local = all_scratch[i_0 * num_reduce_warps];
319 auto v_local_mean = bit_cast<DataType>(v_local[0]);
320 auto v_local_var = bit_cast<DataType>(v_local[1]);
321 int v_local_count = kWelford ? bit_cast<int>(v_local[2]) : 0;
324 static_for<0, num_reduce_warps - 1, 1>{}([&](
auto i_1_n1) {
326 const smem_dtype v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
327 const auto v_remote_mean = bit_cast<DataType>(v_remote[0]);
328 const auto v_remote_var = bit_cast<DataType>(v_remote[1]);
331 const auto v_remote_count = bit_cast<int>(v_remote[2]);
333 welford_merge(v_local_mean,
343 v_local_mean += v_remote_mean;
344 v_local_var += v_remote_var;
348 mean_tensor.get_thread_buffer()(i_0) = v_local_mean;
349 var_tensor.get_thread_buffer()(i_0) = v_local_var;
351 count = v_local_count;
360 template <
typename BlockShape>
364 using S = BlockShape;
365 index_t LastloopN = row_size % S::Block_N == 0 ? S::Block_N : row_size % S::Block_N;
366 constexpr
index_t NThread = S::WarpPerBlock_N * S::ThreadPerWarp_N;
367 index_t iNLane = get_thread_id() % NThread;
368 index_t iN0 = LastloopN / (S::Vector_N * S::ThreadPerWarp_N);
369 index_t iN1 = (LastloopN % (S::Vector_N * S::ThreadPerWarp_N)) / S::Vector_N;
370 index_t N2 = (LastloopN % (S::Vector_N * S::ThreadPerWarp_N)) % S::Vector_N;
371 index_t iN3 = iNLane < iN1 ? S::Vector_N : iNLane == iN1 ? N2 : 0;
372 return iN0 * S::Vector_N + iN3;
374 using S_ = BlockShape;
375 constexpr
index_t ThreadsPerBlock_N = S_::WarpPerBlock_N * S_::ThreadPerWarp_N;
378 const index_t element_per_row = row_size / S_::Vector_N;
379 index_t lane_id_n = get_thread_id() % ThreadsPerBlock_N;
384 index_t _a = lane_id_n < element_per_row ? 1 : 0;
386 lane_id_n += ThreadsPerBlock_N;
388 return cnt * S_::Vector_N;
392 template <
typename VarDistributedTensor_,
bool FastFdiv_ = false>
397 using DataType =
typename VarDistributedTensor_::DataType;
400 if(FastFdiv_ && std::is_same_v<DataType, float>)
402 x = x * __builtin_amdgcn_rcpf(type_convert<DataType>(count));
406 x = x / type_convert<DataType>(count);
CK_TILE_DEVICE void block_sync_lds()
Definition: arch.hpp:190
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
constexpr CK_TILE_HOST_DEVICE auto make_reduce_tile_distribution_encoding(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition: tile_distribution_encoding.hpp:762
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE bool is_power_of_two_integer(int32_t x)
Definition: math.hpp:462
CK_TILE_DEVICE T warp_shuffle(const T &v_local, uint32_t src_lane)
Definition: utility.hpp:78
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition: tile_elementwise.hpp:23
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE void block_tile_welford_post_scale_var(VarDistributedTensor_ &var_tensor, int count, bool_constant< FastFdiv_ >={})
Definition: block_norm_reduce.hpp:393
constexpr CK_TILE_HOST_DEVICE int32_t integer_log2_floor(int32_t x)
Definition: math.hpp:455
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
CK_TILE_DEVICE void welford_update(T &mean, T &var, T x, int count, bool_constant< kFastFDiv >={})
Definition: thread_welford.hpp:11
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition: tile_elementwise.hpp:177
constexpr CK_TILE_DEVICE index_t block_tile_welford_calculate_max_count(int row_size)
Definition: block_norm_reduce.hpp:361
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
Definition: block_norm_reduce.hpp:199
typename Problem::BlockShape BlockShape
Definition: block_norm_reduce.hpp:201
remove_cvref_t< Problem_ > Problem
Definition: block_norm_reduce.hpp:200
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: block_norm_reduce.hpp:233
std::conditional_t< kWelford, fp32x4_t, fp32x2_t > smem_dtype
Definition: block_norm_reduce.hpp:204
CK_TILE_DEVICE void operator()(MeanDistributedTensor_ &mean_tensor, VarDistributedTensor_ &var_tensor, int &count, void *smem)
Definition: block_norm_reduce.hpp:258
static constexpr CK_TILE_DEVICE index_t GetReduceWarps()
Definition: block_norm_reduce.hpp:207
Definition: block_norm_reduce.hpp:13
constexpr CK_TILE_DEVICE BlockNormReduce()
Definition: block_norm_reduce.hpp:20
typename Problem::ComputeDataType ComputeDataType
Definition: block_norm_reduce.hpp:16
static CK_TILE_DEVICE auto MakeMeanVarBlockTile()
Definition: block_norm_reduce.hpp:68
remove_cvref_t< Problem_ > Problem
Definition: block_norm_reduce.hpp:14
static constexpr bool kWelford
Definition: block_norm_reduce.hpp:18
CK_TILE_DEVICE void operator()(const XDistributedTensor_ &x_tensor, MeanDistributedTensor_ &mean_tensor, VarDistributedTensor_ &var_tensor, int &cur_count_, const int &max_count_)
Definition: block_norm_reduce.hpp:29
CK_TILE_DEVICE auto operator()(const XDistributedTensor_ &x_tensor, int &cur_count_, const int &max_count_)
Definition: block_norm_reduce.hpp:87
typename Problem::XDataType XDataType
Definition: block_norm_reduce.hpp:15
static constexpr bool kFastFDiv
Definition: block_norm_reduce.hpp:17
Definition: block_norm_reduce.hpp:102
CK_TILE_DEVICE void operator()(MeanDistributedTensor_ &mean_tensor, VarDistributedTensor_ &var_tensor, int &count)
Definition: block_norm_reduce.hpp:109
static constexpr bool kWelford
Definition: block_norm_reduce.hpp:105
remove_cvref_t< Problem_ > Problem
Definition: block_norm_reduce.hpp:103
static constexpr bool kFastFDiv
Definition: block_norm_reduce.hpp:104
Definition: integral_constant.hpp:13
static constexpr value_type value
Definition: integral_constant.hpp:16
Definition: sequence.hpp:49
Definition: functional.hpp:43