/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/topk/block/block_topk_stream_2d.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/topk/block/block_topk_stream_2d.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/topk/block/block_topk_stream_2d.hpp Source File
block_topk_stream_2d.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 
8 namespace ck_tile {
9 
10 /*
11 simple 2d topk implementation, along row (dim=1)
12 requirement:
13  1). each row is within a warp
14 */
15 template <typename Problem_, typename Policy_ = void>
17 {
20 
21  using DataType = typename Problem::DataType;
22  using IndexType = typename Problem::IndexType;
23 
24  // TODO: if DataType is subdword, need pack into single dword to use argmax
25  struct ArgmaxPacket
26  {
29  };
30 
31  template <typename DistributedTensor, typename OutWindow, typename IdxWindow, index_t dim = 1>
32  CK_TILE_DEVICE void operator()(const DistributedTensor& x,
33  const OutWindow& out_window,
34  const IdxWindow& idx_window,
35  index_t k,
36  number<dim> = {})
37  {
38  OutWindow out_window_tmp = out_window;
39  IdxWindow idx_window_tmp = idx_window;
40  static_assert(
41  std::is_same_v<typename DistributedTensor::DataType, typename OutWindow::DataType> &&
42  std::is_same_v<typename DistributedTensor::DataType, DataType>);
43  static_assert(std::is_same_v<typename IdxWindow::DataType, IndexType>);
44 
45  DistributedTensor x_tmp = x;
46  constexpr auto dst_dist = typename IdxWindow::TileDstr{};
47 
48  // argmax for topk
49  const auto f_argmax = [](ArgmaxPacket e0, ArgmaxPacket e1) {
50  return e0.arg > e1.arg ? e0 : e1;
51  };
52 
53  for(index_t i_k = 0; i_k < k; i_k++)
54  {
55  constexpr auto span_2d = DistributedTensor::get_distributed_spans();
56  auto packet = [&]() {
57  auto tmp = make_static_distributed_tensor<ArgmaxPacket>(x.get_tile_distribution());
58 
59  sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
60  sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
61  const auto tile_idx = get_x_indices_from_distributed_indices(
62  tmp.get_tile_distribution(), make_tuple(idx0, idx1));
63  constexpr auto i_j_idx = make_tuple(idx0, idx1);
64  ArgmaxPacket t;
65  t.arg = x_tmp(i_j_idx); // !!! we reference x here
66  t.value = tile_idx.at(number<1>{});
67  tmp(i_j_idx) = t;
68  });
69  });
70  return tmp;
71  }();
72 
73  auto argmax_init = ArgmaxPacket{-numeric<DataType>::infinity(), 0};
74  auto r = block_tile_reduce<ArgmaxPacket>(packet, sequence<1>{}, f_argmax, argmax_init);
75  block_tile_reduce_xor_sync(r, f_argmax);
76 
77  auto o = make_static_distributed_tensor<DataType>(dst_dist);
78  auto i = make_static_distributed_tensor<IndexType>(dst_dist);
79  sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
80  sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
81  constexpr auto i_j_idx = make_tuple(idx0, idx1);
82  ArgmaxPacket tmp = r(i_j_idx);
83  o(i_j_idx) = tmp.arg;
84  i(i_j_idx) = tmp.value;
85  });
86  });
87 
88  // update value
89  sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
90  sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
91  const auto tile_idx = get_x_indices_from_distributed_indices(
92  x.get_tile_distribution(), make_tuple(idx0, idx1));
93  auto col_id = tile_idx.at(number<1>{});
94 
95  constexpr auto i_j_idx = make_tuple(idx0, idx1);
96 
97  x_tmp(i_j_idx) = (col_id == r(i_j_idx).value) ? -numeric<DataType>::infinity()
98  : x_tmp(i_j_idx);
99  });
100  });
101 
102  if(threadIdx.x % Problem::ColLanes == 0)
103  {
104  store_tile(out_window_tmp, o);
105  store_tile(idx_window_tmp, i);
106  }
107  move_tile_window(out_window_tmp, {number<0>{}, number<1>{}});
108  move_tile_window(idx_window_tmp, {number<0>{}, number<1>{}});
109  }
110  }
111 };
112 
113 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE void block_tile_reduce_xor_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func)
Definition: block_reduce.hpp:132
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, DistributedIndices distributed_indices)
Definition: static_distributed_tensor.hpp:159
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
Definition: block_topk_stream_2d.hpp:26
DataType arg
Definition: block_topk_stream_2d.hpp:27
index_t value
Definition: block_topk_stream_2d.hpp:28
Definition: block_topk_stream_2d.hpp:17
remove_cvref_t< Policy_ > Policy
Definition: block_topk_stream_2d.hpp:19
CK_TILE_DEVICE void operator()(const DistributedTensor &x, const OutWindow &out_window, const IdxWindow &idx_window, index_t k, number< dim >={})
Definition: block_topk_stream_2d.hpp:32
remove_cvref_t< Problem_ > Problem
Definition: block_topk_stream_2d.hpp:18
typename Problem::IndexType IndexType
Definition: block_topk_stream_2d.hpp:22
typename Problem::DataType DataType
Definition: block_topk_stream_2d.hpp:21
Definition: integral_constant.hpp:13
static constexpr CK_TILE_HOST_DEVICE T infinity()
Definition: numeric.hpp:38