/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp Source File
layernorm2d_fwd_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
9 
10 namespace ck_tile {
11 
12 // host side args
14 {
15  const void* p_x; // [m ,n], input, fp16/bf16
16  const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
17  const void* p_sm_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
18  const void* p_x_bias; // [1, n], bias, prec same as input
19  const void* p_gamma; // [1, n], gamma, prec same as input
20  const void* p_beta; // [1, n], beta, prec same as input
21 
22  void* p_y; // [m, n], output, fp16/bf16
23  void* p_y_residual; // [m, n], shortcut output, prec same as input, nullptr if not used
24  void* p_y_scale; // [m, 1], output a dynamic quant per row, nullptr if not used
25  void* p_mean; // [m, 1], output mean, prec same as input, nullptr if not used
26  void* p_invStd; // [m, 1], output inv-stdvariance, prec same as input, nullptr if not used
27 
28  float epsilon;
29 
32  index_t x_stride; // x row_stride
33  index_t xr_stride; // x residule row stride
34  index_t y_stride; // y row stride
35  index_t yr_stride; // y residule row stride
36 };
37 
38 // TODO: Extract some type to wrapper class
39 template <typename Pipeline_, typename Epilogue_>
41 {
44  using Problem = typename Pipeline::Problem;
45 
56 
57  // for simplicity, shortcut input/output type is same as X
60 
61  static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, null_type>;
62  static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, null_type>;
63  static constexpr bool kSaveMeanInvStd = Problem::Traits::kSaveMeanInvStd;
64  static constexpr bool kSaveMean = Problem::Traits::kSaveMeanInvStd;
65  static constexpr bool kSaveInvStd = Problem::Traits::kSaveMeanInvStd;
66 
67  static constexpr index_t Block_M = Problem::BlockShape::Block_M;
68  static constexpr index_t Block_N = Problem::BlockShape::Block_N;
69  static constexpr bool kPadM = false; // always no need to pad along M
70  static constexpr bool kPadN = Problem::Traits::kPadN;
71  static constexpr bool kTwoPass = Problem::Traits::kTwoPass;
72  static constexpr auto kXbias = Problem::Traits::kXbias;
73  static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
74  static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
75 
76  static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
77  static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
78  static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
79  static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
80  static constexpr auto I0 = number<0>{};
81  static constexpr auto I1 = number<1>{};
82 
83  struct Kargs
84  {
85  const void* p_x; // [m ,n], input, fp16/bf16
86  const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
87  const void* p_sm_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
88  const void* p_x_bias; // [1, n], bias, prec same as input
89  const void* p_gamma; // [1, n], gamma, prec same as input
90  const void* p_beta; // [1, n], beta, prec same as input
91 
92  void* p_y; // [m, n], output, fp16/bf16
93  void* p_y_residual; // [m, n], shortcut output, prec same as input, nullptr if not used
94  void* p_y_scale; // [m, 1], output a dynamic quant per row, nullptr if not used
95 
96  void* p_mean; // [m, 1], output mean, prec same as input, nullptr if not used
97  void* p_invStd; // [m, 1], output inv-stdvariance, prec same as input, nullptr if not used
98 
99  float epsilon;
100 
103  index_t x_stride; // x row_stride
104  index_t xr_stride; // x residule row stride
105  index_t y_stride; // y row stride
106  index_t yr_stride; // y residule row stride
107  };
109 
110  CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
111  {
112  return Kargs{hargs.p_x,
113  hargs.p_x_residual,
114  hargs.p_sm_scale,
115  hargs.p_x_bias,
116  hargs.p_gamma,
117  hargs.p_beta,
118  hargs.p_y,
119  hargs.p_y_residual,
120  hargs.p_y_scale,
121  hargs.p_mean,
122  hargs.p_invStd,
123  hargs.epsilon,
124  hargs.m,
125  hargs.n,
126  hargs.x_stride,
127  hargs.xr_stride,
128  hargs.y_stride,
129  hargs.yr_stride};
130  }
131 
132  CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
133  {
134  return dim3(integer_divide_ceil(hargs.m, Block_M));
135  }
136 
137  CK_TILE_HOST static constexpr auto BlockSize()
138  {
139  return is_wave32() ? Problem::BlockShape::template GetBlockSize<true>()
140  : Problem::BlockShape::template GetBlockSize<false>();
141  }
142 
143  // clang-format off
144  template <typename T> struct t2s;
145  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
146  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
147  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
148  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
149  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
150  template <> struct t2s<ck_tile::int8_t> { static constexpr const char * name = "int8"; };
151  // clang-format on
152 
153  // in byte
154  CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
155 
156  CK_TILE_HOST static std::string GetName()
157  {
158 #define _SS_ std::string
159 #define _TS_ std::to_string
160  // clang-format off
161  using S_ = typename Problem::BlockShape;
162  auto surfix = [&] () {
163  std::string n;
167  if (kPadN) n += "_pn";
168  if (kSaveMeanInvStd) n += "_mv";
169  // if (kTwoPass) n += "_2p";
170  return n; }();
171 
172  auto prec_str = [&] () {
173  std::string base_str = _SS_(t2s<XDataType>::name);
174  if (!std::is_same_v<XDataType, YDataType>) {
175  base_str += _SS_("_") + _SS_(t2s<YDataType>::name);
176  }
178  base_str += _SS_("_sx") + _SS_(t2s<SmoothScaleDataType>::name);
179  base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
180  }
182  base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
183  }
184  return base_str;
185  }();
186 
187  return _SS_("layernorm2d_fwd_") + _SS_(prec_str) + "_" +
188  _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
189  _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
190  _SS_(Pipeline::name) + surfix;
191  // clang-format on
192 #undef _SS_
193 #undef _TS_
194  }
195 
196  CK_TILE_DEVICE void operator()(Kargs kargs) const
197  {
198  const auto iM = get_block_id() * Block_M;
199 
200  const auto x_window = [&]() {
201  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
202  static_cast<const XDataType*>(kargs.p_x),
203  make_tuple(kargs.m, kargs.n),
204  make_tuple(kargs.x_stride, 1),
206  number<1>{});
207 
208  // NOTE: we don't do any pad in this kernel for loading, assume that inside kernel will
209  // check the max count dynamically
210  const auto tmp2_ = pad_tensor_view(
212  return make_tile_window(
213  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
214  }();
215 
216  const auto x_residual_window = [&]() {
219  {
220  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
221  static_cast<const XResidualDataType*>(kargs.p_x_residual),
222  make_tuple(kargs.m, kargs.n),
223  make_tuple(kargs.xr_stride, 1),
225  number<1>{});
226 
227  // NOTE: we don't do any pad in this kernel for loading, assume that inside kernel
228  // will check the max count dynamically
229  const auto tmp2_ = pad_tensor_view(tmp_,
232  return make_tile_window(
233  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
234  }
235  else
236  {
238  }
239  }();
240 
241  const auto x_bias_window = [&]() {
242  if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
243  {
244  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
245  static_cast<const XBiasDataType*>(kargs.p_x_bias),
246  make_tuple(kargs.n),
247  make_tuple(1),
249  number<1>{});
250 
251  const auto tmp2_ =
253 
254  return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
255  }
256  else
257  {
259  }
260  }();
261 
262  const auto gamma_window = [&]() {
263  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
264  static_cast<const GammaDataType*>(kargs.p_gamma),
265  make_tuple(kargs.n),
266  make_tuple(1),
268  number<1>{});
269 
270  const auto tmp2_ =
272 
273  return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
274  }();
275 
276  const auto beta_window = [&]() {
277  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
278  static_cast<const BetaDataType*>(kargs.p_beta),
279  make_tuple(kargs.n),
280  make_tuple(1),
282  number<1>{});
283 
284  const auto tmp2_ =
287  }();
288 
289  auto y_window = [&]() {
290  auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
291  static_cast<YDataType*>(kargs.p_y),
292  make_tuple(kargs.m, kargs.n),
293  make_tuple(kargs.y_stride, 1),
295  number<1>{});
296 
297  auto tmp2_ = pad_tensor_view(
299  return make_tile_window(
300  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
301  }();
302 
303  auto y_residual_window = [&]() {
305  {
306  auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
307  static_cast<YResidualDataType*>(kargs.p_y_residual),
308  make_tuple(kargs.m, kargs.n),
309  make_tuple(kargs.yr_stride, 1),
311  number<1>{});
312 
313  auto tmp2_ = pad_tensor_view(tmp_,
316  return make_tile_window(
317  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
318  }
319  else
320  {
322  }
323  }();
324 
325  auto mean_window = [&]() {
326  if constexpr(kSaveMean)
327  {
328  const auto mean_m = [&]() {
329  const auto mean_dram_naive =
330  make_naive_tensor_view_packed<address_space_enum::global>(
331  static_cast<MeanDataType*>(kargs.p_mean),
332  make_tuple(kargs.m),
333  number<1>{});
334 
335  return pad_tensor_view(
336  mean_dram_naive, make_tuple(number<Block_M>{}), sequence<kPadM>{});
337  }();
338  return make_tile_window(mean_m, make_tuple(number<Block_M>{}), {iM});
339  }
340  else
342  }();
343 
344  auto inv_std_window = [&]() {
345  if constexpr(kSaveInvStd)
346  {
347  const auto inv_std_m = [&]() {
348  const auto inv_std_dram_naive =
349  make_naive_tensor_view_packed<address_space_enum::global>(
350  static_cast<InvStdDataType*>(kargs.p_invStd),
351  make_tuple(kargs.m),
352  number<1>{});
353 
354  return pad_tensor_view(
355  inv_std_dram_naive, make_tuple(number<Block_M>{}), sequence<kPadM>{});
356  }();
357  return make_tile_window(inv_std_m, make_tuple(number<Block_M>{}), {iM});
358  }
359  else
361  }();
362 
363  auto sm_scale_window = [&]() {
365  {
366  const auto win_ = [&]() {
367  const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
368  static_cast<const SmoothScaleDataType*>(kargs.p_sm_scale),
369  make_tuple(kargs.n),
370  number<Vector_N>{});
371 
372  return pad_tensor_view(tmp_0_,
374  sequence<false>{}); // sm_scale no need pad
375  }();
376  return make_tile_window(win_, make_tuple(number<Block_N>{}), {0});
377  }
378  else
380  }();
381 
382  auto y_scale_window = [&]() {
385  {
386  const auto win_ = [&]() {
387  const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
388  static_cast<YScaleDataType*>(kargs.p_y_scale),
389  make_tuple(kargs.m),
390  number<1>{});
391 
392  return pad_tensor_view(
394  }();
395  return make_tile_window(win_, make_tuple(number<Block_M>{}), {iM});
396  }
397  else
399  }();
400 
401  __shared__ char smem[GetSmemSize()];
402 
403  Pipeline{}(x_window,
404  x_residual_window,
405  x_bias_window,
406  gamma_window,
407  beta_window,
408  y_window,
409  y_residual_window,
410  mean_window,
411  inv_std_window,
412  sm_scale_window,
413  y_scale_window,
414  static_cast<const ComputeDataType>(kargs.epsilon),
415  kargs.n,
416  smem,
417  Epilogue{});
418  }
419 };
420 
421 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
#define _TS_
#define _SS_
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
_Float16 fp16_t
Definition: half.hpp:110
int8_t int8_t
Definition: int8.hpp:20
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
Definition: layernorm2d_fwd_traits.hpp:33
Definition: layernorm2d_fwd_traits.hpp:47
Definition: layernorm2d_fwd_kernel.hpp:84
index_t yr_stride
Definition: layernorm2d_fwd_kernel.hpp:106
void * p_y_scale
Definition: layernorm2d_fwd_kernel.hpp:94
const void * p_x_bias
Definition: layernorm2d_fwd_kernel.hpp:88
void * p_y_residual
Definition: layernorm2d_fwd_kernel.hpp:93
const void * p_beta
Definition: layernorm2d_fwd_kernel.hpp:90
index_t n
Definition: layernorm2d_fwd_kernel.hpp:102
const void * p_gamma
Definition: layernorm2d_fwd_kernel.hpp:89
const void * p_x
Definition: layernorm2d_fwd_kernel.hpp:85
const void * p_sm_scale
Definition: layernorm2d_fwd_kernel.hpp:87
void * p_invStd
Definition: layernorm2d_fwd_kernel.hpp:97
index_t m
Definition: layernorm2d_fwd_kernel.hpp:101
void * p_y
Definition: layernorm2d_fwd_kernel.hpp:92
void * p_mean
Definition: layernorm2d_fwd_kernel.hpp:96
index_t y_stride
Definition: layernorm2d_fwd_kernel.hpp:105
index_t x_stride
Definition: layernorm2d_fwd_kernel.hpp:103
const void * p_x_residual
Definition: layernorm2d_fwd_kernel.hpp:86
index_t xr_stride
Definition: layernorm2d_fwd_kernel.hpp:104
float epsilon
Definition: layernorm2d_fwd_kernel.hpp:99
Definition: layernorm2d_fwd_kernel.hpp:144
Definition: layernorm2d_fwd_kernel.hpp:14
void * p_invStd
Definition: layernorm2d_fwd_kernel.hpp:26
index_t xr_stride
Definition: layernorm2d_fwd_kernel.hpp:33
index_t y_stride
Definition: layernorm2d_fwd_kernel.hpp:34
index_t n
Definition: layernorm2d_fwd_kernel.hpp:31
index_t m
Definition: layernorm2d_fwd_kernel.hpp:30
void * p_y_residual
Definition: layernorm2d_fwd_kernel.hpp:23
const void * p_x
Definition: layernorm2d_fwd_kernel.hpp:15
const void * p_x_bias
Definition: layernorm2d_fwd_kernel.hpp:18
float epsilon
Definition: layernorm2d_fwd_kernel.hpp:28
const void * p_gamma
Definition: layernorm2d_fwd_kernel.hpp:19
void * p_mean
Definition: layernorm2d_fwd_kernel.hpp:25
void * p_y_scale
Definition: layernorm2d_fwd_kernel.hpp:24
const void * p_beta
Definition: layernorm2d_fwd_kernel.hpp:20
const void * p_x_residual
Definition: layernorm2d_fwd_kernel.hpp:16
void * p_y
Definition: layernorm2d_fwd_kernel.hpp:22
const void * p_sm_scale
Definition: layernorm2d_fwd_kernel.hpp:17
index_t x_stride
Definition: layernorm2d_fwd_kernel.hpp:32
index_t yr_stride
Definition: layernorm2d_fwd_kernel.hpp:35
Definition: layernorm2d_fwd_kernel.hpp:41
typename Pipeline::Problem Problem
Definition: layernorm2d_fwd_kernel.hpp:44
static constexpr CK_TILE_HOST auto GridSize(const Hargs &hargs)
Definition: layernorm2d_fwd_kernel.hpp:132
remove_cvref_t< typename Problem::BetaDataType > BetaDataType
Definition: layernorm2d_fwd_kernel.hpp:49
remove_cvref_t< Pipeline_ > Pipeline
Definition: layernorm2d_fwd_kernel.hpp:42
remove_cvref_t< typename Problem::XDataType > XDataType
Definition: layernorm2d_fwd_kernel.hpp:46
static constexpr bool kHasBeta
Definition: layernorm2d_fwd_kernel.hpp:62
static CK_TILE_HOST std::string GetName()
Definition: layernorm2d_fwd_kernel.hpp:156
static constexpr auto kFusedAdd
Definition: layernorm2d_fwd_kernel.hpp:73
static constexpr index_t Repeat_N
Definition: layernorm2d_fwd_kernel.hpp:78
static constexpr CK_TILE_HOST auto BlockSize()
Definition: layernorm2d_fwd_kernel.hpp:137
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: layernorm2d_fwd_kernel.hpp:154
static constexpr CK_TILE_HOST Kargs MakeKargs(const Hargs &hargs)
Definition: layernorm2d_fwd_kernel.hpp:110
static constexpr index_t Block_N
Definition: layernorm2d_fwd_kernel.hpp:68
remove_cvref_t< typename Problem::XBiasDataType > XBiasDataType
Definition: layernorm2d_fwd_kernel.hpp:47
static constexpr auto I0
Definition: layernorm2d_fwd_kernel.hpp:80
static constexpr bool kHasGamma
Definition: layernorm2d_fwd_kernel.hpp:61
static constexpr index_t ThreadPerWarp_N
Definition: layernorm2d_fwd_kernel.hpp:76
static constexpr bool kSaveMeanInvStd
Definition: layernorm2d_fwd_kernel.hpp:63
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: layernorm2d_fwd_kernel.hpp:196
static constexpr bool kPadN
Definition: layernorm2d_fwd_kernel.hpp:70
remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition: layernorm2d_fwd_kernel.hpp:50
remove_cvref_t< typename Problem::YScaleDataType > YScaleDataType
Definition: layernorm2d_fwd_kernel.hpp:55
static constexpr index_t Vector_N
Definition: layernorm2d_fwd_kernel.hpp:77
static constexpr bool kTwoPass
Definition: layernorm2d_fwd_kernel.hpp:71
XDataType YResidualDataType
Definition: layernorm2d_fwd_kernel.hpp:59
static constexpr auto kFusedQuant
Definition: layernorm2d_fwd_kernel.hpp:74
static constexpr bool kSaveMean
Definition: layernorm2d_fwd_kernel.hpp:64
remove_cvref_t< typename Problem::YDataType > YDataType
Definition: layernorm2d_fwd_kernel.hpp:51
static constexpr bool kPadM
Definition: layernorm2d_fwd_kernel.hpp:69
static constexpr auto kXbias
Definition: layernorm2d_fwd_kernel.hpp:72
static constexpr index_t Block_M
Definition: layernorm2d_fwd_kernel.hpp:67
XDataType XResidualDataType
Definition: layernorm2d_fwd_kernel.hpp:58
remove_cvref_t< typename Problem::InvStdDataType > InvStdDataType
Definition: layernorm2d_fwd_kernel.hpp:53
remove_cvref_t< typename Problem::MeanDataType > MeanDataType
Definition: layernorm2d_fwd_kernel.hpp:52
static constexpr bool kSaveInvStd
Definition: layernorm2d_fwd_kernel.hpp:65
static constexpr index_t kBlockSize
Definition: layernorm2d_fwd_kernel.hpp:79
remove_cvref_t< typename Problem::SmoothScaleDataType > SmoothScaleDataType
Definition: layernorm2d_fwd_kernel.hpp:54
remove_cvref_t< Epilogue_ > Epilogue
Definition: layernorm2d_fwd_kernel.hpp:43
remove_cvref_t< typename Problem::GammaDataType > GammaDataType
Definition: layernorm2d_fwd_kernel.hpp:48
static constexpr auto I1
Definition: layernorm2d_fwd_kernel.hpp:81
Definition: layernorm2d_fwd_traits.hpp:18
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49