/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/block/block_position_encoding.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/block/block_position_encoding.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/block/block_position_encoding.hpp Source File
block_position_encoding.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 #include <cmath>
9 #include <vector>
10 
11 namespace ck_tile {
12 
14 {
15  NO = 0,
16  ALIBI = 1,
17 };
18 
19 /*
20 VERTICAL:
21  [0] 1 2 3 4 5
22  [0] 1 2 3 4 5
23  [0] 1 2 3 4 5
24  [0] 1 2 3 4 5
25 
26 TOP_LEFT(but negative):
27  [0] 1 2 3 4 5
28  1 [0] 1 2 3 4
29  2 1 [0] 1 2 3
30  3 2 1 [0] 1 2
31 
32 FROM_BOTTOM_RIGHT(but negative):
33  2 1 [0] 1 2 3
34  3 2 1 [0] 1 2
35  4 3 2 1 [0] 1
36  5 4 3 2 1 [0]
37 */
38 
39 enum struct AlibiMode
40 {
41  VERTICAL = 0,
42  FROM_TOP_LEFT = 1, // keep sync with mask enum
44 };
45 
46 template <typename DataType, bool RowMajor = true, unsigned LogMaxSadOprndSize = 16>
47 struct Alibi
48 {
49  static_assert(1 <= LogMaxSadOprndSize && LogMaxSadOprndSize <= 32,
50  "for LogMaxSadOprndSize <= 16, we use SAD uint16_t, otherwise, use SAD uint32_t");
51 
52  // RowMajor here means if pixel within the same thread are along the row, or col
53  // this may impact the performance of update(), while the result are the same.
54  // e.g. fwd prefer use RowMajor=true, bwd some cases prefer use RowMajor=false
55  CK_TILE_HOST_DEVICE Alibi(DataType slope_,
56  index_t y_total_,
57  index_t x_total_,
59  {
60  slope = mode_ == AlibiMode::VERTICAL ? slope_ : -slope_;
61 
62  shift_left_up = [&]() {
63  if(RowMajor)
64  {
65  return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(y_total_ - x_total_, 0) : 0;
66  }
67  else
68  {
69  return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(x_total_ - y_total_, 0) : 0;
70  }
71  }();
72  shift_right_down = [&]() {
73  if(RowMajor)
74  {
75  return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(x_total_ - y_total_, 0) : 0;
76  }
77  else
78  {
79  return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(y_total_ - x_total_, 0) : 0;
80  }
81  }();
82  mode = mode_;
83  }
84 
85  CK_TILE_HOST uint32_t sad(uint32_t x, uint32_t y, uint32_t acc) { return sad_u32(x, y, acc); }
86 
88  {
89  if constexpr(LogMaxSadOprndSize <= 16)
90  {
91  return sad_u16(
92  static_cast<uint16_t>(x), static_cast<uint16_t>(y), static_cast<uint16_t>(acc));
93  }
94 
95  return sad_u32(x, y, acc);
96  }
97 
98  CK_TILE_HOST_DEVICE void update(DataType& pixel, index_t row_idx, index_t col_idx)
99  {
100  if constexpr(RowMajor)
101  {
102  // at least 3 instructions per row
103  index_t current_zero_point =
105 
106  // for every threads, most of the pixels are along the row, below operation should be
107  // the main hot spot.
108  auto position = type_convert<DataType>(sad(bit_cast<uint32_t>(current_zero_point),
109  bit_cast<uint32_t>(col_idx + shift_left_up),
110  0));
111  pixel += slope * position;
112  }
113  else
114  {
115  // at least 3 instructions per col;
116  index_t current_zero_point = mode == AlibiMode::VERTICAL
117  ? row_idx + col_idx + shift_right_down
118  : col_idx + shift_right_down;
119 
120  // for every threads, most of the pixels are along the col, below operation should be
121  // the main hot spot.
122  auto position = type_convert<DataType>(sad(bit_cast<uint32_t>(current_zero_point),
123  bit_cast<uint32_t>(row_idx + shift_left_up),
124  0));
125  pixel += slope * position;
126  }
127  }
128 
129  DataType slope; // float?
130  index_t shift_left_up; // always possitive
131  index_t shift_right_down; // always possitive
133 };
134 
135 template <typename DataType>
137 {
138  CK_TILE_HOST_DEVICE void update(DataType& /*pixel*/, index_t /*row_idx*/, index_t /*col_idx*/)
139  {
140  }
141 };
142 
143 //
144 // can convert from the FA style left/right to our generic coordinate
145 // if left_size < 0 && right_size = 0, it is normal causal mask
146 // local is left_size >=0 or right_size >=0
147 template <typename DataType, bool RowMajor = true, unsigned LogMaxSadOprndSize = 16>
149  index_t window_left_size,
150  index_t window_right_size,
151  index_t y_total,
152  index_t x_total,
153  GenericAttentionMaskEnum mask_enum)
154 {
155  // assume mask_enum will never be NO_MASK, since if we do not have mask, it's
156  // totally OK to use constexpr
157  bool is_causal = window_left_size < 0 && window_right_size == 0;
158  AlibiMode alibi_mode =
159  is_causal ? AlibiMode::VERTICAL
160  : static_cast<AlibiMode>(mask_enum) /*either top-left or bottom-right*/;
161  return Alibi<DataType, RowMajor, LogMaxSadOprndSize>{slope, y_total, x_total, alibi_mode};
162 }
163 
164 // https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
165 // Do we need a device version?
166 template <typename DataType>
167 CK_TILE_HOST std::vector<DataType> get_alibi_slopes(ck_tile::index_t nheads)
168 {
169  auto get_slopes_power_of_2 = [](ck_tile::index_t n) {
170  float start = std::powf(
171  static_cast<float>(2),
172  -std::powf(static_cast<float>(2), -static_cast<float>((integer_log2_floor(n) - 3))));
173 
174  std::vector<DataType> rtn;
175  for(auto i = 0; i < n; i++)
176  {
177  rtn.push_back(static_cast<DataType>(start * std::powf(start, i)));
178  }
179  return rtn;
180  };
181  if(is_power_of_two_integer(nheads))
182  {
183  // power of 2 calculation
184  return get_slopes_power_of_2(nheads);
185  }
186  else
187  {
188  ck_tile::index_t closest_power_of_2 = 1 << integer_log2_floor(nheads);
189  auto v0 = get_slopes_power_of_2(closest_power_of_2);
190  auto v1 = get_slopes_power_of_2(closest_power_of_2 * 2);
191  auto v1_sliced = [&](auto vec, ck_tile::index_t rem) {
192  std::vector<DataType> sliced;
193  for(ck_tile::index_t i = 0; i < static_cast<ck_tile::index_t>(vec.size()); i++)
194  {
195  if(i % 2 == 0)
196  sliced.push_back(vec[i]);
197  }
198  std::vector<DataType> sliced_2(sliced.begin(), sliced.begin() + rem);
199  return sliced_2;
200  }(v1, nheads - closest_power_of_2);
201  v0.insert(v0.end(), v1_sliced.begin(), v1_sliced.end());
202  return v0;
203  }
204 }
205 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE bool is_power_of_two_integer(int32_t x)
Definition: math.hpp:462
CK_TILE_HOST std::vector< DataType > get_alibi_slopes(ck_tile::index_t nheads)
Definition: block_position_encoding.hpp:167
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope, index_t window_left_size, index_t window_right_size, index_t y_total, index_t x_total, GenericAttentionMaskEnum mask_enum)
Definition: block_position_encoding.hpp:148
constexpr CK_TILE_HOST_DEVICE int32_t integer_log2_floor(int32_t x)
Definition: math.hpp:455
CK_TILE_DEVICE uint32_t sad_u32(uint32_t x, uint32_t y, uint32_t acc)
Definition: math.hpp:504
CK_TILE_DEVICE uint16_t sad_u16(uint16_t x, uint16_t y, uint16_t acc)
Definition: math.hpp:499
PositionEncodingEnum
Definition: block_position_encoding.hpp:14
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
AlibiMode
Definition: block_position_encoding.hpp:40
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
unsigned short uint16_t
Definition: stdint.h:125
unsigned int uint32_t
Definition: stdint.h:126
Definition: block_position_encoding.hpp:48
CK_TILE_HOST_DEVICE Alibi(DataType slope_, index_t y_total_, index_t x_total_, AlibiMode mode_=AlibiMode::VERTICAL)
Definition: block_position_encoding.hpp:55
AlibiMode mode
Definition: block_position_encoding.hpp:132
CK_TILE_DEVICE uint32_t sad(uint32_t x, uint32_t y, uint32_t acc)
Definition: block_position_encoding.hpp:87
index_t shift_right_down
Definition: block_position_encoding.hpp:131
CK_TILE_HOST uint32_t sad(uint32_t x, uint32_t y, uint32_t acc)
Definition: block_position_encoding.hpp:85
DataType slope
Definition: block_position_encoding.hpp:129
CK_TILE_HOST_DEVICE void update(DataType &pixel, index_t row_idx, index_t col_idx)
Definition: block_position_encoding.hpp:98
index_t shift_left_up
Definition: block_position_encoding.hpp:130
Definition: block_position_encoding.hpp:137
CK_TILE_HOST_DEVICE void update(DataType &, index_t, index_t)
Definition: block_position_encoding.hpp:138