/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp Source File
flatmm_sn_32x128x512_1x4x1_16x16x32.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
12 // "S"tream update output along "N"
13 // A in smem, B load from global
14 // require 4 wave, occupancy=1c
16 {
17  static constexpr index_t Block_M = 32;
18  static constexpr index_t Block_N = 128;
19  static constexpr index_t Block_K = 512;
20 
21  static constexpr index_t WarpPerBlock_M = 1;
22  static constexpr index_t WarpPerBlock_N = 4;
23  static constexpr index_t WarpPerBlock_K = 1;
24 
25  static constexpr index_t Warp_M = 16;
26  static constexpr index_t Warp_N = 16;
27  static constexpr index_t Warp_K = 32;
28 
29  static constexpr index_t BlockSize = 256;
30 
31  // static constexpr index_t KPack = 2; // this is used to gurantee every threads can do dwordx4
32 
33  // TODO: note Nr/Kr/W need consider KPack
34  static constexpr index_t Block_W = Warp_N * Warp_K; // 512 element
35  static constexpr index_t Block_Nr = Block_N / Warp_N; // 32 element, 4 per wave
36  static constexpr index_t Block_Kr = Block_K / Warp_K; // 4
37 
38  static constexpr index_t Repeat_M = Block_M / (Warp_M * WarpPerBlock_M); // 2
39  static constexpr index_t Repeat_N = Block_N / (Warp_N * WarpPerBlock_N); // 2
40  static constexpr index_t Repeat_K = Block_K / (Warp_K * WarpPerBlock_K); // 16
41 
42  static CK_TILE_DEVICE constexpr auto MakeCBlockDist()
43  {
44  constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
45  sequence<>,
49  sequence<2, 1>, // !! note here is different
50  sequence<0, 0>>{};
51 
53 
54  constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
55  c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
56  constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
57  return c_block_dstr;
58  }
59 
61  {
62  // y y p p p y
63  // reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
64  // but order is N0*M0*Nv
65  // in LDS we need store as
66  // M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
67  // y y wave-id lid/16 lid%16 v
68  constexpr index_t nbufs = 2;
69  return 2 * 2 * 4 * 4 * (16 * 4 + 4) * sizeof(bf16_t) * nbufs;
70  }
71 };
72 
74 {
75  using BDataType = bf16_t;
76  using ODataType = bf16_t;
77 
78  // TODO: need paired with tile_window_linear!
79  // TODO: need call init_raw() before call this function!
80  // template <typename AWindow, typename BWindow, typename OWindow, typename ScaleTensor>
81  template <typename BRes,
82  typename BCoords,
83  typename ORes,
84  typename OCoords,
85  typename OFlags,
86  typename ScaleTensor>
87  CK_TILE_DEVICE auto
88  operator()(const BRes& res_b,
89  const BCoords& cached_coords_b,
90  const ORes& res_o,
91  const OCoords& cached_coords_o,
92  const OFlags& o_flags, // this should be in sgpr
93  CK_TILE_LDS_ADDR void* smem,
94  index_t n, // loop along n dim
95  const ScaleTensor& scale_,
96  index_t tile_offset_b, // stride b is fixed to blockKr * blockW, but still can adjust
97  index_t tile_offset_o)
98  {
99  static_assert(BCoords::size() == 8); // 8
100  static_assert(OCoords::size() == 8);
101 
102  const index_t tile_stride_b_bytes = tile_offset_b * sizeof(BDataType);
103  const index_t tile_stride_o_bytes = tile_offset_o * sizeof(ODataType);
104 
105  static_assert(ScaleTensor::size() == 2);
106  float s0 = scale_[number<0>{}];
107  float s1 = scale_[number<1>{}];
108 
109  index_t loop_cnt = n / Block_N;
110 
111  register float v_c0 asm("v64");
112  register float v_c1 asm("v65");
113  register float v_c2 asm("v66");
114  register float v_c3 asm("v67");
115  register float v_c4 asm("v68");
116  register float v_c5 asm("v69");
117  register float v_c6 asm("v70");
118  register float v_c7 asm("v71");
119  register float v_c8 asm("v72");
120  register float v_c9 asm("v73");
121  register float v_c10 asm("v74");
122  register float v_c11 asm("v75");
123  register float v_c12 asm("v76");
124  register float v_c13 asm("v77");
125  register float v_c14 asm("v78");
126  register float v_c15 asm("v79");
127  register float v_c16 asm("v80");
128  register float v_c17 asm("v81");
129  register float v_c18 asm("v82");
130  register float v_c19 asm("v83");
131  register float v_c20 asm("v84");
132  register float v_c21 asm("v85");
133  register float v_c22 asm("v86");
134  register float v_c23 asm("v87");
135  register float v_c24 asm("v88");
136  register float v_c25 asm("v89");
137  register float v_c26 asm("v90");
138  register float v_c27 asm("v91");
139  register float v_c28 asm("v92");
140  register float v_c29 asm("v93");
141  register float v_c30 asm("v94");
142  register float v_c31 asm("v95");
143  int32_t nan_hi = 0x7fff0000;
144  int32_t nan_lo = 0x00007fff;
145 
146  // in smem, the layout is M0(2)*K0(128)*M1(16)*K1(4)
147  // every threads need 8xK in contiguous register
148  // ... and every wave need the same data
149  int lane_id = threadIdx.x % 64;
150  int sld_y_os = (lane_id % 16) * 4 + (lane_id / 16) * 128;
151  sld_y_os *= 2;
152 
153  // y y p p p y
154  // reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
155  // but order is N0*M0*Nv
156  // in LDS we need store as
157  // M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
158  // y y wave-id lid/16 lid%16 v
159  // sst(v3) = (v0/16*34 + v0%16 * 2 + wid*136) * 4
160  int sfl_sst = (threadIdx.x % 16 * 4) + (threadIdx.x / 16) * (64 + 4);
161  sfl_sst *= 2;
162 
163  // from LDS we need load as
164  // M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16) * Nv(4) + 4)
165  // ( 2 issue) (rem 32-lane) (4 wave*4issue) 2lane*1ussue(pk2)
166  // sld(v4) = v0/2 *34*4 + v0 % 2 *4 + wid*2 *4
167  int sfl_sld = (lane_id % 2) * 2 + (lane_id / 2) * (64 + 4) + (threadIdx.x / 64) * 4;
168  sfl_sld *= 2;
169 
170  // B nr->kr
171  // clang-format off
172 #pragma clang diagnostic push
173 #pragma clang diagnostic ignored "-Winline-asm"
174  asm volatile(
175 #define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
177  :[smem_]"+r"(smem),
178  [s_loop_cnt]"+s"(loop_cnt),
179  [c0]"+v" (v_c0),
180  [c1]"+v" (v_c1),
181  [c2]"+v" (v_c2),
182  [c3]"+v" (v_c3),
183  [c4]"+v" (v_c4),
184  [c5]"+v" (v_c5),
185  [c6]"+v" (v_c6),
186  [c7]"+v" (v_c7),
187  [c8]"+v" (v_c8),
188  [c9]"+v" (v_c9),
189  [c10]"+v"(v_c10),
190  [c11]"+v"(v_c11),
191  [c12]"+v"(v_c12),
192  [c13]"+v"(v_c13),
193  [c14]"+v"(v_c14),
194  [c15]"+v"(v_c15),
195  [c16]"+v"(v_c16),
196  [c17]"+v"(v_c17),
197  [c18]"+v"(v_c18),
198  [c19]"+v"(v_c19),
199  [c20]"+v"(v_c20),
200  [c21]"+v"(v_c21),
201  [c22]"+v"(v_c22),
202  [c23]"+v"(v_c23),
203  [c24]"+v"(v_c24),
204  [c25]"+v"(v_c25),
205  [c26]"+v"(v_c26),
206  [c27]"+v"(v_c27),
207  [c28]"+v"(v_c28),
208  [c29]"+v"(v_c29),
209  [c30]"+v"(v_c30),
210  [c31]"+v"(v_c31)
211  :
212  [sld_a_base]"n"(0),
213  [shfl_base]"n"(0),
214  [v_sld_y_os]"v"(sld_y_os),
215  [v_sfl_sld]"v"(sfl_sld),
216  [v_sfl_sst]"v"(sfl_sst),
217  [s_res_o0]"s"(res_o[0]),
218  [s_res_o1]"s"(res_o[1]),
219  //[s_res_o2]"s"(res_o[2]),
220  //[s_res_o3]"s"(res_o[3]),
221  [s_res_b0]"s"(res_b[0]),
222  [s_res_b1]"s"(res_b[1]),
223  [s_res_b2]"s"(res_b[2]),
224  [s_res_b3]"s"(res_b[3]),
225  [v_os_o0]"v"(static_cast<index_t>(cached_coords_o[number<0>{}] * sizeof(ODataType))),
226  [v_os_o1]"v"(static_cast<index_t>(cached_coords_o[number<1>{}] * sizeof(ODataType))),
227  [v_os_o2]"v"(static_cast<index_t>(cached_coords_o[number<2>{}] * sizeof(ODataType))),
228  [v_os_o3]"v"(static_cast<index_t>(cached_coords_o[number<3>{}] * sizeof(ODataType))),
229  [v_os_o4]"v"(static_cast<index_t>(cached_coords_o[number<4>{}] * sizeof(ODataType))),
230  [v_os_o5]"v"(static_cast<index_t>(cached_coords_o[number<5>{}] * sizeof(ODataType))),
231  [v_os_o6]"v"(static_cast<index_t>(cached_coords_o[number<6>{}] * sizeof(ODataType))),
232  [v_os_o7]"v"(static_cast<index_t>(cached_coords_o[number<7>{}] * sizeof(ODataType))),
233  [v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
234  [v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
235  [v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
236  [v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
237  [v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
238  [v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
239  [v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
240  [v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
241 
242  [s_tile_os_o]"s"(tile_stride_o_bytes),
243  [s_tile_os_b]"s"(tile_stride_b_bytes),
244  [scale_0]"v"(s0),
245  [scale_1]"v"(s1),
246  [v_nan_lo]"v"(nan_lo),
247  [v_nan_hi]"v"(nan_hi),
248  [s_execflag_0]"s"(o_flags[number<0>{}]),
249  [s_execflag_1]"s"(o_flags[number<1>{}]),
250  [s_execflag_2]"s"(o_flags[number<2>{}]),
251  [s_execflag_3]"s"(o_flags[number<3>{}]),
252  [s_execflag_4]"s"(o_flags[number<4>{}]),
253  [s_execflag_5]"s"(o_flags[number<5>{}]),
254  [s_execflag_6]"s"(o_flags[number<6>{}]),
255  [s_execflag_7]"s"(o_flags[number<7>{}])
256  :
257  "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
258  "a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
259  "a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
260  "a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
261  "a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
262  "a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
263  "a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
264  "a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
265  "a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
266  "a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
267  "a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
268  "a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
269  "a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
270  "a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
271  "a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
272  "a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
273  "a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
274  "a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
275  "a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
276  "a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
277  "a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
278  "a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
279  "a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
280  "a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
281  "a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
282  "a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
283  "a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
284  "a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
285  "a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
286  "a252", "a253", "a254", "a255",
287  "s8", "s9", "s12", "s13", "s14", "s15", "s38", "s39", "s52", "s86",
288  "s36", "s37",
289  "v50", "v54", "v55",
290  "v64","v65","v66","v67","v68","v69","v70","v71",
291  "v72","v73","v74","v75","v76","v77","v78","v79",
292  "v80","v81","v82","v83","v84","v85","v86","v87",
293  "v88","v89","v90","v91","v92","v93","v94","v95",
294  "v128", "v129", "v130", "v131",
295  "v132", "v133", "v134", "v135", "v136", "v137", "v138", "v139",
296  "v140", "v141", "v142", "v143", "v144", "v145", "v146", "v147",
297  "v148", "v149", "v150", "v151", "v152", "v153", "v154", "v155",
298  "v156", "v157", "v158", "v159", "v160", "v161", "v162", "v163",
299  "v164", "v165", "v166", "v167", "v168", "v169", "v170", "v171",
300  "v172", "v173", "v174", "v175", "v176", "v177", "v178", "v179",
301  "v180", "v181", "v182", "v183", "v184", "v185", "v186", "v187",
302  "v188", "v189", "v190", "v191", "v192", "v193", "v194", "v195",
303  "v196", "v197", "v198", "v199", "v200", "v201", "v202", "v203",
304  "v204", "v205", "v206", "v207", "v208", "v209", "v210", "v211",
305  "v212", "v213", "v214", "v215", "v216", "v217", "v218", "v219",
306  "v220", "v221", "v222", "v223", "v224", "v225", "v226", "v227",
307  "v228", "v229", "v230", "v231", "v232", "v233", "v234", "v235",
308  "v236", "v237", "v238", "v239", "v240", "v241", "v242", "v243",
309  "v244", "v245", "v246", "v247", "v248", "v249", "v250", "v251",
310  "v252", "v253", "v254", "v255"
311  );
312 #pragma clang diagnostic pop
313  // clang-format on
314  }
315 };
316 
318 {
319  using BDataType = bf16_t;
320  using ODataType = bf16_t;
321 
322  // TODO: need paired with tile_window_linear!
323  // TODO: need call init_raw() before call this function!
324  // template <typename AWindow, typename BWindow, typename OWindow, typename ScaleTensor>
325  template <typename BRes,
326  typename BCoords,
327  typename ORes,
328  typename OCoords,
329  typename OFlags,
330  typename ScaleTensor>
331  CK_TILE_DEVICE auto
332  operator()(const BRes& res_b,
333  const BCoords& cached_coords_b,
334  const ORes& res_o,
335  const OCoords& cached_coords_o,
336  const OFlags& o_flags, // this should be in sgpr
337  CK_TILE_LDS_ADDR void* smem,
338  index_t n, // loop along n dim
339  const ScaleTensor& scale_,
340  index_t tile_offset_b, // stride b is fixed to blockKr * blockW, but still can adjust
341  index_t tile_offset_o)
342  {
343  static_assert(BCoords::size() == 8); // 8
344  static_assert(OCoords::size() == 8);
345 
346  const index_t tile_stride_b_bytes = tile_offset_b * sizeof(BDataType);
347  const index_t tile_stride_o_bytes = tile_offset_o * sizeof(ODataType);
348 
349  static_assert(ScaleTensor::size() == 2);
350  float s0 = scale_[number<0>{}];
351  float s1 = scale_[number<1>{}];
352 
353  index_t loop_cnt = n / Block_N;
354 
355  register float v_c0 asm("v64");
356  register float v_c1 asm("v65");
357  register float v_c2 asm("v66");
358  register float v_c3 asm("v67");
359  register float v_c4 asm("v68");
360  register float v_c5 asm("v69");
361  register float v_c6 asm("v70");
362  register float v_c7 asm("v71");
363  register float v_c8 asm("v72");
364  register float v_c9 asm("v73");
365  register float v_c10 asm("v74");
366  register float v_c11 asm("v75");
367  register float v_c12 asm("v76");
368  register float v_c13 asm("v77");
369  register float v_c14 asm("v78");
370  register float v_c15 asm("v79");
371  register float v_c16 asm("v80");
372  register float v_c17 asm("v81");
373  register float v_c18 asm("v82");
374  register float v_c19 asm("v83");
375  register float v_c20 asm("v84");
376  register float v_c21 asm("v85");
377  register float v_c22 asm("v86");
378  register float v_c23 asm("v87");
379  register float v_c24 asm("v88");
380  register float v_c25 asm("v89");
381  register float v_c26 asm("v90");
382  register float v_c27 asm("v91");
383  register float v_c28 asm("v92");
384  register float v_c29 asm("v93");
385  register float v_c30 asm("v94");
386  register float v_c31 asm("v95");
387  int32_t nan_hi = 0x7fff0000;
388  int32_t nan_lo = 0x00007fff;
389 
390  // in smem, the layout is M0(2)*K0(128)*M1(16)*K1(4)
391  // every threads need 8xK in contiguous register
392  // ... and every wave need the same data
393  int lane_id = threadIdx.x % 64;
394  int sld_y_os = (lane_id % 16) * 4 + (lane_id / 16) * 128;
395  sld_y_os *= 2;
396 
397  // y y p p p y
398  // reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
399  // but order is N0*M0*Nv
400  // in LDS we need store as
401  // M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
402  // y y wave-id lid/16 lid%16 v
403  // sst(v3) = (v0/16*34 + v0%16 * 2 + wid*136) * 4
404  int sfl_sst = (threadIdx.x % 16 * 4) + (threadIdx.x / 16) * (64 + 4);
405  sfl_sst *= 2;
406 
407  // from LDS we need load as
408  // M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16) * Nv(4) + 4)
409  // ( 2 issue) (rem 32-lane) (4 wave*4issue) 2lane*1ussue(pk2)
410  // sld(v4) = v0/2 *34*4 + v0 % 2 *4 + wid*2 *4
411  int sfl_sld = (lane_id % 2) * 2 + (lane_id / 2) * (64 + 4) + (threadIdx.x / 64) * 4;
412  sfl_sld *= 2;
413 
414  // B nr->kr
415  // clang-format off
416 #pragma clang diagnostic push
417 #pragma clang diagnostic ignored "-Winline-asm"
418  asm volatile(
419 #define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
421  :[smem_]"+r"(smem),
422  [s_loop_cnt]"+s"(loop_cnt),
423  [c0]"+v" (v_c0),
424  [c1]"+v" (v_c1),
425  [c2]"+v" (v_c2),
426  [c3]"+v" (v_c3),
427  [c4]"+v" (v_c4),
428  [c5]"+v" (v_c5),
429  [c6]"+v" (v_c6),
430  [c7]"+v" (v_c7),
431  [c8]"+v" (v_c8),
432  [c9]"+v" (v_c9),
433  [c10]"+v"(v_c10),
434  [c11]"+v"(v_c11),
435  [c12]"+v"(v_c12),
436  [c13]"+v"(v_c13),
437  [c14]"+v"(v_c14),
438  [c15]"+v"(v_c15),
439  [c16]"+v"(v_c16),
440  [c17]"+v"(v_c17),
441  [c18]"+v"(v_c18),
442  [c19]"+v"(v_c19),
443  [c20]"+v"(v_c20),
444  [c21]"+v"(v_c21),
445  [c22]"+v"(v_c22),
446  [c23]"+v"(v_c23),
447  [c24]"+v"(v_c24),
448  [c25]"+v"(v_c25),
449  [c26]"+v"(v_c26),
450  [c27]"+v"(v_c27),
451  [c28]"+v"(v_c28),
452  [c29]"+v"(v_c29),
453  [c30]"+v"(v_c30),
454  [c31]"+v"(v_c31)
455  :
456  [sld_a_base]"n"(0),
457  [shfl_base]"n"(0),
458  [v_sld_y_os]"v"(sld_y_os),
459  [v_sfl_sld]"v"(sfl_sld),
460  [v_sfl_sst]"v"(sfl_sst),
461  [s_res_o0]"s"(res_o[0]),
462  [s_res_o1]"s"(res_o[1]),
463  //[s_res_o2]"s"(res_o[2]),
464  //[s_res_o3]"s"(res_o[3]),
465  [s_res_b0]"s"(res_b[0]),
466  [s_res_b1]"s"(res_b[1]),
467  [s_res_b2]"s"(res_b[2]),
468  [s_res_b3]"s"(res_b[3]),
469  [v_os_o0]"v"(static_cast<index_t>(cached_coords_o[number<0>{}] * sizeof(ODataType))),
470  [v_os_o1]"v"(static_cast<index_t>(cached_coords_o[number<1>{}] * sizeof(ODataType))),
471  [v_os_o2]"v"(static_cast<index_t>(cached_coords_o[number<2>{}] * sizeof(ODataType))),
472  [v_os_o3]"v"(static_cast<index_t>(cached_coords_o[number<3>{}] * sizeof(ODataType))),
473  [v_os_o4]"v"(static_cast<index_t>(cached_coords_o[number<4>{}] * sizeof(ODataType))),
474  [v_os_o5]"v"(static_cast<index_t>(cached_coords_o[number<5>{}] * sizeof(ODataType))),
475  [v_os_o6]"v"(static_cast<index_t>(cached_coords_o[number<6>{}] * sizeof(ODataType))),
476  [v_os_o7]"v"(static_cast<index_t>(cached_coords_o[number<7>{}] * sizeof(ODataType))),
477  [v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
478  [v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
479  [v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
480  [v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
481  [v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
482  [v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
483  [v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
484  [v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
485 
486  [s_tile_os_o]"s"(tile_stride_o_bytes),
487  [s_tile_os_b]"s"(tile_stride_b_bytes),
488  [scale_0]"v"(s0),
489  [scale_1]"v"(s1),
490  [v_nan_lo]"v"(nan_lo),
491  [v_nan_hi]"v"(nan_hi),
492  [s_execflag_0]"s"(o_flags[number<0>{}]),
493  [s_execflag_1]"s"(o_flags[number<1>{}]),
494  [s_execflag_2]"s"(o_flags[number<2>{}]),
495  [s_execflag_3]"s"(o_flags[number<3>{}]),
496  [s_execflag_4]"s"(o_flags[number<4>{}]),
497  [s_execflag_5]"s"(o_flags[number<5>{}]),
498  [s_execflag_6]"s"(o_flags[number<6>{}]),
499  [s_execflag_7]"s"(o_flags[number<7>{}])
500  :
501  "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
502  "a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
503  "a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
504  "a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
505  "a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
506  "a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
507  "a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
508  "a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
509  "a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
510  "a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
511  "a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
512  "a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
513  "a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
514  "a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
515  "a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
516  "a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
517  "a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
518  "a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
519  "a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
520  "a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
521  "a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
522  "a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
523  "a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
524  "a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
525  "a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
526  "a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
527  "a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
528  "a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
529  "a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
530  "a252", "a253", "a254", "a255",
531  "s8", "s9", "s12", "s13", "s14", "s15", "s38", "s39", "s52", "s86",
532  "s36", "s37",
533  "v50", "v54", "v55",
534  "v64","v65","v66","v67","v68","v69","v70","v71",
535  "v72","v73","v74","v75","v76","v77","v78","v79",
536  "v80","v81","v82","v83","v84","v85","v86","v87",
537  "v88","v89","v90","v91","v92","v93","v94","v95",
538  "v128", "v129", "v130", "v131",
539  "v132", "v133", "v134", "v135", "v136", "v137", "v138", "v139",
540  "v140", "v141", "v142", "v143", "v144", "v145", "v146", "v147",
541  "v148", "v149", "v150", "v151", "v152", "v153", "v154", "v155",
542  "v156", "v157", "v158", "v159", "v160", "v161", "v162", "v163",
543  "v164", "v165", "v166", "v167", "v168", "v169", "v170", "v171",
544  "v172", "v173", "v174", "v175", "v176", "v177", "v178", "v179",
545  "v180", "v181", "v182", "v183", "v184", "v185", "v186", "v187",
546  "v188", "v189", "v190", "v191", "v192", "v193", "v194", "v195",
547  "v196", "v197", "v198", "v199", "v200", "v201", "v202", "v203",
548  "v204", "v205", "v206", "v207", "v208", "v209", "v210", "v211",
549  "v212", "v213", "v214", "v215", "v216", "v217", "v218", "v219",
550  "v220", "v221", "v222", "v223", "v224", "v225", "v226", "v227",
551  "v228", "v229", "v230", "v231", "v232", "v233", "v234", "v235",
552  "v236", "v237", "v238", "v239", "v240", "v241", "v242", "v243",
553  "v244", "v245", "v246", "v247", "v248", "v249", "v250", "v251",
554  "v252", "v253", "v254", "v255"
555  );
556 #pragma clang diagnostic pop
557  // clang-format on
558  }
559 };
560 
561 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_LDS_ADDR
Definition: config.hpp:58
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:457
Definition: cluster_descriptor.hpp:13
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
int32_t index_t
Definition: integer.hpp:9
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:74
bf16_t BDataType
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:75
CK_TILE_DEVICE auto operator()(const BRes &res_b, const BCoords &cached_coords_b, const ORes &res_o, const OCoords &cached_coords_o, const OFlags &o_flags, CK_TILE_LDS_ADDR void *smem, index_t n, const ScaleTensor &scale_, index_t tile_offset_b, index_t tile_offset_o)
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:88
bf16_t ODataType
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:76
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:16
static constexpr index_t WarpPerBlock_M
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:21
static constexpr index_t WarpPerBlock_N
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:22
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:60
static constexpr index_t Block_N
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:18
static constexpr index_t Warp_K
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:27
static constexpr index_t Repeat_M
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:38
static constexpr index_t Block_Nr
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:35
static constexpr index_t WarpPerBlock_K
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:23
static constexpr CK_TILE_DEVICE auto MakeCBlockDist()
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:42
static constexpr index_t BlockSize
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:29
static constexpr index_t Block_K
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:19
static constexpr index_t Block_W
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:34
static constexpr index_t Repeat_K
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:40
static constexpr index_t Block_Kr
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:36
static constexpr index_t Warp_N
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:26
static constexpr index_t Warp_M
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:25
static constexpr index_t Repeat_N
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:39
static constexpr index_t Block_M
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:17
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:318
CK_TILE_DEVICE auto operator()(const BRes &res_b, const BCoords &cached_coords_b, const ORes &res_o, const OCoords &cached_coords_o, const OFlags &o_flags, CK_TILE_LDS_ADDR void *smem, index_t n, const ScaleTensor &scale_, index_t tile_offset_b, index_t tile_offset_o)
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:332
bf16_t BDataType
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:319
bf16_t ODataType
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:320
Definition: warp_gemm_impl.hpp:11
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192