44 template <
typename Problem_,
typename Policy_ =
void>
55 template <
bool kProcessIndex,
56 typename XDistributedTensor_,
57 typename YDistributedTensor_,
58 typename YIndexDistributedTensor_,
60 typename IndexCalculatorFunc,
61 typename ReducePacksPerXDim>
62 CK_TILE_DEVICE void reduce_impl(
const XDistributedTensor_& x_tensor,
63 YDistributedTensor_& y_tensor,
64 YIndexDistributedTensor_& y_index_tensor,
65 const ReduceFunc& reduce_func,
66 const IndexCalculatorFunc& index_calculator,
69 sweep_tile<XDistributedTensor_>(
74 auto val = ck_tile::type_convert<ComputeDataType>(x_tensor[idx]);
76 if constexpr(kProcessIndex)
80 XDistributedTensor_::get_tile_distribution(), idx);
81 const auto new_idx = index_calculator(x_indices);
82 auto current_idx = y_index_tensor(idx_0);
84 AccumulateWithIndex{}(
85 reduce_func, y_tensor(idx_0), current_idx, val, new_idx);
87 y_index_tensor(idx_0) =
88 type_convert<typename YIndexDistributedTensor_::DataType>(current_idx);
92 Accumulate{}(reduce_func, y_tensor(idx_0), val);
96 ReducePacksPerXDim{});
102 typename XDistributedTensor_,
103 typename YDistributedTensor_,
105 typename ReducePacksPerXDim =
106 uniform_sequence_gen_t<2, 1>>
108 YDistributedTensor_& y_tensor,
109 const ReduceFunc& reduce_func,
110 ReducePacksPerXDim = {})
117 [](
auto) {
return 0; },
118 ReducePacksPerXDim{});
122 template <
typename XDistributedTensor_,
123 typename YDistributedTensor_,
124 typename YIndexDistributedTensor_,
126 typename IndexCalculatorFunc,
127 typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
129 YDistributedTensor_& y_tensor,
130 YIndexDistributedTensor_& y_index_tensor,
131 const ReduceFunc& reduce_func,
132 const IndexCalculatorFunc& index_calculator,
133 ReducePacksPerXDim = {})
135 reduce_impl<Problem::kOutputIndex>(x_tensor,
140 ReducePacksPerXDim{});
144 constexpr
auto I0 = number<0>{};
145 constexpr
auto I1 = number<1>{};
146 constexpr
auto spans = XDistributedTensor_::get_distributed_spans();
150 constexpr
auto y_dstr_idx =
make_tuple(dstr_idx_i0);
152 auto y = y_tensor[y_dstr_idx];
155 constexpr
auto in_dstr_idx =
make_tuple(dstr_idx_i0, dstr_idx_i1);
156 const auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
158 y = reduce_func(y, x);
161 y_tensor(y_dstr_idx) = y;
165 template <
typename XDistributedTensor_>
168 static_assert(std::is_same_v<XDataType, typename XDistributedTensor_::DataType>,
"wrong!");
173 constexpr
auto dstr =
175 XDistributedTensor_::get_tile_distribution()
176 .get_static_tile_distribution_encoding(),
179 auto tensor = make_static_distributed_tensor<ComputeDataType>(dstr);
184 template <
typename XDistributedTensor_,
typename IndexDataType = index_t>
187 static_assert(std::is_same_v<XDataType, typename XDistributedTensor_::DataType>,
"wrong!");
192 constexpr
auto dstr =
194 XDistributedTensor_::get_tile_distribution()
195 .get_static_tile_distribution_encoding(),
198 auto tensor = make_static_distributed_tensor<IndexDataType>(dstr);
205 template <
typename XDistributedTensor_,
210 const ReduceFunc& reduce_func,
211 ReducePacksPerXDim = {})
213 auto y_tensor = MakeYBlockTile<XDistributedTensor_>();
215 (*this)(x_tensor, y_tensor, reduce_func, ReducePacksPerXDim{});
222 template <
typename Problem_,
typename Policy_ =
void>
228 template <
bool kProcessIndex,
229 typename YDistributedTensor_,
230 typename YIndexDistributedTensor_,
233 YIndexDistributedTensor_& y_index_tensor,
234 const ReduceFunc& reduce_func)
236 using Dstr =
typename YDistributedTensor_::StaticTileDistribution;
237 using DstrEncode =
typename Dstr::DstrEncode;
238 using DstrEncodeDetail =
typename DstrEncode::detail;
240 constexpr
index_t NDimP = Dstr::get_num_of_dimension_p();
241 constexpr
index_t NDimR = Dstr::get_num_of_dimension_r();
243 constexpr
index_t idim_p_lane = NDimP - 1;
249 constexpr
index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
253 auto v_local = y_tensor.get_thread_buffer()[i];
255 using IndexDataType =
typename YIndexDistributedTensor_::DataType;
256 IndexDataType idx_local{};
258 if constexpr(kProcessIndex)
260 idx_local = y_index_tensor.get_thread_buffer()[i];
268 if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
270 constexpr
index_t r_length = DstrEncode::rs_lengths_[idim_r];
272 constexpr
index_t lid_over_rid_derivative =
273 DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
276 "wrong! only support power of 2 reduction");
285 (
number<lid_over_rid_derivative << istage.
value>{}.value);
290 if constexpr(kProcessIndex)
292 const auto idx_remote =
warp_shuffle(idx_local, src_lane);
295 reduce_func, v_local, idx_local, v_remote, idx_remote);
306 y_tensor.get_thread_buffer()(i) = v_local;
308 if constexpr(kProcessIndex)
310 y_index_tensor.get_thread_buffer()(i) = idx_local;
316 template <
typename YDistributedTensor_,
typename ReduceFunc>
319 reduce_impl<false>(y_tensor, y_tensor, reduce_func);
322 template <
typename YDistributedTensor_,
typename YIndexDistributedTensor_,
typename ReduceFunc>
324 YIndexDistributedTensor_& y_index_tensor,
325 const ReduceFunc& reduce_func)
327 reduce_impl<Problem::kOutputIndex>(y_tensor, y_index_tensor, reduce_func);
332 template <
typename Problem_,
typename Policy_ =
void>
338 template <
typename YDistributedTensor_>
341 constexpr
index_t num_reduce_warps = [&]() {
342 using Dstr =
typename YDistributedTensor_::StaticTileDistribution;
343 using DstrEncode =
typename Dstr::DstrEncode;
344 using DstrEncodeDetail =
typename DstrEncode::detail;
346 constexpr
index_t NDimR = Dstr::get_num_of_dimension_r();
348 constexpr
index_t idim_p_warp = 0;
352 if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r])
354 constexpr
index_t r_length = DstrEncode::rs_lengths_[idim_r];
360 return num_reduce_warps;
364 template <
typename YDistributedTensor_>
367 using DataType =
typename YDistributedTensor_::DataType;
368 constexpr
index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
384 return num_warps * thread_buf_size *
sizeof(DataType);
388 template <
typename YIndexDistributedTensor_>
391 using IndexDataType =
typename YIndexDistributedTensor_::DataType;
392 constexpr
index_t thread_buf_size = YIndexDistributedTensor_::get_thread_buffer_size();
394 return num_warps * thread_buf_size *
sizeof(IndexDataType);
398 template <
bool kProcessIndex,
399 typename YDistributedTensor_,
400 typename YIndexDistributedTensor_,
403 YIndexDistributedTensor_& y_index_tensor,
405 void* smem_indices_ptr,
406 const ReduceFunc& reduce_func)
408 using DataType =
typename YDistributedTensor_::DataType;
409 using IndexDataType =
typename YIndexDistributedTensor_::DataType;
411 constexpr
index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
413 DataType* smem_ptr =
reinterpret_cast<DataType*
>(smem);
414 IndexDataType* smem_indices =
nullptr;
415 if constexpr(kProcessIndex)
417 smem_indices =
reinterpret_cast<IndexDataType*
>(smem_indices_ptr);
420 const index_t lane_id = get_lane_id();
421 const index_t warp_id = get_warp_id();
424 constexpr
index_t num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
426 if constexpr(num_reduce_warps == 1)
430 const
index_t smem_offset = warp_id;
433 static_for<0, thread_buf_size, 1>{}([&](
auto i) {
435 smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
436 if constexpr(kProcessIndex)
438 smem_indices[smem_offset + i * num_warps] =
439 y_index_tensor.get_thread_buffer()[i];
446 const index_t local_warp_id = warp_id / num_reduce_warps;
447 const index_t local_smem_os = local_warp_id * num_reduce_warps;
449 static_for<0, thread_buf_size, 1>{}([&](
auto i) {
450 DataType v[num_reduce_warps];
451 [[maybe_unused]] std::
452 conditional_t<kProcessIndex, IndexDataType[num_reduce_warps], IndexDataType> idx_v;
454 static_for<0, num_reduce_warps, 1>{}([&](
auto idx) {
455 v[idx] = smem_ptr[i * num_warps + local_smem_os + idx];
456 if constexpr(kProcessIndex)
458 idx_v[idx] = smem_indices[i * num_warps + local_smem_os + idx];
463 "wrong! only support power of 2 reduction");
467 static_for<0, nstage, 1>{}([&](
auto istage) {
468 constexpr
index_t stride = 1 << istage.value;
469 static_for<0, num_reduce_warps, stride * 2>{}([&](
auto idx_) {
471 constexpr
index_t i1 = idx_ + stride;
472 if constexpr(i1 < num_reduce_warps)
474 if constexpr(kProcessIndex)
476 AccumulateWithIndex{}(reduce_func, v[i0], idx_v[i0], v[i1], idx_v[i1]);
480 Accumulate{}(reduce_func, v[i0], v[i1]);
486 y_tensor.get_thread_buffer()(i) = v[0];
487 if constexpr(kProcessIndex)
489 y_index_tensor.get_thread_buffer()(i) = idx_v[0];
495 template <
typename YDistributedTensor_,
typename ReduceFunc>
497 operator()(YDistributedTensor_& y_tensor,
void* smem,
const ReduceFunc& reduce_func)
499 reduce_impl<false>(y_tensor, y_tensor, smem,
nullptr, reduce_func);
502 template <
typename YDistributedTensor_,
typename YIndexDistributedTensor_,
typename ReduceFunc>
504 YIndexDistributedTensor_& y_index_tensor,
507 const ReduceFunc& reduce_func)
509 reduce_impl<Problem::kOutputIndex>(
510 y_tensor, y_index_tensor, smem, smem_indices, reduce_func);
514 template <
typename Problem_,
typename Policy_ =
void>
520 template <
typename YDistributedTensor_>
523 constexpr
index_t num_reduce_warps = [&]() {
524 using Dstr =
typename YDistributedTensor_::StaticTileDistribution;
525 using DstrEncode =
typename Dstr::DstrEncode;
526 using DstrEncodeDetail =
typename DstrEncode::detail;
528 constexpr
index_t NDimR = Dstr::get_num_of_dimension_r();
530 constexpr
index_t idim_p_warp = 0;
534 if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r])
536 constexpr
index_t r_length = DstrEncode::rs_lengths_[idim_r];
542 return num_reduce_warps;
546 template <
typename YDistributedTensor_>
549 using DataType =
typename YDistributedTensor_::DataType;
550 constexpr
index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
566 return num_warps * thread_buf_size *
sizeof(DataType);
570 template <
typename YIndexDistributedTensor_>
573 using IndexDataType =
typename YIndexDistributedTensor_::DataType;
574 constexpr
index_t thread_buf_size = YIndexDistributedTensor_::get_thread_buffer_size();
576 return num_warps * thread_buf_size *
sizeof(IndexDataType);
580 template <
bool kProcessIndex,
581 typename YDistributedTensor_,
582 typename YIndexDistributedTensor_,
585 YIndexDistributedTensor_& y_index_tensor,
587 void* smem_indices_ptr,
588 const ReduceFunc& reduce_func)
590 using DataType =
typename YDistributedTensor_::DataType;
591 using IndexDataType =
typename YIndexDistributedTensor_::DataType;
593 constexpr
index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
595 DataType* smem_ptr =
reinterpret_cast<DataType*
>(smem);
596 IndexDataType* smem_indices =
nullptr;
597 if constexpr(kProcessIndex)
599 smem_indices =
reinterpret_cast<IndexDataType*
>(smem_indices_ptr);
602 const index_t lane_id = get_lane_id();
603 const index_t warp_id = get_warp_id();
604 constexpr
auto num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
606 const index_t smem_offset = warp_id;
609 if constexpr(num_reduce_warps == 1)
615 static_for<0, thread_buf_size, 1>{}([&](
auto i) {
616 smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
617 if constexpr(kProcessIndex)
619 smem_indices[smem_offset + i * num_warps] =
620 y_index_tensor.get_thread_buffer()[i];
627 index_t local_warp_id = warp_id / num_reduce_warps;
628 index_t local_smem_os = local_warp_id * num_reduce_warps;
630 DataType all_scratch[thread_buf_size * num_reduce_warps];
632 IndexDataType[thread_buf_size * num_reduce_warps],
633 IndexDataType> all_indices;
636 static_for<0, thread_buf_size, 1>{}([&](
auto i_0) {
637 static_for<0, num_reduce_warps, 1>{}([&](
auto i_1) {
638 all_scratch[i_0 * num_reduce_warps + i_1] =
639 smem_ptr[i_0 * num_warps + local_smem_os + i_1];
641 if constexpr(kProcessIndex)
643 all_indices[i_0 * num_reduce_warps + i_1] =
644 smem_indices[i_0 * num_warps + local_smem_os + i_1];
651 static_for<0, thread_buf_size, 1>{}([&](
auto i_0) {
653 auto v_local = all_scratch[i_0 * num_reduce_warps];
655 IndexDataType idx_local{};
656 if constexpr(kProcessIndex)
658 idx_local = all_indices[i_0 * num_reduce_warps];
662 static_for<0, num_reduce_warps - 1, 1>{}([&](
auto i_1_n1) {
663 constexpr
auto i_1 = number<i_1_n1 + 1>{};
664 const DataType v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
666 if constexpr(kProcessIndex)
668 const IndexDataType idx_remote = all_indices[i_0 * num_reduce_warps + i_1];
670 bool changed =
false;
671 v_local = reduce_func(v_local, v_remote, changed);
674 idx_local = idx_remote;
679 v_local = reduce_func(v_local, v_remote);
683 y_tensor.get_thread_buffer()(i_0) = v_local;
684 if constexpr(kProcessIndex)
686 y_index_tensor.get_thread_buffer()(i_0) = idx_local;
692 template <
typename YDistributedTensor_,
typename ReduceFunc>
694 operator()(YDistributedTensor_& y_tensor,
void* smem,
const ReduceFunc& reduce_func)
696 reduce_impl<false>(y_tensor, y_tensor, smem,
nullptr, reduce_func);
699 template <
typename YDistributedTensor_,
typename YIndexDistributedTensor_,
typename ReduceFunc>
701 YIndexDistributedTensor_& y_index_tensor,
704 const ReduceFunc& reduce_func)
706 reduce_impl<Problem::kOutputIndex>(
707 y_tensor, y_index_tensor, smem, smem_indices, reduce_func);
CK_TILE_DEVICE void block_sync_lds()
Definition: arch.hpp:245
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
constexpr CK_TILE_HOST_DEVICE auto make_reduce_tile_distribution_encoding(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition: tile_distribution_encoding.hpp:762
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:95
constexpr CK_TILE_HOST_DEVICE bool is_power_of_two_integer(int32_t x)
Definition: math.hpp:462
CK_TILE_DEVICE T warp_shuffle(const T &v_local, uint32_t src_lane)
Definition: utility.hpp:78
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE int32_t integer_log2_floor(int32_t x)
Definition: math.hpp:455
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, DistributedIndices distributed_indices)
Definition: static_distributed_tensor.hpp:159
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:1026
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
int32_t index_t
Definition: ck.hpp:299
Definition: reduce_operator_accumulate.hpp:41
Accumulate with index tracking reductions, provides deterministic first occurring index.
Definition: reduce_operator_accumulate.hpp:12
Definition: block_reduce2d.hpp:334
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: block_reduce2d.hpp:365
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, YIndexDistributedTensor_ &y_index_tensor, void *smem, void *smem_indices, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:503
remove_cvref_t< Problem_ > Problem
Definition: block_reduce2d.hpp:335
static constexpr CK_TILE_HOST_DEVICE index_t GetIndicesSmemSize()
Definition: block_reduce2d.hpp:389
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, void *smem, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:497
typename Problem::BlockShape BlockShape
Definition: block_reduce2d.hpp:336
static constexpr CK_TILE_DEVICE index_t GetReduceWarps()
Definition: block_reduce2d.hpp:339
Definition: block_reduce2d.hpp:46
constexpr CK_TILE_DEVICE BlockReduce2d()
Definition: block_reduce2d.hpp:52
CK_TILE_DEVICE void operator()(const XDistributedTensor_ &x_tensor, YDistributedTensor_ &y_tensor, YIndexDistributedTensor_ &y_index_tensor, const ReduceFunc &reduce_func, const IndexCalculatorFunc &index_calculator, ReducePacksPerXDim={})
Definition: block_reduce2d.hpp:128
typename Problem::ComputeDataType ComputeDataType
Definition: block_reduce2d.hpp:50
static CK_TILE_DEVICE auto MakeYBlockTile()
Definition: block_reduce2d.hpp:166
CK_TILE_DEVICE void operator()(const XDistributedTensor_ &x_tensor, YDistributedTensor_ &y_tensor, const ReduceFunc &reduce_func, ReducePacksPerXDim={})
Definition: block_reduce2d.hpp:107
remove_cvref_t< Problem_ > Problem
Definition: block_reduce2d.hpp:48
CK_TILE_DEVICE auto operator()(const XDistributedTensor_ &x_tensor, const ComputeDataType &reduce_init, const ReduceFunc &reduce_func, ReducePacksPerXDim={})
Definition: block_reduce2d.hpp:208
typename Problem::XDataType XDataType
Definition: block_reduce2d.hpp:49
static CK_TILE_DEVICE auto MakeYIndexBlockTile()
Definition: block_reduce2d.hpp:185
Definition: block_reduce2d.hpp:516
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, YIndexDistributedTensor_ &y_index_tensor, void *smem, void *smem_indices, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:700
remove_cvref_t< Problem_ > Problem
Definition: block_reduce2d.hpp:517
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, void *smem, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:694
typename Problem::BlockShape BlockShape
Definition: block_reduce2d.hpp:518
static constexpr CK_TILE_DEVICE index_t GetReduceWarps()
Definition: block_reduce2d.hpp:521
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: block_reduce2d.hpp:547
static constexpr CK_TILE_HOST_DEVICE index_t GetIndicesSmemSize()
Definition: block_reduce2d.hpp:571
Definition: block_reduce2d.hpp:224
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:317
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, YIndexDistributedTensor_ &y_index_tensor, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:323
remove_cvref_t< Problem_ > Problem
Definition: block_reduce2d.hpp:225
Definition: integral_constant.hpp:13
static constexpr value_type value
Definition: integral_constant.hpp:16
Definition: sequence.hpp:49
Definition: functional.hpp:43