55 c_block_outer_dstr_encoding,
typename WG::CWarpDstrEncoding{});
69 return 2 * 2 * 4 * 4 * (16 * 4 + 4) *
sizeof(
bf16_t) * nbufs;
81 template <
typename BRes,
89 const BCoords& cached_coords_b,
91 const OCoords& cached_coords_o,
92 const OFlags& o_flags,
95 const ScaleTensor& scale_,
99 static_assert(BCoords::size() == 8);
100 static_assert(OCoords::size() == 8);
105 static_assert(ScaleTensor::size() == 2);
111 register float v_c0
asm(
"v64");
112 register float v_c1
asm(
"v65");
113 register float v_c2
asm(
"v66");
114 register float v_c3
asm(
"v67");
115 register float v_c4
asm(
"v68");
116 register float v_c5
asm(
"v69");
117 register float v_c6
asm(
"v70");
118 register float v_c7
asm(
"v71");
119 register float v_c8
asm(
"v72");
120 register float v_c9
asm(
"v73");
121 register float v_c10
asm(
"v74");
122 register float v_c11
asm(
"v75");
123 register float v_c12
asm(
"v76");
124 register float v_c13
asm(
"v77");
125 register float v_c14
asm(
"v78");
126 register float v_c15
asm(
"v79");
127 register float v_c16
asm(
"v80");
128 register float v_c17
asm(
"v81");
129 register float v_c18
asm(
"v82");
130 register float v_c19
asm(
"v83");
131 register float v_c20
asm(
"v84");
132 register float v_c21
asm(
"v85");
133 register float v_c22
asm(
"v86");
134 register float v_c23
asm(
"v87");
135 register float v_c24
asm(
"v88");
136 register float v_c25
asm(
"v89");
137 register float v_c26
asm(
"v90");
138 register float v_c27
asm(
"v91");
139 register float v_c28
asm(
"v92");
140 register float v_c29
asm(
"v93");
141 register float v_c30
asm(
"v94");
142 register float v_c31
asm(
"v95");
149 int lane_id = threadIdx.x % 64;
150 int sld_y_os = (lane_id % 16) * 4 + (lane_id / 16) * 128;
160 int sfl_sst = (threadIdx.x % 16 * 4) + (threadIdx.x / 16) * (64 + 4);
167 int sfl_sld = (lane_id % 2) * 2 + (lane_id / 2) * (64 + 4) + (threadIdx.x / 64) * 4;
172 #pragma clang diagnostic push
173 #pragma clang diagnostic ignored "-Winline-asm"
175 #define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
178 [s_loop_cnt]
"+s"(loop_cnt),
214 [v_sld_y_os]
"v"(sld_y_os),
215 [v_sfl_sld]
"v"(sfl_sld),
216 [v_sfl_sst]
"v"(sfl_sst),
217 [s_res_o0]
"s"(res_o[0]),
218 [s_res_o1]
"s"(res_o[1]),
221 [s_res_b0]
"s"(res_b[0]),
222 [s_res_b1]
"s"(res_b[1]),
223 [s_res_b2]
"s"(res_b[2]),
224 [s_res_b3]
"s"(res_b[3]),
242 [s_tile_os_o]
"s"(tile_stride_o_bytes),
243 [s_tile_os_b]
"s"(tile_stride_b_bytes),
246 [v_nan_lo]
"v"(nan_lo),
247 [v_nan_hi]
"v"(nan_hi),
257 "memory",
"a0",
"a1",
"a2",
"a3",
"a4",
"a5",
"a6",
"a7",
"a8",
"a9",
258 "a10",
"a11",
"a12",
"a13",
"a14",
"a15",
"a16",
"a17",
"a18",
"a19",
259 "a20",
"a21",
"a22",
"a23",
"a24",
"a25",
"a26",
"a27",
"a28",
"a29",
260 "a30",
"a31",
"a32",
"a33",
"a34",
"a35",
"a36",
"a37",
"a38",
"a39",
261 "a40",
"a41",
"a42",
"a43",
"a44",
"a45",
"a46",
"a47",
"a48",
"a49",
262 "a50",
"a51",
"a52",
"a53",
"a54",
"a55",
"a56",
"a57",
"a58",
"a59",
263 "a60",
"a61",
"a62",
"a63",
"a64",
"a65",
"a66",
"a67",
"a68",
"a69",
264 "a70",
"a71",
"a72",
"a73",
"a74",
"a75",
"a76",
"a77",
"a78",
"a79",
265 "a80",
"a81",
"a82",
"a83",
"a84",
"a85",
"a86",
"a87",
"a88",
"a89",
266 "a90",
"a91",
"a92",
"a93",
"a94",
"a95",
"a96",
"a97",
"a98",
"a99",
267 "a100",
"a101",
"a102",
"a103",
"a104",
"a105",
"a106",
"a107",
268 "a108",
"a109",
"a110",
"a111",
"a112",
"a113",
"a114",
"a115",
269 "a116",
"a117",
"a118",
"a119",
"a120",
"a121",
"a122",
"a123",
270 "a124",
"a125",
"a126",
"a127",
"a128",
"a129",
"a130",
"a131",
271 "a132",
"a133",
"a134",
"a135",
"a136",
"a137",
"a138",
"a139",
272 "a140",
"a141",
"a142",
"a143",
"a144",
"a145",
"a146",
"a147",
273 "a148",
"a149",
"a150",
"a151",
"a152",
"a153",
"a154",
"a155",
274 "a156",
"a157",
"a158",
"a159",
"a160",
"a161",
"a162",
"a163",
275 "a164",
"a165",
"a166",
"a167",
"a168",
"a169",
"a170",
"a171",
276 "a172",
"a173",
"a174",
"a175",
"a176",
"a177",
"a178",
"a179",
277 "a180",
"a181",
"a182",
"a183",
"a184",
"a185",
"a186",
"a187",
278 "a188",
"a189",
"a190",
"a191",
"a192",
"a193",
"a194",
"a195",
279 "a196",
"a197",
"a198",
"a199",
"a200",
"a201",
"a202",
"a203",
280 "a204",
"a205",
"a206",
"a207",
"a208",
"a209",
"a210",
"a211",
281 "a212",
"a213",
"a214",
"a215",
"a216",
"a217",
"a218",
"a219",
282 "a220",
"a221",
"a222",
"a223",
"a224",
"a225",
"a226",
"a227",
283 "a228",
"a229",
"a230",
"a231",
"a232",
"a233",
"a234",
"a235",
284 "a236",
"a237",
"a238",
"a239",
"a240",
"a241",
"a242",
"a243",
285 "a244",
"a245",
"a246",
"a247",
"a248",
"a249",
"a250",
"a251",
286 "a252",
"a253",
"a254",
"a255",
287 "s8",
"s9",
"s12",
"s13",
"s14",
"s15",
"s38",
"s39",
"s52",
"s86",
290 "v64",
"v65",
"v66",
"v67",
"v68",
"v69",
"v70",
"v71",
291 "v72",
"v73",
"v74",
"v75",
"v76",
"v77",
"v78",
"v79",
292 "v80",
"v81",
"v82",
"v83",
"v84",
"v85",
"v86",
"v87",
293 "v88",
"v89",
"v90",
"v91",
"v92",
"v93",
"v94",
"v95",
294 "v128",
"v129",
"v130",
"v131",
295 "v132",
"v133",
"v134",
"v135",
"v136",
"v137",
"v138",
"v139",
296 "v140",
"v141",
"v142",
"v143",
"v144",
"v145",
"v146",
"v147",
297 "v148",
"v149",
"v150",
"v151",
"v152",
"v153",
"v154",
"v155",
298 "v156",
"v157",
"v158",
"v159",
"v160",
"v161",
"v162",
"v163",
299 "v164",
"v165",
"v166",
"v167",
"v168",
"v169",
"v170",
"v171",
300 "v172",
"v173",
"v174",
"v175",
"v176",
"v177",
"v178",
"v179",
301 "v180",
"v181",
"v182",
"v183",
"v184",
"v185",
"v186",
"v187",
302 "v188",
"v189",
"v190",
"v191",
"v192",
"v193",
"v194",
"v195",
303 "v196",
"v197",
"v198",
"v199",
"v200",
"v201",
"v202",
"v203",
304 "v204",
"v205",
"v206",
"v207",
"v208",
"v209",
"v210",
"v211",
305 "v212",
"v213",
"v214",
"v215",
"v216",
"v217",
"v218",
"v219",
306 "v220",
"v221",
"v222",
"v223",
"v224",
"v225",
"v226",
"v227",
307 "v228",
"v229",
"v230",
"v231",
"v232",
"v233",
"v234",
"v235",
308 "v236",
"v237",
"v238",
"v239",
"v240",
"v241",
"v242",
"v243",
309 "v244",
"v245",
"v246",
"v247",
"v248",
"v249",
"v250",
"v251",
310 "v252",
"v253",
"v254",
"v255"
312 #pragma clang diagnostic pop
325 template <
typename BRes,
330 typename ScaleTensor>
333 const BCoords& cached_coords_b,
335 const OCoords& cached_coords_o,
336 const OFlags& o_flags,
339 const ScaleTensor& scale_,
343 static_assert(BCoords::size() == 8);
344 static_assert(OCoords::size() == 8);
349 static_assert(ScaleTensor::size() == 2);
355 register float v_c0
asm(
"v64");
356 register float v_c1
asm(
"v65");
357 register float v_c2
asm(
"v66");
358 register float v_c3
asm(
"v67");
359 register float v_c4
asm(
"v68");
360 register float v_c5
asm(
"v69");
361 register float v_c6
asm(
"v70");
362 register float v_c7
asm(
"v71");
363 register float v_c8
asm(
"v72");
364 register float v_c9
asm(
"v73");
365 register float v_c10
asm(
"v74");
366 register float v_c11
asm(
"v75");
367 register float v_c12
asm(
"v76");
368 register float v_c13
asm(
"v77");
369 register float v_c14
asm(
"v78");
370 register float v_c15
asm(
"v79");
371 register float v_c16
asm(
"v80");
372 register float v_c17
asm(
"v81");
373 register float v_c18
asm(
"v82");
374 register float v_c19
asm(
"v83");
375 register float v_c20
asm(
"v84");
376 register float v_c21
asm(
"v85");
377 register float v_c22
asm(
"v86");
378 register float v_c23
asm(
"v87");
379 register float v_c24
asm(
"v88");
380 register float v_c25
asm(
"v89");
381 register float v_c26
asm(
"v90");
382 register float v_c27
asm(
"v91");
383 register float v_c28
asm(
"v92");
384 register float v_c29
asm(
"v93");
385 register float v_c30
asm(
"v94");
386 register float v_c31
asm(
"v95");
393 int lane_id = threadIdx.x % 64;
394 int sld_y_os = (lane_id % 16) * 4 + (lane_id / 16) * 128;
404 int sfl_sst = (threadIdx.x % 16 * 4) + (threadIdx.x / 16) * (64 + 4);
411 int sfl_sld = (lane_id % 2) * 2 + (lane_id / 2) * (64 + 4) + (threadIdx.x / 64) * 4;
416 #pragma clang diagnostic push
417 #pragma clang diagnostic ignored "-Winline-asm"
419 #define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
422 [s_loop_cnt]
"+s"(loop_cnt),
458 [v_sld_y_os]
"v"(sld_y_os),
459 [v_sfl_sld]
"v"(sfl_sld),
460 [v_sfl_sst]
"v"(sfl_sst),
461 [s_res_o0]
"s"(res_o[0]),
462 [s_res_o1]
"s"(res_o[1]),
465 [s_res_b0]
"s"(res_b[0]),
466 [s_res_b1]
"s"(res_b[1]),
467 [s_res_b2]
"s"(res_b[2]),
468 [s_res_b3]
"s"(res_b[3]),
486 [s_tile_os_o]
"s"(tile_stride_o_bytes),
487 [s_tile_os_b]
"s"(tile_stride_b_bytes),
490 [v_nan_lo]
"v"(nan_lo),
491 [v_nan_hi]
"v"(nan_hi),
501 "memory",
"a0",
"a1",
"a2",
"a3",
"a4",
"a5",
"a6",
"a7",
"a8",
"a9",
502 "a10",
"a11",
"a12",
"a13",
"a14",
"a15",
"a16",
"a17",
"a18",
"a19",
503 "a20",
"a21",
"a22",
"a23",
"a24",
"a25",
"a26",
"a27",
"a28",
"a29",
504 "a30",
"a31",
"a32",
"a33",
"a34",
"a35",
"a36",
"a37",
"a38",
"a39",
505 "a40",
"a41",
"a42",
"a43",
"a44",
"a45",
"a46",
"a47",
"a48",
"a49",
506 "a50",
"a51",
"a52",
"a53",
"a54",
"a55",
"a56",
"a57",
"a58",
"a59",
507 "a60",
"a61",
"a62",
"a63",
"a64",
"a65",
"a66",
"a67",
"a68",
"a69",
508 "a70",
"a71",
"a72",
"a73",
"a74",
"a75",
"a76",
"a77",
"a78",
"a79",
509 "a80",
"a81",
"a82",
"a83",
"a84",
"a85",
"a86",
"a87",
"a88",
"a89",
510 "a90",
"a91",
"a92",
"a93",
"a94",
"a95",
"a96",
"a97",
"a98",
"a99",
511 "a100",
"a101",
"a102",
"a103",
"a104",
"a105",
"a106",
"a107",
512 "a108",
"a109",
"a110",
"a111",
"a112",
"a113",
"a114",
"a115",
513 "a116",
"a117",
"a118",
"a119",
"a120",
"a121",
"a122",
"a123",
514 "a124",
"a125",
"a126",
"a127",
"a128",
"a129",
"a130",
"a131",
515 "a132",
"a133",
"a134",
"a135",
"a136",
"a137",
"a138",
"a139",
516 "a140",
"a141",
"a142",
"a143",
"a144",
"a145",
"a146",
"a147",
517 "a148",
"a149",
"a150",
"a151",
"a152",
"a153",
"a154",
"a155",
518 "a156",
"a157",
"a158",
"a159",
"a160",
"a161",
"a162",
"a163",
519 "a164",
"a165",
"a166",
"a167",
"a168",
"a169",
"a170",
"a171",
520 "a172",
"a173",
"a174",
"a175",
"a176",
"a177",
"a178",
"a179",
521 "a180",
"a181",
"a182",
"a183",
"a184",
"a185",
"a186",
"a187",
522 "a188",
"a189",
"a190",
"a191",
"a192",
"a193",
"a194",
"a195",
523 "a196",
"a197",
"a198",
"a199",
"a200",
"a201",
"a202",
"a203",
524 "a204",
"a205",
"a206",
"a207",
"a208",
"a209",
"a210",
"a211",
525 "a212",
"a213",
"a214",
"a215",
"a216",
"a217",
"a218",
"a219",
526 "a220",
"a221",
"a222",
"a223",
"a224",
"a225",
"a226",
"a227",
527 "a228",
"a229",
"a230",
"a231",
"a232",
"a233",
"a234",
"a235",
528 "a236",
"a237",
"a238",
"a239",
"a240",
"a241",
"a242",
"a243",
529 "a244",
"a245",
"a246",
"a247",
"a248",
"a249",
"a250",
"a251",
530 "a252",
"a253",
"a254",
"a255",
531 "s8",
"s9",
"s12",
"s13",
"s14",
"s15",
"s38",
"s39",
"s52",
"s86",
534 "v64",
"v65",
"v66",
"v67",
"v68",
"v69",
"v70",
"v71",
535 "v72",
"v73",
"v74",
"v75",
"v76",
"v77",
"v78",
"v79",
536 "v80",
"v81",
"v82",
"v83",
"v84",
"v85",
"v86",
"v87",
537 "v88",
"v89",
"v90",
"v91",
"v92",
"v93",
"v94",
"v95",
538 "v128",
"v129",
"v130",
"v131",
539 "v132",
"v133",
"v134",
"v135",
"v136",
"v137",
"v138",
"v139",
540 "v140",
"v141",
"v142",
"v143",
"v144",
"v145",
"v146",
"v147",
541 "v148",
"v149",
"v150",
"v151",
"v152",
"v153",
"v154",
"v155",
542 "v156",
"v157",
"v158",
"v159",
"v160",
"v161",
"v162",
"v163",
543 "v164",
"v165",
"v166",
"v167",
"v168",
"v169",
"v170",
"v171",
544 "v172",
"v173",
"v174",
"v175",
"v176",
"v177",
"v178",
"v179",
545 "v180",
"v181",
"v182",
"v183",
"v184",
"v185",
"v186",
"v187",
546 "v188",
"v189",
"v190",
"v191",
"v192",
"v193",
"v194",
"v195",
547 "v196",
"v197",
"v198",
"v199",
"v200",
"v201",
"v202",
"v203",
548 "v204",
"v205",
"v206",
"v207",
"v208",
"v209",
"v210",
"v211",
549 "v212",
"v213",
"v214",
"v215",
"v216",
"v217",
"v218",
"v219",
550 "v220",
"v221",
"v222",
"v223",
"v224",
"v225",
"v226",
"v227",
551 "v228",
"v229",
"v230",
"v231",
"v232",
"v233",
"v234",
"v235",
552 "v236",
"v237",
"v238",
"v239",
"v240",
"v241",
"v242",
"v243",
553 "v244",
"v245",
"v246",
"v247",
"v248",
"v249",
"v250",
"v251",
554 "v252",
"v253",
"v254",
"v255"
556 #pragma clang diagnostic pop
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_LDS_ADDR
Definition: config.hpp:58
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:457
Definition: cluster_descriptor.hpp:13
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
int32_t index_t
Definition: integer.hpp:9
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:74
bf16_t BDataType
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:75
CK_TILE_DEVICE auto operator()(const BRes &res_b, const BCoords &cached_coords_b, const ORes &res_o, const OCoords &cached_coords_o, const OFlags &o_flags, CK_TILE_LDS_ADDR void *smem, index_t n, const ScaleTensor &scale_, index_t tile_offset_b, index_t tile_offset_o)
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:88
bf16_t ODataType
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:76
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:16
static constexpr index_t WarpPerBlock_M
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:21
static constexpr index_t WarpPerBlock_N
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:22
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:60
static constexpr index_t Block_N
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:18
static constexpr index_t Warp_K
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:27
static constexpr index_t Repeat_M
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:38
static constexpr index_t Block_Nr
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:35
static constexpr index_t WarpPerBlock_K
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:23
static constexpr CK_TILE_DEVICE auto MakeCBlockDist()
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:42
static constexpr index_t BlockSize
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:29
static constexpr index_t Block_K
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:19
static constexpr index_t Block_W
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:34
static constexpr index_t Repeat_K
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:40
static constexpr index_t Block_Kr
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:36
static constexpr index_t Warp_N
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:26
static constexpr index_t Warp_M
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:25
static constexpr index_t Repeat_N
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:39
static constexpr index_t Block_M
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:17
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:318
CK_TILE_DEVICE auto operator()(const BRes &res_b, const BCoords &cached_coords_b, const ORes &res_o, const OCoords &cached_coords_o, const OFlags &o_flags, CK_TILE_LDS_ADDR void *smem, index_t n, const ScaleTensor &scale_, index_t tile_offset_b, index_t tile_offset_o)
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:332
bf16_t BDataType
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:319
bf16_t ODataType
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:320
Definition: warp_gemm_impl.hpp:11
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192