/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp Source File
moe_smoothquant_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
8 
9 namespace ck_tile {
10 
11 // host side args
13 {
14  const void* p_x; // [tokens ,hidden_size], input, fp16/bf16
15  const void* p_smscale; // [experts, hidden_size], input, columnwise scale, fp32
16  const void* p_topk_ids; // [tokens, topk]
17 
18  void* p_yscale; // [topk * tokens, 1], output, rowwise quant scale
19  void* p_qy; // [topk * tokens, hidden_size], output
20 
25  index_t x_stride; // input x row stride
26  index_t y_stride; // output y stride(stride for topk)
27 };
28 
29 // TODO: Extract some type to wrapper class
30 template <typename Pipeline_>
32 {
34  using Problem = typename Pipeline::Problem;
35 
41 
42  static constexpr index_t Block_M = Problem::BlockShape::Block_M;
43  static constexpr index_t Block_N = Problem::BlockShape::Block_N;
44  static constexpr bool kPadM = false; // always no need to pad along M
45  static constexpr bool kPadN = Problem::kPadN;
46  static constexpr bool kTwoPass = Problem::kTwoPass;
47 
48  static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
49  static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
50  static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
51  static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
52 
53  static constexpr auto I0 = number<0>{};
54  static constexpr auto I1 = number<1>{};
55 
56  static_assert(Problem::BlockShape::Repeat_M == 1);
57 
58  struct Kargs
59  {
60  const void* p_x; // [tokens ,hidden_size], input, fp16/bf16
61  const void* p_smscale; // [experts, hidden_size], input, columnwise scale, fp32
62  const void* p_topk_ids; // [tokens, topk]
63 
64  void* p_yscale; // [topk, tokens, 1], output, rowwise quant scale
65  void* p_qy; // [topk, tokens, hidden_size], output
66 
71  index_t x_stride; // input x row stride
72  index_t y_stride; // output y stride(stride for topk)
73  };
75 
76  CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
77  {
78  return Kargs{hargs.p_x,
79  hargs.p_smscale,
80  hargs.p_topk_ids,
81  hargs.p_yscale,
82  hargs.p_qy,
83  hargs.tokens,
84  hargs.hidden_size,
85  hargs.experts,
86  hargs.topk,
87  hargs.x_stride,
88  hargs.y_stride};
89  }
90 
91  CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
92  {
93  return dim3(hargs.topk, integer_divide_ceil(hargs.tokens, Block_M), 1);
94  }
95 
96  CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; }
97 
98  // clang-format off
99  template <typename T> struct t2s;
100  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
101  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
102  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
103  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
104  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
105  template <> struct t2s<ck_tile::int8_t> { static constexpr const char * name = "i8"; };
106  // clang-format on
107 
108  // in byte
109  CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
110 
111  CK_TILE_HOST static std::string GetName()
112  {
113  // clang-format off
114  using S_ = typename Problem::BlockShape;
115  auto surfix = [&] () {
116  std::string n;
117  if (kPadN) n += "_pn";
118  if (kTwoPass) n += "_2p";
119  return n; }();
120 
121  #define _SS_ std::string
122  #define _TS_ std::to_string
123  return _SS_("moe_smoothquant_") + _SS_(t2s<XDataType>::name) + "_" + _SS_(t2s<QYDataType>::name) + "_" +
124  _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
125  _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
126  _SS_(Pipeline::name) + surfix;
127  #undef _SS_
128  #undef _TS_
129  // clang-format on
130  }
131 
132  CK_TILE_DEVICE void operator()(Kargs kargs) const
133  {
134  const index_t i_topk = blockIdx.x;
135  const index_t i_token = blockIdx.y * Block_M;
136  const index_t i_token_in_thrd =
137  __builtin_amdgcn_readfirstlane(threadIdx.x / Problem::BlockShape::ThreadPerBlock_N);
138 
139  const index_t i_expert = reinterpret_cast<const index_t*>(
140  kargs.p_topk_ids)[(i_token + i_token_in_thrd) * kargs.topk + i_topk];
141 
142  // [tokens ,hidden_size]
143  const auto x_window = [&]() {
144  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
145  static_cast<const XDataType*>(kargs.p_x),
146  make_tuple(kargs.tokens, kargs.hidden_size),
147  make_tuple(kargs.x_stride, 1),
149  number<1>{});
150 
151  const auto tmp2_ = pad_tensor_view(
153  return make_tile_window(
154  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {i_token, 0});
155  }();
156 
157  // [experts, hidden_size],
158  const auto smscale_window = [&]() {
159  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
160  static_cast<const SmoothScaleDataType*>(kargs.p_smscale) +
161  i_expert * kargs.hidden_size,
162  make_tuple(kargs.hidden_size),
163  make_tuple(1),
165  number<1>{});
166 
167  const auto tmp2_ =
169 
170  return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
171  }();
172 
173  // [topk, tokens]
174  auto yscale_window = [&]() {
175  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
176  static_cast<YScaleDataType*>(kargs.p_yscale) + i_topk * kargs.tokens,
177  make_tuple(kargs.tokens),
178  make_tuple(1),
179  number<1>{});
180 
181  const auto tmp2_ =
183 
184  return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {i_token});
185  }();
186 
187  // [topk, tokens, hidden_size]
188  auto qy_window = [&]() {
189  auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
190  static_cast<QYDataType*>(kargs.p_qy) + i_topk * kargs.tokens * kargs.y_stride,
191  make_tuple(kargs.tokens, kargs.hidden_size),
192  make_tuple(kargs.y_stride, 1),
194  number<1>{});
195 
196  auto tmp2_ = pad_tensor_view(
198  return make_tile_window(
199  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {i_token, 0});
200  }();
201 
202  __shared__ char smem[GetSmemSize()];
203 
204  Pipeline{}(x_window, smscale_window, yscale_window, qy_window, kargs.hidden_size, smem);
205  }
206 };
207 
208 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
#define _TS_
#define _SS_
Definition: cluster_descriptor.hpp:13
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
_Float16 fp16_t
Definition: half.hpp:110
int8_t int8_t
Definition: int8.hpp:20
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
Definition: moe_smoothquant_kernel.hpp:59
index_t y_stride
Definition: moe_smoothquant_kernel.hpp:72
const void * p_x
Definition: moe_smoothquant_kernel.hpp:60
index_t tokens
Definition: moe_smoothquant_kernel.hpp:67
index_t topk
Definition: moe_smoothquant_kernel.hpp:70
index_t x_stride
Definition: moe_smoothquant_kernel.hpp:71
void * p_yscale
Definition: moe_smoothquant_kernel.hpp:64
const void * p_smscale
Definition: moe_smoothquant_kernel.hpp:61
void * p_qy
Definition: moe_smoothquant_kernel.hpp:65
index_t experts
Definition: moe_smoothquant_kernel.hpp:69
const void * p_topk_ids
Definition: moe_smoothquant_kernel.hpp:62
index_t hidden_size
Definition: moe_smoothquant_kernel.hpp:68
Definition: moe_smoothquant_kernel.hpp:99
Definition: moe_smoothquant_kernel.hpp:13
index_t x_stride
Definition: moe_smoothquant_kernel.hpp:25
index_t topk
Definition: moe_smoothquant_kernel.hpp:24
index_t hidden_size
Definition: moe_smoothquant_kernel.hpp:22
void * p_yscale
Definition: moe_smoothquant_kernel.hpp:18
index_t experts
Definition: moe_smoothquant_kernel.hpp:23
index_t y_stride
Definition: moe_smoothquant_kernel.hpp:26
index_t tokens
Definition: moe_smoothquant_kernel.hpp:21
const void * p_topk_ids
Definition: moe_smoothquant_kernel.hpp:16
const void * p_smscale
Definition: moe_smoothquant_kernel.hpp:15
void * p_qy
Definition: moe_smoothquant_kernel.hpp:19
const void * p_x
Definition: moe_smoothquant_kernel.hpp:14
Definition: moe_smoothquant_kernel.hpp:32
static constexpr bool kTwoPass
Definition: moe_smoothquant_kernel.hpp:46
remove_cvref_t< typename Problem::SmoothScaleDataType > SmoothScaleDataType
Definition: moe_smoothquant_kernel.hpp:37
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: moe_smoothquant_kernel.hpp:132
remove_cvref_t< typename Problem::YScaleDataType > YScaleDataType
Definition: moe_smoothquant_kernel.hpp:39
static constexpr bool kPadM
Definition: moe_smoothquant_kernel.hpp:44
static constexpr CK_TILE_HOST Kargs MakeKargs(const Hargs &hargs)
Definition: moe_smoothquant_kernel.hpp:76
static constexpr auto I0
Definition: moe_smoothquant_kernel.hpp:53
static constexpr bool kPadN
Definition: moe_smoothquant_kernel.hpp:45
remove_cvref_t< typename Problem::QYDataType > QYDataType
Definition: moe_smoothquant_kernel.hpp:40
remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition: moe_smoothquant_kernel.hpp:38
remove_cvref_t< Pipeline_ > Pipeline
Definition: moe_smoothquant_kernel.hpp:33
remove_cvref_t< typename Problem::XDataType > XDataType
Definition: moe_smoothquant_kernel.hpp:36
static constexpr index_t Vector_N
Definition: moe_smoothquant_kernel.hpp:49
static constexpr CK_TILE_HOST auto BlockSize()
Definition: moe_smoothquant_kernel.hpp:96
static constexpr index_t Block_N
Definition: moe_smoothquant_kernel.hpp:43
static CK_TILE_HOST std::string GetName()
Definition: moe_smoothquant_kernel.hpp:111
typename Pipeline::Problem Problem
Definition: moe_smoothquant_kernel.hpp:34
static constexpr CK_TILE_HOST auto GridSize(const Hargs &hargs)
Definition: moe_smoothquant_kernel.hpp:91
static constexpr index_t Repeat_N
Definition: moe_smoothquant_kernel.hpp:50
static constexpr index_t ThreadPerWarp_N
Definition: moe_smoothquant_kernel.hpp:48
static constexpr index_t kBlockSize
Definition: moe_smoothquant_kernel.hpp:51
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: moe_smoothquant_kernel.hpp:109
static constexpr auto I1
Definition: moe_smoothquant_kernel.hpp:54
static constexpr index_t Block_M
Definition: moe_smoothquant_kernel.hpp:42
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49