/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp Source File
topk_softmax_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
10 #include <string>
11 #include <type_traits>
12 
13 namespace ck_tile {
14 
16 {
17  const void* p_input;
18  void* p_output;
19  void* p_indices;
23  index_t stride_input; // row stride for input, at least experts
24  index_t stride_output; // row stride for output/indices, at least tpok
25 };
26 
27 template <typename Pipeline_>
29 {
32 
33  using InputType = typename Problem::InputType;
34  using WeightType = typename Problem::WeightType;
35  using IndexType = typename Problem::IndexType;
36 
37  static constexpr index_t kBlockSize = Problem::BlockSize;
38 
40  {
41  const void* p_input;
42  void* p_output;
43  void* p_indices;
47  index_t stride_input; // row stride for input, at least experts
48  index_t stride_output; // row stride for output/indices, at least tpok
49  };
50 
53 
54  CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
55  {
56  if constexpr(Problem::LaunchType > 0)
57  {
58  int num_cu = [&]() {
59  hipDeviceProp_t dev_prop;
60  hipDevice_t dev;
61  HIP_CHECK_ERROR(hipGetDevice(&dev));
62  HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
63  return dev_prop.multiProcessorCount;
64  }();
65  return dim3(num_cu * Problem::LaunchType);
66  }
67  else
68  {
69  const int num_warps = (h.num_rows + Problem::RowsPerWarp - 1) / Problem::RowsPerWarp;
70  const int num_blocks =
71  (num_warps + Problem::WarpsPerBlock - 1) / Problem::WarpsPerBlock;
72  return dim3(num_blocks);
73  }
74  }
75 
76  CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
77  {
78  Kargs k;
79  k.p_input = h.p_input;
80  k.p_output = h.p_output;
81  k.p_indices = h.p_indices;
82  k.num_rows = h.num_rows;
84  k.topk = h.topk;
87  return k;
88  }
89 
90  CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::BlockSize; }
91 
92  CK_TILE_DEVICE void operator()(Kargs kargs) const
93  {
94  index_t block_row_id = static_cast<index_t>(blockIdx.x * Problem::RowsPerBlock);
95 
96  if(block_row_id > kargs.num_rows)
97  return;
98 
99  index_t block_os_inp = __builtin_amdgcn_readfirstlane(block_row_id * kargs.stride_input);
100  index_t block_os_out = __builtin_amdgcn_readfirstlane(block_row_id * kargs.stride_output);
101  index_t num_rows_rem = __builtin_amdgcn_readfirstlane(kargs.num_rows - block_row_id);
102 
103  const auto input_window = [&]() {
104  const InputType* p_input =
105  reinterpret_cast<const InputType*>(kargs.p_input) + block_os_inp;
106 
107  auto tmp = make_naive_tensor_view<address_space_enum::global>(
108  p_input,
109  make_tuple(num_rows_rem, kargs.num_experts),
110  make_tuple(kargs.stride_input, 1),
112  number<1>{});
113 
114  auto view = pad_tensor_view(
115  tmp,
117  sequence<0, 1>{}); // out-most dim no need pad(leverage oob)
118 
119  return make_tile_window(
120  view,
122  {0, 0});
123  }();
124 
125  auto output_window = [&]() {
126  WeightType* p_output = reinterpret_cast<WeightType*>(kargs.p_output) + block_os_out;
127  auto tmp = make_naive_tensor_view<address_space_enum::global>(
128  p_output,
129  make_tuple(num_rows_rem, kargs.topk),
130  make_tuple(kargs.stride_output, 1),
132  number<1>{});
133  auto view =
134  pad_tensor_view(tmp,
136  sequence<0, 0>{}); // 1. out-most dim no need pad(leverage oob)
137  // 2. we loop over topk 1-1, no need padding
138  return make_tile_window(
140  }();
141 
142  auto indices_window = [&]() {
143  IndexType* p_indices = reinterpret_cast<IndexType*>(kargs.p_indices) + block_os_out;
144  auto tmp = make_naive_tensor_view<address_space_enum::global>(
145  p_indices,
146  make_tuple(num_rows_rem, kargs.topk),
147  make_tuple(kargs.stride_output, 1),
149  number<1>{});
150  auto view =
151  pad_tensor_view(tmp,
153  sequence<0, 0>{}); // 1. out-most dim no need pad(leverage oob)
154  // 2. we loop over topk 1-1, no need padding
155  return make_tile_window(
157  }();
158 
159  Pipeline{}(input_window,
160  output_window,
161  indices_window,
162  kargs.num_rows,
163  kargs.num_experts,
164  kargs.topk,
165  block_row_id);
166  }
167 };
168 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
#define HIP_CHECK_ERROR(retval_or_funcall)
Definition: hip_check_error.hpp:21
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
Definition: topk_softmax_kernel.hpp:16
index_t num_experts
Definition: topk_softmax_kernel.hpp:21
index_t topk
Definition: topk_softmax_kernel.hpp:22
index_t stride_output
Definition: topk_softmax_kernel.hpp:24
const void * p_input
Definition: topk_softmax_kernel.hpp:17
index_t num_rows
Definition: topk_softmax_kernel.hpp:20
void * p_indices
Definition: topk_softmax_kernel.hpp:19
index_t stride_input
Definition: topk_softmax_kernel.hpp:23
void * p_output
Definition: topk_softmax_kernel.hpp:18
Definition: topk_softmax_kernel.hpp:40
const void * p_input
Definition: topk_softmax_kernel.hpp:41
void * p_output
Definition: topk_softmax_kernel.hpp:42
index_t stride_output
Definition: topk_softmax_kernel.hpp:48
index_t stride_input
Definition: topk_softmax_kernel.hpp:47
index_t num_rows
Definition: topk_softmax_kernel.hpp:44
void * p_indices
Definition: topk_softmax_kernel.hpp:43
index_t topk
Definition: topk_softmax_kernel.hpp:46
index_t num_experts
Definition: topk_softmax_kernel.hpp:45
Definition: topk_softmax_kernel.hpp:29
static constexpr index_t kBlockSize
Definition: topk_softmax_kernel.hpp:37
remove_cvref_t< typename Pipeline::Problem > Problem
Definition: topk_softmax_kernel.hpp:31
remove_cvref_t< Pipeline_ > Pipeline
Definition: topk_softmax_kernel.hpp:30
typename Problem::InputType InputType
Definition: topk_softmax_kernel.hpp:33
static constexpr CK_TILE_HOST auto MakeKargs(const Hargs &h)
Definition: topk_softmax_kernel.hpp:76
static constexpr CK_TILE_HOST auto GridSize(const Hargs &h)
Definition: topk_softmax_kernel.hpp:54
static constexpr CK_TILE_HOST_DEVICE auto BlockSize()
Definition: topk_softmax_kernel.hpp:90
typename Problem::WeightType WeightType
Definition: topk_softmax_kernel.hpp:34
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: topk_softmax_kernel.hpp:92
typename Problem::IndexType IndexType
Definition: topk_softmax_kernel.hpp:35
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49