/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/block/block_masking.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/block/block_masking.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/block/block_masking.hpp Source File
block_masking.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 
8 namespace ck_tile {
9 
11 {
12  NO_MASK = 0,
13 
14  // below enum could be causal, or sliding window
17 
18  // this enum maybe not used by xformer/FA, since it's hard to
19  // specify left/right window for varlen case. put it here for
20  // debug purpose
22 };
23 
24 // clang-format off
25 /* generic Attention Mask Coordinate
26  use x(horizontal axis), y(vertical axis) to describe mask.
27  top-left corner is origin
28 
29  x=1/y=5(top-left) x=4/y=5(botm-r) x=6/y=5 x=8/y=5(no mask)
30  1 * * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1
31  1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1
32  1 1 1 * * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
33  1 1 1 1 * * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
34  1 1 1 1 1 * * * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
35  l=7,-1/r=0(tl) l=7,-1/r=0(br)
36 
37  x=1/y=2 x=4/y=2 x=6/y=2 x=8/y=2
38  1 * * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1
39  1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1
40  * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1
41  * * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1
42  * * * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 * * * 1 1 1 1 1
43  l=1/r=0(tl) l=1/r=3(tl) l=1/r=5(tl) l=1/r=7(tl)
44  l=4/r=0(br) l=4/r=2(br) l=4/r=4(br)
45 
46  x=4/y=-1 x=6/y=-1 x=8/y=-1
47  * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1
48  * * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 1
49  * * * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1
50  * * * * * 1 1 * * * * * * 1 1 1 * * * * * 1 1 1
51  * * * * * * 1 1 * * * * * * 1 1 * * * * * * 1 1
52 
53  x=-2/y=5 x=1/y=5(top-left) x=0/y=5(botm-r)
54  * * * * * * * * 1 * * * * * * *
55  * * * * * * * * 1 1 * * 1 * * *
56  * * * * * * * * 1 1 1 * 1 1 * *
57  1 * * * * * * * 1 1 1 1 1 1 1 *
58  1 1 * * * * * * 1 1 1 1 1 1 1 1
59 
60  Validations:
61  x + y > 1 (x + y >= 2)
62 
63  Note:
64  y = seq_q, x = 1 -> top-left
65  y = seq_q, x = seq_k - seq_q + 1 -> bottom-right
66  y < seq_q, x < seq_k -> local-attn
67  y = seq_q, x = seq_k -> no mask
68 
69 */
70 namespace impl {
71  template <bool IsMasking_, bool IsLocal_> struct MaskName;
72  template<> struct MaskName<false, false> { static constexpr const char * name = "mn"; };
73  template<> struct MaskName<false, true> { static constexpr const char * name = "mn"; };
74  template<> struct MaskName<true, false> { static constexpr const char * name = "mc"; };
75  template<> struct MaskName<true, true> { static constexpr const char * name = "mg"; };
76 }
77 // clang-format on
78 
79 template <bool IsMasking_ = true, bool IsLocal_ = false>
81 {
82  static constexpr bool IsMasking = IsMasking_; // false will disable masking
83  static constexpr bool IsLocal = IsLocal_; // if true, upper/lower area could have mask,
84  // else only upper-right could have mask
85 
86  static constexpr const char* name = impl::MaskName<IsMasking, IsLocal>::name;
87 
89  : GenericAttentionMask(0, 0, y_total_, x_total_)
90  {
91  }
92 
94  GenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_)
95  : y(y_), x(x_), y_total(y_total_), x_total(x_total_)
96  {
97  }
98  template <typename MaskCoordinates>
99  CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord)
100  : y(mask_coord.at(number<0>{})),
101  x(mask_coord.at(number<1>{})),
102  y_total(mask_coord.at(number<2>{})),
103  x_total(mask_coord.at(number<3>{}))
104  {
105  }
106 
107  // to get the loop length along X axis, return index:[start, end), end-start=length
108  // use this if need loop over X axis tile by tile (like k-seqlen loopover)
109  // TODO: x_end still could be negative, so end-start could be negative(need check)
110  template <index_t YTile, index_t XTile>
111  CK_TILE_HOST_DEVICE constexpr auto
113  {
114  if constexpr(!IsMasking)
115  {
116  return ck_tile::make_tuple(0, x_total);
117  }
118  else
119  {
120  // get the tile start/end range assum we loop over along X tile by tile
121  index_t x_start = [&]() {
122  if constexpr(IsLocal)
123  {
124  index_t tmp = max(-y + i_y + 1, 0);
125  return (tmp / XTile) * XTile; // round to tile aligned
126  }
127  else
128  {
129  return 0;
130  }
131  }();
132 
133  // TODO: end could be negative, we ignore clamp here, and let caller to check
134  // ... in which case end-start is negative
135  index_t x_end = [&]() {
136  index_t tmp = min(i_y + YTile - 1 + x, x_total);
137  return ((tmp + XTile - 1) / XTile) * XTile;
138  }();
139 
140  return ck_tile::make_tuple(x_start, x_end);
141  }
142  }
143 
144  // to get the loop length along Y axis, return index:[start, end), end-start=length
145  // use this if need loop over Y axis tile by tile (like q-seqlen loopover)
146  // TODO: y_end still could be negative, so end-start could be negative(need check)
147  template <index_t YTile, index_t XTile>
148  CK_TILE_HOST_DEVICE constexpr auto
150  {
151  if constexpr(!IsMasking)
152  {
153  return ck_tile::make_tuple(0, y_total);
154  }
155  else
156  {
157  // get the tile start/end range assum we loop over along Y tile by tile
158  index_t y_start = [&]() {
159  index_t tmp = max(-x + i_x + 1, 0);
160  return (tmp / YTile) * YTile; // round to tile aligned
161  }();
162 
163  // TODO: end could be negative, we ignore clamp here, and let caller to check
164  // ... in which case end-start is negative
165  index_t y_end = [&]() {
166  index_t tmp = min(i_x + XTile - 1 + y, y_total);
167  return ((tmp + YTile - 1) / YTile) * YTile;
168  }();
169 
170  return ck_tile::make_tuple(y_start, y_end);
171  }
172  }
173 
174  // per-pixel check if out-of-bound, if true, need mask a value(like -INF)
175  CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
176  {
177  if constexpr(!IsMasking)
178  {
179  return i_x >= x_total;
180  }
181  else
182  {
183  // no need to do min/max here, since i_x will never be < 0 or >= x_total
184  index_t x_start = -y + i_y + 1;
185  index_t x_end = min(i_y + x, x_total);
186 
187  if constexpr(IsLocal)
188  {
189  return i_x < x_start || i_x >= x_end;
190  }
191  else
192  {
193  return i_x >= x_end || i_y >= y_total;
194  }
195  }
196  }
197 
198  // if current tile is at the edge, means need per-pixel mask check.
199  // otherwise no need to check per-pixel
200  // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
201  // can be used as a fast-path to decide if do per-pixel check or not
202  template <index_t TileHeight, index_t TileWidth>
203  CK_TILE_HOST_DEVICE constexpr auto
205  {
206  if constexpr(!IsMasking)
207  {
208  // TODO: no need to check begin
209  return (i_tile_left + TileWidth) > x_total;
210  }
211  else
212  {
213  if constexpr(IsLocal)
214  {
215  // check top-right corner > x or left-borrom corner < x
216  index_t i_tile_right = i_tile_left + TileWidth;
217  index_t i_tile_bottom = i_tile_top + TileHeight;
218  index_t x_end = min(i_tile_top + x, x_total);
219 
220  bool top_right_edge = i_tile_right > (i_tile_top + x);
221  bool bottom_left_edge = i_tile_bottom > (i_tile_left + y);
222  bool is_partial_out_of_bound =
223  i_tile_right > x_end; // only consider right-pad for now
224 
225  return top_right_edge || bottom_left_edge || is_partial_out_of_bound;
226  }
227  else
228  {
229  // only need to check top-right corner > x
230  index_t i_tile_right = i_tile_left + TileWidth;
231  index_t x_end = min(i_tile_top + x, x_total);
232 
233  bool top_right_edge = i_tile_right > x_end;
234  return top_right_edge;
235  }
236  }
237  }
238 
239  private:
240  index_t y, x;
241  index_t y_total, x_total;
242 };
243 
244 // clang-format off
245 namespace impl {
246  template <bool IsMasking_> struct SimplifiedMaskName;
247  template<> struct SimplifiedMaskName<false> { static constexpr const char * name = "nomask"; };
248  template<> struct SimplifiedMaskName<true> { static constexpr const char * name = "mask"; };
249 }
250 // clang-format on
251 
252 // this version only have 2 variation: masking and non-masking
253 // This is more friendly to codegen (e.g. need generate less kernel)
254 // ... with the trade-off that may have more instruction in causal mode
255 template <bool IsMasking_ = true>
257 {
258  static constexpr bool IsMasking = IsMasking_; // false will disable masking
259 
260  static constexpr const char* name = impl::SimplifiedMaskName<IsMasking>::name;
261 
263  : SimplifiedGenericAttentionMask(0, 0, y_total_, x_total_)
264  {
265  }
266 
269  : y(y_), x(x_), y_total(y_total_), x_total(x_total_)
270  {
271  }
272  template <typename MaskCoordinates>
273  CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(const MaskCoordinates& mask_coord)
274  : y(mask_coord.at(number<0>{})),
275  x(mask_coord.at(number<1>{})),
276  y_total(mask_coord.at(number<2>{})),
277  x_total(mask_coord.at(number<3>{}))
278  {
279  }
280 
281  // to get the loop length along X axis, return index:[start, end), end-start=length
282  // use this if need loop over X axis tile by tile (like k-seqlen loopover)
283  // TODO: x_end still could be negative, so end-start could be negative(need check)
284  template <index_t YTile, index_t XTile>
285  CK_TILE_HOST_DEVICE constexpr auto
287  {
288  if constexpr(!IsMasking)
289  {
290  return ck_tile::make_tuple(0, x_total);
291  }
292  else
293  {
294  // get the tile start/end range assum we loop over along X tile by tile
295  index_t x_start = [&]() {
296  index_t tmp = max(-y + i_y + 1, 0);
297  return (tmp / XTile) * XTile; // round to tile aligned
298  }();
299 
300  // TODO: end could be negative, we ignore clamp here, and let caller to check
301  // ... in which case end-start is negative
302  index_t x_end = [&]() {
303  index_t tmp = min(i_y + YTile - 1 + x, x_total);
304  return ((tmp + XTile - 1) / XTile) * XTile;
305  }();
306 
307  return ck_tile::make_tuple(x_start, x_end);
308  }
309  }
310 
311  template <index_t TileHeight, index_t TileWidth>
313  number<TileHeight> height,
314  number<TileWidth> width,
315  index_t num_splits,
316  index_t i_split) const
317  {
318  auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width);
319 
320  const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits));
321  const index_t split_start = x_per_split * i_split;
322  const index_t split_end = ck_tile::min(x_total, split_start + x_per_split);
323 
324  return ck_tile::make_tuple(ck_tile::max(origin_start, split_start),
325  ck_tile::min(origin_end, split_end));
326  }
327 
328  // to get the loop length along Y axis, return index:[start, end), end-start=length
329  // use this if need loop over Y axis tile by tile (like q-seqlen loopover)
330  // TODO: y_end still could be negative, so end-start could be negative(need check)
331  template <index_t YTile, index_t XTile>
332  CK_TILE_HOST_DEVICE constexpr auto
334  {
335  if constexpr(!IsMasking)
336  {
337  return ck_tile::make_tuple(0, y_total);
338  }
339  else
340  {
341  // get the tile start/end range assum we loop over along Y tile by tile
342  index_t y_start = [&]() {
343  index_t tmp = max(-x + i_x + 1, 0);
344  return (tmp / YTile) * YTile; // round to tile aligned
345  }();
346 
347  // TODO: end could be negative, we ignore clamp here, and let caller to check
348  // ... in which case end-start is negative
349  index_t y_end = [&]() {
350  index_t tmp = min(i_x + XTile - 1 + y, y_total);
351  return ((tmp + YTile - 1) / YTile) * YTile;
352  }();
353 
354  return ck_tile::make_tuple(y_start, y_end);
355  }
356  }
357 
358  // per-pixel check if out-of-bound, if true, need mask a value(like -INF)
359  CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
360  {
361  if constexpr(!IsMasking)
362  {
363  // the only case that need do following compare is under kPadSeqLenK
364  // ... for non-masking kernel.
365  return i_x >= x_total;
366  }
367  else
368  {
369  index_t x_start = -y + i_y + 1; // this could be negative, but it's fine
370  index_t x_end = min(i_y + x, x_total); // need min in case x is padded
371 
372  return i_x < x_start || i_x >= x_end || i_y >= y_total;
373  }
374  }
375 
376  // if current tile is at the edge, means need per-pixel mask check.
377  // otherwise no need to check per-pixel
378  // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
379  // can be used as a fast-path to decide if do per-pixel check or not
380  template <index_t TileHeight, index_t TileWidth>
381  CK_TILE_HOST_DEVICE constexpr auto
383  {
384  if constexpr(!IsMasking)
385  {
386  // the only case that need do following compare is under kPadSeqLenK
387  // ... for non-masking kernel.
388  // return (i_x < x_total) && ((i_x + TileWidth) > x_total);
389 
390  // TODO: no need to check begin
391  return (i_x + TileWidth) > x_total;
392  }
393  else
394  {
395  // check top-right corner > x or left-borrom corner < x
396  index_t i_x_end = i_x + TileWidth;
397  index_t i_y_end = i_y + TileHeight;
398  // index_t x_end = min(i_y + x, x_total);
399 
400  bool top_right_edge = i_x_end > min(i_y + x, x_total); // consider right pad
401  bool bottom_left_edge = i_y_end > min(i_x + y, y_total); // consider bottom pad
402  // bool is_partial_out_of_bound = i_x_end > x_end; // only consider right-pad for now
403 
404  return top_right_edge || bottom_left_edge;
405  }
406  }
407 
408  private:
409  index_t y, x;
410  index_t y_total, x_total;
411 };
412 
413 // clang-format off
414 namespace impl {
415  template <bool IsMasking_> struct SimplifiedRatioMaskName;
416  template<> struct SimplifiedRatioMaskName<false> { static constexpr const char * name = "nomask"; };
417  template<> struct SimplifiedRatioMaskName<true> { static constexpr const char * name = "mask"; };
418 }
419 // clang-format on
420 
421 // this version is used for cases that the step length of y-direction changes greater than one. It
422 // means that the mask is not a regular triangular matrix.
423 
424 // clang-format off
425 /* y_ratio is used to describe the step length of y-direction changes
426  in certain performance optimization scenarios like merging seqlen
427  and qk_head_ratio, for example:
428 
429  x=1/y=6/y_ratio=2(top-left)
430  1 * * * * * * *
431  1 * * * * * * *
432  1 1 * * * * * *
433  1 1 * * * * * *
434  1 1 1 * * * * *
435  1 1 1 * * * * *
436 
437 */
438 // clang-format on
439 template <bool IsMasking_ = true>
441 {
442  static constexpr bool IsMasking = IsMasking_; // false will disable masking
443 
444  static constexpr const char* name = impl::SimplifiedRatioMaskName<IsMasking>::name;
445 
447  : SimplifiedRatioAttentionMask(0, 0, y_total_, x_total_, 0, 1, mdiv{})
448  {
449  }
450 
453  index_t y_real_, index_t x_, index_t y_total_, index_t x_total_, mdiv y_ratio_mdiv_)
454  : SimplifiedRatioAttentionMask(/*y_=*/y_real_ * static_cast<index_t>(y_ratio_mdiv_.get()),
455  /*x_=*/x_,
456  /*y_total_=*/y_total_,
457  /*x_total_=*/x_total_,
458  /*y_real_=*/y_real_,
459  /*y_ratio_=*/static_cast<index_t>(y_ratio_mdiv_.get()),
460  /*y_ratio_mdiv_=*/y_ratio_mdiv_)
461 
462  {
463  }
466  index_t x_,
467  index_t y_total_,
468  index_t x_total_,
469  index_t y_real_,
470  index_t y_ratio_,
471  mdiv y_ratio_mdiv_)
472  : y(y_),
473  x(x_),
474  y_total(y_total_),
475  x_total(x_total_),
476  y_real(y_real_),
477  y_ratio(y_ratio_),
478  y_ratio_mdiv(y_ratio_mdiv_)
479  {
480  }
481 
482  // to get the loop length along X axis, return index:[start, end), end-start=length
483  // use this if need loop over X axis tile by tile (like k-seqlen loopover)
484  // TODO: x_end still could be negative, so end-start could be negative(need check)
485  template <index_t YTile, index_t XTile>
486  CK_TILE_HOST_DEVICE constexpr auto
488  {
489  if constexpr(!IsMasking)
490  {
491  return ck_tile::make_tuple(0, x_total);
492  }
493  else
494  {
495  // get the tile start/end range assum we loop over along X tile by tile
496  index_t x_start = [&]() {
497  index_t tmp = -y_real +
498  static_cast<index_t>(y_ratio_mdiv.div(static_cast<uint32_t>(i_y))) +
499  1;
500 
501  return (tmp / XTile) * XTile; // round to tile aligned
502  }();
503 
504  // TODO: end could be negative, we ignore clamp here, and let caller to check
505  // ... in which case end-start is negative
506  index_t x_end = [&]() {
507  uint32_t y_offset = i_y + YTile - 1;
508  index_t tmp = min(static_cast<index_t>(y_ratio_mdiv.div(y_offset)) + x, x_total);
509  return ((tmp + XTile - 1) / XTile) * XTile;
510  }();
511 
512  return ck_tile::make_tuple(x_start, x_end);
513  }
514  }
515 
516  // to get the loop length along Y axis, return index:[start, end), end-start=length
517  // use this if need loop over Y axis tile by tile (like q-seqlen loopover)
518  // TODO: y_end still could be negative, so end-start could be negative(need check)
519  template <index_t YTile, index_t XTile>
520  CK_TILE_HOST_DEVICE constexpr auto
522  {
523  if constexpr(!IsMasking)
524  {
525  return ck_tile::make_tuple(0, y_total);
526  }
527  else
528  {
529  // get the tile start/end range assum we loop over along Y tile by tile
530  index_t y_start = [&]() {
531  index_t tmp = max((-x + i_x + 1) * y_ratio, 0);
532  return (tmp / YTile) * YTile; // round to tile aligned
533  }();
534 
535  // TODO: end could be negative, we ignore clamp here, and let caller to check
536  // ... in which case end-start is negative
537  index_t y_end = [&]() {
538  index_t tmp = min((i_x + XTile - 1) * y_ratio + y, y_total);
539  return ((tmp + YTile - 1) / YTile) * YTile;
540  }();
541 
542  return ck_tile::make_tuple(y_start, y_end);
543  }
544  }
545 
546  // per-pixel check if out-of-bound, if true, need mask a value(like -INF)
547  CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
548  {
549  if constexpr(!IsMasking)
550  {
551  return i_x >= x_total;
552  }
553  else
554  {
555  index_t x_tmp = static_cast<index_t>(y_ratio_mdiv.div(static_cast<uint32_t>(i_y)));
556  index_t x_start = -y_real + x_tmp + 1;
557  index_t x_end = min(x_tmp + x,
558  x_total); // need min in case x is padded
559  return i_x < x_start || i_x >= x_end || i_y >= y_total;
560  }
561  }
562 
563  // if current tile is at the edge, means need per-pixel mask check.
564  // otherwise no need to check per-pixel
565  // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
566  // can be used as a fast-path to decide if do per-pixel check or not
567  template <index_t TileHeight, index_t TileWidth>
568  CK_TILE_HOST_DEVICE constexpr auto
570  {
571  if constexpr(!IsMasking)
572  {
573  // the only case that need do following compare is under kPadSeqLenK
574  // ... for non-masking kernel.
575  // return (i_x < x_total) && ((i_x + TileWidth) > x_total);
576 
577  return (i_x + TileWidth) > x_total;
578  }
579  else
580  {
581  // check top-right corner > x or left-borrom corner < x
582  index_t i_x_end = i_x + TileWidth;
583  index_t i_y_end = i_y + TileHeight;
584  // index_t x_end = min(i_y + x, x_total);
585  uint32_t y_tmp = static_cast<uint32_t>(i_y);
586  bool top_right_edge = i_x_end > min(static_cast<index_t>(y_ratio_mdiv.div(y_tmp)) + x,
587  x_total); // consider right pad
588  bool bottom_left_edge =
589  i_y_end > min(i_x * y_ratio + y, y_total); // consider bottom pad
590  return top_right_edge || bottom_left_edge;
591  }
592  }
593 
594  private:
595  index_t y, x;
596  index_t y_total, x_total;
597  // y_real is vertical axis before multiplying y_ratio. y_real * y_ratio = y
598  index_t y_real;
599  index_t y_ratio;
600  mdiv y_ratio_mdiv;
601 };
602 
603 // TODO: prefer use this function in host code
604 // can convert from the FA style left/right to our generic coordinate
605 // if left_size < 0 && right_size = 0, it is normal causal mask
606 // local is left_size >=0 or right_size >=0
607 CK_TILE_HOST_DEVICE constexpr auto
609  index_t right_size,
610  index_t y_total,
611  index_t x_total,
612  bool is_top_left = true)
613 {
614  // TODO: below should all use sgpr arithmetic
615  index_t left_size_tmp = is_top_left ? y_total - 1 : x_total - 1;
616  index_t right_size_tmp = is_top_left ? x_total - 1 : y_total - 1;
617 
618  left_size = left_size < 0 ? left_size_tmp : left_size;
619  right_size = right_size < 0 ? right_size_tmp : right_size;
620 
621  index_t x_tmp = is_top_left ? 0 : x_total - y_total;
622  index_t y_tmp = is_top_left ? 0 : y_total - x_total;
623 
624  index_t x = 1 + right_size + x_tmp;
625  index_t y = 1 + left_size + y_tmp;
626 
627  return ck_tile::make_tuple(y, x, y_total, x_total);
628 }
629 
630 template <typename MaskType>
631 CK_TILE_HOST_DEVICE constexpr auto
633  index_t right_size,
634  index_t y_total,
635  index_t x_total,
636  bool is_top_left = true)
637 {
639  left_size, right_size, y_total, x_total, is_top_left);
640  return MaskType{r.at(number<0>{}), r.at(number<1>{}), y_total, x_total};
641 }
642 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_generic_attention_mask_from_lr_window(index_t left_size, index_t right_size, index_t y_total, index_t x_total, bool is_top_left=true)
Definition: block_masking.hpp:632
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
constexpr CK_TILE_HOST_DEVICE auto make_generic_attention_mask_coordinates_from_lr_window(index_t left_size, index_t right_size, index_t y_total, index_t x_total, bool is_top_left=true)
Definition: block_masking.hpp:608
unsigned int uint32_t
Definition: stdint.h:126
Definition: block_masking.hpp:81
CK_TILE_HOST_DEVICE GenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_)
Definition: block_masking.hpp:94
constexpr CK_TILE_HOST_DEVICE auto IsEdgeTile(index_t i_tile_top, index_t i_tile_left, number< TileHeight >, number< TileWidth >) const
Definition: block_masking.hpp:204
constexpr CK_TILE_HOST_DEVICE auto GetTileRangeAlongX(index_t i_y, number< YTile >, number< XTile >) const
Definition: block_masking.hpp:112
constexpr CK_TILE_HOST_DEVICE auto IsOutOfBound(index_t i_y, index_t i_x) const
Definition: block_masking.hpp:175
CK_TILE_HOST_DEVICE GenericAttentionMask(index_t y_total_, index_t x_total_)
Definition: block_masking.hpp:88
static constexpr const char * name
Definition: block_masking.hpp:86
constexpr CK_TILE_HOST_DEVICE auto GetTileRangeAlongY(index_t i_x, number< YTile >, number< XTile >) const
Definition: block_masking.hpp:149
CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates &mask_coord)
Definition: block_masking.hpp:99
static constexpr bool IsMasking
Definition: block_masking.hpp:82
static constexpr bool IsLocal
Definition: block_masking.hpp:83
Definition: block_masking.hpp:257
constexpr CK_TILE_HOST_DEVICE auto IsEdgeTile(index_t i_y, index_t i_x, number< TileHeight >, number< TileWidth >) const
Definition: block_masking.hpp:382
constexpr CK_TILE_HOST_DEVICE auto GetTileRangeAlongX(index_t i_y, number< YTile >, number< XTile >) const
Definition: block_masking.hpp:286
constexpr CK_TILE_HOST_DEVICE auto IsOutOfBound(index_t i_y, index_t i_x) const
Definition: block_masking.hpp:359
static constexpr const char * name
Definition: block_masking.hpp:260
constexpr CK_TILE_HOST_DEVICE auto GetTileRangeAlongY(index_t i_x, number< YTile >, number< XTile >) const
Definition: block_masking.hpp:333
constexpr CK_TILE_HOST_DEVICE auto GetTileRangeAlongX(index_t i_y, number< TileHeight > height, number< TileWidth > width, index_t num_splits, index_t i_split) const
Definition: block_masking.hpp:312
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(const MaskCoordinates &mask_coord)
Definition: block_masking.hpp:273
static constexpr bool IsMasking
Definition: block_masking.hpp:258
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(index_t y_total_, index_t x_total_)
Definition: block_masking.hpp:262
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_)
Definition: block_masking.hpp:268
Definition: block_masking.hpp:441
CK_TILE_HOST_DEVICE SimplifiedRatioAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_, index_t y_real_, index_t y_ratio_, mdiv y_ratio_mdiv_)
Definition: block_masking.hpp:465
static constexpr const char * name
Definition: block_masking.hpp:444
CK_TILE_HOST_DEVICE SimplifiedRatioAttentionMask(index_t y_real_, index_t x_, index_t y_total_, index_t x_total_, mdiv y_ratio_mdiv_)
Definition: block_masking.hpp:452
constexpr CK_TILE_HOST_DEVICE auto GetTileRangeAlongX(index_t i_y, number< YTile >, number< XTile >) const
Definition: block_masking.hpp:487
CK_TILE_HOST_DEVICE SimplifiedRatioAttentionMask(index_t y_total_, index_t x_total_)
Definition: block_masking.hpp:446
constexpr CK_TILE_HOST_DEVICE auto GetTileRangeAlongY(index_t i_x, number< YTile >, number< XTile >) const
Definition: block_masking.hpp:521
constexpr CK_TILE_HOST_DEVICE auto IsEdgeTile(index_t i_y, index_t i_x, number< TileHeight >, number< TileWidth >) const
Definition: block_masking.hpp:569
static constexpr bool IsMasking
Definition: block_masking.hpp:442
constexpr CK_TILE_HOST_DEVICE auto IsOutOfBound(index_t i_y, index_t i_x) const
Definition: block_masking.hpp:547
Definition: integral_constant.hpp:13
Definition: block_masking.hpp:71
Definition: block_masking.hpp:246
Definition: block_masking.hpp:415
Definition: magic_div.hpp:186
CK_TILE_HOST_DEVICE uint32_t div(uint32_t dividend_) const
Definition: magic_div.hpp:212