/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp Source File
dynamic_quant_epilogue.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/reduce.hpp"
8 
9 namespace ck_tile {
10 
11 template <bool kPadM_,
12  bool kPadN_,
13  bool UseSmoothInputScale_,
14  bool UseRawStore_ = true,
15  bool UseMax3_ = false>
17 {
18  static constexpr bool kPadM = kPadM_;
19  static constexpr bool kPadN = kPadN_;
20  static constexpr bool UseSmoothInputScale = UseSmoothInputScale_;
21  static constexpr bool UseRawStore = UseRawStore_;
22  static constexpr bool UseMax3 = UseMax3_;
23 };
24 
25 // this epilogue just store out a M*N matrix, row major
26 template <typename AccDataType_,
27  typename SmoothScaleDataType_,
28  typename YScaleDataType_,
29  typename ODataType_,
30  typename BlockShape_,
31  typename Traits_>
33 {
38  using BlockShape = remove_cvref_t<BlockShape_>; // can consum generic 2d shape
40 };
41 
42 // TODO: we should put descriptor creation function into policy
43 template <typename Problem_, typename Policy_ = void>
45 {
52  static constexpr bool kPadM = Problem::Traits::kPadM;
53  static constexpr bool kPadN = Problem::Traits::kPadN;
54  static constexpr bool UseRawStore = Problem::Traits::UseRawStore;
55  static constexpr bool UseMax3 = Problem::Traits::UseMax3;
56 
57  CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d()
58  {
60  return BlockReduce2d<P_>{};
61  }
62 
63  CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync()
64  {
66  return BlockReduce2dSync<P_>{};
67  }
68 
70  {
73  }
74 
76  {
77  using S = BlockShape;
78 #if 0
79  // don't remove this
80  // Note that if we set encoding purposely like this, you will result in compile fail
81  // TODO: sm_scale create local-scratch to accept arbitrary acc input (with same length)
90 #else
98  sequence<0, 3>>{});
99 #endif
100  }
101 
103  {
104  auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync();
105  return reduce_crosswarp_sync.GetSmemSize();
106  }
107 
108  template <typename ODramWindowTmp, typename YScaleWindow, typename OAccTile>
109  CK_TILE_DEVICE auto Impl(ODramWindowTmp& o_dram_window_tmp,
110  YScaleWindow& y_scale_window,
111  const OAccTile& o_acc_tile,
112  void* smem)
113  {
114  auto reduce = GetBlockReduce2d();
115  auto reduce_sync = GetBlockReduce2dSync();
116  auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync();
117 
118  auto o_acc_tmp = o_acc_tile;
119 
120  const auto f_absmax = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); };
121 
122  auto row_absmax = [&]() {
123  constexpr auto y_size_per_row =
124  OAccTile{}.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(
125  number<1>{});
126  if constexpr(UseMax3 && std::is_same_v<AccDataType, float> && y_size_per_row % 2 == 0)
127  {
128  // fast max3+abs implementation
129  const auto f_max3 = [](auto acc_, auto v_0_, auto v_1_) {
130  float rtn;
131  asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)"
132  : "=v"(rtn)
133  : "v"(acc_), "v"(v_0_), "v"(v_1_));
134  return rtn;
135  };
136  return reduce(o_acc_tmp, type_convert<AccDataType>(0), f_max3, sequence<1, 2>{});
137  }
138  else
139  {
140  return reduce(o_acc_tmp, type_convert<AccDataType>(0), f_absmax);
141  }
142  }();
143  reduce_sync(row_absmax, f_absmax);
144  reduce_crosswarp_sync(row_absmax, smem, f_absmax);
145 
146  // here y_scale is Acc TYpe, need convert to YScale type later
147  auto y_scale = tile_elementwise_in(
148  [&](const auto& v_) {
149  return v_ / type_convert<AccDataType>(numeric<ODataType>::max());
150  },
151  row_absmax);
152 
153  store_tile(y_scale_window, cast_tile<YScaleDataType>(y_scale));
154 
155  sweep_tile(o_acc_tmp, [&](auto idx) {
156  constexpr auto row_id = make_tuple(idx[number<0>{}]);
157  o_acc_tmp(idx) = o_acc_tmp[idx] / y_scale(row_id);
158  });
159 
160  // TODO: this is ugly
161  if constexpr(UseRawStore && (kPadM || kPadN))
162  {
163  store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tmp));
165  }
166  else
167  {
168  store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tmp));
169  }
170  }
171 
172  // TODO: this function assume store out vector size is the same as OAccTile last dimension size
173  // how do we fix this ?
174 
175  // Smooth Dynamic Quant
176  template <typename ODramWindowTmp,
177  typename SmoothScaleWindow,
178  typename YScaleWindow,
179  typename OAccTile>
180  CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
181  const SmoothScaleWindow& sm_scale_window_,
182  YScaleWindow& y_scale_window,
183  const OAccTile& o_acc_tile,
184  void* smem)
185  {
186  const auto sm_scale_window =
188 
189  auto sm_scale = load_tile(sm_scale_window);
190 
191  auto o_acc_tmp = o_acc_tile;
192 
193  sweep_tile(o_acc_tmp, [&](auto idx) {
194  constexpr auto j_idx = make_tuple(idx[number<1>{}]);
195  const auto xs_ = type_convert<AccDataType>(sm_scale[j_idx]);
196  o_acc_tmp(idx) = o_acc_tmp(idx) * xs_;
197  });
198 
199  Impl(o_dram_window_tmp, y_scale_window, o_acc_tmp, smem);
200  }
201 
202  // Dynamic Quant
203  template <typename ODramWindowTmp, typename YScaleWindow, typename OAccTile>
204  CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
205  YScaleWindow& y_scale_window,
206  const OAccTile& o_acc_tile,
207  void* smem)
208  {
209  Impl(o_dram_window_tmp, y_scale_window, o_acc_tile, smem);
210  }
211 };
212 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition: tile_elementwise.hpp:40
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
CK_TILE_DEVICE void buffer_store_fence(index_t cnt=0)
Definition: amd_buffer_addressing.hpp:1000
CK_TILE_DEVICE void store_tile_raw(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:46
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition: bfloat16.hpp:404
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
constexpr CK_TILE_HOST_DEVICE void sweep_tile(const F &f, UnpacksPerXDim={})
Definition: sweep_tile.hpp:231
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
Definition: block_reduce2d.hpp:200
Definition: block_reduce2d.hpp:45
Definition: block_reduce2d_problem.hpp:12
Definition: block_reduce2d.hpp:135
Definition: dynamic_quant_epilogue.hpp:45
static constexpr bool UseMax3
Definition: dynamic_quant_epilogue.hpp:55
remove_cvref_t< typename Problem::BlockShape > BlockShape
Definition: dynamic_quant_epilogue.hpp:51
remove_cvref_t< typename Problem::YScaleDataType > YScaleDataType
Definition: dynamic_quant_epilogue.hpp:49
remove_cvref_t< typename Problem::ODataType > ODataType
Definition: dynamic_quant_epilogue.hpp:50
static constexpr bool kPadM
Definition: dynamic_quant_epilogue.hpp:52
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: dynamic_quant_epilogue.hpp:102
remove_cvref_t< typename Problem::SmoothScaleDataType > SmoothScaleDataType
Definition: dynamic_quant_epilogue.hpp:48
static constexpr CK_TILE_DEVICE auto MakeSmoothInputScaleTileDistribution()
Definition: dynamic_quant_epilogue.hpp:75
CK_TILE_DEVICE auto operator()(ODramWindowTmp &o_dram_window_tmp, YScaleWindow &y_scale_window, const OAccTile &o_acc_tile, void *smem)
Definition: dynamic_quant_epilogue.hpp:204
remove_cvref_t< typename Problem::AccDataType > AccDataType
Definition: dynamic_quant_epilogue.hpp:47
static constexpr CK_TILE_HOST_DEVICE auto GetBlockReduce2dSync()
Definition: dynamic_quant_epilogue.hpp:63
static constexpr bool kPadN
Definition: dynamic_quant_epilogue.hpp:53
static constexpr CK_TILE_HOST_DEVICE auto GetBlockReduce2d()
Definition: dynamic_quant_epilogue.hpp:57
static constexpr bool UseRawStore
Definition: dynamic_quant_epilogue.hpp:54
CK_TILE_DEVICE auto operator()(ODramWindowTmp &o_dram_window_tmp, const SmoothScaleWindow &sm_scale_window_, YScaleWindow &y_scale_window, const OAccTile &o_acc_tile, void *smem)
Definition: dynamic_quant_epilogue.hpp:180
static constexpr CK_TILE_HOST_DEVICE auto GetBlockReduce2dCrossWarpSync()
Definition: dynamic_quant_epilogue.hpp:69
remove_cvref_t< Problem_ > Problem
Definition: dynamic_quant_epilogue.hpp:46
CK_TILE_DEVICE auto Impl(ODramWindowTmp &o_dram_window_tmp, YScaleWindow &y_scale_window, const OAccTile &o_acc_tile, void *smem)
Definition: dynamic_quant_epilogue.hpp:109
Definition: dynamic_quant_epilogue.hpp:33
remove_cvref_t< YScaleDataType_ > YScaleDataType
Definition: dynamic_quant_epilogue.hpp:36
remove_cvref_t< ODataType_ > ODataType
Definition: dynamic_quant_epilogue.hpp:37
remove_cvref_t< Traits_ > Traits
Definition: dynamic_quant_epilogue.hpp:39
remove_cvref_t< BlockShape_ > BlockShape
Definition: dynamic_quant_epilogue.hpp:38
remove_cvref_t< SmoothScaleDataType_ > SmoothScaleDataType
Definition: dynamic_quant_epilogue.hpp:35
remove_cvref_t< AccDataType_ > AccDataType
Definition: dynamic_quant_epilogue.hpp:34
Definition: dynamic_quant_epilogue.hpp:17
static constexpr bool kPadM
Definition: dynamic_quant_epilogue.hpp:18
static constexpr bool UseRawStore
Definition: dynamic_quant_epilogue.hpp:21
static constexpr bool kPadN
Definition: dynamic_quant_epilogue.hpp:19
static constexpr bool UseSmoothInputScale
Definition: dynamic_quant_epilogue.hpp:20
static constexpr bool UseMax3
Definition: dynamic_quant_epilogue.hpp:22
Definition: integral_constant.hpp:13
Definition: numeric.hpp:18
Definition: sequence.hpp:49
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192