/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/block/block_masking.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/block/block_masking.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/block/block_masking.hpp Source File
block_masking.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 
8 namespace ck_tile {
9 
11 {
12  NO_MASK = 0,
13 
14  // below enum could be causal, or sliding window
17 
18  // this enum maybe not used by xformer/FA, since it's hard to
19  // specify left/right window for varlen case. put it here for
20  // debug purpose
22 };
23 
24 // clang-format off
25 /* generic Attention Mask Coordinate
26  use x(horizontal axis), y(vertical axis) to describe mask.
27  top-left corner is origin
28 
29  x=1/y=5(top-left) x=4/y=5(botm-r) x=6/y=5 x=8/y=5(no mask)
30  1 * * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1
31  1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1
32  1 1 1 * * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
33  1 1 1 1 * * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
34  1 1 1 1 1 * * * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
35  l=7,-1/r=0(tl) l=7,-1/r=0(br)
36 
37  x=1/y=2 x=4/y=2 x=6/y=2 x=8/y=2
38  1 * * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1
39  1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1
40  * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1
41  * * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1
42  * * * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 * * * 1 1 1 1 1
43  l=1/r=0(tl) l=1/r=3(tl) l=1/r=5(tl) l=1/r=7(tl)
44  l=4/r=0(br) l=4/r=2(br) l=4/r=4(br)
45 
46  x=4/y=-1 x=6/y=-1 x=8/y=-1
47  * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1
48  * * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 1
49  * * * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1
50  * * * * * 1 1 * * * * * * 1 1 1 * * * * * 1 1 1
51  * * * * * * 1 1 * * * * * * 1 1 * * * * * * 1 1
52 
53  x=-2/y=5 x=1/y=5(top-left) x=0/y=5(botm-r)
54  * * * * * * * * 1 * * * * * * *
55  * * * * * * * * 1 1 * * 1 * * *
56  * * * * * * * * 1 1 1 * 1 1 * *
57  1 * * * * * * * 1 1 1 1 1 1 1 *
58  1 1 * * * * * * 1 1 1 1 1 1 1 1
59 
60  Validations:
61  x + y > 1 (x + y >= 2)
62 
63  Note:
64  y = seq_q, x = 1 -> top-left
65  y = seq_q, x = seq_k - seq_q + 1 -> bottom-right
66  y < seq_q, x < seq_k -> local-attn
67  y = seq_q, x = seq_k -> no mask
68 
69 */
70 namespace impl {
71  template <bool IsMasking_, bool IsLocal_> struct MaskName;
72  template<> struct MaskName<false, false> { static constexpr const char * name = "mn"; };
73  template<> struct MaskName<false, true> { static constexpr const char * name = "mn"; };
74  template<> struct MaskName<true, false> { static constexpr const char * name = "mc"; };
75  template<> struct MaskName<true, true> { static constexpr const char * name = "mg"; };
76 }
77 // clang-format on
78 
79 template <bool IsMasking_ = true, bool IsLocal_ = false>
81 {
82  static constexpr bool IsMasking = IsMasking_; // false will disable masking
83  static constexpr bool IsLocal = IsLocal_; // if true, upper/lower area could have mask,
84  // else only upper-right could have mask
85 
86  static constexpr const char* name = impl::MaskName<IsMasking, IsLocal>::name;
87 
89  : GenericAttentionMask(0, 0, y_total_, x_total_)
90  {
91  }
92 
94  GenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_)
95  : y(y_), x(x_), y_total(y_total_), x_total(x_total_)
96  {
97  }
98  template <typename MaskCoordinates>
99  CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord)
100  : y(mask_coord.at(number<0>{})),
101  x(mask_coord.at(number<1>{})),
102  y_total(mask_coord.at(number<2>{})),
103  x_total(mask_coord.at(number<3>{}))
104  {
105  }
106 
107  // to get the loop length along X axis, return index:[start, end), end-start=length
108  // use this if need loop over X axis tile by tile (like k-seqlen loopover)
109  // TODO: x_end still could be negative, so end-start could be negative(need check)
110  template <index_t YTile, index_t XTile>
111  CK_TILE_HOST_DEVICE constexpr auto
113  {
114  if constexpr(!IsMasking)
115  {
116  return ck_tile::make_tuple(0, x_total);
117  }
118  else
119  {
120  // get the tile start/end range assum we loop over along X tile by tile
121  index_t x_start = [&]() {
122  if constexpr(IsLocal)
123  {
124  index_t tmp = max(-y + i_y + 1, 0);
125  return (tmp / XTile) * XTile; // round to tile aligned
126  }
127  else
128  {
129  return 0;
130  }
131  }();
132 
133  // TODO: end could be negative, we ignore clamp here, and let caller to check
134  // ... in which case end-start is negative
135  index_t x_end = [&]() {
136  index_t tmp = min(i_y + YTile - 1 + x, x_total);
137  return ((tmp + XTile - 1) / XTile) * XTile;
138  }();
139 
140  return ck_tile::make_tuple(x_start, x_end);
141  }
142  }
143 
144  // to get the loop length along Y axis, return index:[start, end), end-start=length
145  // use this if need loop over Y axis tile by tile (like q-seqlen loopover)
146  // TODO: y_end still could be negative, so end-start could be negative(need check)
147  template <index_t YTile, index_t XTile>
148  CK_TILE_HOST_DEVICE constexpr auto
150  {
151  if constexpr(!IsMasking)
152  {
153  return ck_tile::make_tuple(0, y_total);
154  }
155  else
156  {
157  // get the tile start/end range assum we loop over along Y tile by tile
158  index_t y_start = [&]() {
159  index_t tmp = max(-x + i_x + 1, 0);
160  return (tmp / YTile) * YTile; // round to tile aligned
161  }();
162 
163  // TODO: end could be negative, we ignore clamp here, and let caller to check
164  // ... in which case end-start is negative
165  index_t y_end = [&]() {
166  index_t tmp = min(i_x + XTile - 1 + y, y_total);
167  return ((tmp + YTile - 1) / YTile) * YTile;
168  }();
169 
170  return ck_tile::make_tuple(y_start, y_end);
171  }
172  }
173 
174  // per-pixel check if out-of-bound, if true, need mask a value(like -INF)
175  CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
176  {
177  if constexpr(!IsMasking)
178  {
179  return i_x >= x_total;
180  }
181  else
182  {
183  // no need to do min/max here, since i_x will never be < 0 or >= x_total
184  index_t x_start = -y + i_y + 1;
185  index_t x_end = min(i_y + x, x_total);
186 
187  if constexpr(IsLocal)
188  {
189  return i_x < x_start || i_x >= x_end;
190  }
191  else
192  {
193  return i_x >= x_end || i_y >= y_total;
194  }
195  }
196  }
197 
198  // if current tile is at the edge, means need per-pixel mask check.
199  // otherwise no need to check per-pixel
200  // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
201  // can be used as a fast-path to decide if do per-pixel check or not
202  template <index_t TileHeight, index_t TileWidth>
203  CK_TILE_HOST_DEVICE constexpr auto
205  {
206  if constexpr(IsLocal)
207  {
208  // check top-right corner > x or left-borrom corner < x
209  index_t i_tile_right = i_tile_left + TileWidth;
210  index_t i_tile_bottom = i_tile_top + TileHeight;
211  index_t x_end = min(i_tile_top + x, x_total);
212 
213  bool top_right_edge = i_tile_right > (i_tile_top + x);
214  bool bottom_left_edge = i_tile_bottom > (i_tile_left + y);
215  bool is_partial_out_of_bound = i_tile_right > x_end; // only consider right-pad for now
216 
217  return top_right_edge || bottom_left_edge || is_partial_out_of_bound;
218  }
219  else
220  {
221  // only need to check top-right corner > x
222  index_t i_tile_right = i_tile_left + TileWidth;
223  index_t x_end = min(i_tile_top + x, x_total);
224 
225  bool top_right_edge = i_tile_right > x_end;
226  return top_right_edge;
227  }
228  }
229 
230  private:
231  index_t y, x;
232  index_t y_total, x_total;
233 };
234 
235 // clang-format off
236 namespace impl {
237  template <bool IsMasking_> struct SimplifiedMaskName;
238  template<> struct SimplifiedMaskName<false> { static constexpr const char * name = "nomask"; };
239  template<> struct SimplifiedMaskName<true> { static constexpr const char * name = "mask"; };
240 }
241 // clang-format on
242 
243 // this version only have 2 variation: masking and non-masking
244 // This is more friendly to codegen (e.g. need generate less kernel)
245 // ... with the trade-off that may have more instruction in causal mode
246 template <bool IsMasking_ = true>
248 {
249  static constexpr bool IsMasking = IsMasking_; // false will disable masking
250 
251  static constexpr const char* name = impl::SimplifiedMaskName<IsMasking>::name;
252 
254  : SimplifiedGenericAttentionMask(0, 0, y_total_, x_total_)
255  {
256  }
257 
260  : y(y_), x(x_), y_total(y_total_), x_total(x_total_)
261  {
262  }
263  template <typename MaskCoordinates>
264  CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(const MaskCoordinates& mask_coord)
265  : y(mask_coord.at(number<0>{})),
266  x(mask_coord.at(number<1>{})),
267  y_total(mask_coord.at(number<2>{})),
268  x_total(mask_coord.at(number<3>{}))
269  {
270  }
271 
272  // to get the loop length along X axis, return index:[start, end), end-start=length
273  // use this if need loop over X axis tile by tile (like k-seqlen loopover)
274  // TODO: x_end still could be negative, so end-start could be negative(need check)
275  template <index_t YTile, index_t XTile>
276  CK_TILE_HOST_DEVICE constexpr auto
278  {
279  if constexpr(!IsMasking)
280  {
281  return ck_tile::make_tuple(0, x_total);
282  }
283  else
284  {
285  // get the tile start/end range assum we loop over along X tile by tile
286  index_t x_start = [&]() {
287  index_t tmp = max(-y + i_y + 1, 0);
288  return (tmp / XTile) * XTile; // round to tile aligned
289  }();
290 
291  // TODO: end could be negative, we ignore clamp here, and let caller to check
292  // ... in which case end-start is negative
293  index_t x_end = [&]() {
294  index_t tmp = min(i_y + YTile - 1 + x, x_total);
295  return ((tmp + XTile - 1) / XTile) * XTile;
296  }();
297 
298  return ck_tile::make_tuple(x_start, x_end);
299  }
300  }
301 
302  template <index_t TileHeight, index_t TileWidth>
304  number<TileHeight> height,
305  number<TileWidth> width,
306  index_t num_splits,
307  index_t i_split) const
308  {
309  auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width);
310 
311  const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits));
312  const index_t split_start = x_per_split * i_split;
313  const index_t split_end = ck_tile::min(x_total, split_start + x_per_split);
314 
315  return ck_tile::make_tuple(ck_tile::max(origin_start, split_start),
316  ck_tile::min(origin_end, split_end));
317  }
318 
319  // to get the loop length along Y axis, return index:[start, end), end-start=length
320  // use this if need loop over Y axis tile by tile (like q-seqlen loopover)
321  // TODO: y_end still could be negative, so end-start could be negative(need check)
322  template <index_t YTile, index_t XTile>
323  CK_TILE_HOST_DEVICE constexpr auto
325  {
326  if constexpr(!IsMasking)
327  {
328  return ck_tile::make_tuple(0, y_total);
329  }
330  else
331  {
332  // get the tile start/end range assum we loop over along Y tile by tile
333  index_t y_start = [&]() {
334  index_t tmp = max(-x + i_x + 1, 0);
335  return (tmp / YTile) * YTile; // round to tile aligned
336  }();
337 
338  // TODO: end could be negative, we ignore clamp here, and let caller to check
339  // ... in which case end-start is negative
340  index_t y_end = [&]() {
341  index_t tmp = min(i_x + XTile - 1 + y, y_total);
342  return ((tmp + YTile - 1) / YTile) * YTile;
343  }();
344 
345  return ck_tile::make_tuple(y_start, y_end);
346  }
347  }
348 
349  // per-pixel check if out-of-bound, if true, need mask a value(like -INF)
350  CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
351  {
352  if constexpr(!IsMasking)
353  {
354  // the only case that need do following compare is under kPadSeqLenK
355  // ... for non-masking kernel.
356  return i_x >= x_total;
357  }
358  else
359  {
360  index_t x_start = -y + i_y + 1; // this could be negative, but it's fine
361  index_t x_end = min(i_y + x, x_total); // need min in case x is padded
362 
363  return i_x < x_start || i_x >= x_end || i_y >= y_total;
364  }
365  }
366 
367  // if current tile is at the edge, means need per-pixel mask check.
368  // otherwise no need to check per-pixel
369  // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
370  // can be used as a fast-path to decide if do per-pixel check or not
371  template <index_t TileHeight, index_t TileWidth>
372  CK_TILE_HOST_DEVICE constexpr auto
374  {
375  if constexpr(!IsMasking)
376  {
377  // the only case that need do following compare is under kPadSeqLenK
378  // ... for non-masking kernel.
379  // return (i_x < x_total) && ((i_x + TileWidth) > x_total);
380 
381  // TODO: no need to check begin
382  return (i_x + TileWidth) > x_total;
383  }
384  else
385  {
386  // check top-right corner > x or left-borrom corner < x
387  index_t i_x_end = i_x + TileWidth;
388  index_t i_y_end = i_y + TileHeight;
389  // index_t x_end = min(i_y + x, x_total);
390 
391  bool top_right_edge = i_x_end > min(i_y + x, x_total); // consider right pad
392  bool bottom_left_edge = i_y_end > min(i_x + y, y_total); // consider bottom pad
393  // bool is_partial_out_of_bound = i_x_end > x_end; // only consider right-pad for now
394 
395  return top_right_edge || bottom_left_edge;
396  }
397  }
398 
399  private:
400  index_t y, x;
401  index_t y_total, x_total;
402 };
403 
404 // clang-format off
405 namespace impl {
406  template <bool IsMasking_> struct SimplifiedRatioMaskName;
407  template<> struct SimplifiedRatioMaskName<false> { static constexpr const char * name = "nomask"; };
408  template<> struct SimplifiedRatioMaskName<true> { static constexpr const char * name = "mask"; };
409 }
410 // clang-format on
411 
412 // this version is used for cases that the step length of y-direction changes greater than one. It
413 // means that the mask is not a regular triangular matrix.
414 
415 // clang-format off
416 /* y_ratio is used to describe the step length of y-direction changes
417  in certain performance optimization scenarios like merging seqlen
418  and qk_head_ratio, for example:
419 
420  x=1/y=6/y_ratio=2(top-left)
421  1 * * * * * * *
422  1 * * * * * * *
423  1 1 * * * * * *
424  1 1 * * * * * *
425  1 1 1 * * * * *
426  1 1 1 * * * * *
427 
428 */
429 // clang-format on
430 template <bool IsMasking_ = true>
432 {
433  static constexpr bool IsMasking = IsMasking_; // false will disable masking
434 
435  static constexpr const char* name = impl::SimplifiedRatioMaskName<IsMasking>::name;
436 
438  : SimplifiedRatioAttentionMask(0, 0, y_total_, x_total_, 0, 1, mdiv{})
439  {
440  }
441 
444  index_t y_real_, index_t x_, index_t y_total_, index_t x_total_, mdiv y_ratio_mdiv_)
445  : SimplifiedRatioAttentionMask(/*y_=*/y_real_ * static_cast<index_t>(y_ratio_mdiv_.get()),
446  /*x_=*/x_,
447  /*y_total_=*/y_total_,
448  /*x_total_=*/x_total_,
449  /*y_real_=*/y_real_,
450  /*y_ratio_=*/static_cast<index_t>(y_ratio_mdiv_.get()),
451  /*y_ratio_mdiv_=*/y_ratio_mdiv_)
452 
453  {
454  }
457  index_t x_,
458  index_t y_total_,
459  index_t x_total_,
460  index_t y_real_,
461  index_t y_ratio_,
462  mdiv y_ratio_mdiv_)
463  : y(y_),
464  x(x_),
465  y_total(y_total_),
466  x_total(x_total_),
467  y_real(y_real_),
468  y_ratio(y_ratio_),
469  y_ratio_mdiv(y_ratio_mdiv_)
470  {
471  }
472 
473  // to get the loop length along X axis, return index:[start, end), end-start=length
474  // use this if need loop over X axis tile by tile (like k-seqlen loopover)
475  // TODO: x_end still could be negative, so end-start could be negative(need check)
476  template <index_t YTile, index_t XTile>
477  CK_TILE_HOST_DEVICE constexpr auto
479  {
480  if constexpr(!IsMasking)
481  {
482  return ck_tile::make_tuple(0, x_total);
483  }
484  else
485  {
486  // get the tile start/end range assum we loop over along X tile by tile
487  index_t x_start = [&]() {
488  index_t tmp = -y_real +
489  static_cast<index_t>(y_ratio_mdiv.div(static_cast<uint32_t>(i_y))) +
490  1;
491 
492  return (tmp / XTile) * XTile; // round to tile aligned
493  }();
494 
495  // TODO: end could be negative, we ignore clamp here, and let caller to check
496  // ... in which case end-start is negative
497  index_t x_end = [&]() {
498  uint32_t y_offset = i_y + YTile - 1;
499  index_t tmp = min(static_cast<index_t>(y_ratio_mdiv.div(y_offset)) + x, x_total);
500  return ((tmp + XTile - 1) / XTile) * XTile;
501  }();
502 
503  return ck_tile::make_tuple(x_start, x_end);
504  }
505  }
506 
507  // to get the loop length along Y axis, return index:[start, end), end-start=length
508  // use this if need loop over Y axis tile by tile (like q-seqlen loopover)
509  // TODO: y_end still could be negative, so end-start could be negative(need check)
510  template <index_t YTile, index_t XTile>
511  CK_TILE_HOST_DEVICE constexpr auto
513  {
514  if constexpr(!IsMasking)
515  {
516  return ck_tile::make_tuple(0, y_total);
517  }
518  else
519  {
520  // get the tile start/end range assum we loop over along Y tile by tile
521  index_t y_start = [&]() {
522  index_t tmp = max((-x + i_x + 1) * y_ratio, 0);
523  return (tmp / YTile) * YTile; // round to tile aligned
524  }();
525 
526  // TODO: end could be negative, we ignore clamp here, and let caller to check
527  // ... in which case end-start is negative
528  index_t y_end = [&]() {
529  index_t tmp = min((i_x + XTile - 1) * y_ratio + y, y_total);
530  return ((tmp + YTile - 1) / YTile) * YTile;
531  }();
532 
533  return ck_tile::make_tuple(y_start, y_end);
534  }
535  }
536 
537  // per-pixel check if out-of-bound, if true, need mask a value(like -INF)
538  CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
539  {
540  if constexpr(!IsMasking)
541  {
542  return i_x >= x_total;
543  }
544  else
545  {
546  index_t x_tmp = static_cast<index_t>(y_ratio_mdiv.div(static_cast<uint32_t>(i_y)));
547  index_t x_start = -y_real + x_tmp + 1;
548  index_t x_end = min(x_tmp + x,
549  x_total); // need min in case x is padded
550  return i_x < x_start || i_x >= x_end || i_y >= y_total;
551  }
552  }
553 
554  // if current tile is at the edge, means need per-pixel mask check.
555  // otherwise no need to check per-pixel
556  // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
557  // can be used as a fast-path to decide if do per-pixel check or not
558  template <index_t TileHeight, index_t TileWidth>
559  CK_TILE_HOST_DEVICE constexpr auto
561  {
562  if constexpr(!IsMasking)
563  {
564  // the only case that need do following compare is under kPadSeqLenK
565  // ... for non-masking kernel.
566  // return (i_x < x_total) && ((i_x + TileWidth) > x_total);
567 
568  return (i_x + TileWidth) > x_total;
569  }
570  else
571  {
572  // check top-right corner > x or left-borrom corner < x
573  index_t i_x_end = i_x + TileWidth;
574  index_t i_y_end = i_y + TileHeight;
575  // index_t x_end = min(i_y + x, x_total);
576  uint32_t y_tmp = static_cast<uint32_t>(i_y);
577  bool top_right_edge = i_x_end > min(static_cast<index_t>(y_ratio_mdiv.div(y_tmp)) + x,
578  x_total); // consider right pad
579  bool bottom_left_edge =
580  i_y_end > min(i_x * y_ratio + y, y_total); // consider bottom pad
581  return top_right_edge || bottom_left_edge;
582  }
583  }
584 
585  private:
586  index_t y, x;
587  index_t y_total, x_total;
588  // y_real is vertical axis before multiplying y_ratio. y_real * y_ratio = y
589  index_t y_real;
590  index_t y_ratio;
591  mdiv y_ratio_mdiv;
592 };
593 
594 // TODO: prefer use this function in host code
595 // can convert from the FA style left/right to our generic coordinate
596 // if left_size < 0 && right_size = 0, it is normal causal mask
597 // local is left_size >=0 or right_size >=0
598 CK_TILE_HOST_DEVICE constexpr auto
600  index_t right_size,
601  index_t y_total,
602  index_t x_total,
603  bool is_top_left = true)
604 {
605  // TODO: below should all use sgpr arithmetic
606  index_t left_size_tmp = is_top_left ? y_total - 1 : x_total - 1;
607  index_t right_size_tmp = is_top_left ? x_total - 1 : y_total - 1;
608 
609  left_size = left_size < 0 ? left_size_tmp : left_size;
610  right_size = right_size < 0 ? right_size_tmp : right_size;
611 
612  index_t x_tmp = is_top_left ? 0 : x_total - y_total;
613  index_t y_tmp = is_top_left ? 0 : y_total - x_total;
614 
615  index_t x = 1 + right_size + x_tmp;
616  index_t y = 1 + left_size + y_tmp;
617 
618  return ck_tile::make_tuple(y, x, y_total, x_total);
619 }
620 
621 template <typename MaskType>
622 CK_TILE_HOST_DEVICE constexpr auto
624  index_t right_size,
625  index_t y_total,
626  index_t x_total,
627  bool is_top_left = true)
628 {
630  left_size, right_size, y_total, x_total, is_top_left);
631  return MaskType{r.at(number<0>{}), r.at(number<1>{}), y_total, x_total};
632 }
633 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_generic_attention_mask_from_lr_window(index_t left_size, index_t right_size, index_t y_total, index_t x_total, bool is_top_left=true)
Definition: block_masking.hpp:623
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
constexpr CK_TILE_HOST_DEVICE auto make_generic_attention_mask_coordinates_from_lr_window(index_t left_size, index_t right_size, index_t y_total, index_t x_total, bool is_top_left=true)
Definition: block_masking.hpp:599
unsigned int uint32_t
Definition: stdint.h:126
Definition: block_masking.hpp:81
CK_TILE_HOST_DEVICE GenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_)
Definition: block_masking.hpp:94
constexpr CK_TILE_HOST_DEVICE auto IsEdgeTile(index_t i_tile_top, index_t i_tile_left, number< TileHeight >, number< TileWidth >) const
Definition: block_masking.hpp:204
constexpr CK_TILE_HOST_DEVICE auto GetTileRangeAlongX(index_t i_y, number< YTile >, number< XTile >) const
Definition: block_masking.hpp:112
constexpr CK_TILE_HOST_DEVICE auto IsOutOfBound(index_t i_y, index_t i_x) const
Definition: block_masking.hpp:175
CK_TILE_HOST_DEVICE GenericAttentionMask(index_t y_total_, index_t x_total_)
Definition: block_masking.hpp:88
static constexpr const char * name
Definition: block_masking.hpp:86
constexpr CK_TILE_HOST_DEVICE auto GetTileRangeAlongY(index_t i_x, number< YTile >, number< XTile >) const
Definition: block_masking.hpp:149
CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates &mask_coord)
Definition: block_masking.hpp:99
static constexpr bool IsMasking
Definition: block_masking.hpp:82
static constexpr bool IsLocal
Definition: block_masking.hpp:83
Definition: block_masking.hpp:248
constexpr CK_TILE_HOST_DEVICE auto IsEdgeTile(index_t i_y, index_t i_x, number< TileHeight >, number< TileWidth >) const
Definition: block_masking.hpp:373
constexpr CK_TILE_HOST_DEVICE auto GetTileRangeAlongX(index_t i_y, number< YTile >, number< XTile >) const
Definition: block_masking.hpp:277
constexpr CK_TILE_HOST_DEVICE auto IsOutOfBound(index_t i_y, index_t i_x) const
Definition: block_masking.hpp:350
static constexpr const char * name
Definition: block_masking.hpp:251
constexpr CK_TILE_HOST_DEVICE auto GetTileRangeAlongY(index_t i_x, number< YTile >, number< XTile >) const
Definition: block_masking.hpp:324
constexpr CK_TILE_HOST_DEVICE auto GetTileRangeAlongX(index_t i_y, number< TileHeight > height, number< TileWidth > width, index_t num_splits, index_t i_split) const
Definition: block_masking.hpp:303
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(const MaskCoordinates &mask_coord)
Definition: block_masking.hpp:264
static constexpr bool IsMasking
Definition: block_masking.hpp:249
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(index_t y_total_, index_t x_total_)
Definition: block_masking.hpp:253
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_)
Definition: block_masking.hpp:259
Definition: block_masking.hpp:432
CK_TILE_HOST_DEVICE SimplifiedRatioAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_, index_t y_real_, index_t y_ratio_, mdiv y_ratio_mdiv_)
Definition: block_masking.hpp:456
static constexpr const char * name
Definition: block_masking.hpp:435
CK_TILE_HOST_DEVICE SimplifiedRatioAttentionMask(index_t y_real_, index_t x_, index_t y_total_, index_t x_total_, mdiv y_ratio_mdiv_)
Definition: block_masking.hpp:443
constexpr CK_TILE_HOST_DEVICE auto GetTileRangeAlongX(index_t i_y, number< YTile >, number< XTile >) const
Definition: block_masking.hpp:478
CK_TILE_HOST_DEVICE SimplifiedRatioAttentionMask(index_t y_total_, index_t x_total_)
Definition: block_masking.hpp:437
constexpr CK_TILE_HOST_DEVICE auto GetTileRangeAlongY(index_t i_x, number< YTile >, number< XTile >) const
Definition: block_masking.hpp:512
constexpr CK_TILE_HOST_DEVICE auto IsEdgeTile(index_t i_y, index_t i_x, number< TileHeight >, number< TileWidth >) const
Definition: block_masking.hpp:560
static constexpr bool IsMasking
Definition: block_masking.hpp:433
constexpr CK_TILE_HOST_DEVICE auto IsOutOfBound(index_t i_y, index_t i_x) const
Definition: block_masking.hpp:538
Definition: integral_constant.hpp:13
Definition: block_masking.hpp:71
Definition: block_masking.hpp:237
Definition: block_masking.hpp:406
Definition: magic_div.hpp:186
CK_TILE_HOST_DEVICE uint32_t div(uint32_t dividend_) const
Definition: magic_div.hpp:212