/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp Source File
rmsnorm2d_fwd_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
9 
10 namespace ck_tile {
11 
12 // host side args
14 {
15  const void* p_x; // [m ,n], input, fp16/bf16
16  const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
17  const void* p_sm_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
18  const void* p_gamma; // [1, n], gamma, prec same as input
19 
20  void* p_y; // [m, n], output, fp16/bf16
21  void* p_y_residual; // [m, n], shortcut output, prec same as input, nullptr if not used
22  void* p_y_scale; // [m, 1], output a dynamic quant per row, nullptr if not used
23  void* p_invRms; // [m, 1], output inv-rms, prec same as input, nullptr if not used
24  void* p_y_unquant; // [m, n], output result before quant, nullptr if not used
25 
26  float epsilon;
27 
30  index_t x_stride; // x row_stride
31  index_t xr_stride; // x residule row stride
32  index_t y_stride; // y row stride
33  index_t yr_stride; // y residule row stride
34 };
35 
36 // TODO: Extract some type to wrapper class
37 template <typename Pipeline_, typename Epilogue_>
39 {
42  using Problem = typename Pipeline::Problem;
43 
52 
53  // for simplicity, shortcut input/output type is same as X
56 
57  static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, null_type>;
58  static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
59  static constexpr bool kSaveUnquant = Problem::Traits::kSaveUnquant;
60 
61  static constexpr index_t Block_M = Problem::BlockShape::Block_M;
62  static constexpr index_t Block_N = Problem::BlockShape::Block_N;
63  static constexpr bool kPadM = false; // always no need to pad along M
64  static constexpr bool kPadN = Problem::Traits::kPadN;
65  static constexpr bool kTwoPass = Problem::Traits::kTwoPass;
66  static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
67  static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
68  static constexpr auto kUseModelSensitiveRMSNorm = Problem::Traits::kUseModelSensitiveRMSNorm;
69 
70  static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
71  static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
72  static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
73  static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
74 
75  static constexpr auto I0 = number<0>{};
76  static constexpr auto I1 = number<1>{};
77 
78  struct Kargs
79  {
80  const void* p_x;
81  const void* p_x_residual;
82  const void* p_sm_scale;
83  const void* p_gamma;
84 
85  void* p_y;
86  void* p_y_residual;
87  void* p_y_scale;
88  void* p_invRms;
89  void* p_y_unquant;
90 
91  float epsilon;
92 
95  index_t x_stride; // x row_stride
96  index_t xr_stride; // x residule row stride
97  index_t y_stride; // y row stride
98  index_t yr_stride; // y residule row stride
99  };
101 
102  CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
103  {
104  return Kargs{hargs.p_x,
105  hargs.p_x_residual,
106  hargs.p_sm_scale,
107  hargs.p_gamma,
108  hargs.p_y,
109  hargs.p_y_residual,
110  hargs.p_y_scale,
111  hargs.p_invRms,
112  hargs.p_y_unquant,
113  hargs.epsilon,
114  hargs.m,
115  hargs.n,
116  hargs.x_stride,
117  hargs.xr_stride,
118  hargs.y_stride,
119  hargs.yr_stride};
120  }
121 
122  CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
123  {
124  return dim3(integer_divide_ceil(hargs.m, Block_M));
125  }
126 
127  CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; }
128 
129  // clang-format off
130  template <typename T> struct t2s;
131  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
132  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
133  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
134  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
135  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
136  template <> struct t2s<ck_tile::int8_t> { static constexpr const char * name = "int8"; };
137  // clang-format on
138 
139  // in byte
140  CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
141 
142  CK_TILE_HOST static std::string GetName()
143  {
144 #define _SS_ std::string
145 #define _TS_ std::to_string
146  // clang-format off
147  using S_ = typename Problem::BlockShape;
148  auto surfix = [&] () {
149  std::string n;
152  if (kPadN) n += "_pn";
153  if (kSaveInvRms) n += "_rms";
154  if (kTwoPass) n += "_2p";
157  return n; }();
158 
159  auto prec_str = [&] () {
160  std::string base_str = _SS_(t2s<XDataType>::name);
161  if (!std::is_same_v<XDataType, YDataType>) {
162  base_str += _SS_("_") + _SS_(t2s<YDataType>::name);
163  }
165  base_str += _SS_("_sx") + _SS_(t2s<SmoothScaleDataType>::name);
166  base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
167  }
169  base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
170  }
171  return base_str;
172  }();
173 
174  return _SS_("rmsnorm2d_fwd_") + _SS_(prec_str) + "_" +
175  _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
176  _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
177  _SS_(Pipeline::name) + surfix;
178  // clang-format on
179 #undef _SS_
180 #undef _TS_
181  }
182 
183  CK_TILE_DEVICE void operator()(Kargs kargs) const
184  {
185  const auto iM = get_block_id() * Block_M;
186 
187  const auto x_window = [&]() {
188  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
189  static_cast<const XDataType*>(kargs.p_x),
190  make_tuple(kargs.m, kargs.n),
191  make_tuple(kargs.x_stride, 1),
193  number<1>{});
194 
195  const auto tmp2_ = pad_tensor_view(
197  return make_tile_window(
198  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
199  }();
200 
201  const auto x_residual_window = [&]() {
202  if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD ||
204  {
205  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
206  static_cast<const XResidualDataType*>(kargs.p_x_residual),
207  make_tuple(kargs.m, kargs.n),
208  make_tuple(kargs.xr_stride, 1),
210  number<1>{});
211 
212  const auto tmp2_ = pad_tensor_view(tmp_,
215  return make_tile_window(
216  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
217  }
218  else
219  {
221  }
222  }();
223 
224  const auto gamma_window = [&]() {
225  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
226  static_cast<const GammaDataType*>(kargs.p_gamma),
227  make_tuple(kargs.n),
228  make_tuple(1),
230  number<1>{});
231 
232  const auto tmp2_ =
234 
235  return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
236  }();
237 
238  auto y_window = [&]() {
239  auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
240  static_cast<YDataType*>(kargs.p_y),
241  make_tuple(kargs.m, kargs.n),
242  make_tuple(kargs.y_stride, 1),
244  number<1>{});
245 
246  auto tmp2_ = pad_tensor_view(
248  return make_tile_window(
249  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
250  }();
251 
252  auto y_residual_window = [&]() {
254  {
255  auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
256  static_cast<YResidualDataType*>(kargs.p_y_residual),
257  make_tuple(kargs.m, kargs.n),
258  make_tuple(kargs.yr_stride, 1),
260  number<1>{});
261 
262  auto tmp2_ = pad_tensor_view(tmp_,
265  return make_tile_window(
266  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
267  }
268  else
269  {
271  }
272  }();
273 
274  auto inv_rms_window = [&]() {
275  if constexpr(kSaveInvRms)
276  {
277  const auto inv_rms_m = [&]() {
278  const auto inv_rms_dram_naive =
279  make_naive_tensor_view_packed<address_space_enum::global>(
280  static_cast<InvRmsDataType*>(kargs.p_invRms),
281  make_tuple(kargs.m),
282  number<1>{});
283 
284  return pad_tensor_view(
285  inv_rms_dram_naive, make_tuple(number<Block_M>{}), sequence<kPadM>{});
286  }();
287  return make_tile_window(inv_rms_m, make_tuple(number<Block_M>{}), {iM});
288  }
289  else
291  }();
292 
293  auto sm_scale_window = [&]() {
295  {
296  const auto win_ = [&]() {
297  const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
298  static_cast<const SmoothScaleDataType*>(kargs.p_sm_scale),
299  make_tuple(kargs.n),
300  number<Vector_N>{});
301 
302  return pad_tensor_view(tmp_0_,
304  sequence<false>{}); // sm_scale no need pad
305  }();
306  return make_tile_window(win_, make_tuple(number<Block_N>{}), {0});
307  }
308  else
309  {
311  }
312  }();
313 
314  auto y_scale_window = [&]() {
317  {
318  const auto win_ = [&]() {
319  const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
320  static_cast<YScaleDataType*>(kargs.p_y_scale),
321  make_tuple(kargs.m),
322  number<1>{});
323 
324  return pad_tensor_view(
326  }();
327  return make_tile_window(win_, make_tuple(number<Block_M>{}), {iM});
328  }
329  else
330  {
332  }
333  }();
334 
335  auto unquant_y_window = [&]() {
338  kSaveUnquant)
339  {
340  auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
341  static_cast<UnquantYDataType*>(kargs.p_y_unquant),
342  make_tuple(kargs.m, kargs.n),
343  make_tuple(kargs.y_stride, 1),
345  number<1>{});
346 
347  auto tmp2_ = pad_tensor_view(tmp_,
350  return make_tile_window(
351  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
352  }
353  else
354  {
356  }
357  }();
358 
359  __shared__ char smem[GetSmemSize()];
360 
361  Pipeline{}(x_window,
362  x_residual_window,
363  gamma_window,
364  y_window,
365  y_residual_window,
366  inv_rms_window,
367  sm_scale_window,
368  y_scale_window,
369  unquant_y_window,
370  static_cast<const ComputeDataType>(kargs.epsilon),
371  kargs.n,
372  smem,
373  Epilogue{});
374  }
375 };
376 
377 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
_Float16 fp16_t
Definition: half.hpp:110
int8_t int8_t
Definition: int8.hpp:20
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
#define _TS_
#define _SS_
Definition: rmsnorm2d_fwd_traits.hpp:20
Definition: rmsnorm2d_fwd_traits.hpp:34
Definition: rmsnorm2d_fwd_kernel.hpp:79
void * p_invRms
Definition: rmsnorm2d_fwd_kernel.hpp:88
void * p_y_scale
Definition: rmsnorm2d_fwd_kernel.hpp:87
index_t n
Definition: rmsnorm2d_fwd_kernel.hpp:94
const void * p_x
Definition: rmsnorm2d_fwd_kernel.hpp:80
index_t yr_stride
Definition: rmsnorm2d_fwd_kernel.hpp:98
index_t y_stride
Definition: rmsnorm2d_fwd_kernel.hpp:97
void * p_y
Definition: rmsnorm2d_fwd_kernel.hpp:85
index_t xr_stride
Definition: rmsnorm2d_fwd_kernel.hpp:96
const void * p_sm_scale
Definition: rmsnorm2d_fwd_kernel.hpp:82
void * p_y_residual
Definition: rmsnorm2d_fwd_kernel.hpp:86
void * p_y_unquant
Definition: rmsnorm2d_fwd_kernel.hpp:89
index_t m
Definition: rmsnorm2d_fwd_kernel.hpp:93
const void * p_gamma
Definition: rmsnorm2d_fwd_kernel.hpp:83
float epsilon
Definition: rmsnorm2d_fwd_kernel.hpp:91
const void * p_x_residual
Definition: rmsnorm2d_fwd_kernel.hpp:81
index_t x_stride
Definition: rmsnorm2d_fwd_kernel.hpp:95
Definition: rmsnorm2d_fwd_kernel.hpp:130
Definition: rmsnorm2d_fwd_kernel.hpp:14
void * p_invRms
Definition: rmsnorm2d_fwd_kernel.hpp:23
index_t xr_stride
Definition: rmsnorm2d_fwd_kernel.hpp:31
void * p_y_residual
Definition: rmsnorm2d_fwd_kernel.hpp:21
const void * p_x_residual
Definition: rmsnorm2d_fwd_kernel.hpp:16
void * p_y_scale
Definition: rmsnorm2d_fwd_kernel.hpp:22
float epsilon
Definition: rmsnorm2d_fwd_kernel.hpp:26
void * p_y
Definition: rmsnorm2d_fwd_kernel.hpp:20
index_t yr_stride
Definition: rmsnorm2d_fwd_kernel.hpp:33
void * p_y_unquant
Definition: rmsnorm2d_fwd_kernel.hpp:24
index_t x_stride
Definition: rmsnorm2d_fwd_kernel.hpp:30
index_t y_stride
Definition: rmsnorm2d_fwd_kernel.hpp:32
index_t m
Definition: rmsnorm2d_fwd_kernel.hpp:28
index_t n
Definition: rmsnorm2d_fwd_kernel.hpp:29
const void * p_sm_scale
Definition: rmsnorm2d_fwd_kernel.hpp:17
const void * p_x
Definition: rmsnorm2d_fwd_kernel.hpp:15
const void * p_gamma
Definition: rmsnorm2d_fwd_kernel.hpp:18
Definition: rmsnorm2d_fwd_kernel.hpp:39
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: rmsnorm2d_fwd_kernel.hpp:183
XDataType XResidualDataType
Definition: rmsnorm2d_fwd_kernel.hpp:54
remove_cvref_t< typename Problem::UnquantYDataType > UnquantYDataType
Definition: rmsnorm2d_fwd_kernel.hpp:51
remove_cvref_t< Epilogue_ > Epilogue
Definition: rmsnorm2d_fwd_kernel.hpp:41
remove_cvref_t< typename Problem::YScaleDataType > YScaleDataType
Definition: rmsnorm2d_fwd_kernel.hpp:50
static constexpr bool kTwoPass
Definition: rmsnorm2d_fwd_kernel.hpp:65
static constexpr bool kSaveInvRms
Definition: rmsnorm2d_fwd_kernel.hpp:58
static constexpr CK_TILE_HOST auto GridSize(const Hargs &hargs)
Definition: rmsnorm2d_fwd_kernel.hpp:122
static constexpr auto I0
Definition: rmsnorm2d_fwd_kernel.hpp:75
typename Pipeline::Problem Problem
Definition: rmsnorm2d_fwd_kernel.hpp:42
remove_cvref_t< typename Problem::InvRmsDataType > InvRmsDataType
Definition: rmsnorm2d_fwd_kernel.hpp:48
static constexpr bool kPadN
Definition: rmsnorm2d_fwd_kernel.hpp:64
static CK_TILE_HOST std::string GetName()
Definition: rmsnorm2d_fwd_kernel.hpp:142
remove_cvref_t< typename Problem::YDataType > YDataType
Definition: rmsnorm2d_fwd_kernel.hpp:47
static constexpr auto kFusedQuant
Definition: rmsnorm2d_fwd_kernel.hpp:67
remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition: rmsnorm2d_fwd_kernel.hpp:46
remove_cvref_t< Pipeline_ > Pipeline
Definition: rmsnorm2d_fwd_kernel.hpp:40
static constexpr auto I1
Definition: rmsnorm2d_fwd_kernel.hpp:76
static constexpr bool kPadM
Definition: rmsnorm2d_fwd_kernel.hpp:63
remove_cvref_t< typename Problem::SmoothScaleDataType > SmoothScaleDataType
Definition: rmsnorm2d_fwd_kernel.hpp:49
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: rmsnorm2d_fwd_kernel.hpp:140
static constexpr index_t Block_M
Definition: rmsnorm2d_fwd_kernel.hpp:61
static constexpr auto kFusedAdd
Definition: rmsnorm2d_fwd_kernel.hpp:66
XDataType YResidualDataType
Definition: rmsnorm2d_fwd_kernel.hpp:55
static constexpr bool kHasGamma
Definition: rmsnorm2d_fwd_kernel.hpp:57
remove_cvref_t< typename Problem::XDataType > XDataType
Definition: rmsnorm2d_fwd_kernel.hpp:44
static constexpr index_t kBlockSize
Definition: rmsnorm2d_fwd_kernel.hpp:73
static constexpr index_t ThreadPerWarp_N
Definition: rmsnorm2d_fwd_kernel.hpp:70
static constexpr CK_TILE_HOST auto BlockSize()
Definition: rmsnorm2d_fwd_kernel.hpp:127
static constexpr CK_TILE_HOST Kargs MakeKargs(const Hargs &hargs)
Definition: rmsnorm2d_fwd_kernel.hpp:102
remove_cvref_t< typename Problem::GammaDataType > GammaDataType
Definition: rmsnorm2d_fwd_kernel.hpp:45
static constexpr index_t Block_N
Definition: rmsnorm2d_fwd_kernel.hpp:62
static constexpr bool kSaveUnquant
Definition: rmsnorm2d_fwd_kernel.hpp:59
static constexpr index_t Repeat_N
Definition: rmsnorm2d_fwd_kernel.hpp:72
static constexpr auto kUseModelSensitiveRMSNorm
Definition: rmsnorm2d_fwd_kernel.hpp:68
static constexpr index_t Vector_N
Definition: rmsnorm2d_fwd_kernel.hpp:71
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49