43 template <
typename Problem_,
typename Policy_ =
void>
54 typename XDistributedTensor_,
55 typename YDistributedTensor_,
57 typename ReducePacksPerXDim =
60 YDistributedTensor_& y_tensor,
61 const ReduceFunc& reduce_func,
62 ReducePacksPerXDim = {})
64 sweep_tile<XDistributedTensor_>(
67 y_tensor(idx_0) = reduce_func(
68 y_tensor(idx_0), ck_tile::type_convert<ComputeDataType>(x_tensor[idx_])...);
70 ReducePacksPerXDim{});
73 constexpr
auto I0 = number<0>{};
74 constexpr
auto I1 = number<1>{};
75 constexpr
auto spans = XDistributedTensor_::get_distributed_spans();
79 constexpr
auto y_dstr_idx =
make_tuple(dstr_idx_i0);
81 auto y = y_tensor[y_dstr_idx];
84 constexpr
auto in_dstr_idx =
make_tuple(dstr_idx_i0, dstr_idx_i1);
85 const auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
87 y = reduce_func(y, x);
90 y_tensor(y_dstr_idx) = y;
95 template <
typename XDistributedTensor_>
98 static_assert(std::is_same_v<XDataType, typename XDistributedTensor_::DataType>,
"wrong!");
103 constexpr
auto dstr =
105 XDistributedTensor_::get_tile_distribution()
106 .get_static_tile_distribution_encoding(),
109 auto tensor = make_static_distributed_tensor<ComputeDataType>(dstr);
116 template <
typename XDistributedTensor_,
121 const ReduceFunc& reduce_func,
122 ReducePacksPerXDim = {})
124 auto y_tensor = MakeYBlockTile<XDistributedTensor_>();
126 (*this)(x_tensor, y_tensor, reduce_func, ReducePacksPerXDim{});
133 template <
typename Problem_,
typename Policy_ =
void>
138 template <
typename YDistributedTensor_,
typename ReduceFunc>
141 using Dstr =
typename YDistributedTensor_::StaticTileDistribution;
142 using DstrEncode =
typename Dstr::DstrEncode;
143 using DstrEncodeDetail =
typename DstrEncode::detail;
145 constexpr
index_t NDimP = Dstr::get_num_of_dimension_p();
146 constexpr
index_t NDimR = Dstr::get_num_of_dimension_r();
148 constexpr
index_t idim_p_lane = NDimP - 1;
154 constexpr
index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
158 auto v_local = y_tensor.get_thread_buffer()[i];
165 if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
167 constexpr
index_t r_length = DstrEncode::rs_lengths_[idim_r];
169 constexpr
index_t lid_over_rid_derivative =
170 DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
173 "wrong! only support power of 2 reduction");
182 (
number<lid_over_rid_derivative << istage.
value>{}.value);
186 v_local = reduce_func(v_local, v_remote);
192 y_tensor.get_thread_buffer()(i) = v_local;
198 template <
typename Problem_,
typename Policy_ =
void>
204 template <
typename YDistributedTensor_>
207 constexpr
index_t num_reduce_warps = [&]() {
208 using Dstr =
typename YDistributedTensor_::StaticTileDistribution;
209 using DstrEncode =
typename Dstr::DstrEncode;
210 using DstrEncodeDetail =
typename DstrEncode::detail;
212 constexpr
index_t NDimR = Dstr::get_num_of_dimension_r();
214 constexpr
index_t idim_p_warp = 0;
218 if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r])
220 constexpr
index_t r_length = DstrEncode::rs_lengths_[idim_r];
226 return num_reduce_warps;
230 template <
typename YDistributedTensor_>
233 using DataType =
typename YDistributedTensor_::DataType;
236 constexpr
index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
252 return num_warps * thread_buf_size *
sizeof(DataType);
255 template <
typename YDistributedTensor_,
typename ReduceFunc>
257 operator()(YDistributedTensor_& y_tensor,
void* smem,
const ReduceFunc& reduce_func)
259 using DataType =
typename YDistributedTensor_::DataType;
261 constexpr
index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
263 DataType* smem_ptr =
reinterpret_cast<DataType*
>(smem);
264 const index_t lane_id = get_lane_id();
265 const index_t warp_id = get_warp_id();
266 constexpr
auto num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
268 const index_t smem_offset = warp_id;
271 if constexpr(num_reduce_warps == 1)
278 smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
284 index_t local_warp_id = warp_id / num_reduce_warps;
285 index_t local_smem_os = local_warp_id * num_reduce_warps;
286 DataType all_scratch[thread_buf_size * num_reduce_warps];
289 all_scratch[i_0 * num_reduce_warps + i_1] =
290 smem_ptr[i_0 * num_warps + local_smem_os + i_1];
297 auto v_local = all_scratch[i_0 * num_reduce_warps];
300 static_for<0, num_reduce_warps - 1, 1>{}([&](
auto i_1_n1) {
302 const DataType v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
303 v_local = reduce_func(v_local, v_remote);
306 y_tensor.get_thread_buffer()(i_0) = v_local;
311 template <
typename Problem_,
typename Policy_ =
void>
317 template <
typename YDistributedTensor_>
320 constexpr
index_t num_reduce_warps = [&]() {
321 using Dstr =
typename YDistributedTensor_::StaticTileDistribution;
322 using DstrEncode =
typename Dstr::DstrEncode;
323 using DstrEncodeDetail =
typename DstrEncode::detail;
325 constexpr
index_t NDimR = Dstr::get_num_of_dimension_r();
327 constexpr
index_t idim_p_warp = 0;
331 if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r])
333 constexpr
index_t r_length = DstrEncode::rs_lengths_[idim_r];
339 return num_reduce_warps;
343 template <
typename YDistributedTensor_>
346 using DataType =
typename YDistributedTensor_::DataType;
347 constexpr
index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
362 constexpr
index_t num_warps = BlockShape::BlockSize / warpSize;
363 return num_warps * thread_buf_size *
sizeof(DataType);
366 template <
typename YDistributedTensor_,
typename ReduceFunc>
368 operator()(YDistributedTensor_& y_tensor,
void* smem,
const ReduceFunc& reduce_func)
370 using Dstr =
typename YDistributedTensor_::StaticTileDistribution;
371 using DstrEncode =
typename Dstr::DstrEncode;
372 using DstrEncodeDetail =
typename DstrEncode::detail;
373 using DataType =
typename YDistributedTensor_::DataType;
375 constexpr
index_t NDimP = Dstr::get_num_of_dimension_p();
376 constexpr
index_t NDimR = Dstr::get_num_of_dimension_r();
378 constexpr
index_t idim_p_lane = NDimP - 1;
379 constexpr
index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
381 DataType* smem_ptr =
reinterpret_cast<DataType*
>(smem);
382 const index_t lane_id = get_lane_id();
383 const index_t warp_id = get_warp_id();
386 constexpr
index_t num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
388 if constexpr(num_reduce_warps == 1)
392 const index_t smem_offset = warp_id;
397 smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
405 if(lane_id < num_reduce_warps)
407 v = smem_ptr[lane_id + i * num_warps];
415 if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
417 constexpr
index_t r_length = DstrEncode::rs_lengths_[idim_r];
419 constexpr
index_t lid_over_rid_derivative =
420 DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
423 "wrong! only support power of 2 reduction");
431 __shfl_xor(v,
number<lid_over_rid_derivative << istage.
value>{}.value);
434 v = reduce_func(v, o);
439 y_tensor.get_thread_buffer()(i) = v;
CK_TILE_DEVICE void block_sync_lds()
Definition: arch.hpp:190
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
constexpr CK_TILE_HOST_DEVICE auto make_reduce_tile_distribution_encoding(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition: tile_distribution_encoding.hpp:762
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:95
constexpr CK_TILE_HOST_DEVICE bool is_power_of_two_integer(int32_t x)
Definition: math.hpp:462
CK_TILE_DEVICE T warp_shuffle(const T &v_local, uint32_t src_lane)
Definition: utility.hpp:78
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE int32_t integer_log2_floor(int32_t x)
Definition: math.hpp:455
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:1023
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
Definition: block_reduce2d.hpp:200
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: block_reduce2d.hpp:231
remove_cvref_t< Problem_ > Problem
Definition: block_reduce2d.hpp:201
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, void *smem, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:257
typename Problem::BlockShape BlockShape
Definition: block_reduce2d.hpp:202
static constexpr CK_TILE_DEVICE index_t GetReduceWarps()
Definition: block_reduce2d.hpp:205
Definition: block_reduce2d.hpp:45
constexpr CK_TILE_DEVICE BlockReduce2d()
Definition: block_reduce2d.hpp:51
typename Problem::ComputeDataType ComputeDataType
Definition: block_reduce2d.hpp:49
static CK_TILE_DEVICE auto MakeYBlockTile()
Definition: block_reduce2d.hpp:96
CK_TILE_DEVICE void operator()(const XDistributedTensor_ &x_tensor, YDistributedTensor_ &y_tensor, const ReduceFunc &reduce_func, ReducePacksPerXDim={})
Definition: block_reduce2d.hpp:59
remove_cvref_t< Problem_ > Problem
Definition: block_reduce2d.hpp:47
CK_TILE_DEVICE auto operator()(const XDistributedTensor_ &x_tensor, const ComputeDataType &reduce_init, const ReduceFunc &reduce_func, ReducePacksPerXDim={})
Definition: block_reduce2d.hpp:119
typename Problem::XDataType XDataType
Definition: block_reduce2d.hpp:48
Definition: block_reduce2d.hpp:135
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:139
remove_cvref_t< Problem_ > Problem
Definition: block_reduce2d.hpp:136
Definition: block_reduce2d.hpp:313
static constexpr CK_TILE_DEVICE index_t GetReduceWarps()
Definition: block_reduce2d.hpp:318
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: block_reduce2d.hpp:344
typename Problem::BlockShape BlockShape
Definition: block_reduce2d.hpp:315
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, void *smem, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:368
remove_cvref_t< Problem_ > Problem
Definition: block_reduce2d.hpp:314
Definition: integral_constant.hpp:13
static constexpr value_type value
Definition: integral_constant.hpp:16
Definition: sequence.hpp:49
Definition: functional.hpp:43