79 c_block_outer_dstr_encoding,
typename WG::CWarpDstrEncoding{});
86 using CDataType = float;
88 auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
89 return c_block_tensor;
103 constexpr
index_t KPad = KPack_;
105 static_assert(
Block_K % KVector == 0);
107 if constexpr(LanesPerK >= warpSize)
110 static_assert(LanesPerK % warpSize == 0);
111 constexpr
index_t wavesPerK = LanesPerK / warpSize;
127 number<wavesPerK*(warpSize * KVector + KPad)>{},
143 return lds_block_desc_issues_warps_lanes;
149 static_assert(warpSize % LanesPerK == 0);
150 constexpr
index_t LaneGroups = warpSize / LanesPerK;
176 return lds_block_desc_issues_warps_lanes;
185 constexpr
index_t KPad = KPack_;
187 constexpr
index_t kAMLane = 16;
188 constexpr
index_t kABKLane = 4;
189 constexpr
index_t kABKPerLane = 4;
191 static_assert(KPack_ == (kABKPerLane * kKIter));
193 constexpr
auto lds_block_desc_0 =
220 constexpr
index_t kAMLane = 16;
221 constexpr
index_t kABKLane = 4;
222 constexpr
index_t kABKPerLane = 4;
243 #define _EXPAND_ASM_ARGS_OUT_ONE_ACC \
244 [s_loop_cnt]"+s"(loop_cnt), \
245 [v_acc_0]"+v"(v_acc[0]), \
246 [v_acc_1]"+v"(v_acc[1]), \
247 [v_acc_2]"+v"(v_acc[2]), \
248 [v_acc_3]"+v"(v_acc[3]), \
249 [v_acc_4]"+v"(v_acc[4]), \
250 [v_acc_5]"+v"(v_acc[5]), \
251 [v_acc_6]"+v"(v_acc[6]), \
252 [v_acc_7]"+v"(v_acc[7]), \
253 [v_acc_8]"+v"(v_acc[8]), \
254 [v_acc_9]"+v"(v_acc[9]), \
255 [v_acc_10]"+v"(v_acc[10]), \
256 [v_acc_11]"+v"(v_acc[11]), \
257 [v_acc_12]"+v"(v_acc[12]), \
258 [v_acc_13]"+v"(v_acc[13]), \
259 [v_acc_14]"+v"(v_acc[14]), \
260 [v_acc_15]"+v"(v_acc[15]), \
263 #define _EXPAND_ASM_ARGS_OUT_TWO_ACC \
264 [s_loop_cnt]"+s"(loop_cnt), \
265 [v_acc_0]"+v"(v_acc[0]), \
266 [v_acc_1]"+v"(v_acc[1]), \
267 [v_acc_2]"+v"(v_acc[2]), \
268 [v_acc_3]"+v"(v_acc[3]), \
269 [v_acc_4]"+v"(v_acc[4]), \
270 [v_acc_5]"+v"(v_acc[5]), \
271 [v_acc_6]"+v"(v_acc[6]), \
272 [v_acc_7]"+v"(v_acc[7]), \
273 [v_acc_8]"+v"(v_acc[8]), \
274 [v_acc_9]"+v"(v_acc[9]), \
275 [v_acc_10]"+v"(v_acc[10]), \
276 [v_acc_11]"+v"(v_acc[11]), \
277 [v_acc_12]"+v"(v_acc[12]), \
278 [v_acc_13]"+v"(v_acc[13]), \
279 [v_acc_14]"+v"(v_acc[14]), \
280 [v_acc_15]"+v"(v_acc[15]), \
281 [v_acc_16]"+v"(v_acc[16]), \
282 [v_acc_17]"+v"(v_acc[17]), \
283 [v_acc_18]"+v"(v_acc[18]), \
284 [v_acc_19]"+v"(v_acc[19]), \
285 [v_acc_20]"+v"(v_acc[20]), \
286 [v_acc_21]"+v"(v_acc[21]), \
287 [v_acc_22]"+v"(v_acc[22]), \
288 [v_acc_23]"+v"(v_acc[23]), \
289 [v_acc_24]"+v"(v_acc[24]), \
290 [v_acc_25]"+v"(v_acc[25]), \
291 [v_acc_26]"+v"(v_acc[26]), \
292 [v_acc_27]"+v"(v_acc[27]), \
293 [v_acc_28]"+v"(v_acc[28]), \
294 [v_acc_29]"+v"(v_acc[29]), \
295 [v_acc_30]"+v"(v_acc[30]), \
296 [v_acc_31]"+v"(v_acc[31]), \
299 #define _EXPAND_ASM_ARGS_IN \
300 [s_res_a0]"s"(res_a[0]), \
301 [s_res_a1]"s"(res_a[1]), \
302 [s_res_a2]"s"(res_a[2]), \
303 [s_res_a3]"s"(res_a[3]), \
304 [s_res_b0]"s"(res_b[0]), \
305 [s_res_b1]"s"(res_b[1]), \
306 [s_res_b2]"s"(res_b[2]), \
307 [s_res_b3]"s"(res_b[3]), \
308 [v_os_a0]"v"(static_cast<index_t>(cached_coords_a[number<0>{}] * sizeof(ADataType))), \
309 [v_os_a1]"v"(static_cast<index_t>(cached_coords_a[number<1>{}] * sizeof(ADataType))), \
310 [v_os_a2]"v"(static_cast<index_t>(cached_coords_a[number<2>{}] * sizeof(ADataType))), \
311 [v_os_a3]"v"(static_cast<index_t>(cached_coords_a[number<3>{}] * sizeof(ADataType))), \
312 [v_os_a4]"v"(static_cast<index_t>(cached_coords_a[number<4>{}] * sizeof(ADataType))), \
313 [v_os_a5]"v"(static_cast<index_t>(cached_coords_a[number<5>{}] * sizeof(ADataType))), \
314 [v_os_a6]"v"(static_cast<index_t>(cached_coords_a[number<6>{}] * sizeof(ADataType))), \
315 [v_os_a7]"v"(static_cast<index_t>(cached_coords_a[number<7>{}] * sizeof(ADataType))), \
317 [v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))), \
318 [v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))), \
319 [v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))), \
320 [v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))), \
321 [v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))), \
322 [v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))), \
323 [v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))), \
324 [v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))), \
326 [v_os_slda]"v"(static_cast<index_t>(a_sld.cached_coords_[number<0>{}].get_offset() * sizeof(ADataType))),\
327 [s_m0_init]"s"(m0_init_value), \
328 [s_size_per_issue]"s"(size_per_issue), \
329 [smem_sz]"n"(smem_buf_size), \
330 [sld_os_0]"n"(sld_os[number<0>{}].value), \
331 [sld_os_1]"n"(sld_os[number<1>{}].value), \
332 [sld_os_2]"n"(sld_os[number<2>{}].value), \
333 [sld_os_3]"n"(sld_os[number<3>{}].value), \
334 [sld_os_4]"n"(sld_os[number<4>{}].value), \
335 [sld_os_5]"n"(sld_os[number<5>{}].value), \
336 [sld_os_6]"n"(sld_os[number<6>{}].value), \
337 [sld_os_7]"n"(sld_os[number<7>{}].value), \
338 [s_tile_os_a]"s"(tile_offset_a_bytes), \
339 [s_tile_os_b]"s"(tile_offset_b_bytes)
341 #define _EXPAND_ASM_ARGS_CLOBBER \
342 "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9", \
343 "a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19", \
344 "a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29", \
345 "a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39", \
346 "a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49", \
347 "a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59", \
348 "a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69", \
349 "a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79", \
350 "a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89", \
351 "a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99", \
352 "a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107", \
353 "a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115", \
354 "a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123", \
355 "a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131", \
356 "a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139", \
357 "a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147", \
358 "a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155", \
359 "a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163", \
360 "a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171", \
361 "a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179", \
362 "a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187", \
363 "a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195", \
364 "a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203", \
365 "a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211", \
366 "a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219", \
367 "a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227", \
368 "a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235", \
369 "a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243", \
370 "a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251", \
371 "a252", "a253", "a254", "a255", \
372 "s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23", \
374 "v64", "v65", "v66", "v67", "v68", "v69", \
375 "v70", "v71", "v72", "v73", "v74", "v75", "v76", "v77", "v78", "v79", \
376 "v80", "v81", "v82", "v83", "v84", "v85", "v86", "v87", "v88", "v89", \
377 "v90", "v91", "v92", "v93", "v94", "v95", "v96", "v97", "v98", "v99", \
378 "v100", "v101", "v102", "v103", "v104", "v105", "v106", "v107", \
379 "v108", "v109", "v110", "v111", "v112", "v113", "v114", "v115", \
380 "v116", "v117", "v118", "v119", "v120", "v121", "v122", "v123", \
381 "v124", "v125", "v126", "v127"
393 template <
typename ARes,
typename ACoords,
typename BRes,
typename BCoords,
bool Is2B = false>
396 const ACoords& cached_coords_a,
398 const BCoords& cached_coords_b,
406 static_assert(BCoords::size() ==
Repeat_N);
409 make_tensor_view<address_space_enum::lds>(
416 constexpr
auto a_outer_dstr_enc = tile_distribution_encoding<
417 sequence<WarpPerBlock_N>,
418 tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_K>>,
419 tuple<sequence<1, 0>>,
420 tuple<sequence<1, 0>>,
423 constexpr
auto a_block_dstr_encode =
426 make_tensor_view<address_space_enum::lds>(
437 constexpr
auto smem_buf_size =
439 static_assert(a_sld.get_num_of_access() == 8);
442 return number<a_sld.get_bottom_linear_offset(i_access) *
sizeof(
ADataType)>{};
444 number<a_sld.get_num_of_access()>{});
454 #pragma clang diagnostic push
455 #pragma clang diagnostic ignored "-Winline-asm"
458 #define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
459 #define CK_TILE_FLATMM_UK_2B 1
463 [s_res_b4]
"s"(res_b[4]),
464 [s_res_b5]
"s"(res_b[5]),
465 [s_res_b6]
"s"(res_b[6]),
466 [s_res_b7]
"s"(res_b[7])
470 #pragma clang diagnostic pop
474 for(
auto i = 0; i < 16; i++)
476 c.at(number<0>{}).get_thread_buffer()[4 * i + 0] = v_acc[i].x;
477 c.at(number<0>{}).get_thread_buffer()[4 * i + 1] = v_acc[i].y;
478 c.at(number<0>{}).get_thread_buffer()[4 * i + 2] = v_acc[i].z;
479 c.at(number<0>{}).get_thread_buffer()[4 * i + 3] = v_acc[i].w;
481 for(
auto i = 0; i < 16; i++)
483 c.at(number<1>{}).get_thread_buffer()[4 * i + 0] = v_acc[16 + i].x;
484 c.at(number<1>{}).get_thread_buffer()[4 * i + 1] = v_acc[16 + i].y;
485 c.at(number<1>{}).get_thread_buffer()[4 * i + 2] = v_acc[16 + i].z;
486 c.at(number<1>{}).get_thread_buffer()[4 * i + 3] = v_acc[16 + i].w;
496 #pragma clang diagnostic push
497 #pragma clang diagnostic ignored "-Winline-asm"
500 #define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
507 #pragma clang diagnostic pop
511 for(
auto i = 0; i < 16; i++)
513 c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
514 c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
515 c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
516 c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
530 template <
typename ARes,
typename ACoords,
typename BRes,
typename BCoords,
bool Is2B = false>
533 const ACoords& cached_coords_a,
535 const BCoords& cached_coords_b,
543 static_assert(BCoords::size() ==
Repeat_N);
546 make_tensor_view<address_space_enum::lds>(
553 constexpr
auto a_outer_dstr_enc = tile_distribution_encoding<
554 sequence<WarpPerBlock_N>,
555 tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_K>>,
556 tuple<sequence<1, 0>>,
557 tuple<sequence<1, 0>>,
560 constexpr
auto a_block_dstr_encode =
563 make_tensor_view<address_space_enum::lds>(
574 constexpr
auto smem_buf_size =
576 static_assert(a_sld.get_num_of_access() == 8);
579 return number<a_sld.get_bottom_linear_offset(i_access) *
sizeof(
ADataType)>{};
581 number<a_sld.get_num_of_access()>{});
591 #pragma clang diagnostic push
592 #pragma clang diagnostic ignored "-Winline-asm"
595 #define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
596 #define CK_TILE_FLATMM_UK_2B 1
600 [s_res_b4]
"s"(res_b[4]),
601 [s_res_b5]
"s"(res_b[5]),
602 [s_res_b6]
"s"(res_b[6]),
603 [s_res_b7]
"s"(res_b[7])
607 #pragma clang diagnostic pop
611 for(
auto i = 0; i < 16; i++)
613 c.at(number<0>{}).get_thread_buffer()[4 * i + 0] = v_acc[i].x;
614 c.at(number<0>{}).get_thread_buffer()[4 * i + 1] = v_acc[i].y;
615 c.at(number<0>{}).get_thread_buffer()[4 * i + 2] = v_acc[i].z;
616 c.at(number<0>{}).get_thread_buffer()[4 * i + 3] = v_acc[i].w;
618 for(
auto i = 0; i < 16; i++)
620 c.at(number<1>{}).get_thread_buffer()[4 * i + 0] = v_acc[16 + i].x;
621 c.at(number<1>{}).get_thread_buffer()[4 * i + 1] = v_acc[16 + i].y;
622 c.at(number<1>{}).get_thread_buffer()[4 * i + 2] = v_acc[16 + i].z;
623 c.at(number<1>{}).get_thread_buffer()[4 * i + 3] = v_acc[16 + i].w;
633 #pragma clang diagnostic push
634 #pragma clang diagnostic ignored "-Winline-asm"
637 #define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
644 #pragma clang diagnostic pop
648 for(
auto i = 0; i < 16; i++)
650 c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
651 c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
652 c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
653 c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
659 #undef _EXPAND_ASM_ARGS_OUT_ONE_ACC
660 #undef _EXPAND_ASM_ARGS_OUT_TWO_ACC
661 #undef _EXPAND_ASM_ARGS_IN
662 #undef _EXPAND_ASM_ARGS_CLOBBER
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_LDS_ADDR
Definition: config.hpp:56
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
#define _EXPAND_ASM_ARGS_OUT_TWO_ACC
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:263
#define _EXPAND_ASM_ARGS_CLOBBER
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:341
#define _EXPAND_ASM_ARGS_IN
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:299
#define _EXPAND_ASM_ARGS_OUT_ONE_ACC
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:243
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:420
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE index_t get_warp_size()
Definition: arch.hpp:51
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:255
_Float16 fp16_t
Definition: half.hpp:110
bfloat16_t bf16_t
Definition: bfloat16.hpp:106
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1672
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1615
constant< v > number
Definition: integral_constant.hpp:33
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:184
CK_TILE_DEVICE auto get_async_store_smem_info(LdsTileWindow_ &&lds_tile)
Definition: tile_window_utils.hpp:24
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
constexpr CK_TILE_DEVICE auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition: tile_window_linear.hpp:1124
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:400
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
WarpGemmImpl< WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution< WarpGemmAttributeMfmaImplF16F16F32M16N16K16< WGAttrCtlEnum::Default_ >, 2 > > WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
Definition: warp_gemm.hpp:52
float fp32x4_t
Definition: vector_type.hpp:87
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:498
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:385
bf16_t ADataType
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:386
bf16_t BDataType
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:387
CK_TILE_DEVICE auto operator()(const ARes &res_a, const ACoords &cached_coords_a, const BRes &res_b, const BCoords &cached_coords_b, CK_TILE_LDS_ADDR void *smem, index_t k, index_t tile_offset_a, index_t tile_offset_b, bool_constant< Is2B >={})
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:395
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:38
static constexpr index_t NumWarps
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:47
static constexpr index_t WarpPerBlock_N
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:44
static constexpr CK_TILE_DEVICE auto MakeCBlockTile()
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:84
static constexpr index_t SubKPacks
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:55
static constexpr index_t WarpPerBlock_K
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:45
static constexpr index_t Block_K
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:41
static constexpr index_t Block_N
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:40
static constexpr CK_TILE_HOST_DEVICE auto MakeLdsLoadDesc_A()
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:181
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:235
static constexpr index_t Block_W
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:58
static constexpr auto GetGemm_AWarpEnc()
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:218
static constexpr index_t Warp_M
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:49
static constexpr CK_TILE_HOST_DEVICE auto MakeLdsStoreDesc_A()
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:92
static constexpr index_t Block_M
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:39
static constexpr index_t Block_Kr
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:60
static constexpr index_t Repeat_K
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:64
static constexpr CK_TILE_DEVICE auto MakeCBlockDist()
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:66
static constexpr index_t Warp_N
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:50
static constexpr index_t Block_Nr
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:59
static constexpr index_t Warp_K
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:51
static constexpr index_t Repeat_M
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:62
static constexpr index_t BlockSize
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:53
static constexpr index_t WarpPerBlock_M
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:43
static constexpr index_t Repeat_N
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:63
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:524
CK_TILE_DEVICE auto operator()(const ARes &res_a, const ACoords &cached_coords_a, const BRes &res_b, const BCoords &cached_coords_b, CK_TILE_LDS_ADDR void *smem, index_t k, index_t tile_offset_a, index_t tile_offset_b, bool_constant< Is2B >={})
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:532
fp16_t ADataType
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:525
fp16_t BDataType
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:526
Definition: integral_constant.hpp:13
Definition: sequence.hpp:52
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192