/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp Source File
rmsnorm2d_fwd_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
9 
10 namespace ck_tile {
11 
12 // host side args
14 {
15  const void* p_x; // [m ,n], input, fp16/bf16
16  const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
17  const void* p_sm_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
18  const void* p_gamma; // [1, n], gamma, prec same as input
19 
20  void* p_y; // [m, n], output, fp16/bf16
21  void* p_y_residual; // [m, n], shortcut output, prec same as input, nullptr if not used
22  void* p_y_scale; // [m, 1], output a dynamic quant per row, nullptr if not used
23  void* p_invRms; // [m, 1], output inv-rms, prec same as input, nullptr if not used
24  void* p_y_unquant; // [m, n], output result before quant, nullptr if not used
25 
26  float epsilon;
27 
30  index_t x_stride; // x row_stride
31  index_t xr_stride; // x residule row stride
32  index_t y_stride; // y row stride
33  index_t yr_stride; // y residule row stride
34 };
35 
36 // TODO: Extract some type to wrapper class
37 template <typename Pipeline_, typename Epilogue_>
39 {
42  using Problem = typename Pipeline::Problem;
43 
52 
53  // for simplicity, shortcut input/output type is same as X
56 
57  static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, null_type>;
58  static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
59  static constexpr bool kSaveUnquant = Problem::Traits::kSaveUnquant;
60 
61  static constexpr index_t Block_M = Problem::BlockShape::Block_M;
62  static constexpr index_t Block_N = Problem::BlockShape::Block_N;
63  static constexpr bool kPadM = false; // always no need to pad along M
64  static constexpr bool kPadN = Problem::Traits::kPadN;
65  static constexpr bool kTwoPass = Problem::Traits::kTwoPass;
66  static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
67  static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
68  static constexpr auto kUseModelSensitiveRMSNorm = Problem::Traits::kUseModelSensitiveRMSNorm;
69 
70  static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
71  static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
72  static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
73  static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
74 
75  static constexpr auto I0 = number<0>{};
76  static constexpr auto I1 = number<1>{};
77 
78  struct Kargs
79  {
80  const void* p_x;
81  const void* p_x_residual;
82  const void* p_sm_scale;
83  const void* p_gamma;
84 
85  void* p_y;
86  void* p_y_residual;
87  void* p_y_scale;
88  void* p_invRms;
89  void* p_y_unquant;
90 
91  float epsilon;
92 
95  index_t x_stride; // x row_stride
96  index_t xr_stride; // x residule row stride
97  index_t y_stride; // y row stride
98  index_t yr_stride; // y residule row stride
99  };
101 
102  CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
103  {
104  return Kargs{hargs.p_x,
105  hargs.p_x_residual,
106  hargs.p_sm_scale,
107  hargs.p_gamma,
108  hargs.p_y,
109  hargs.p_y_residual,
110  hargs.p_y_scale,
111  hargs.p_invRms,
112  hargs.p_y_unquant,
113  hargs.epsilon,
114  hargs.m,
115  hargs.n,
116  hargs.x_stride,
117  hargs.xr_stride,
118  hargs.y_stride,
119  hargs.yr_stride};
120  }
121 
122  CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
123  {
124  return dim3(integer_divide_ceil(hargs.m, Block_M));
125  }
126 
127  CK_TILE_HOST static constexpr auto BlockSize()
128  {
129  return is_wave32() ? Problem::BlockShape::template GetBlockSize<true>()
130  : Problem::BlockShape::template GetBlockSize<false>();
131  }
132 
133  // clang-format off
134  template <typename T> struct t2s;
135  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
136  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
137  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
138  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
139  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
140  template <> struct t2s<ck_tile::int8_t> { static constexpr const char * name = "int8"; };
141  // clang-format on
142 
143  // in byte
144  CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
145 
146  CK_TILE_HOST static std::string GetName()
147  {
148 #define _SS_ std::string
149 #define _TS_ std::to_string
150  // clang-format off
151  using S_ = typename Problem::BlockShape;
152  auto surfix = [&] () {
153  std::string n;
156  if (kPadN) n += "_pn";
157  if (kSaveInvRms) n += "_rms";
158  if (kTwoPass) n += "_2p";
161  return n; }();
162 
163  auto prec_str = [&] () {
164  std::string base_str = _SS_(t2s<XDataType>::name);
165  if (!std::is_same_v<XDataType, YDataType>) {
166  base_str += _SS_("_") + _SS_(t2s<YDataType>::name);
167  }
169  base_str += _SS_("_sx") + _SS_(t2s<SmoothScaleDataType>::name);
170  base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
171  }
173  base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
174  }
175  return base_str;
176  }();
177 
178  return _SS_("rmsnorm2d_fwd_") + _SS_(prec_str) + "_" +
179  _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
180  _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
181  _SS_(Pipeline::name) + surfix;
182  // clang-format on
183 #undef _SS_
184 #undef _TS_
185  }
186 
187  CK_TILE_DEVICE void operator()(Kargs kargs) const
188  {
189  const auto iM = get_block_id() * Block_M;
190 
191  const auto x_window = [&]() {
192  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
193  static_cast<const XDataType*>(kargs.p_x),
194  make_tuple(kargs.m, kargs.n),
195  make_tuple(kargs.x_stride, 1),
197  number<1>{});
198 
199  const auto tmp2_ = pad_tensor_view(
201  return make_tile_window(
202  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
203  }();
204 
205  const auto x_residual_window = [&]() {
206  if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD ||
208  {
209  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
210  static_cast<const XResidualDataType*>(kargs.p_x_residual),
211  make_tuple(kargs.m, kargs.n),
212  make_tuple(kargs.xr_stride, 1),
214  number<1>{});
215 
216  const auto tmp2_ = pad_tensor_view(tmp_,
219  return make_tile_window(
220  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
221  }
222  else
223  {
225  }
226  }();
227 
228  const auto gamma_window = [&]() {
229  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
230  static_cast<const GammaDataType*>(kargs.p_gamma),
231  make_tuple(kargs.n),
232  make_tuple(1),
234  number<1>{});
235 
236  const auto tmp2_ =
238 
239  return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
240  }();
241 
242  auto y_window = [&]() {
243  auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
244  static_cast<YDataType*>(kargs.p_y),
245  make_tuple(kargs.m, kargs.n),
246  make_tuple(kargs.y_stride, 1),
248  number<1>{});
249 
250  auto tmp2_ = pad_tensor_view(
252  return make_tile_window(
253  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
254  }();
255 
256  auto y_residual_window = [&]() {
258  {
259  auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
260  static_cast<YResidualDataType*>(kargs.p_y_residual),
261  make_tuple(kargs.m, kargs.n),
262  make_tuple(kargs.yr_stride, 1),
264  number<1>{});
265 
266  auto tmp2_ = pad_tensor_view(tmp_,
269  return make_tile_window(
270  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
271  }
272  else
273  {
275  }
276  }();
277 
278  auto inv_rms_window = [&]() {
279  if constexpr(kSaveInvRms)
280  {
281  const auto inv_rms_m = [&]() {
282  const auto inv_rms_dram_naive =
283  make_naive_tensor_view_packed<address_space_enum::global>(
284  static_cast<InvRmsDataType*>(kargs.p_invRms),
285  make_tuple(kargs.m),
286  number<1>{});
287 
288  return pad_tensor_view(
289  inv_rms_dram_naive, make_tuple(number<Block_M>{}), sequence<kPadM>{});
290  }();
291  return make_tile_window(inv_rms_m, make_tuple(number<Block_M>{}), {iM});
292  }
293  else
295  }();
296 
297  auto sm_scale_window = [&]() {
299  {
300  const auto win_ = [&]() {
301  const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
302  static_cast<const SmoothScaleDataType*>(kargs.p_sm_scale),
303  make_tuple(kargs.n),
304  number<Vector_N>{});
305 
306  return pad_tensor_view(tmp_0_,
308  sequence<false>{}); // sm_scale no need pad
309  }();
310  return make_tile_window(win_, make_tuple(number<Block_N>{}), {0});
311  }
312  else
313  {
315  }
316  }();
317 
318  auto y_scale_window = [&]() {
321  {
322  const auto win_ = [&]() {
323  const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
324  static_cast<YScaleDataType*>(kargs.p_y_scale),
325  make_tuple(kargs.m),
326  number<1>{});
327 
328  return pad_tensor_view(
330  }();
331  return make_tile_window(win_, make_tuple(number<Block_M>{}), {iM});
332  }
333  else
334  {
336  }
337  }();
338 
339  auto unquant_y_window = [&]() {
342  kSaveUnquant)
343  {
344  auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
345  static_cast<UnquantYDataType*>(kargs.p_y_unquant),
346  make_tuple(kargs.m, kargs.n),
347  make_tuple(kargs.y_stride, 1),
349  number<1>{});
350 
351  auto tmp2_ = pad_tensor_view(tmp_,
354  return make_tile_window(
355  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
356  }
357  else
358  {
360  }
361  }();
362 
363  __shared__ char smem[GetSmemSize()];
364 
365  Pipeline{}(x_window,
366  x_residual_window,
367  gamma_window,
368  y_window,
369  y_residual_window,
370  inv_rms_window,
371  sm_scale_window,
372  y_scale_window,
373  unquant_y_window,
374  static_cast<const ComputeDataType>(kargs.epsilon),
375  kargs.n,
376  smem,
377  Epilogue{});
378  }
379 };
380 
381 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
_Float16 fp16_t
Definition: half.hpp:110
int8_t int8_t
Definition: int8.hpp:20
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
#define _TS_
#define _SS_
Definition: rmsnorm2d_fwd_traits.hpp:20
Definition: rmsnorm2d_fwd_traits.hpp:34
Definition: rmsnorm2d_fwd_kernel.hpp:79
void * p_invRms
Definition: rmsnorm2d_fwd_kernel.hpp:88
void * p_y_scale
Definition: rmsnorm2d_fwd_kernel.hpp:87
index_t n
Definition: rmsnorm2d_fwd_kernel.hpp:94
const void * p_x
Definition: rmsnorm2d_fwd_kernel.hpp:80
index_t yr_stride
Definition: rmsnorm2d_fwd_kernel.hpp:98
index_t y_stride
Definition: rmsnorm2d_fwd_kernel.hpp:97
void * p_y
Definition: rmsnorm2d_fwd_kernel.hpp:85
index_t xr_stride
Definition: rmsnorm2d_fwd_kernel.hpp:96
const void * p_sm_scale
Definition: rmsnorm2d_fwd_kernel.hpp:82
void * p_y_residual
Definition: rmsnorm2d_fwd_kernel.hpp:86
void * p_y_unquant
Definition: rmsnorm2d_fwd_kernel.hpp:89
index_t m
Definition: rmsnorm2d_fwd_kernel.hpp:93
const void * p_gamma
Definition: rmsnorm2d_fwd_kernel.hpp:83
float epsilon
Definition: rmsnorm2d_fwd_kernel.hpp:91
const void * p_x_residual
Definition: rmsnorm2d_fwd_kernel.hpp:81
index_t x_stride
Definition: rmsnorm2d_fwd_kernel.hpp:95
Definition: rmsnorm2d_fwd_kernel.hpp:134
Definition: rmsnorm2d_fwd_kernel.hpp:14
void * p_invRms
Definition: rmsnorm2d_fwd_kernel.hpp:23
index_t xr_stride
Definition: rmsnorm2d_fwd_kernel.hpp:31
void * p_y_residual
Definition: rmsnorm2d_fwd_kernel.hpp:21
const void * p_x_residual
Definition: rmsnorm2d_fwd_kernel.hpp:16
void * p_y_scale
Definition: rmsnorm2d_fwd_kernel.hpp:22
float epsilon
Definition: rmsnorm2d_fwd_kernel.hpp:26
void * p_y
Definition: rmsnorm2d_fwd_kernel.hpp:20
index_t yr_stride
Definition: rmsnorm2d_fwd_kernel.hpp:33
void * p_y_unquant
Definition: rmsnorm2d_fwd_kernel.hpp:24
index_t x_stride
Definition: rmsnorm2d_fwd_kernel.hpp:30
index_t y_stride
Definition: rmsnorm2d_fwd_kernel.hpp:32
index_t m
Definition: rmsnorm2d_fwd_kernel.hpp:28
index_t n
Definition: rmsnorm2d_fwd_kernel.hpp:29
const void * p_sm_scale
Definition: rmsnorm2d_fwd_kernel.hpp:17
const void * p_x
Definition: rmsnorm2d_fwd_kernel.hpp:15
const void * p_gamma
Definition: rmsnorm2d_fwd_kernel.hpp:18
Definition: rmsnorm2d_fwd_kernel.hpp:39
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: rmsnorm2d_fwd_kernel.hpp:187
XDataType XResidualDataType
Definition: rmsnorm2d_fwd_kernel.hpp:54
remove_cvref_t< typename Problem::UnquantYDataType > UnquantYDataType
Definition: rmsnorm2d_fwd_kernel.hpp:51
remove_cvref_t< Epilogue_ > Epilogue
Definition: rmsnorm2d_fwd_kernel.hpp:41
remove_cvref_t< typename Problem::YScaleDataType > YScaleDataType
Definition: rmsnorm2d_fwd_kernel.hpp:50
static constexpr bool kTwoPass
Definition: rmsnorm2d_fwd_kernel.hpp:65
static constexpr bool kSaveInvRms
Definition: rmsnorm2d_fwd_kernel.hpp:58
static constexpr CK_TILE_HOST auto GridSize(const Hargs &hargs)
Definition: rmsnorm2d_fwd_kernel.hpp:122
static constexpr auto I0
Definition: rmsnorm2d_fwd_kernel.hpp:75
typename Pipeline::Problem Problem
Definition: rmsnorm2d_fwd_kernel.hpp:42
remove_cvref_t< typename Problem::InvRmsDataType > InvRmsDataType
Definition: rmsnorm2d_fwd_kernel.hpp:48
static constexpr bool kPadN
Definition: rmsnorm2d_fwd_kernel.hpp:64
static CK_TILE_HOST std::string GetName()
Definition: rmsnorm2d_fwd_kernel.hpp:146
remove_cvref_t< typename Problem::YDataType > YDataType
Definition: rmsnorm2d_fwd_kernel.hpp:47
static constexpr auto kFusedQuant
Definition: rmsnorm2d_fwd_kernel.hpp:67
remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition: rmsnorm2d_fwd_kernel.hpp:46
remove_cvref_t< Pipeline_ > Pipeline
Definition: rmsnorm2d_fwd_kernel.hpp:40
static constexpr auto I1
Definition: rmsnorm2d_fwd_kernel.hpp:76
static constexpr bool kPadM
Definition: rmsnorm2d_fwd_kernel.hpp:63
remove_cvref_t< typename Problem::SmoothScaleDataType > SmoothScaleDataType
Definition: rmsnorm2d_fwd_kernel.hpp:49
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: rmsnorm2d_fwd_kernel.hpp:144
static constexpr index_t Block_M
Definition: rmsnorm2d_fwd_kernel.hpp:61
static constexpr auto kFusedAdd
Definition: rmsnorm2d_fwd_kernel.hpp:66
XDataType YResidualDataType
Definition: rmsnorm2d_fwd_kernel.hpp:55
static constexpr bool kHasGamma
Definition: rmsnorm2d_fwd_kernel.hpp:57
remove_cvref_t< typename Problem::XDataType > XDataType
Definition: rmsnorm2d_fwd_kernel.hpp:44
static constexpr index_t kBlockSize
Definition: rmsnorm2d_fwd_kernel.hpp:73
static constexpr index_t ThreadPerWarp_N
Definition: rmsnorm2d_fwd_kernel.hpp:70
static constexpr CK_TILE_HOST auto BlockSize()
Definition: rmsnorm2d_fwd_kernel.hpp:127
static constexpr CK_TILE_HOST Kargs MakeKargs(const Hargs &hargs)
Definition: rmsnorm2d_fwd_kernel.hpp:102
remove_cvref_t< typename Problem::GammaDataType > GammaDataType
Definition: rmsnorm2d_fwd_kernel.hpp:45
static constexpr index_t Block_N
Definition: rmsnorm2d_fwd_kernel.hpp:62
static constexpr bool kSaveUnquant
Definition: rmsnorm2d_fwd_kernel.hpp:59
static constexpr index_t Repeat_N
Definition: rmsnorm2d_fwd_kernel.hpp:72
static constexpr auto kUseModelSensitiveRMSNorm
Definition: rmsnorm2d_fwd_kernel.hpp:68
static constexpr index_t Vector_N
Definition: rmsnorm2d_fwd_kernel.hpp:71
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49