/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp Source File
smoothquant_pipeline_two_pass.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 #include <string>
9 #include <type_traits>
10 
11 namespace ck_tile {
12 
13 template <typename Problem_, typename Policy_ = SmoothquantPipelineDefaultPolicy>
15 {
18 
24 
25  static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
26  static constexpr bool kPadM = false; // TODO - BlockSmoothquantProblem::kPadM
27  static constexpr bool kPadN = Problem::kPadN;
28  static constexpr bool UseMax3 = true; // TODO - Move to trait
29 
30  static constexpr const char* name = []() {
31  if constexpr(kNeedCrossWarpSync)
32  return "bpr_tp"; // block per row
33  else
34  return "wpr_tp"; // warp per row
35  }();
36 
38  {
39  return Policy::template GetSmemSize<Problem>();
40  }
41 
42  template <typename XWindow,
43  typename SmoothScaleWindow,
44  typename QYWindow,
45  typename YScaleWindow>
46  CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
47  const SmoothScaleWindow& smscale_window_,
48  YScaleWindow& yscale_window,
49  QYWindow& qy_window,
50  ck_tile::index_t row_size,
51  void* smem) const
52  {
53  auto x_window =
54  make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
55  auto smscale_window = make_tile_window(
56  smscale_window_, Policy::template MakeSmoothScaleBlockTileDistribution<Problem>());
57 
58  static constexpr index_t Block_N = Problem::BlockShape::Block_N;
59  index_t num_n_tile_iteration =
60  __builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N));
61 
62  auto reduce_absmax_func = ReduceOp::AbsMax{};
63  auto reduce_absmax3_func = [](auto acc_, auto v_0_, auto v_1_) {
64  float rtn;
65  asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)"
66  : "=v"(rtn)
67  : "v"(acc_), "v"(v_0_), "v"(v_1_));
68  return rtn;
69  };
70  auto reduce_max_func = ReduceOp::Max{};
71  auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
72  auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
73  auto block_reduce2d_cross_warp_sync =
74  Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
75 
76  using XTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
77  auto absmax = block_reduce2d.template MakeYBlockTile<XTensorType>();
78  set_tile(absmax, reduce_absmax_func.GetIdentityValue<ComputeDataType>());
79 
80  for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
81  {
82  const auto x = load_tile(x_window);
83  const auto smscale = load_tile(smscale_window);
84  const auto y = tile_elementwise_in(
85  [&](const auto& a, const auto& b) {
86  return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b);
87  },
88  x,
89  smscale);
90 
91  constexpr auto x_size_per_row =
92  x.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(number<1>{});
93  if constexpr(UseMax3 && std::is_same_v<ComputeDataType, float> &&
94  x_size_per_row % 2 == 0)
95  block_reduce2d(y, absmax, reduce_absmax3_func, sequence<1, 2>{});
96  else
97  block_reduce2d(y, absmax, reduce_absmax_func);
98 
99  move_tile_window(x_window, {0, Block_N});
100  move_tile_window(smscale_window, {Block_N});
101  }
102 
103  // compute absmax, cross-lane->cross-warp
104  block_reduce2d_sync(absmax, reduce_max_func);
105  block_reduce2d_cross_warp_sync(absmax, smem, reduce_max_func);
106 
107  // ex: yscale = absmax / 127 if int8
108  auto yscale = tile_elementwise_in(
109  [&](const auto& v_) {
110  return v_ / type_convert<ComputeDataType>(numeric<QYDataType>::max());
111  },
112  absmax);
113  store_tile(yscale_window, cast_tile<YScaleDataType>(yscale));
114 
115  // reverse read x to reuse cache
116  ck_tile::index_t stride_to_right_most_window =
117  row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
118 
119  move_tile_window(x_window, {0, -Block_N});
120  move_tile_window(smscale_window, {-Block_N});
121  move_tile_window(qy_window, {0, stride_to_right_most_window});
122 
123  // recompute y and quantize y to qy
124  for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
125  {
126  const auto x = load_tile(x_window);
127  const auto smscale = load_tile(smscale_window);
128  const auto y = tile_elementwise_in(
129  [&](const auto& a, const auto& b) {
130  return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b);
131  },
132  x,
133  smscale);
134 
135  auto qy = make_static_distributed_tensor<QYDataType>(y.get_tile_distribution());
136  sweep_tile(qy, [&](auto idx) {
137  constexpr auto i_idx = make_tuple(idx[number<0>{}]);
138  auto qy_ = y[idx] / yscale[i_idx];
139  qy(idx) = type_convert<QYDataType>(saturates<QYDataType>{}(qy_));
140  });
141  store_tile(qy_window, qy);
142 
143  move_tile_window(x_window, {0, -Block_N});
144  move_tile_window(smscale_window, {0, -Block_N});
145  move_tile_window(qy_window, {0, -Block_N});
146  }
147  }
148 };
149 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition: tile_elementwise.hpp:40
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:95
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
constexpr CK_TILE_HOST_DEVICE void sweep_tile(const F &f, UnpacksPerXDim={})
Definition: sweep_tile.hpp:231
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1249
Definition: reduce_operator.hpp:91
Definition: reduce_operator.hpp:68
Definition: smoothquant_pipeline_two_pass.hpp:15
ck_tile::remove_cvref_t< Problem_ > Problem
Definition: smoothquant_pipeline_two_pass.hpp:16
ck_tile::remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition: smoothquant_pipeline_two_pass.hpp:21
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: smoothquant_pipeline_two_pass.hpp:37
static constexpr const char * name
Definition: smoothquant_pipeline_two_pass.hpp:30
CK_TILE_DEVICE auto operator()(const XWindow &x_window_, const SmoothScaleWindow &smscale_window_, YScaleWindow &yscale_window, QYWindow &qy_window, ck_tile::index_t row_size, void *smem) const
Definition: smoothquant_pipeline_two_pass.hpp:46
ck_tile::remove_cvref_t< Policy_ > Policy
Definition: smoothquant_pipeline_two_pass.hpp:17
ck_tile::remove_cvref_t< typename Problem::SmoothScaleDataType > SmoothScaleDataType
Definition: smoothquant_pipeline_two_pass.hpp:20
ck_tile::remove_cvref_t< typename Problem::QYDataType > QYDataType
Definition: smoothquant_pipeline_two_pass.hpp:22
static constexpr bool UseMax3
Definition: smoothquant_pipeline_two_pass.hpp:28
static constexpr bool kNeedCrossWarpSync
Definition: smoothquant_pipeline_two_pass.hpp:25
ck_tile::remove_cvref_t< typename Problem::YScaleDataType > YScaleDataType
Definition: smoothquant_pipeline_two_pass.hpp:23
static constexpr bool kPadN
Definition: smoothquant_pipeline_two_pass.hpp:27
ck_tile::remove_cvref_t< typename Problem::XDataType > XDataType
Definition: smoothquant_pipeline_two_pass.hpp:19
static constexpr bool kPadM
Definition: smoothquant_pipeline_two_pass.hpp:26
Definition: integral_constant.hpp:13
Definition: numeric.hpp:18
Definition: unary_element_function.hpp:56
Definition: sequence.hpp:49