67 template <index_t LanesPerK, index_t WarpSize,
typename =
void>
68 struct LdsStoreDescSelector;
70 template <index_t LanesPerK, index_t WarpSize>
71 struct LdsStoreDescSelector<LanesPerK, WarpSize, std::
enable_if_t<(LanesPerK >= WarpSize)>>
73 template <index_t NumWarps, index_t Block_M, index_t Block_K, index_t KVector, index_t KPad>
77 static_assert(LanesPerK % WarpSize == 0);
78 constexpr
index_t wavesPerK = LanesPerK / WarpSize;
94 number<wavesPerK*(WarpSize * KVector + KPad)>{},
95 number<WarpSize * KVector + KPad>{},
107 make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}),
108 make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
110 return lds_block_desc_issues_warps_lanes;
115 template <index_t LanesPerK, index_t WarpSize>
116 struct LdsStoreDescSelector<LanesPerK, WarpSize, std::
enable_if_t<(LanesPerK < WarpSize)>>
118 template <index_t NumWarps, index_t Block_M, index_t Block_K, index_t KVector, index_t KPad>
122 static_assert(WarpSize % LanesPerK == 0);
123 constexpr
index_t LaneGroups = WarpSize / LanesPerK;
128 number<LaneGroups>{},
134 number<WarpSize * KVector + KPad>{},
145 number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
146 make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
147 make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
149 return lds_block_desc_issues_warps_lanes;
167 c_block_outer_dstr_encoding,
typename WG::CWarpDstrEncoding{});
174 using CDataType = float;
176 auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
177 return c_block_tensor;
187 constexpr
index_t KPad = KPack_;
189 static_assert(
Block_K % KVector == 0);
192 return LdsStoreDescSelector<LanesPerK, WarpSize>::
193 template MakeDesc<NumWarps, Block_M, Block_K, KVector, KPad>();
201 constexpr
index_t KPad = KPack_;
203 constexpr
index_t kAMLane = 16;
204 constexpr
index_t kABKLane = 4;
205 constexpr
index_t kABKPerLane = 4;
207 static_assert(KPack_ == (kABKPerLane * kKIter));
209 constexpr
auto lds_block_desc_0 =
236 constexpr
index_t kAMLane = 16;
237 constexpr
index_t kABKLane = 4;
238 constexpr
index_t kABKPerLane = 4;
259 #define _EXPAND_ASM_ARGS_OUT_ONE_ACC \
260 [s_loop_cnt]"+s"(loop_cnt), \
261 [v_acc_0]"+v"(v_acc[0]), \
262 [v_acc_1]"+v"(v_acc[1]), \
263 [v_acc_2]"+v"(v_acc[2]), \
264 [v_acc_3]"+v"(v_acc[3]), \
265 [v_acc_4]"+v"(v_acc[4]), \
266 [v_acc_5]"+v"(v_acc[5]), \
267 [v_acc_6]"+v"(v_acc[6]), \
268 [v_acc_7]"+v"(v_acc[7]), \
269 [v_acc_8]"+v"(v_acc[8]), \
270 [v_acc_9]"+v"(v_acc[9]), \
271 [v_acc_10]"+v"(v_acc[10]), \
272 [v_acc_11]"+v"(v_acc[11]), \
273 [v_acc_12]"+v"(v_acc[12]), \
274 [v_acc_13]"+v"(v_acc[13]), \
275 [v_acc_14]"+v"(v_acc[14]), \
276 [v_acc_15]"+v"(v_acc[15]), \
279 #define _EXPAND_ASM_ARGS_OUT_TWO_ACC \
280 [s_loop_cnt]"+s"(loop_cnt), \
281 [v_acc_0]"+v"(v_acc[0]), \
282 [v_acc_1]"+v"(v_acc[1]), \
283 [v_acc_2]"+v"(v_acc[2]), \
284 [v_acc_3]"+v"(v_acc[3]), \
285 [v_acc_4]"+v"(v_acc[4]), \
286 [v_acc_5]"+v"(v_acc[5]), \
287 [v_acc_6]"+v"(v_acc[6]), \
288 [v_acc_7]"+v"(v_acc[7]), \
289 [v_acc_8]"+v"(v_acc[8]), \
290 [v_acc_9]"+v"(v_acc[9]), \
291 [v_acc_10]"+v"(v_acc[10]), \
292 [v_acc_11]"+v"(v_acc[11]), \
293 [v_acc_12]"+v"(v_acc[12]), \
294 [v_acc_13]"+v"(v_acc[13]), \
295 [v_acc_14]"+v"(v_acc[14]), \
296 [v_acc_15]"+v"(v_acc[15]), \
297 [v_acc_16]"+v"(v_acc[16]), \
298 [v_acc_17]"+v"(v_acc[17]), \
299 [v_acc_18]"+v"(v_acc[18]), \
300 [v_acc_19]"+v"(v_acc[19]), \
301 [v_acc_20]"+v"(v_acc[20]), \
302 [v_acc_21]"+v"(v_acc[21]), \
303 [v_acc_22]"+v"(v_acc[22]), \
304 [v_acc_23]"+v"(v_acc[23]), \
305 [v_acc_24]"+v"(v_acc[24]), \
306 [v_acc_25]"+v"(v_acc[25]), \
307 [v_acc_26]"+v"(v_acc[26]), \
308 [v_acc_27]"+v"(v_acc[27]), \
309 [v_acc_28]"+v"(v_acc[28]), \
310 [v_acc_29]"+v"(v_acc[29]), \
311 [v_acc_30]"+v"(v_acc[30]), \
312 [v_acc_31]"+v"(v_acc[31]), \
315 #define _EXPAND_ASM_ARGS_IN \
316 [s_res_a0]"s"(res_a[0]), \
317 [s_res_a1]"s"(res_a[1]), \
318 [s_res_a2]"s"(res_a[2]), \
319 [s_res_a3]"s"(res_a[3]), \
320 [s_res_b0]"s"(res_b[0]), \
321 [s_res_b1]"s"(res_b[1]), \
322 [s_res_b2]"s"(res_b[2]), \
323 [s_res_b3]"s"(res_b[3]), \
324 [v_os_a0]"v"(static_cast<index_t>(cached_coords_a[number<0>{}] * sizeof(ADataType))), \
325 [v_os_a1]"v"(static_cast<index_t>(cached_coords_a[number<1>{}] * sizeof(ADataType))), \
326 [v_os_a2]"v"(static_cast<index_t>(cached_coords_a[number<2>{}] * sizeof(ADataType))), \
327 [v_os_a3]"v"(static_cast<index_t>(cached_coords_a[number<3>{}] * sizeof(ADataType))), \
328 [v_os_a4]"v"(static_cast<index_t>(cached_coords_a[number<4>{}] * sizeof(ADataType))), \
329 [v_os_a5]"v"(static_cast<index_t>(cached_coords_a[number<5>{}] * sizeof(ADataType))), \
330 [v_os_a6]"v"(static_cast<index_t>(cached_coords_a[number<6>{}] * sizeof(ADataType))), \
331 [v_os_a7]"v"(static_cast<index_t>(cached_coords_a[number<7>{}] * sizeof(ADataType))), \
333 [v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))), \
334 [v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))), \
335 [v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))), \
336 [v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))), \
337 [v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))), \
338 [v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))), \
339 [v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))), \
340 [v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))), \
342 [v_os_slda]"v"(static_cast<index_t>(a_sld.cached_coords_[number<0>{}].get_offset() * sizeof(ADataType))),\
343 [s_m0_init]"s"(m0_init_value), \
344 [s_size_per_issue]"s"(size_per_issue), \
345 [smem_sz]"n"(smem_buf_size), \
346 [sld_os_0]"n"(sld_os[number<0>{}].value), \
347 [sld_os_1]"n"(sld_os[number<1>{}].value), \
348 [sld_os_2]"n"(sld_os[number<2>{}].value), \
349 [sld_os_3]"n"(sld_os[number<3>{}].value), \
350 [sld_os_4]"n"(sld_os[number<4>{}].value), \
351 [sld_os_5]"n"(sld_os[number<5>{}].value), \
352 [sld_os_6]"n"(sld_os[number<6>{}].value), \
353 [sld_os_7]"n"(sld_os[number<7>{}].value), \
354 [s_tile_os_a]"s"(tile_offset_a_bytes), \
355 [s_tile_os_b]"s"(tile_offset_b_bytes)
357 #define _EXPAND_ASM_ARGS_CLOBBER \
358 "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9", \
359 "a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19", \
360 "a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29", \
361 "a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39", \
362 "a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49", \
363 "a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59", \
364 "a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69", \
365 "a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79", \
366 "a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89", \
367 "a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99", \
368 "a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107", \
369 "a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115", \
370 "a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123", \
371 "a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131", \
372 "a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139", \
373 "a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147", \
374 "a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155", \
375 "a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163", \
376 "a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171", \
377 "a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179", \
378 "a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187", \
379 "a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195", \
380 "a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203", \
381 "a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211", \
382 "a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219", \
383 "a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227", \
384 "a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235", \
385 "a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243", \
386 "a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251", \
387 "a252", "a253", "a254", "a255", \
388 "s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23", \
390 "v64", "v65", "v66", "v67", "v68", "v69", \
391 "v70", "v71", "v72", "v73", "v74", "v75", "v76", "v77", "v78", "v79", \
392 "v80", "v81", "v82", "v83", "v84", "v85", "v86", "v87", "v88", "v89", \
393 "v90", "v91", "v92", "v93", "v94", "v95", "v96", "v97", "v98", "v99", \
394 "v100", "v101", "v102", "v103", "v104", "v105", "v106", "v107", \
395 "v108", "v109", "v110", "v111", "v112", "v113", "v114", "v115", \
396 "v116", "v117", "v118", "v119", "v120", "v121", "v122", "v123", \
397 "v124", "v125", "v126", "v127"
409 template <
typename ARes,
typename ACoords,
typename BRes,
typename BCoords,
bool Is2B = false>
412 const ACoords& cached_coords_a,
414 const BCoords& cached_coords_b,
422 static_assert(BCoords::size() ==
Repeat_N);
425 make_tensor_view<address_space_enum::lds>(
432 constexpr
auto a_outer_dstr_enc = tile_distribution_encoding<
433 sequence<WarpPerBlock_N>,
434 tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_K>>,
435 tuple<sequence<1, 0>>,
436 tuple<sequence<1, 0>>,
439 constexpr
auto a_block_dstr_encode =
442 make_tensor_view<address_space_enum::lds>(
453 constexpr
auto smem_buf_size =
455 static_assert(a_sld.get_num_of_access() == 8);
458 return number<a_sld.get_bottom_linear_offset(i_access) *
sizeof(
ADataType)>{};
460 number<a_sld.get_num_of_access()>{});
470 #pragma clang diagnostic push
471 #pragma clang diagnostic ignored "-Winline-asm"
474 #define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
475 #define CK_TILE_FLATMM_UK_2B 1
479 [s_res_b4]
"s"(res_b[4]),
480 [s_res_b5]
"s"(res_b[5]),
481 [s_res_b6]
"s"(res_b[6]),
482 [s_res_b7]
"s"(res_b[7])
486 #pragma clang diagnostic pop
490 for(
auto i = 0; i < 16; i++)
492 c.at(number<0>{}).get_thread_buffer()[4 * i + 0] = v_acc[i].x;
493 c.at(number<0>{}).get_thread_buffer()[4 * i + 1] = v_acc[i].y;
494 c.at(number<0>{}).get_thread_buffer()[4 * i + 2] = v_acc[i].z;
495 c.at(number<0>{}).get_thread_buffer()[4 * i + 3] = v_acc[i].w;
497 for(
auto i = 0; i < 16; i++)
499 c.at(number<1>{}).get_thread_buffer()[4 * i + 0] = v_acc[16 + i].x;
500 c.at(number<1>{}).get_thread_buffer()[4 * i + 1] = v_acc[16 + i].y;
501 c.at(number<1>{}).get_thread_buffer()[4 * i + 2] = v_acc[16 + i].z;
502 c.at(number<1>{}).get_thread_buffer()[4 * i + 3] = v_acc[16 + i].w;
512 #pragma clang diagnostic push
513 #pragma clang diagnostic ignored "-Winline-asm"
516 #define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
523 #pragma clang diagnostic pop
527 for(
auto i = 0; i < 16; i++)
529 c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
530 c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
531 c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
532 c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
546 template <
typename ARes,
typename ACoords,
typename BRes,
typename BCoords,
bool Is2B = false>
549 const ACoords& cached_coords_a,
551 const BCoords& cached_coords_b,
559 static_assert(BCoords::size() ==
Repeat_N);
562 make_tensor_view<address_space_enum::lds>(
569 constexpr
auto a_outer_dstr_enc = tile_distribution_encoding<
570 sequence<WarpPerBlock_N>,
571 tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_K>>,
572 tuple<sequence<1, 0>>,
573 tuple<sequence<1, 0>>,
576 constexpr
auto a_block_dstr_encode =
579 make_tensor_view<address_space_enum::lds>(
590 constexpr
auto smem_buf_size =
592 static_assert(a_sld.get_num_of_access() == 8);
595 return number<a_sld.get_bottom_linear_offset(i_access) *
sizeof(
ADataType)>{};
597 number<a_sld.get_num_of_access()>{});
607 #pragma clang diagnostic push
608 #pragma clang diagnostic ignored "-Winline-asm"
611 #define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
612 #define CK_TILE_FLATMM_UK_2B 1
616 [s_res_b4]
"s"(res_b[4]),
617 [s_res_b5]
"s"(res_b[5]),
618 [s_res_b6]
"s"(res_b[6]),
619 [s_res_b7]
"s"(res_b[7])
623 #pragma clang diagnostic pop
627 for(
auto i = 0; i < 16; i++)
629 c.at(number<0>{}).get_thread_buffer()[4 * i + 0] = v_acc[i].x;
630 c.at(number<0>{}).get_thread_buffer()[4 * i + 1] = v_acc[i].y;
631 c.at(number<0>{}).get_thread_buffer()[4 * i + 2] = v_acc[i].z;
632 c.at(number<0>{}).get_thread_buffer()[4 * i + 3] = v_acc[i].w;
634 for(
auto i = 0; i < 16; i++)
636 c.at(number<1>{}).get_thread_buffer()[4 * i + 0] = v_acc[16 + i].x;
637 c.at(number<1>{}).get_thread_buffer()[4 * i + 1] = v_acc[16 + i].y;
638 c.at(number<1>{}).get_thread_buffer()[4 * i + 2] = v_acc[16 + i].z;
639 c.at(number<1>{}).get_thread_buffer()[4 * i + 3] = v_acc[16 + i].w;
649 #pragma clang diagnostic push
650 #pragma clang diagnostic ignored "-Winline-asm"
653 #define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
660 #pragma clang diagnostic pop
664 for(
auto i = 0; i < 16; i++)
666 c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
667 c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
668 c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
669 c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
675 #undef _EXPAND_ASM_ARGS_OUT_ONE_ACC
676 #undef _EXPAND_ASM_ARGS_OUT_TWO_ACC
677 #undef _EXPAND_ASM_ARGS_IN
678 #undef _EXPAND_ASM_ARGS_CLOBBER
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_LDS_ADDR
Definition: config.hpp:58
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
#define _EXPAND_ASM_ARGS_OUT_TWO_ACC
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:279
#define _EXPAND_ASM_ARGS_CLOBBER
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:357
#define _EXPAND_ASM_ARGS_IN
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:315
#define _EXPAND_ASM_ARGS_OUT_ONE_ACC
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:259
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:457
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:268
_Float16 fp16_t
Definition: half.hpp:110
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1615
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
constant< v > number
Definition: integral_constant.hpp:37
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:197
CK_TILE_DEVICE auto get_async_store_smem_info(LdsTileWindow_ &&lds_tile)
Definition: tile_window_utils.hpp:31
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_DEVICE auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition: tile_window_linear.hpp:993
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
float fp32x4_t
Definition: vector_type.hpp:117
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:401
bf16_t ADataType
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:402
bf16_t BDataType
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:403
CK_TILE_DEVICE auto operator()(const ARes &res_a, const ACoords &cached_coords_a, const BRes &res_b, const BCoords &cached_coords_b, CK_TILE_LDS_ADDR void *smem, index_t k, index_t tile_offset_a, index_t tile_offset_b, bool_constant< Is2B >={})
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:411
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:38
static constexpr index_t NumWarps
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:47
static constexpr index_t WarpPerBlock_N
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:44
static constexpr CK_TILE_DEVICE auto MakeCBlockTile()
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:172
static constexpr index_t SubKPacks
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:55
static constexpr index_t WarpPerBlock_K
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:45
static constexpr index_t Block_K
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:41
static constexpr index_t Block_N
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:40
static constexpr CK_TILE_HOST_DEVICE auto MakeLdsLoadDesc_A()
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:197
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:251
static constexpr index_t Block_W
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:58
static constexpr auto GetGemm_AWarpEnc()
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:234
static constexpr index_t Warp_M
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:49
static constexpr CK_TILE_HOST_DEVICE auto MakeLdsStoreDesc_A()
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:180
static constexpr index_t Block_M
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:39
static constexpr index_t Block_Kr
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:60
static constexpr index_t Repeat_K
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:64
static constexpr CK_TILE_DEVICE auto MakeCBlockDist()
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:154
static constexpr index_t Warp_N
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:50
static constexpr index_t Block_Nr
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:59
static constexpr index_t Warp_K
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:51
static constexpr index_t Repeat_M
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:62
static constexpr index_t BlockSize
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:53
static constexpr index_t WarpPerBlock_M
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:43
static constexpr index_t Repeat_N
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:63
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:540
CK_TILE_DEVICE auto operator()(const ARes &res_a, const ACoords &cached_coords_a, const BRes &res_b, const BCoords &cached_coords_b, CK_TILE_LDS_ADDR void *smem, index_t k, index_t tile_offset_a, index_t tile_offset_b, bool_constant< Is2B >={})
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:548
fp16_t ADataType
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:541
fp16_t BDataType
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:542
Definition: warp_gemm_impl.hpp:11
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192