/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp Source File
layernorm2d_fwd_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
9 
10 namespace ck_tile {
11 
12 // host side args
14 {
15  const void* p_x; // [m ,n], input, fp16/bf16
16  const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
17  const void* p_sm_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
18  const void* p_x_bias; // [1, n], bias, prec same as input
19  const void* p_gamma; // [1, n], gamma, prec same as input
20  const void* p_beta; // [1, n], beta, prec same as input
21 
22  void* p_y; // [m, n], output, fp16/bf16
23  void* p_y_residual; // [m, n], shortcut output, prec same as input, nullptr if not used
24  void* p_y_scale; // [m, 1], output a dynamic quant per row, nullptr if not used
25  void* p_mean; // [m, 1], output mean, prec same as input, nullptr if not used
26  void* p_invStd; // [m, 1], output inv-stdvariance, prec same as input, nullptr if not used
27 
28  float epsilon;
29 
32  index_t x_stride; // x row_stride
33  index_t xr_stride; // x residule row stride
34  index_t y_stride; // y row stride
35  index_t yr_stride; // y residule row stride
36 };
37 
38 // TODO: Extract some type to wrapper class
39 template <typename Pipeline_, typename Epilogue_>
41 {
44  using Problem = typename Pipeline::Problem;
45 
56 
57  // for simplicity, shortcut input/output type is same as X
60 
61  static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, null_type>;
62  static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, null_type>;
63  static constexpr bool kSaveMeanInvStd = Problem::Traits::kSaveMeanInvStd;
64  static constexpr bool kSaveMean = Problem::Traits::kSaveMeanInvStd;
65  static constexpr bool kSaveInvStd = Problem::Traits::kSaveMeanInvStd;
66 
67  static constexpr index_t Block_M = Problem::BlockShape::Block_M;
68  static constexpr index_t Block_N = Problem::BlockShape::Block_N;
69  static constexpr bool kPadM = false; // always no need to pad along M
70  static constexpr bool kPadN = Problem::Traits::kPadN;
71  static constexpr bool kTwoPass = Problem::Traits::kTwoPass;
72  static constexpr auto kXbias = Problem::Traits::kXbias;
73  static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
74  static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
75 
76  static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
77  static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
78  static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
79  static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
80  static constexpr auto I0 = number<0>{};
81  static constexpr auto I1 = number<1>{};
82 
83  struct Kargs
84  {
85  const void* p_x; // [m ,n], input, fp16/bf16
86  const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
87  const void* p_sm_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
88  const void* p_x_bias; // [1, n], bias, prec same as input
89  const void* p_gamma; // [1, n], gamma, prec same as input
90  const void* p_beta; // [1, n], beta, prec same as input
91 
92  void* p_y; // [m, n], output, fp16/bf16
93  void* p_y_residual; // [m, n], shortcut output, prec same as input, nullptr if not used
94  void* p_y_scale; // [m, 1], output a dynamic quant per row, nullptr if not used
95 
96  void* p_mean; // [m, 1], output mean, prec same as input, nullptr if not used
97  void* p_invStd; // [m, 1], output inv-stdvariance, prec same as input, nullptr if not used
98 
99  float epsilon;
100 
103  index_t x_stride; // x row_stride
104  index_t xr_stride; // x residule row stride
105  index_t y_stride; // y row stride
106  index_t yr_stride; // y residule row stride
107  };
109 
110  CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
111  {
112  return Kargs{hargs.p_x,
113  hargs.p_x_residual,
114  hargs.p_sm_scale,
115  hargs.p_x_bias,
116  hargs.p_gamma,
117  hargs.p_beta,
118  hargs.p_y,
119  hargs.p_y_residual,
120  hargs.p_y_scale,
121  hargs.p_mean,
122  hargs.p_invStd,
123  hargs.epsilon,
124  hargs.m,
125  hargs.n,
126  hargs.x_stride,
127  hargs.xr_stride,
128  hargs.y_stride,
129  hargs.yr_stride};
130  }
131 
132  CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
133  {
134  return dim3(integer_divide_ceil(hargs.m, Block_M));
135  }
136 
137  CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; }
138 
139  // clang-format off
140  template <typename T> struct t2s;
141  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
142  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
143  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
144  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
145  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
146  template <> struct t2s<ck_tile::int8_t> { static constexpr const char * name = "int8"; };
147  // clang-format on
148 
149  // in byte
150  CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
151 
152  CK_TILE_HOST static std::string GetName()
153  {
154 #define _SS_ std::string
155 #define _TS_ std::to_string
156  // clang-format off
157  using S_ = typename Problem::BlockShape;
158  auto surfix = [&] () {
159  std::string n;
163  if (kPadN) n += "_pn";
164  if (kSaveMeanInvStd) n += "_mv";
165  // if (kTwoPass) n += "_2p";
166  return n; }();
167 
168  auto prec_str = [&] () {
169  std::string base_str = _SS_(t2s<XDataType>::name);
170  if (!std::is_same_v<XDataType, YDataType>) {
171  base_str += _SS_("_") + _SS_(t2s<YDataType>::name);
172  }
174  base_str += _SS_("_sx") + _SS_(t2s<SmoothScaleDataType>::name);
175  base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
176  }
178  base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
179  }
180  return base_str;
181  }();
182 
183  return _SS_("layernorm2d_fwd_") + _SS_(prec_str) + "_" +
184  _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
185  _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
186  _SS_(Pipeline::name) + surfix;
187  // clang-format on
188 #undef _SS_
189 #undef _TS_
190  }
191 
192  CK_TILE_DEVICE void operator()(Kargs kargs) const
193  {
194  const auto iM = get_block_id() * Block_M;
195 
196  const auto x_window = [&]() {
197  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
198  static_cast<const XDataType*>(kargs.p_x),
199  make_tuple(kargs.m, kargs.n),
200  make_tuple(kargs.x_stride, 1),
202  number<1>{});
203 
204  // NOTE: we don't do any pad in this kernel for loading, assume that inside kernel will
205  // check the max count dynamically
206  const auto tmp2_ = pad_tensor_view(
208  return make_tile_window(
209  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
210  }();
211 
212  const auto x_residual_window = [&]() {
215  {
216  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
217  static_cast<const XResidualDataType*>(kargs.p_x_residual),
218  make_tuple(kargs.m, kargs.n),
219  make_tuple(kargs.xr_stride, 1),
221  number<1>{});
222 
223  // NOTE: we don't do any pad in this kernel for loading, assume that inside kernel
224  // will check the max count dynamically
225  const auto tmp2_ = pad_tensor_view(tmp_,
228  return make_tile_window(
229  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
230  }
231  else
232  {
234  }
235  }();
236 
237  const auto x_bias_window = [&]() {
238  if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
239  {
240  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
241  static_cast<const XBiasDataType*>(kargs.p_x_bias),
242  make_tuple(kargs.n),
243  make_tuple(1),
245  number<1>{});
246 
247  const auto tmp2_ =
249 
250  return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
251  }
252  else
253  {
255  }
256  }();
257 
258  const auto gamma_window = [&]() {
259  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
260  static_cast<const GammaDataType*>(kargs.p_gamma),
261  make_tuple(kargs.n),
262  make_tuple(1),
264  number<1>{});
265 
266  const auto tmp2_ =
268 
269  return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
270  }();
271 
272  const auto beta_window = [&]() {
273  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
274  static_cast<const BetaDataType*>(kargs.p_beta),
275  make_tuple(kargs.n),
276  make_tuple(1),
278  number<1>{});
279 
280  const auto tmp2_ =
283  }();
284 
285  auto y_window = [&]() {
286  auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
287  static_cast<YDataType*>(kargs.p_y),
288  make_tuple(kargs.m, kargs.n),
289  make_tuple(kargs.y_stride, 1),
291  number<1>{});
292 
293  auto tmp2_ = pad_tensor_view(
295  return make_tile_window(
296  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
297  }();
298 
299  auto y_residual_window = [&]() {
301  {
302  auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
303  static_cast<YResidualDataType*>(kargs.p_y_residual),
304  make_tuple(kargs.m, kargs.n),
305  make_tuple(kargs.yr_stride, 1),
307  number<1>{});
308 
309  auto tmp2_ = pad_tensor_view(tmp_,
312  return make_tile_window(
313  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
314  }
315  else
316  {
318  }
319  }();
320 
321  auto mean_window = [&]() {
322  if constexpr(kSaveMean)
323  {
324  const auto mean_m = [&]() {
325  const auto mean_dram_naive =
326  make_naive_tensor_view_packed<address_space_enum::global>(
327  static_cast<MeanDataType*>(kargs.p_mean),
328  make_tuple(kargs.m),
329  number<1>{});
330 
331  return pad_tensor_view(
332  mean_dram_naive, make_tuple(number<Block_M>{}), sequence<kPadM>{});
333  }();
334  return make_tile_window(mean_m, make_tuple(number<Block_M>{}), {iM});
335  }
336  else
338  }();
339 
340  auto inv_std_window = [&]() {
341  if constexpr(kSaveInvStd)
342  {
343  const auto inv_std_m = [&]() {
344  const auto inv_std_dram_naive =
345  make_naive_tensor_view_packed<address_space_enum::global>(
346  static_cast<InvStdDataType*>(kargs.p_invStd),
347  make_tuple(kargs.m),
348  number<1>{});
349 
350  return pad_tensor_view(
351  inv_std_dram_naive, make_tuple(number<Block_M>{}), sequence<kPadM>{});
352  }();
353  return make_tile_window(inv_std_m, make_tuple(number<Block_M>{}), {iM});
354  }
355  else
357  }();
358 
359  auto sm_scale_window = [&]() {
361  {
362  const auto win_ = [&]() {
363  const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
364  static_cast<const SmoothScaleDataType*>(kargs.p_sm_scale),
365  make_tuple(kargs.n),
366  number<Vector_N>{});
367 
368  return pad_tensor_view(tmp_0_,
370  sequence<false>{}); // sm_scale no need pad
371  }();
372  return make_tile_window(win_, make_tuple(number<Block_N>{}), {0});
373  }
374  else
376  }();
377 
378  auto y_scale_window = [&]() {
381  {
382  const auto win_ = [&]() {
383  const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
384  static_cast<YScaleDataType*>(kargs.p_y_scale),
385  make_tuple(kargs.m),
386  number<1>{});
387 
388  return pad_tensor_view(
390  }();
391  return make_tile_window(win_, make_tuple(number<Block_M>{}), {iM});
392  }
393  else
395  }();
396 
397  __shared__ char smem[GetSmemSize()];
398 
399  Pipeline{}(x_window,
400  x_residual_window,
401  x_bias_window,
402  gamma_window,
403  beta_window,
404  y_window,
405  y_residual_window,
406  mean_window,
407  inv_std_window,
408  sm_scale_window,
409  y_scale_window,
410  static_cast<const ComputeDataType>(kargs.epsilon),
411  kargs.n,
412  smem,
413  Epilogue{});
414  }
415 };
416 
417 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
#define _TS_
#define _SS_
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
_Float16 fp16_t
Definition: half.hpp:110
int8_t int8_t
Definition: int8.hpp:20
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
Definition: layernorm2d_fwd_traits.hpp:33
Definition: layernorm2d_fwd_traits.hpp:47
Definition: layernorm2d_fwd_kernel.hpp:84
index_t yr_stride
Definition: layernorm2d_fwd_kernel.hpp:106
void * p_y_scale
Definition: layernorm2d_fwd_kernel.hpp:94
const void * p_x_bias
Definition: layernorm2d_fwd_kernel.hpp:88
void * p_y_residual
Definition: layernorm2d_fwd_kernel.hpp:93
const void * p_beta
Definition: layernorm2d_fwd_kernel.hpp:90
index_t n
Definition: layernorm2d_fwd_kernel.hpp:102
const void * p_gamma
Definition: layernorm2d_fwd_kernel.hpp:89
const void * p_x
Definition: layernorm2d_fwd_kernel.hpp:85
const void * p_sm_scale
Definition: layernorm2d_fwd_kernel.hpp:87
void * p_invStd
Definition: layernorm2d_fwd_kernel.hpp:97
index_t m
Definition: layernorm2d_fwd_kernel.hpp:101
void * p_y
Definition: layernorm2d_fwd_kernel.hpp:92
void * p_mean
Definition: layernorm2d_fwd_kernel.hpp:96
index_t y_stride
Definition: layernorm2d_fwd_kernel.hpp:105
index_t x_stride
Definition: layernorm2d_fwd_kernel.hpp:103
const void * p_x_residual
Definition: layernorm2d_fwd_kernel.hpp:86
index_t xr_stride
Definition: layernorm2d_fwd_kernel.hpp:104
float epsilon
Definition: layernorm2d_fwd_kernel.hpp:99
Definition: layernorm2d_fwd_kernel.hpp:140
Definition: layernorm2d_fwd_kernel.hpp:14
void * p_invStd
Definition: layernorm2d_fwd_kernel.hpp:26
index_t xr_stride
Definition: layernorm2d_fwd_kernel.hpp:33
index_t y_stride
Definition: layernorm2d_fwd_kernel.hpp:34
index_t n
Definition: layernorm2d_fwd_kernel.hpp:31
index_t m
Definition: layernorm2d_fwd_kernel.hpp:30
void * p_y_residual
Definition: layernorm2d_fwd_kernel.hpp:23
const void * p_x
Definition: layernorm2d_fwd_kernel.hpp:15
const void * p_x_bias
Definition: layernorm2d_fwd_kernel.hpp:18
float epsilon
Definition: layernorm2d_fwd_kernel.hpp:28
const void * p_gamma
Definition: layernorm2d_fwd_kernel.hpp:19
void * p_mean
Definition: layernorm2d_fwd_kernel.hpp:25
void * p_y_scale
Definition: layernorm2d_fwd_kernel.hpp:24
const void * p_beta
Definition: layernorm2d_fwd_kernel.hpp:20
const void * p_x_residual
Definition: layernorm2d_fwd_kernel.hpp:16
void * p_y
Definition: layernorm2d_fwd_kernel.hpp:22
const void * p_sm_scale
Definition: layernorm2d_fwd_kernel.hpp:17
index_t x_stride
Definition: layernorm2d_fwd_kernel.hpp:32
index_t yr_stride
Definition: layernorm2d_fwd_kernel.hpp:35
Definition: layernorm2d_fwd_kernel.hpp:41
typename Pipeline::Problem Problem
Definition: layernorm2d_fwd_kernel.hpp:44
static constexpr CK_TILE_HOST auto GridSize(const Hargs &hargs)
Definition: layernorm2d_fwd_kernel.hpp:132
remove_cvref_t< typename Problem::BetaDataType > BetaDataType
Definition: layernorm2d_fwd_kernel.hpp:49
remove_cvref_t< Pipeline_ > Pipeline
Definition: layernorm2d_fwd_kernel.hpp:42
remove_cvref_t< typename Problem::XDataType > XDataType
Definition: layernorm2d_fwd_kernel.hpp:46
static constexpr bool kHasBeta
Definition: layernorm2d_fwd_kernel.hpp:62
static CK_TILE_HOST std::string GetName()
Definition: layernorm2d_fwd_kernel.hpp:152
static constexpr auto kFusedAdd
Definition: layernorm2d_fwd_kernel.hpp:73
static constexpr index_t Repeat_N
Definition: layernorm2d_fwd_kernel.hpp:78
static constexpr CK_TILE_HOST auto BlockSize()
Definition: layernorm2d_fwd_kernel.hpp:137
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: layernorm2d_fwd_kernel.hpp:150
static constexpr CK_TILE_HOST Kargs MakeKargs(const Hargs &hargs)
Definition: layernorm2d_fwd_kernel.hpp:110
static constexpr index_t Block_N
Definition: layernorm2d_fwd_kernel.hpp:68
remove_cvref_t< typename Problem::XBiasDataType > XBiasDataType
Definition: layernorm2d_fwd_kernel.hpp:47
static constexpr auto I0
Definition: layernorm2d_fwd_kernel.hpp:80
static constexpr bool kHasGamma
Definition: layernorm2d_fwd_kernel.hpp:61
static constexpr index_t ThreadPerWarp_N
Definition: layernorm2d_fwd_kernel.hpp:76
static constexpr bool kSaveMeanInvStd
Definition: layernorm2d_fwd_kernel.hpp:63
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: layernorm2d_fwd_kernel.hpp:192
static constexpr bool kPadN
Definition: layernorm2d_fwd_kernel.hpp:70
remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition: layernorm2d_fwd_kernel.hpp:50
remove_cvref_t< typename Problem::YScaleDataType > YScaleDataType
Definition: layernorm2d_fwd_kernel.hpp:55
static constexpr index_t Vector_N
Definition: layernorm2d_fwd_kernel.hpp:77
static constexpr bool kTwoPass
Definition: layernorm2d_fwd_kernel.hpp:71
XDataType YResidualDataType
Definition: layernorm2d_fwd_kernel.hpp:59
static constexpr auto kFusedQuant
Definition: layernorm2d_fwd_kernel.hpp:74
static constexpr bool kSaveMean
Definition: layernorm2d_fwd_kernel.hpp:64
remove_cvref_t< typename Problem::YDataType > YDataType
Definition: layernorm2d_fwd_kernel.hpp:51
static constexpr bool kPadM
Definition: layernorm2d_fwd_kernel.hpp:69
static constexpr auto kXbias
Definition: layernorm2d_fwd_kernel.hpp:72
static constexpr index_t Block_M
Definition: layernorm2d_fwd_kernel.hpp:67
XDataType XResidualDataType
Definition: layernorm2d_fwd_kernel.hpp:58
remove_cvref_t< typename Problem::InvStdDataType > InvStdDataType
Definition: layernorm2d_fwd_kernel.hpp:53
remove_cvref_t< typename Problem::MeanDataType > MeanDataType
Definition: layernorm2d_fwd_kernel.hpp:52
static constexpr bool kSaveInvStd
Definition: layernorm2d_fwd_kernel.hpp:65
static constexpr index_t kBlockSize
Definition: layernorm2d_fwd_kernel.hpp:79
remove_cvref_t< typename Problem::SmoothScaleDataType > SmoothScaleDataType
Definition: layernorm2d_fwd_kernel.hpp:54
remove_cvref_t< Epilogue_ > Epilogue
Definition: layernorm2d_fwd_kernel.hpp:43
remove_cvref_t< typename Problem::GammaDataType > GammaDataType
Definition: layernorm2d_fwd_kernel.hpp:48
static constexpr auto I1
Definition: layernorm2d_fwd_kernel.hpp:81
Definition: layernorm2d_fwd_traits.hpp:18
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49