/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/block/block_masking.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/block/block_masking.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/block/block_masking.hpp Source File
block_masking.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 
8 namespace ck_tile {
9 
11 {
12  NO_MASK = 0,
13 
14  // below enum could be causal, or sliding window
17 
18  // this enum maybe not used by xformer/FA, since it's hard to
19  // specify left/right window for varlen case. put it here for
20  // debug purpose
22 };
23 
24 // clang-format off
25 /* generic Attention Mask Coordinate
26  use x(horizontal axis), y(vertical axis) to describe mask.
27  top-left corner is origin
28 
29  x=1/y=5(top-left) x=4/y=5(botm-r) x=6/y=5 x=8/y=5(no mask)
30  1 * * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1
31  1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1
32  1 1 1 * * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
33  1 1 1 1 * * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
34  1 1 1 1 1 * * * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
35  l=7,-1/r=0(tl) l=7,-1/r=0(br)
36 
37  x=1/y=2 x=4/y=2 x=6/y=2 x=8/y=2
38  1 * * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1
39  1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1
40  * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1
41  * * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1
42  * * * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 * * * 1 1 1 1 1
43  l=1/r=0(tl) l=1/r=3(tl) l=1/r=5(tl) l=1/r=7(tl)
44  l=4/r=0(br) l=4/r=2(br) l=4/r=4(br)
45 
46  x=4/y=-1 x=6/y=-1 x=8/y=-1
47  * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1
48  * * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 1
49  * * * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1
50  * * * * * 1 1 * * * * * * 1 1 1 * * * * * 1 1 1
51  * * * * * * 1 1 * * * * * * 1 1 * * * * * * 1 1
52 
53  x=-2/y=5 x=1/y=5(top-left) x=0/y=5(botm-r)
54  * * * * * * * * 1 * * * * * * *
55  * * * * * * * * 1 1 * * 1 * * *
56  * * * * * * * * 1 1 1 * 1 1 * *
57  1 * * * * * * * 1 1 1 1 1 1 1 *
58  1 1 * * * * * * 1 1 1 1 1 1 1 1
59 
60  Validations:
61  x + y > 1 (x + y >= 2)
62 
63  Note:
64  y = seq_q, x = 1 -> top-left
65  y = seq_q, x = seq_k - seq_q + 1 -> bottom-right
66  y < seq_q, x < seq_k -> local-attn
67  y = seq_q, x = seq_k -> no mask
68 
69 */
70 namespace impl {
71  template <bool IsMasking_, bool IsLocal_> struct MaskName;
72  template<> struct MaskName<false, false> { static constexpr const char * name = "mn"; };
73  template<> struct MaskName<false, true> { static constexpr const char * name = "mn"; };
74  template<> struct MaskName<true, false> { static constexpr const char * name = "mc"; };
75  template<> struct MaskName<true, true> { static constexpr const char * name = "mg"; };
76 }
77 // clang-format on
78 
79 template <bool IsMasking_ = true, bool IsLocal_ = false>
81 {
82  static constexpr bool IsMasking = IsMasking_; // false will disable masking
83  static constexpr bool IsLocal = IsLocal_; // if true, upper/lower area could have mask,
84  // else only upper-right could have mask
85 
86  static constexpr const char* name = impl::MaskName<IsMasking, IsLocal>::name;
87 
89  : GenericAttentionMask(0, 0, 0, y_total_, x_total_)
90  {
91  }
92 
94  GenericAttentionMask(index_t y_, index_t x_, index_t sink_, index_t y_total_, index_t x_total_)
95  : y(y_), x(x_), sink(sink_), y_total(y_total_), x_total(x_total_)
96  {
97  }
98  template <typename MaskCoordinates>
99  CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord)
100  : y(mask_coord.at(number<0>{})),
101  x(mask_coord.at(number<1>{})),
102  sink(mask_coord.at(number<2>{})),
103  y_total(mask_coord.at(number<3>{})),
104  x_total(mask_coord.at(number<4>{}))
105  {
106  }
107 
108  // to get the loop length along X axis, return index:[start, end), end-start=length
109  // use this if need loop over X axis tile by tile (like k-seqlen loopover)
110  // TODO: x_end still could be negative, so end-start could be negative(need check)
111  template <index_t YTile, index_t XTile>
112  CK_TILE_HOST_DEVICE constexpr auto
114  {
115  if constexpr(!IsMasking)
116  {
117  return ck_tile::make_tuple(0, x_total);
118  }
119  else
120  {
121  // get the tile start/end range assum we loop over along X tile by tile
122  index_t x_start = [&]() {
123  if constexpr(IsLocal)
124  {
125  index_t tmp = max(-y + i_y + 1, 0);
126  return (tmp / XTile) * XTile; // round to tile aligned
127  }
128  else
129  {
130  return 0;
131  }
132  }();
133 
134  // TODO: end could be negative, we ignore clamp here, and let caller to check
135  // ... in which case end-start is negative
136  index_t x_end = [&]() {
137  index_t tmp = min(i_y + YTile - 1 + x, x_total);
138  return ((tmp + XTile - 1) / XTile) * XTile;
139  }();
140 
141  return ck_tile::make_tuple(x_start, x_end);
142  }
143  }
144 
145  template <index_t YTile, index_t XTile>
146  CK_TILE_HOST_DEVICE constexpr auto
148  {
149  if constexpr(!IsMasking)
150  {
151  return ck_tile::make_tuple(0, 0, x_total);
152  }
153  else
154  {
155  // get the tile start/end range assum we loop over along X tile by tile
156  index_t x_start = [&]() {
157  if constexpr(IsLocal)
158  {
159  index_t tmp = max(-y + i_y + 1, 0);
160  return (tmp / XTile) * XTile; // round to tile aligned
161  }
162  else
163  {
164  return 0;
165  }
166  }();
167 
168  // TODO: end could be negative, we ignore clamp here, and let caller to check
169  // ... in which case end-start is negative
170  index_t x_end = [&]() {
171  index_t tmp = min(i_y + YTile - 1 + x, x_total);
172  return ((tmp + XTile - 1) / XTile) * XTile;
173  }();
174 
175  index_t sink_seq_end = sink > 0 ? ((sink + XTile - 1) / XTile) * XTile : 0;
176  if(x_start <= sink_seq_end && sink > 0)
177  return ck_tile::make_tuple(0, 0, x_end);
178  else
179  return ck_tile::make_tuple(sink_seq_end, x_start, x_end);
180  }
181  }
182 
183  // to get the loop length along Y axis, return index:[start, end), end-start=length
184  // use this if need loop over Y axis tile by tile (like q-seqlen loopover)
185  // TODO: y_end still could be negative, so end-start could be negative(need check)
186  template <index_t YTile, index_t XTile>
187  CK_TILE_HOST_DEVICE constexpr auto
189  {
190  if constexpr(!IsMasking)
191  {
192  return ck_tile::make_tuple(0, y_total);
193  }
194  else
195  {
196  // get the tile start/end range assum we loop over along Y tile by tile
197  index_t y_start = [&]() {
198  index_t tmp = max(-x + i_x + 1, 0);
199  return (tmp / YTile) * YTile; // round to tile aligned
200  }();
201 
202  // TODO: end could be negative, we ignore clamp here, and let caller to check
203  // ... in which case end-start is negative
204  index_t y_end = [&]() {
205  index_t tmp = min(i_x + XTile - 1 + y, y_total);
206  return ((tmp + YTile - 1) / YTile) * YTile;
207  }();
208 
209  return ck_tile::make_tuple(y_start, y_end);
210  }
211  }
212 
213  // per-pixel check if out-of-bound, if true, need mask a value(like -INF)
214  CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
215  {
216  if constexpr(!IsMasking)
217  {
218  return i_x >= x_total;
219  }
220  else
221  {
222  // no need to do min/max here, since i_x will never be < 0 or >= x_total
223  index_t x_start = -y + i_y + 1;
224  index_t x_end = min(i_y + x, x_total);
225 
226  if constexpr(IsLocal)
227  {
228  return i_x < x_start || i_x >= x_end;
229  }
230  else
231  {
232  return i_x >= x_end || i_y >= y_total;
233  }
234  }
235  }
236 
237  CK_TILE_HOST_DEVICE constexpr auto IsOutOfSinkBound(index_t i_y, index_t i_x) const
238  {
239  if constexpr(!IsMasking)
240  return i_x >= x_total;
241  // no need to do min/max here, since i_x will never be < 0 or >= x_total
242  index_t x_start = -y + i_y + 1;
243  index_t x_end = min(i_y + x, x_total);
244 
245  if constexpr(IsLocal)
246  {
247  if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total)
248  return false;
249  else
250  return i_x < x_start || i_x >= x_end;
251  }
252  else
253  {
254  if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total)
255  return false;
256  else
257  return i_x >= x_end || i_y >= y_total;
258  }
259  }
260 
261  // if current tile is at the edge, means need per-pixel mask check.
262  // otherwise no need to check per-pixel
263  // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
264  // can be used as a fast-path to decide if do per-pixel check or not
265  template <index_t TileHeight, index_t TileWidth>
266  CK_TILE_HOST_DEVICE constexpr auto
268  {
269  if constexpr(!IsMasking)
270  {
271  // TODO: no need to check begin
272  return (i_tile_left + TileWidth) > x_total;
273  }
274  else
275  {
276  if constexpr(IsLocal)
277  {
278  // check top-right corner > x or left-borrom corner < x
279  index_t i_tile_right = i_tile_left + TileWidth;
280  index_t i_tile_bottom = i_tile_top + TileHeight;
281  index_t x_end = min(i_tile_top + x, x_total);
282 
283  bool top_right_edge = i_tile_right > (i_tile_top + x);
284  bool bottom_left_edge = i_tile_bottom > (i_tile_left + y);
285  bool is_partial_out_of_bound =
286  i_tile_right > x_end; // only consider right-pad for now
287 
288  return top_right_edge || bottom_left_edge || is_partial_out_of_bound;
289  }
290  else
291  {
292  // only need to check top-right corner > x
293  index_t i_tile_right = i_tile_left + TileWidth;
294  index_t x_end = min(i_tile_top + x, x_total);
295 
296  bool top_right_edge = i_tile_right > x_end;
297  return top_right_edge;
298  }
299  }
300  }
301 
302  private:
303  index_t y, x, sink;
304  index_t y_total, x_total;
305 };
306 
307 // clang-format off
308 namespace impl {
309  template <bool IsMasking_> struct SimplifiedMaskName;
310  template<> struct SimplifiedMaskName<false> { static constexpr const char * name = "nomask"; };
311  template<> struct SimplifiedMaskName<true> { static constexpr const char * name = "mask"; };
312 }
313 // clang-format on
314 
315 // this version only have 2 variation: masking and non-masking
316 // This is more friendly to codegen (e.g. need generate less kernel)
317 // ... with the trade-off that may have more instruction in causal mode
318 template <bool IsMasking_ = true>
320 {
321  static constexpr bool IsMasking = IsMasking_; // false will disable masking
322 
323  static constexpr const char* name = impl::SimplifiedMaskName<IsMasking>::name;
324 
326  : SimplifiedGenericAttentionMask(0, 0, 0, y_total_, x_total_)
327  {
328  }
329 
332  index_t y_, index_t x_, index_t sink_, index_t y_total_, index_t x_total_)
333  : y(y_), x(x_), sink(sink_), y_total(y_total_), x_total(x_total_)
334  {
335  }
336  template <typename MaskCoordinates>
337  CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(const MaskCoordinates& mask_coord)
338  : y(mask_coord.at(number<0>{})),
339  x(mask_coord.at(number<1>{})),
340  sink(mask_coord.at(number<2>{})),
341  y_total(mask_coord.at(number<3>{})),
342  x_total(mask_coord.at(number<4>{}))
343  {
344  }
345 
346  // to get the loop length along X axis, return index:[start, end), end-start=length
347  // use this if need loop over X axis tile by tile (like k-seqlen loopover)
348  // TODO: x_end still could be negative, so end-start could be negative(need check)
349  template <index_t YTile, index_t XTile>
350  CK_TILE_HOST_DEVICE constexpr auto
352  {
353  if constexpr(!IsMasking)
354  {
355  return ck_tile::make_tuple(0, x_total);
356  }
357  else
358  {
359  // get the tile start/end range assum we loop over along X tile by tile
360  index_t x_start = [&]() {
361  index_t tmp = max(-y + i_y + 1, 0);
362  return (tmp / XTile) * XTile; // round to tile aligned
363  }();
364 
365  // TODO: end could be negative, we ignore clamp here, and let caller to check
366  // ... in which case end-start is negative
367  index_t x_end = [&]() {
368  index_t tmp = min(i_y + YTile - 1 + x, x_total);
369  return ((tmp + XTile - 1) / XTile) * XTile;
370  }();
371 
372  return ck_tile::make_tuple(x_start, x_end);
373  }
374  }
375 
376  template <index_t YTile, index_t XTile>
377  CK_TILE_HOST_DEVICE constexpr auto
379  {
380  if constexpr(!IsMasking)
381  {
382  return ck_tile::make_tuple(0, 0, x_total);
383  }
384  else
385  {
386  // get the tile start/end range assum we loop over along X tile by tile
387  index_t x_start = [&]() {
388  index_t tmp = max(-y + i_y + 1, 0);
389  return (tmp / XTile) * XTile; // round to tile aligned
390  }();
391 
392  // TODO: end could be negative, we ignore clamp here, and let caller to check
393  // ... in which case end-start is negative
394  index_t x_end = [&]() {
395  index_t tmp = min(i_y + YTile - 1 + x, x_total);
396  return ((tmp + XTile - 1) / XTile) * XTile;
397  }();
398 
399  index_t sink_seq_end = sink > 0 ? ((sink + XTile - 1) / XTile) * XTile : 0;
400 
401  if(x_start <= sink_seq_end && sink > 0)
402  return ck_tile::make_tuple(0, 0, x_end);
403  else
404  return ck_tile::make_tuple(sink_seq_end, x_start, x_end);
405  }
406  }
407 
408  template <index_t TileHeight, index_t TileWidth>
410  number<TileHeight> height,
411  number<TileWidth> width,
412  index_t num_splits,
413  index_t i_split) const
414  {
415  auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width);
416 
417  const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits));
418  const index_t split_start = x_per_split * i_split;
419  const index_t split_end = ck_tile::min(x_total, split_start + x_per_split);
420 
421  return ck_tile::make_tuple(ck_tile::max(origin_start, split_start),
422  ck_tile::min(origin_end, split_end));
423  }
424 
425  template <index_t TileHeight, index_t TileWidth>
427  number<TileHeight> height,
428  number<TileWidth> width,
429  index_t num_splits,
430  index_t i_split) const
431  {
432  auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width);
433  const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits));
434  const index_t split_start = x_per_split * i_split; // 128
435  const index_t split_end = ck_tile::min(x_total, split_start + x_per_split); // 256
436  const index_t sink_seq_end = sink > 0 ? ((sink + width - 1) / width) * width : 0;
437  const index_t start = ck_tile::max(origin_start, split_start);
438  const index_t end = ck_tile::min(origin_end, split_end);
439  const bool is_first_intersecting_split =
440  (split_start <= origin_start && split_end >= origin_start);
441  const bool sink_in_range = (sink_seq_end <= start);
442 
443  const index_t sink_offset =
444  (is_first_intersecting_split && sink_in_range) ? sink_seq_end : 0;
445  return ck_tile::make_tuple(sink_offset, start, end);
446  }
447 
448  // to get the loop length along Y axis, return index:[start, end), end-start=length
449  // use this if need loop over Y axis tile by tile (like q-seqlen loopover)
450  // TODO: y_end still could be negative, so end-start could be negative(need check)
451  template <index_t YTile, index_t XTile>
452  CK_TILE_HOST_DEVICE constexpr auto
454  {
455  if constexpr(!IsMasking)
456  {
457  return ck_tile::make_tuple(0, y_total);
458  }
459  else
460  {
461  // get the tile start/end range assum we loop over along Y tile by tile
462  index_t y_start = [&]() {
463  index_t tmp = max(-x + i_x + 1, 0);
464  return (tmp / YTile) * YTile; // round to tile aligned
465  }();
466 
467  // TODO: end could be negative, we ignore clamp here, and let caller to check
468  // ... in which case end-start is negative
469  index_t y_end = [&]() {
470  index_t tmp = min(i_x + XTile - 1 + y, y_total);
471  return ((tmp + YTile - 1) / YTile) * YTile;
472  }();
473 
474  return ck_tile::make_tuple(y_start, y_end);
475  }
476  }
477 
478  // per-pixel check if out-of-bound, if true, need mask a value(like -INF)
479  CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
480  {
481  if constexpr(!IsMasking)
482  {
483  // the only case that need do following compare is under kPadSeqLenK
484  // ... for non-masking kernel.
485  return i_x >= x_total;
486  }
487  else
488  {
489  index_t x_start = -y + i_y + 1; // this could be negative, but it's fine
490  index_t x_end = min(i_y + x, x_total); // need min in case x is padded
491  return i_x < x_start || i_x >= x_end || i_y >= y_total;
492  }
493  }
494 
495  CK_TILE_HOST_DEVICE constexpr auto IsOutOfSinkBound(index_t i_y, index_t i_x) const
496  {
497  if constexpr(!IsMasking)
498  return i_x >= x_total;
499  index_t x_start = -y + i_y + 1; // this could be negative, but it's fine
500  index_t x_end = min(i_y + x, x_total); // need min in case x is padded
501  if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total)
502  return false;
503  else
504  return i_x < x_start || i_x >= x_end || i_y >= y_total;
505  }
506 
507  // if current tile is at the edge, means need per-pixel mask check.
508  // otherwise no need to check per-pixel
509  // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
510  // can be used as a fast-path to decide if do per-pixel check or not
511  template <index_t TileHeight, index_t TileWidth>
512  CK_TILE_HOST_DEVICE constexpr auto
514  {
515  if constexpr(!IsMasking)
516  {
517  // the only case that need do following compare is under kPadSeqLenK
518  // ... for non-masking kernel.
519  // return (i_x < x_total) && ((i_x + TileWidth) > x_total);
520 
521  // TODO: no need to check begin
522  return (i_x + TileWidth) > x_total;
523  }
524  else
525  {
526  // check top-right corner > x or left-borrom corner < x
527  index_t i_x_end = i_x + TileWidth;
528  index_t i_y_end = i_y + TileHeight;
529  // index_t x_end = min(i_y + x, x_total);
530 
531  bool top_right_edge = i_x_end > min(i_y + x, x_total); // consider right pad
532  bool bottom_left_edge = i_y_end > min(i_x + y, y_total); // consider bottom pad
533  // bool is_partial_out_of_bound = i_x_end > x_end; // only consider right-pad for now
534 
535  return top_right_edge || bottom_left_edge;
536  }
537  }
538 
539  private:
540  index_t y, x, sink;
541  index_t y_total, x_total;
542 };
543 
544 // clang-format off
545 namespace impl {
546  template <bool IsMasking_> struct SimplifiedRatioMaskName;
547  template<> struct SimplifiedRatioMaskName<false> { static constexpr const char * name = "nomask"; };
548  template<> struct SimplifiedRatioMaskName<true> { static constexpr const char * name = "mask"; };
549 }
550 // clang-format on
551 
552 // this version is used for cases that the step length of y-direction changes greater than one. It
553 // means that the mask is not a regular triangular matrix.
554 
555 // clang-format off
556 /* y_ratio is used to describe the step length of y-direction changes
557  in certain performance optimization scenarios like merging seqlen
558  and qk_head_ratio, for example:
559 
560  x=1/y=6/y_ratio=2(top-left)
561  1 * * * * * * *
562  1 * * * * * * *
563  1 1 * * * * * *
564  1 1 * * * * * *
565  1 1 1 * * * * *
566  1 1 1 * * * * *
567 
568 */
569 // clang-format on
570 template <bool IsMasking_ = true>
572 {
573  static constexpr bool IsMasking = IsMasking_; // false will disable masking
574 
575  static constexpr const char* name = impl::SimplifiedRatioMaskName<IsMasking>::name;
576 
578  : SimplifiedRatioAttentionMask(0, 0, y_total_, x_total_, 0, 1, mdiv{})
579  {
580  }
581 
584  index_t y_real_, index_t x_, index_t y_total_, index_t x_total_, mdiv y_ratio_mdiv_)
585  : SimplifiedRatioAttentionMask(/*y_=*/y_real_ * static_cast<index_t>(y_ratio_mdiv_.get()),
586  /*x_=*/x_,
587  /*y_total_=*/y_total_,
588  /*x_total_=*/x_total_,
589  /*y_real_=*/y_real_,
590  /*y_ratio_=*/static_cast<index_t>(y_ratio_mdiv_.get()),
591  /*y_ratio_mdiv_=*/y_ratio_mdiv_)
592 
593  {
594  }
597  index_t x_,
598  index_t y_total_,
599  index_t x_total_,
600  index_t y_real_,
601  index_t y_ratio_,
602  mdiv y_ratio_mdiv_)
603  : y(y_),
604  x(x_),
605  y_total(y_total_),
606  x_total(x_total_),
607  y_real(y_real_),
608  y_ratio(y_ratio_),
609  y_ratio_mdiv(y_ratio_mdiv_)
610  {
611  }
612 
613  // to get the loop length along X axis, return index:[start, end), end-start=length
614  // use this if need loop over X axis tile by tile (like k-seqlen loopover)
615  // TODO: x_end still could be negative, so end-start could be negative(need check)
616  template <index_t YTile, index_t XTile>
617  CK_TILE_HOST_DEVICE constexpr auto
619  {
620  if constexpr(!IsMasking)
621  {
622  return ck_tile::make_tuple(0, x_total);
623  }
624  else
625  {
626  // get the tile start/end range assum we loop over along X tile by tile
627  index_t x_start = [&]() {
628  index_t tmp = -y_real +
629  static_cast<index_t>(y_ratio_mdiv.div(static_cast<uint32_t>(i_y))) +
630  1;
631 
632  return (tmp / XTile) * XTile; // round to tile aligned
633  }();
634 
635  // TODO: end could be negative, we ignore clamp here, and let caller to check
636  // ... in which case end-start is negative
637  index_t x_end = [&]() {
638  uint32_t y_offset = i_y + YTile - 1;
639  index_t tmp = min(static_cast<index_t>(y_ratio_mdiv.div(y_offset)) + x, x_total);
640  return ((tmp + XTile - 1) / XTile) * XTile;
641  }();
642 
643  return ck_tile::make_tuple(x_start, x_end);
644  }
645  }
646 
647  // to get the loop length along Y axis, return index:[start, end), end-start=length
648  // use this if need loop over Y axis tile by tile (like q-seqlen loopover)
649  // TODO: y_end still could be negative, so end-start could be negative(need check)
650  template <index_t YTile, index_t XTile>
651  CK_TILE_HOST_DEVICE constexpr auto
653  {
654  if constexpr(!IsMasking)
655  {
656  return ck_tile::make_tuple(0, y_total);
657  }
658  else
659  {
660  // get the tile start/end range assum we loop over along Y tile by tile
661  index_t y_start = [&]() {
662  index_t tmp = max((-x + i_x + 1) * y_ratio, 0);
663  return (tmp / YTile) * YTile; // round to tile aligned
664  }();
665 
666  // TODO: end could be negative, we ignore clamp here, and let caller to check
667  // ... in which case end-start is negative
668  index_t y_end = [&]() {
669  index_t tmp = min((i_x + XTile - 1) * y_ratio + y, y_total);
670  return ((tmp + YTile - 1) / YTile) * YTile;
671  }();
672 
673  return ck_tile::make_tuple(y_start, y_end);
674  }
675  }
676 
677  // per-pixel check if out-of-bound, if true, need mask a value(like -INF)
678  CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
679  {
680  if constexpr(!IsMasking)
681  {
682  return i_x >= x_total;
683  }
684  else
685  {
686  index_t x_tmp = static_cast<index_t>(y_ratio_mdiv.div(static_cast<uint32_t>(i_y)));
687  index_t x_start = -y_real + x_tmp + 1;
688  index_t x_end = min(x_tmp + x,
689  x_total); // need min in case x is padded
690  return i_x < x_start || i_x >= x_end || i_y >= y_total;
691  }
692  }
693 
694  // if current tile is at the edge, means need per-pixel mask check.
695  // otherwise no need to check per-pixel
696  // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
697  // can be used as a fast-path to decide if do per-pixel check or not
698  template <index_t TileHeight, index_t TileWidth>
699  CK_TILE_HOST_DEVICE constexpr auto
701  {
702  if constexpr(!IsMasking)
703  {
704  // the only case that need do following compare is under kPadSeqLenK
705  // ... for non-masking kernel.
706  // return (i_x < x_total) && ((i_x + TileWidth) > x_total);
707 
708  return (i_x + TileWidth) > x_total;
709  }
710  else
711  {
712  // check top-right corner > x or left-borrom corner < x
713  index_t i_x_end = i_x + TileWidth;
714  index_t i_y_end = i_y + TileHeight;
715  // index_t x_end = min(i_y + x, x_total);
716  uint32_t y_tmp = static_cast<uint32_t>(i_y);
717  bool top_right_edge = i_x_end > min(static_cast<index_t>(y_ratio_mdiv.div(y_tmp)) + x,
718  x_total); // consider right pad
719  bool bottom_left_edge =
720  i_y_end > min(i_x * y_ratio + y, y_total); // consider bottom pad
721  return top_right_edge || bottom_left_edge;
722  }
723  }
724 
725  private:
726  index_t y, x;
727  index_t y_total, x_total;
728  // y_real is vertical axis before multiplying y_ratio. y_real * y_ratio = y
729  index_t y_real;
730  index_t y_ratio;
731  mdiv y_ratio_mdiv;
732 };
733 
734 template <typename>
736 {
737 };
738 
739 template <bool IsMasking, bool IsLocal>
741 {
742 };
743 
744 template <typename Mask>
745 static constexpr bool is_generic_attention_mask_v = is_generic_attention_mask<Mask>::value;
746 
747 // TODO: prefer use this function in host code
748 // can convert from the FA style left/right to our generic coordinate
749 // if left_size < 0 && right_size = 0, it is normal causal mask
750 // local is left_size >=0 or right_size >=0
751 CK_TILE_HOST_DEVICE constexpr auto
753  index_t right_size,
754  index_t sink_size,
755  index_t y_total,
756  index_t x_total,
757  bool is_top_left = true)
758 {
759  // TODO: below should all use sgpr arithmetic
760  index_t left_size_tmp = is_top_left ? y_total - 1 : x_total - 1;
761  index_t right_size_tmp = is_top_left ? x_total - 1 : y_total - 1;
762 
763  left_size = left_size < 0 ? left_size_tmp : left_size;
764  right_size = right_size < 0 ? right_size_tmp : right_size;
765 
766  index_t x_tmp = is_top_left ? 0 : x_total - y_total;
767  index_t y_tmp = is_top_left ? 0 : y_total - x_total;
768 
769  index_t x = 1 + right_size + x_tmp;
770  index_t y = 1 + left_size + y_tmp;
771 
772  return ck_tile::make_tuple(y, x, sink_size, y_total, x_total);
773 }
774 
775 template <typename MaskType>
776 CK_TILE_HOST_DEVICE constexpr auto
778  index_t right_size,
779  index_t sink_size,
780  index_t y_total,
781  index_t x_total,
782  bool is_top_left = true)
783 {
785  left_size, right_size, sink_size, y_total, x_total, is_top_left);
786  return MaskType{r.at(number<0>{}), r.at(number<1>{}), sink_size, y_total, x_total};
787 }
788 
789 template <typename MaskType>
790 CK_TILE_HOST_DEVICE constexpr auto
792  index_t right_size,
793  index_t y_total,
794  index_t x_total,
795  bool is_top_left = true)
796 {
798  left_size, right_size, 0, y_total, x_total, is_top_left);
799  return MaskType{r.at(number<0>{}), r.at(number<1>{}), 0, y_total, x_total};
800 }
801 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:145
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_generic_attention_mask_from_lr_window(index_t left_size, index_t right_size, index_t sink_size, index_t y_total, index_t x_total, bool is_top_left=true)
Definition: block_masking.hpp:777
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
constexpr CK_TILE_HOST_DEVICE auto make_generic_attention_mask_coordinates_from_lr_window(index_t left_size, index_t right_size, index_t sink_size, index_t y_total, index_t x_total, bool is_top_left=true)
Definition: block_masking.hpp:752
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:206
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:157
bool_constant< false > false_type
Definition: integral_constant.hpp:63
bool_constant< true > true_type
Definition: integral_constant.hpp:62
unsigned int uint32_t
Definition: stdint.h:126
Definition: block_masking.hpp:81
constexpr CK_TILE_HOST_DEVICE auto GetSinkTileRangeAlongX(index_t i_y, number< YTile >, number< XTile >) const
Definition: block_masking.hpp:147
constexpr CK_TILE_HOST_DEVICE auto IsEdgeTile(index_t i_tile_top, index_t i_tile_left, number< TileHeight >, number< TileWidth >) const
Definition: block_masking.hpp:267
constexpr CK_TILE_HOST_DEVICE auto GetTileRangeAlongX(index_t i_y, number< YTile >, number< XTile >) const
Definition: block_masking.hpp:113
constexpr CK_TILE_HOST_DEVICE auto IsOutOfBound(index_t i_y, index_t i_x) const
Definition: block_masking.hpp:214
CK_TILE_HOST_DEVICE GenericAttentionMask(index_t y_total_, index_t x_total_)
Definition: block_masking.hpp:88
CK_TILE_HOST_DEVICE GenericAttentionMask(index_t y_, index_t x_, index_t sink_, index_t y_total_, index_t x_total_)
Definition: block_masking.hpp:94
static constexpr const char * name
Definition: block_masking.hpp:86
constexpr CK_TILE_HOST_DEVICE auto GetTileRangeAlongY(index_t i_x, number< YTile >, number< XTile >) const
Definition: block_masking.hpp:188
constexpr CK_TILE_HOST_DEVICE auto IsOutOfSinkBound(index_t i_y, index_t i_x) const
Definition: block_masking.hpp:237
CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates &mask_coord)
Definition: block_masking.hpp:99
static constexpr bool IsMasking
Definition: block_masking.hpp:82
static constexpr bool IsLocal
Definition: block_masking.hpp:83
Definition: block_masking.hpp:320
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(index_t y_, index_t x_, index_t sink_, index_t y_total_, index_t x_total_)
Definition: block_masking.hpp:331
constexpr CK_TILE_HOST_DEVICE auto IsEdgeTile(index_t i_y, index_t i_x, number< TileHeight >, number< TileWidth >) const
Definition: block_masking.hpp:513
constexpr CK_TILE_HOST_DEVICE auto GetTileRangeAlongX(index_t i_y, number< YTile >, number< XTile >) const
Definition: block_masking.hpp:351
constexpr CK_TILE_HOST_DEVICE auto IsOutOfBound(index_t i_y, index_t i_x) const
Definition: block_masking.hpp:479
constexpr CK_TILE_HOST_DEVICE auto IsOutOfSinkBound(index_t i_y, index_t i_x) const
Definition: block_masking.hpp:495
static constexpr const char * name
Definition: block_masking.hpp:323
constexpr CK_TILE_HOST_DEVICE auto GetTileRangeAlongY(index_t i_x, number< YTile >, number< XTile >) const
Definition: block_masking.hpp:453
constexpr CK_TILE_HOST_DEVICE auto GetTileRangeAlongX(index_t i_y, number< TileHeight > height, number< TileWidth > width, index_t num_splits, index_t i_split) const
Definition: block_masking.hpp:409
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(const MaskCoordinates &mask_coord)
Definition: block_masking.hpp:337
static constexpr bool IsMasking
Definition: block_masking.hpp:321
constexpr CK_TILE_HOST_DEVICE auto GetSinkTileRangeAlongX(index_t i_y, number< TileHeight > height, number< TileWidth > width, index_t num_splits, index_t i_split) const
Definition: block_masking.hpp:426
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(index_t y_total_, index_t x_total_)
Definition: block_masking.hpp:325
constexpr CK_TILE_HOST_DEVICE auto GetSinkTileRangeAlongX(index_t i_y, number< YTile >, number< XTile >) const
Definition: block_masking.hpp:378
Definition: block_masking.hpp:572
CK_TILE_HOST_DEVICE SimplifiedRatioAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_, index_t y_real_, index_t y_ratio_, mdiv y_ratio_mdiv_)
Definition: block_masking.hpp:596
static constexpr const char * name
Definition: block_masking.hpp:575
CK_TILE_HOST_DEVICE SimplifiedRatioAttentionMask(index_t y_real_, index_t x_, index_t y_total_, index_t x_total_, mdiv y_ratio_mdiv_)
Definition: block_masking.hpp:583
constexpr CK_TILE_HOST_DEVICE auto GetTileRangeAlongX(index_t i_y, number< YTile >, number< XTile >) const
Definition: block_masking.hpp:618
CK_TILE_HOST_DEVICE SimplifiedRatioAttentionMask(index_t y_total_, index_t x_total_)
Definition: block_masking.hpp:577
constexpr CK_TILE_HOST_DEVICE auto GetTileRangeAlongY(index_t i_x, number< YTile >, number< XTile >) const
Definition: block_masking.hpp:652
constexpr CK_TILE_HOST_DEVICE auto IsEdgeTile(index_t i_y, index_t i_x, number< TileHeight >, number< TileWidth >) const
Definition: block_masking.hpp:700
static constexpr bool IsMasking
Definition: block_masking.hpp:573
constexpr CK_TILE_HOST_DEVICE auto IsOutOfBound(index_t i_y, index_t i_x) const
Definition: block_masking.hpp:678
Definition: integral_constant.hpp:13
Definition: block_masking.hpp:71
Definition: block_masking.hpp:309
Definition: block_masking.hpp:546
Definition: block_masking.hpp:736
Definition: magic_div.hpp:186
CK_TILE_HOST_DEVICE uint32_t div(uint32_t dividend_) const
Definition: magic_div.hpp:212