10 template <
typename Problem_,
typename Policy_ =
void>
20 template <
typename XDistributedTensor_,
21 typename YDistributedTensor_,
25 YDistributedTensor_& y_tensor,
26 const ReduceFunc& reduce_func,
27 ReducePacksPerXDim = {})
29 sweep_tile<XDistributedTensor_>(
32 y_tensor(idx_0) = reduce_func(
33 y_tensor(idx_0), ck_tile::type_convert<ComputeDataType>(x_tensor[idx_])...);
35 ReducePacksPerXDim{});
37 constexpr
auto I0 = number<0>{};
38 constexpr
auto I1 = number<1>{};
39 constexpr
auto spans = XDistributedTensor_::get_distributed_spans();
43 constexpr
auto y_dstr_idx =
make_tuple(dstr_idx_i0);
45 auto y = y_tensor[y_dstr_idx];
48 constexpr
auto in_dstr_idx =
make_tuple(dstr_idx_i0, dstr_idx_i1);
49 const auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
51 y = reduce_func(y, x);
54 y_tensor(y_dstr_idx) = y;
59 template <
typename XDistributedTensor_>
62 static_assert(std::is_same_v<XDataType, typename XDistributedTensor_::DataType>,
"wrong!");
69 XDistributedTensor_::get_tile_distribution()
70 .get_static_tile_distribution_encoding(),
73 auto tensor = make_static_distributed_tensor<ComputeDataType>(dstr);
78 template <
typename XDistributedTensor_,
83 const ReduceFunc& reduce_func,
84 ReducePacksPerXDim = {})
86 auto y_tensor = MakeYBlockTile<XDistributedTensor_>();
88 (*this)(x_tensor, y_tensor, reduce_func, ReducePacksPerXDim{});
94 template <
typename Problem_,
typename Policy_ =
void>
99 template <
typename YDistributedTensor_,
typename ReduceFunc>
102 using Dstr =
typename YDistributedTensor_::StaticTileDistribution;
103 using DstrEncode =
typename Dstr::DstrEncode;
104 using DstrEncodeDetail =
typename DstrEncode::detail;
106 constexpr
index_t NDimP = Dstr::get_num_of_dimension_p();
107 constexpr
index_t NDimR = Dstr::get_num_of_dimension_r();
109 constexpr
index_t idim_p_lane = NDimP - 1;
115 constexpr
index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
119 auto v_local = y_tensor.get_thread_buffer()[i];
126 if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
128 constexpr
index_t r_length = DstrEncode::rs_lengths_[idim_r];
130 constexpr
index_t lid_over_rid_derivative =
131 DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
134 "wrong! only support power of 2 reduction");
143 (
number<lid_over_rid_derivative << istage.
value>{}.value);
149 v_local = reduce_func(v_local, v_remote);
155 y_tensor.get_thread_buffer()(i) = v_local;
160 template <
typename Problem_,
typename Policy_ =
void>
166 template <
typename YDistributedTensor_>
169 constexpr
index_t num_reduce_warps = [&]() {
170 using Dstr =
typename YDistributedTensor_::StaticTileDistribution;
171 using DstrEncode =
typename Dstr::DstrEncode;
172 using DstrEncodeDetail =
typename DstrEncode::detail;
174 constexpr
index_t NDimR = Dstr::get_num_of_dimension_r();
176 constexpr
index_t idim_p_warp = 0;
180 if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r])
182 constexpr
index_t r_length = DstrEncode::rs_lengths_[idim_r];
188 return num_reduce_warps;
192 template <
typename YDistributedTensor_>
195 using DataType =
typename YDistributedTensor_::DataType;
198 constexpr
index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
213 constexpr
index_t num_warps = BlockShape::BlockSize / warpSize;
214 return num_warps * thread_buf_size *
sizeof(DataType);
217 template <
typename YDistributedTensor_,
typename ReduceFunc>
219 operator()(YDistributedTensor_& y_tensor,
void* smem,
const ReduceFunc& reduce_func)
221 using DataType =
typename YDistributedTensor_::DataType;
223 constexpr
index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
225 DataType* smem_ptr =
reinterpret_cast<DataType*
>(smem);
228 constexpr
auto num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
229 constexpr
index_t num_warps = BlockShape::BlockSize / warpSize;
230 const index_t smem_offset = warp_id;
233 if constexpr(num_reduce_warps == 1)
240 smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
246 index_t local_warp_id = warp_id / num_reduce_warps;
247 index_t local_smem_os = local_warp_id * num_reduce_warps;
248 DataType all_scratch[thread_buf_size * num_reduce_warps];
251 all_scratch[i_0 * num_reduce_warps + i_1] =
252 smem_ptr[i_0 * num_warps + local_smem_os + i_1];
259 auto v_local = all_scratch[i_0 * num_reduce_warps];
262 static_for<0, num_reduce_warps - 1, 1>{}([&](
auto i_1_n1) {
264 const DataType v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
267 v_local = reduce_func(v_local, v_remote);
270 y_tensor.get_thread_buffer()(i_0) = v_local;
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
constexpr CK_TILE_HOST_DEVICE auto make_reduce_tile_distribution_encoding(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition: tile_distribution_encoding.hpp:725
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE index_t get_lane_id()
Definition: arch.hpp:69
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:63
constexpr CK_TILE_HOST_DEVICE bool is_power_of_two_integer(int32_t x)
Definition: math.hpp:462
CK_TILE_DEVICE T warp_shuffle(const T &v_local, uint32_t src_lane)
Definition: utility.hpp:63
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
constexpr CK_TILE_HOST_DEVICE int32_t integer_log2_floor(int32_t x)
Definition: math.hpp:455
CK_TILE_DEVICE index_t get_warp_id()
Definition: arch.hpp:71
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:498
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:1017
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: block_reduce2d.hpp:162
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: block_reduce2d.hpp:193
remove_cvref_t< Problem_ > Problem
Definition: block_reduce2d.hpp:163
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, void *smem, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:219
typename Problem::BlockShape BlockShape
Definition: block_reduce2d.hpp:164
static constexpr CK_TILE_DEVICE index_t GetReduceWarps()
Definition: block_reduce2d.hpp:167
Definition: block_reduce2d.hpp:12
constexpr CK_TILE_DEVICE BlockReduce2d()
Definition: block_reduce2d.hpp:18
typename Problem::ComputeDataType ComputeDataType
Definition: block_reduce2d.hpp:16
static CK_TILE_DEVICE auto MakeYBlockTile()
Definition: block_reduce2d.hpp:60
CK_TILE_DEVICE void operator()(const XDistributedTensor_ &x_tensor, YDistributedTensor_ &y_tensor, const ReduceFunc &reduce_func, ReducePacksPerXDim={})
Definition: block_reduce2d.hpp:24
remove_cvref_t< Problem_ > Problem
Definition: block_reduce2d.hpp:14
CK_TILE_DEVICE auto operator()(const XDistributedTensor_ &x_tensor, const ComputeDataType &reduce_init, const ReduceFunc &reduce_func, ReducePacksPerXDim={})
Definition: block_reduce2d.hpp:81
typename Problem::XDataType XDataType
Definition: block_reduce2d.hpp:15
Definition: block_reduce2d.hpp:96
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:100
remove_cvref_t< Problem_ > Problem
Definition: block_reduce2d.hpp:97
Definition: integral_constant.hpp:13
static constexpr value_type value
Definition: integral_constant.hpp:16
Definition: sequence.hpp:52
Definition: functional.hpp:43