67     template <index_t LanesPerK, index_t WarpSize, 
typename = 
void>
 
   68     struct LdsStoreDescSelector;
 
   70     template <index_t LanesPerK, index_t WarpSize>
 
   71     struct LdsStoreDescSelector<LanesPerK, WarpSize, std::
enable_if_t<(LanesPerK >= WarpSize)>>
 
   73         template <index_t NumWarps, index_t Block_M, index_t Block_K, index_t KVector, index_t KPad>
 
   77             static_assert(LanesPerK % WarpSize == 0);
 
   78             constexpr 
index_t wavesPerK = LanesPerK / WarpSize;
 
   94                                number<wavesPerK*(WarpSize * KVector + KPad)>{}, 
 
   95                                number<WarpSize * KVector + KPad>{},             
 
  107                     make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}),
 
  108                     make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
 
  110                 return lds_block_desc_issues_warps_lanes;
 
  115     template <index_t LanesPerK, index_t WarpSize>
 
  116     struct LdsStoreDescSelector<LanesPerK, WarpSize, std::
enable_if_t<(LanesPerK < WarpSize)>>
 
  118         template <index_t NumWarps, index_t Block_M, index_t Block_K, index_t KVector, index_t KPad>
 
  122             static_assert(WarpSize % LanesPerK == 0);
 
  123             constexpr 
index_t LaneGroups = WarpSize / LanesPerK; 
 
  128                            number<LaneGroups>{},                           
 
  134                            number<WarpSize * KVector + KPad>{},            
 
  145                                number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
 
  146                 make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
 
  147                 make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
 
  149             return lds_block_desc_issues_warps_lanes;
 
  167             c_block_outer_dstr_encoding, 
typename WG::CWarpDstrEncoding{});
 
  174         using CDataType             = float;
 
  176         auto c_block_tensor         = make_static_distributed_tensor<CDataType>(c_block_dstr);
 
  177         return c_block_tensor;
 
  187         constexpr 
index_t KPad    = KPack_; 
 
  189         static_assert(
Block_K % KVector == 0);
 
  192         return LdsStoreDescSelector<LanesPerK, WarpSize>::
 
  193             template MakeDesc<NumWarps, Block_M, Block_K, KVector, KPad>();
 
  201         constexpr 
index_t KPad   = KPack_; 
 
  203         constexpr 
index_t kAMLane     = 16;
 
  204         constexpr 
index_t kABKLane    = 4;
 
  205         constexpr 
index_t kABKPerLane = 4;
 
  207         static_assert(KPack_ == (kABKPerLane * kKIter));
 
  209         constexpr 
auto lds_block_desc_0 =
 
  236         constexpr 
index_t kAMLane     = 16;
 
  237         constexpr 
index_t kABKLane    = 4;
 
  238         constexpr 
index_t kABKPerLane = 4;
 
  259 #define _EXPAND_ASM_ARGS_OUT_ONE_ACC        \ 
  260             [s_loop_cnt]"+s"(loop_cnt),     \
 
  261                 [v_acc_0]"+v"(v_acc[0]),    \
 
  262                 [v_acc_1]"+v"(v_acc[1]),    \
 
  263                 [v_acc_2]"+v"(v_acc[2]),    \
 
  264                 [v_acc_3]"+v"(v_acc[3]),    \
 
  265                 [v_acc_4]"+v"(v_acc[4]),    \
 
  266                 [v_acc_5]"+v"(v_acc[5]),    \
 
  267                 [v_acc_6]"+v"(v_acc[6]),    \
 
  268                 [v_acc_7]"+v"(v_acc[7]),    \
 
  269                 [v_acc_8]"+v"(v_acc[8]),    \
 
  270                 [v_acc_9]"+v"(v_acc[9]),    \
 
  271                 [v_acc_10]"+v"(v_acc[10]),    \
 
  272                 [v_acc_11]"+v"(v_acc[11]),    \
 
  273                 [v_acc_12]"+v"(v_acc[12]),    \
 
  274                 [v_acc_13]"+v"(v_acc[13]),    \
 
  275                 [v_acc_14]"+v"(v_acc[14]),    \
 
  276                 [v_acc_15]"+v"(v_acc[15]),    \
 
  279 #define _EXPAND_ASM_ARGS_OUT_TWO_ACC        \ 
  280             [s_loop_cnt]"+s"(loop_cnt),     \
 
  281                 [v_acc_0]"+v"(v_acc[0]),    \
 
  282                 [v_acc_1]"+v"(v_acc[1]),    \
 
  283                 [v_acc_2]"+v"(v_acc[2]),    \
 
  284                 [v_acc_3]"+v"(v_acc[3]),    \
 
  285                 [v_acc_4]"+v"(v_acc[4]),    \
 
  286                 [v_acc_5]"+v"(v_acc[5]),    \
 
  287                 [v_acc_6]"+v"(v_acc[6]),    \
 
  288                 [v_acc_7]"+v"(v_acc[7]),    \
 
  289                 [v_acc_8]"+v"(v_acc[8]),    \
 
  290                 [v_acc_9]"+v"(v_acc[9]),    \
 
  291                 [v_acc_10]"+v"(v_acc[10]),    \
 
  292                 [v_acc_11]"+v"(v_acc[11]),    \
 
  293                 [v_acc_12]"+v"(v_acc[12]),    \
 
  294                 [v_acc_13]"+v"(v_acc[13]),    \
 
  295                 [v_acc_14]"+v"(v_acc[14]),    \
 
  296                 [v_acc_15]"+v"(v_acc[15]),    \
 
  297                 [v_acc_16]"+v"(v_acc[16]),    \
 
  298                 [v_acc_17]"+v"(v_acc[17]),    \
 
  299                 [v_acc_18]"+v"(v_acc[18]),    \
 
  300                 [v_acc_19]"+v"(v_acc[19]),    \
 
  301                 [v_acc_20]"+v"(v_acc[20]),    \
 
  302                 [v_acc_21]"+v"(v_acc[21]),    \
 
  303                 [v_acc_22]"+v"(v_acc[22]),    \
 
  304                 [v_acc_23]"+v"(v_acc[23]),    \
 
  305                 [v_acc_24]"+v"(v_acc[24]),    \
 
  306                 [v_acc_25]"+v"(v_acc[25]),    \
 
  307                 [v_acc_26]"+v"(v_acc[26]),    \
 
  308                 [v_acc_27]"+v"(v_acc[27]),    \
 
  309                 [v_acc_28]"+v"(v_acc[28]),    \
 
  310                 [v_acc_29]"+v"(v_acc[29]),    \
 
  311                 [v_acc_30]"+v"(v_acc[30]),    \
 
  312                 [v_acc_31]"+v"(v_acc[31]),    \
 
  315 #define _EXPAND_ASM_ARGS_IN     \ 
  316               [s_res_a0]"s"(res_a[0]),    \
 
  317                 [s_res_a1]"s"(res_a[1]),    \
 
  318                 [s_res_a2]"s"(res_a[2]),    \
 
  319                 [s_res_a3]"s"(res_a[3]),    \
 
  320                 [s_res_b0]"s"(res_b[0]),    \
 
  321                 [s_res_b1]"s"(res_b[1]),    \
 
  322                 [s_res_b2]"s"(res_b[2]),    \
 
  323                 [s_res_b3]"s"(res_b[3]),    \
 
  324                 [v_os_a0]"v"(static_cast<index_t>(cached_coords_a[number<0>{}] * sizeof(ADataType))),    \
 
  325                 [v_os_a1]"v"(static_cast<index_t>(cached_coords_a[number<1>{}] * sizeof(ADataType))),    \
 
  326                 [v_os_a2]"v"(static_cast<index_t>(cached_coords_a[number<2>{}] * sizeof(ADataType))),    \
 
  327                 [v_os_a3]"v"(static_cast<index_t>(cached_coords_a[number<3>{}] * sizeof(ADataType))),    \
 
  328                 [v_os_a4]"v"(static_cast<index_t>(cached_coords_a[number<4>{}] * sizeof(ADataType))),    \
 
  329                 [v_os_a5]"v"(static_cast<index_t>(cached_coords_a[number<5>{}] * sizeof(ADataType))),    \
 
  330                 [v_os_a6]"v"(static_cast<index_t>(cached_coords_a[number<6>{}] * sizeof(ADataType))),    \
 
  331                 [v_os_a7]"v"(static_cast<index_t>(cached_coords_a[number<7>{}] * sizeof(ADataType))),    \
 
  333                 [v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),    \
 
  334                 [v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),    \
 
  335                 [v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),    \
 
  336                 [v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),    \
 
  337                 [v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),    \
 
  338                 [v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),    \
 
  339                 [v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),    \
 
  340                 [v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),    \
 
  342                 [v_os_slda]"v"(static_cast<index_t>(a_sld.cached_coords_[number<0>{}].get_offset() * sizeof(ADataType))),\
 
  343                 [s_m0_init]"s"(m0_init_value),    \
 
  344                 [s_size_per_issue]"s"(size_per_issue),    \
 
  345                 [smem_sz]"n"(smem_buf_size),   \
 
  346                 [sld_os_0]"n"(sld_os[number<0>{}].value),    \
 
  347                 [sld_os_1]"n"(sld_os[number<1>{}].value),    \
 
  348                 [sld_os_2]"n"(sld_os[number<2>{}].value),    \
 
  349                 [sld_os_3]"n"(sld_os[number<3>{}].value),    \
 
  350                 [sld_os_4]"n"(sld_os[number<4>{}].value),    \
 
  351                 [sld_os_5]"n"(sld_os[number<5>{}].value),    \
 
  352                 [sld_os_6]"n"(sld_os[number<6>{}].value),    \
 
  353                 [sld_os_7]"n"(sld_os[number<7>{}].value),    \
 
  354                 [s_tile_os_a]"s"(tile_offset_a_bytes),    \
 
  355                 [s_tile_os_b]"s"(tile_offset_b_bytes)
 
  357 #define _EXPAND_ASM_ARGS_CLOBBER     \ 
  358           "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",    \
 
  359           "a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",    \
 
  360           "a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",    \
 
  361           "a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",    \
 
  362           "a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",    \
 
  363           "a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",    \
 
  364           "a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",    \
 
  365           "a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",    \
 
  366           "a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",    \
 
  367           "a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",    \
 
  368           "a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",    \
 
  369           "a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",    \
 
  370           "a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",    \
 
  371           "a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",    \
 
  372           "a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",    \
 
  373           "a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",    \
 
  374           "a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",    \
 
  375           "a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",    \
 
  376           "a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",    \
 
  377           "a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",    \
 
  378           "a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",    \
 
  379           "a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",    \
 
  380           "a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",    \
 
  381           "a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",    \
 
  382           "a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",    \
 
  383           "a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",    \
 
  384           "a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",    \
 
  385           "a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",    \
 
  386           "a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",    \
 
  387           "a252", "a253", "a254", "a255",     \
 
  388           "s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23",    \
 
  390           "v64", "v65", "v66", "v67", "v68", "v69",                 \
 
  391           "v70", "v71", "v72", "v73", "v74", "v75", "v76", "v77", "v78", "v79",     \
 
  392           "v80", "v81", "v82", "v83", "v84", "v85", "v86", "v87", "v88", "v89",    \
 
  393           "v90", "v91", "v92", "v93", "v94", "v95", "v96", "v97", "v98", "v99",    \
 
  394           "v100", "v101", "v102", "v103", "v104", "v105", "v106", "v107",    \
 
  395           "v108", "v109", "v110", "v111", "v112", "v113", "v114", "v115",    \
 
  396           "v116", "v117", "v118", "v119", "v120", "v121", "v122", "v123",    \
 
  397           "v124", "v125", "v126", "v127"
 
  409     template <
typename ARes, 
typename ACoords, 
typename BRes, 
typename BCoords, 
bool Is2B = false>
 
  412                const ACoords& cached_coords_a,
 
  414                const BCoords& cached_coords_b,
 
  422         static_assert(BCoords::size() == 
Repeat_N);
 
  425             make_tensor_view<address_space_enum::lds>(
 
  432             constexpr 
auto a_outer_dstr_enc = tile_distribution_encoding<
 
  433                 sequence<WarpPerBlock_N>,
 
  434                 tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_K>>,
 
  435                 tuple<sequence<1, 0>>,
 
  436                 tuple<sequence<1, 0>>,
 
  439             constexpr 
auto a_block_dstr_encode =
 
  442                 make_tensor_view<address_space_enum::lds>(
 
  453         constexpr 
auto smem_buf_size =
 
  455         static_assert(a_sld.get_num_of_access() == 8);
 
  458                 return number<a_sld.get_bottom_linear_offset(i_access) * 
sizeof(
ADataType)>{};
 
  460             number<a_sld.get_num_of_access()>{});
 
  470 #pragma clang diagnostic push 
  471 #pragma clang diagnostic ignored "-Winline-asm" 
  474 #define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16 
  475 #define CK_TILE_FLATMM_UK_2B 1 
  479                     [s_res_b4]
"s"(res_b[4]), 
 
  480                     [s_res_b5]
"s"(res_b[5]),
 
  481                     [s_res_b6]
"s"(res_b[6]),
 
  482                     [s_res_b7]
"s"(res_b[7])
 
  486 #pragma clang diagnostic pop 
  490             for(
auto i = 0; i < 16; i++)
 
  492                 c.at(number<0>{}).get_thread_buffer()[4 * i + 0] = v_acc[i].x;
 
  493                 c.at(number<0>{}).get_thread_buffer()[4 * i + 1] = v_acc[i].y;
 
  494                 c.at(number<0>{}).get_thread_buffer()[4 * i + 2] = v_acc[i].z;
 
  495                 c.at(number<0>{}).get_thread_buffer()[4 * i + 3] = v_acc[i].w;
 
  497             for(
auto i = 0; i < 16; i++)
 
  499                 c.at(number<1>{}).get_thread_buffer()[4 * i + 0] = v_acc[16 + i].x;
 
  500                 c.at(number<1>{}).get_thread_buffer()[4 * i + 1] = v_acc[16 + i].y;
 
  501                 c.at(number<1>{}).get_thread_buffer()[4 * i + 2] = v_acc[16 + i].z;
 
  502                 c.at(number<1>{}).get_thread_buffer()[4 * i + 3] = v_acc[16 + i].w;
 
  512 #pragma clang diagnostic push 
  513 #pragma clang diagnostic ignored "-Winline-asm" 
  516 #define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16 
  523 #pragma clang diagnostic pop 
  527             for(
auto i = 0; i < 16; i++)
 
  529                 c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
 
  530                 c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
 
  531                 c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
 
  532                 c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
 
  546     template <
typename ARes, 
typename ACoords, 
typename BRes, 
typename BCoords, 
bool Is2B = false>
 
  549                const ACoords& cached_coords_a,
 
  551                const BCoords& cached_coords_b,
 
  559         static_assert(BCoords::size() == 
Repeat_N);
 
  562             make_tensor_view<address_space_enum::lds>(
 
  569             constexpr 
auto a_outer_dstr_enc = tile_distribution_encoding<
 
  570                 sequence<WarpPerBlock_N>,
 
  571                 tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_K>>,
 
  572                 tuple<sequence<1, 0>>,
 
  573                 tuple<sequence<1, 0>>,
 
  576             constexpr 
auto a_block_dstr_encode =
 
  579                 make_tensor_view<address_space_enum::lds>(
 
  590         constexpr 
auto smem_buf_size =
 
  592         static_assert(a_sld.get_num_of_access() == 8);
 
  595                 return number<a_sld.get_bottom_linear_offset(i_access) * 
sizeof(
ADataType)>{};
 
  597             number<a_sld.get_num_of_access()>{});
 
  607 #pragma clang diagnostic push 
  608 #pragma clang diagnostic ignored "-Winline-asm" 
  611 #define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16 
  612 #define CK_TILE_FLATMM_UK_2B 1 
  616                     [s_res_b4]
"s"(res_b[4]), 
 
  617                     [s_res_b5]
"s"(res_b[5]),
 
  618                     [s_res_b6]
"s"(res_b[6]),
 
  619                     [s_res_b7]
"s"(res_b[7])
 
  623 #pragma clang diagnostic pop 
  627             for(
auto i = 0; i < 16; i++)
 
  629                 c.at(number<0>{}).get_thread_buffer()[4 * i + 0] = v_acc[i].x;
 
  630                 c.at(number<0>{}).get_thread_buffer()[4 * i + 1] = v_acc[i].y;
 
  631                 c.at(number<0>{}).get_thread_buffer()[4 * i + 2] = v_acc[i].z;
 
  632                 c.at(number<0>{}).get_thread_buffer()[4 * i + 3] = v_acc[i].w;
 
  634             for(
auto i = 0; i < 16; i++)
 
  636                 c.at(number<1>{}).get_thread_buffer()[4 * i + 0] = v_acc[16 + i].x;
 
  637                 c.at(number<1>{}).get_thread_buffer()[4 * i + 1] = v_acc[16 + i].y;
 
  638                 c.at(number<1>{}).get_thread_buffer()[4 * i + 2] = v_acc[16 + i].z;
 
  639                 c.at(number<1>{}).get_thread_buffer()[4 * i + 3] = v_acc[16 + i].w;
 
  649 #pragma clang diagnostic push 
  650 #pragma clang diagnostic ignored "-Winline-asm" 
  653 #define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16 
  660 #pragma clang diagnostic pop 
  664             for(
auto i = 0; i < 16; i++)
 
  666                 c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
 
  667                 c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
 
  668                 c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
 
  669                 c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
 
  675 #undef _EXPAND_ASM_ARGS_OUT_ONE_ACC 
  676 #undef _EXPAND_ASM_ARGS_OUT_TWO_ACC 
  677 #undef _EXPAND_ASM_ARGS_IN 
  678 #undef _EXPAND_ASM_ARGS_CLOBBER 
#define CK_TILE_DEVICE
Definition: config.hpp:41
 
#define CK_TILE_LDS_ADDR
Definition: config.hpp:58
 
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
 
#define _EXPAND_ASM_ARGS_OUT_TWO_ACC
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:279
 
#define _EXPAND_ASM_ARGS_CLOBBER
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:357
 
#define _EXPAND_ASM_ARGS_IN
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:315
 
#define _EXPAND_ASM_ARGS_OUT_ONE_ACC
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:259
 
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:457
 
Definition: cluster_descriptor.hpp:13
 
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:274
 
_Float16 fp16_t
Definition: half.hpp:110
 
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
 
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1615
 
int32_t index_t
Definition: integer.hpp:9
 
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
 
constant< v > number
Definition: integral_constant.hpp:37
 
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:203
 
CK_TILE_DEVICE auto get_async_store_smem_info(LdsTileWindow_ &&lds_tile)
Definition: tile_window_utils.hpp:31
 
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
 
constexpr CK_TILE_DEVICE auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition: tile_window_linear.hpp:993
 
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
 
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
 
float fp32x4_t
Definition: vector_type.hpp:128
 
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
 
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
 
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
 
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:401
 
bf16_t ADataType
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:402
 
bf16_t BDataType
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:403
 
CK_TILE_DEVICE auto operator()(const ARes &res_a, const ACoords &cached_coords_a, const BRes &res_b, const BCoords &cached_coords_b, CK_TILE_LDS_ADDR void *smem, index_t k, index_t tile_offset_a, index_t tile_offset_b, bool_constant< Is2B >={})
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:411
 
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:38
 
static constexpr index_t NumWarps
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:47
 
static constexpr index_t WarpPerBlock_N
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:44
 
static constexpr CK_TILE_DEVICE auto MakeCBlockTile()
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:172
 
static constexpr index_t SubKPacks
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:55
 
static constexpr index_t WarpPerBlock_K
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:45
 
static constexpr index_t Block_K
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:41
 
static constexpr index_t Block_N
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:40
 
static constexpr CK_TILE_HOST_DEVICE auto MakeLdsLoadDesc_A()
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:197
 
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:251
 
static constexpr index_t Block_W
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:58
 
static constexpr auto GetGemm_AWarpEnc()
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:234
 
static constexpr index_t Warp_M
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:49
 
static constexpr CK_TILE_HOST_DEVICE auto MakeLdsStoreDesc_A()
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:180
 
static constexpr index_t Block_M
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:39
 
static constexpr index_t Block_Kr
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:60
 
static constexpr index_t Repeat_K
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:64
 
static constexpr CK_TILE_DEVICE auto MakeCBlockDist()
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:154
 
static constexpr index_t Warp_N
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:50
 
static constexpr index_t Block_Nr
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:59
 
static constexpr index_t Warp_K
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:51
 
static constexpr index_t Repeat_M
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:62
 
static constexpr index_t BlockSize
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:53
 
static constexpr index_t WarpPerBlock_M
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:43
 
static constexpr index_t Repeat_N
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:63
 
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:540
 
CK_TILE_DEVICE auto operator()(const ARes &res_a, const ACoords &cached_coords_a, const BRes &res_b, const BCoords &cached_coords_b, CK_TILE_LDS_ADDR void *smem, index_t k, index_t tile_offset_a, index_t tile_offset_b, bool_constant< Is2B >={})
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:548
 
fp16_t ADataType
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:541
 
fp16_t BDataType
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:542
 
Definition: warp_gemm_impl.hpp:11
 
Definition: integral_constant.hpp:13
 
Definition: sequence.hpp:49
 
Definition: tile_distribution_encoding.hpp:26
 
Definition: tuple.hpp:192