include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp Source File

include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp Source File#

Composable Kernel: include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp Source File
flatmm_32x512x128_1x4x1_16x16x32.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
12 // A async load to LDS, B direct to AGPR
13 // B matrix preshuffled in br*kr*w
14 // require 4 wave, occupancy=1c
15 // agpr useage:256
16 // vgpr usage:64(A local) + 64(acc) + 8(os_a) + 8(os_b) = 144 (rem:112)
17 //
18 // for this gemm, 4 16x16x16 transposed layout
19 // input A vpgpr layout
20 // v0-v15: [ 0:15](gemm_m)x128(gemm_k)
21 // v16-v31: [16:31](gemm_m)x128(gemm_k)
22 
23 // input B vpgpr layout
24 // v0-v15: [ 0: 15](gemm_n)x128(gemm_k)
25 // v16-v31: [ 64: 79](gemm_n)x128(gemm_k)
26 // ......................
27 // v111-v127: [448:463](gemm_n)x128(gemm_k)
28 
29 // output C vpgpr layout
30 // v0-v3 : [ 0:15](gemm_m)x[ 0: 15](gemm_n)
31 // v4-v7 : [16:31](gemm_m)x[ 0: 15](gemm_n)
32 // v8-v11: [ 0:15](gemm_m)x[64: 79](gemm_n)
33 // v12-v15: [16:31](gemm_m)x[64: 79](gemm_n)
34 // ......................
35 // v56-v59: [ 0:15](gemm_m)x[448:463](gemm_n)
36 // v60-v63: [16:31](gemm_m)x[448:463](gemm_n)
38 {
39  static constexpr index_t Block_M = 32;
40  static constexpr index_t Block_N = 512;
41  static constexpr index_t Block_K = 128;
42 
43  static constexpr index_t WarpPerBlock_M = 1;
44  static constexpr index_t WarpPerBlock_N = 4;
45  static constexpr index_t WarpPerBlock_K = 1;
46 
47  static constexpr index_t NumWarps = 4;
48 
49  static constexpr index_t Warp_M = 16;
50  static constexpr index_t Warp_N = 16;
51  static constexpr index_t Warp_K = 32; // 16 * SubKPacks
52 
53  static constexpr index_t BlockSize = 256;
54 
55  static constexpr index_t SubKPacks = 2; // this is used to gurantee every threads can do dwordx4
56 
57  // TODO: note Nr/Kr/W need consider SubKPacks
58  static constexpr index_t Block_W = Warp_N * Warp_K; // 512 element
59  static constexpr index_t Block_Nr = Block_N / Warp_N; // 32 element, 4 per wave
60  static constexpr index_t Block_Kr = Block_K / Warp_K; // 4
61 
62  static constexpr index_t Repeat_M = Block_M / (Warp_M * WarpPerBlock_M); // 2
63  static constexpr index_t Repeat_N = Block_N / (Warp_N * WarpPerBlock_N); // 8
64  static constexpr index_t Repeat_K = Block_K / (Warp_K * WarpPerBlock_K); // 8/2=4
65 
66  static CK_TILE_DEVICE constexpr auto MakeCBlockDist()
67  {
68  constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
69  sequence<>,
73  sequence<2, 1>, // !! note here is different
74  sequence<0, 0>>{};
75 
77 
78  constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
79  c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
80  constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
81  return c_block_dstr;
82  }
83 
84  static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
85  {
86  using CDataType = float;
87  constexpr auto c_block_dstr = MakeCBlockDist();
88  auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
89  return c_block_tensor;
90  }
91 
92  CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A()
93  {
94  // A async->LDS
95  // constexpr index_t Block_M = Problem::BlockShape::Block_M0;
96  // constexpr index_t Block_K = Problem::BlockShape::Block_K0;
97  // constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
98  constexpr index_t warpSize = ck_tile::get_warp_size();
99  // constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
100 
101  constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
102  constexpr index_t KVector = 2; // GetAlignment_A<Problem>(); // async copy 1 dword
103  constexpr index_t KPad = KPack_; // pad between warps
104 
105  static_assert(Block_K % KVector == 0);
106  constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
107  if constexpr(LanesPerK >= warpSize)
108  {
109  // need multiple waves to load K
110  static_assert(LanesPerK % warpSize == 0);
111  constexpr index_t wavesPerK = LanesPerK / warpSize;
112  if constexpr(wavesPerK > NumWarps)
113  {
114  // TODO: need multiple issues along K to load all data
115  }
116  else
117  {
118  constexpr index_t wavesPerM = NumWarps / wavesPerK;
119  constexpr index_t NumIssues = Block_M / wavesPerM;
120  constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
122  number<wavesPerM>{}, // m1
123  number<wavesPerK>{}, // k0
124  number<warpSize>{}, // k1
125  number<KVector>{}), // k2
126  make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
127  number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
129  number<KVector>{}, // k1
130  number<1>{}), // k2
131  number<KVector>{}, // lds store vector(actually no explicit store)
132  number<1>{});
133 
134  constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
135  lds_block_desc_0,
136  make_tuple(
142 
143  return lds_block_desc_issues_warps_lanes;
144  }
145  }
146  else
147  {
148  // lanes within a wave load different M but same K
149  static_assert(warpSize % LanesPerK == 0);
150  constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
151  constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
152 
153  constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
155  number<LaneGroups>{}, // m1
156  number<NumWarps>{}, // m2
157  number<LanesPerK>{}, // k0
158  number<KVector>{}), // k1
159  make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
160  number<Block_K>{}, // m1
162  number<KVector>{}, // k0
163  number<1>{}), // k1
164  number<KVector>{}, // lds store vector(actually no explicit store)
165  number<1>{});
166 
167  constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
168  lds_block_desc_0,
175 
176  return lds_block_desc_issues_warps_lanes;
177  }
178  }
179 
180  // template <typename Problem>
181  CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadDesc_A()
182  {
183  // load from LDS to register, every wave has same layout
184  constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
185  constexpr index_t KPad = KPack_; // pad between warps
186 
187  constexpr index_t kAMLane = 16;
188  constexpr index_t kABKLane = 4;
189  constexpr index_t kABKPerLane = 4;
190  constexpr index_t kKIter = 2;
191  static_assert(KPack_ == (kABKPerLane * kKIter));
192 
193  constexpr auto lds_block_desc_0 =
195  number<kAMLane>{}, // m1 p
196  number<Repeat_K>{}, // k0 y
197  number<kABKLane>{}, // k1 p
198  number<KPack_>{}), // k2 y-vector
199  make_tuple(number<kAMLane*(Block_K + KPad)>{}, // m0
200  number<Block_K + KPad>{}, // m1
202  number<KPack_>{}, // k1
203  number<1>{}), // k2
204  number<KPack_>{}, // lds load vector
205  number<1>{});
206 
207  constexpr auto lds_desc_m_k = transform_tensor_descriptor(
208  lds_block_desc_0,
214 
215  return lds_desc_m_k;
216  }
217 
218  static constexpr auto GetGemm_AWarpEnc()
219  {
220  constexpr index_t kAMLane = 16;
221  constexpr index_t kABKLane = 4;
222  constexpr index_t kABKPerLane = 4;
223  constexpr index_t kKIter = 2;
224 
225  using enc_ = tile_distribution_encoding<
226  sequence<>,
230  sequence<2>,
231  sequence<1>>;
232  return enc_{};
233  }
234 
236  {
237  // return 32 * (128 + 8) * sizeof(bf16_t);
238  return MakeLdsLoadDesc_A().get_element_space_size() * sizeof(bf16_t) * 2; // 2 lds buffers
239  }
240 };
241 
242 // clang-format off
243 #define _EXPAND_ASM_ARGS_OUT_ONE_ACC \
244  [s_loop_cnt]"+s"(loop_cnt), \
245  [v_acc_0]"+v"(v_acc[0]), \
246  [v_acc_1]"+v"(v_acc[1]), \
247  [v_acc_2]"+v"(v_acc[2]), \
248  [v_acc_3]"+v"(v_acc[3]), \
249  [v_acc_4]"+v"(v_acc[4]), \
250  [v_acc_5]"+v"(v_acc[5]), \
251  [v_acc_6]"+v"(v_acc[6]), \
252  [v_acc_7]"+v"(v_acc[7]), \
253  [v_acc_8]"+v"(v_acc[8]), \
254  [v_acc_9]"+v"(v_acc[9]), \
255  [v_acc_10]"+v"(v_acc[10]), \
256  [v_acc_11]"+v"(v_acc[11]), \
257  [v_acc_12]"+v"(v_acc[12]), \
258  [v_acc_13]"+v"(v_acc[13]), \
259  [v_acc_14]"+v"(v_acc[14]), \
260  [v_acc_15]"+v"(v_acc[15]), \
261  [s_mem_]"+r"(smem)
262 
263 #define _EXPAND_ASM_ARGS_OUT_TWO_ACC \
264  [s_loop_cnt]"+s"(loop_cnt), \
265  [v_acc_0]"+v"(v_acc[0]), \
266  [v_acc_1]"+v"(v_acc[1]), \
267  [v_acc_2]"+v"(v_acc[2]), \
268  [v_acc_3]"+v"(v_acc[3]), \
269  [v_acc_4]"+v"(v_acc[4]), \
270  [v_acc_5]"+v"(v_acc[5]), \
271  [v_acc_6]"+v"(v_acc[6]), \
272  [v_acc_7]"+v"(v_acc[7]), \
273  [v_acc_8]"+v"(v_acc[8]), \
274  [v_acc_9]"+v"(v_acc[9]), \
275  [v_acc_10]"+v"(v_acc[10]), \
276  [v_acc_11]"+v"(v_acc[11]), \
277  [v_acc_12]"+v"(v_acc[12]), \
278  [v_acc_13]"+v"(v_acc[13]), \
279  [v_acc_14]"+v"(v_acc[14]), \
280  [v_acc_15]"+v"(v_acc[15]), \
281  [v_acc_16]"+v"(v_acc[16]), \
282  [v_acc_17]"+v"(v_acc[17]), \
283  [v_acc_18]"+v"(v_acc[18]), \
284  [v_acc_19]"+v"(v_acc[19]), \
285  [v_acc_20]"+v"(v_acc[20]), \
286  [v_acc_21]"+v"(v_acc[21]), \
287  [v_acc_22]"+v"(v_acc[22]), \
288  [v_acc_23]"+v"(v_acc[23]), \
289  [v_acc_24]"+v"(v_acc[24]), \
290  [v_acc_25]"+v"(v_acc[25]), \
291  [v_acc_26]"+v"(v_acc[26]), \
292  [v_acc_27]"+v"(v_acc[27]), \
293  [v_acc_28]"+v"(v_acc[28]), \
294  [v_acc_29]"+v"(v_acc[29]), \
295  [v_acc_30]"+v"(v_acc[30]), \
296  [v_acc_31]"+v"(v_acc[31]), \
297  [s_mem_]"+r"(smem)
298 
299 #define _EXPAND_ASM_ARGS_IN \
300  [s_res_a0]"s"(res_a[0]), \
301  [s_res_a1]"s"(res_a[1]), \
302  [s_res_a2]"s"(res_a[2]), \
303  [s_res_a3]"s"(res_a[3]), \
304  [s_res_b0]"s"(res_b[0]), \
305  [s_res_b1]"s"(res_b[1]), \
306  [s_res_b2]"s"(res_b[2]), \
307  [s_res_b3]"s"(res_b[3]), \
308  [v_os_a0]"v"(static_cast<index_t>(cached_coords_a[number<0>{}] * sizeof(ADataType))), \
309  [v_os_a1]"v"(static_cast<index_t>(cached_coords_a[number<1>{}] * sizeof(ADataType))), \
310  [v_os_a2]"v"(static_cast<index_t>(cached_coords_a[number<2>{}] * sizeof(ADataType))), \
311  [v_os_a3]"v"(static_cast<index_t>(cached_coords_a[number<3>{}] * sizeof(ADataType))), \
312  [v_os_a4]"v"(static_cast<index_t>(cached_coords_a[number<4>{}] * sizeof(ADataType))), \
313  [v_os_a5]"v"(static_cast<index_t>(cached_coords_a[number<5>{}] * sizeof(ADataType))), \
314  [v_os_a6]"v"(static_cast<index_t>(cached_coords_a[number<6>{}] * sizeof(ADataType))), \
315  [v_os_a7]"v"(static_cast<index_t>(cached_coords_a[number<7>{}] * sizeof(ADataType))), \
316  \
317  [v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))), \
318  [v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))), \
319  [v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))), \
320  [v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))), \
321  [v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))), \
322  [v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))), \
323  [v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))), \
324  [v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))), \
325  \
326  [v_os_slda]"v"(static_cast<index_t>(a_sld.cached_coords_[number<0>{}].get_offset() * sizeof(ADataType))),\
327  [s_m0_init]"s"(m0_init_value), \
328  [s_size_per_issue]"s"(size_per_issue), \
329  [smem_sz]"n"(smem_buf_size), \
330  [sld_os_0]"n"(sld_os[number<0>{}].value), \
331  [sld_os_1]"n"(sld_os[number<1>{}].value), \
332  [sld_os_2]"n"(sld_os[number<2>{}].value), \
333  [sld_os_3]"n"(sld_os[number<3>{}].value), \
334  [sld_os_4]"n"(sld_os[number<4>{}].value), \
335  [sld_os_5]"n"(sld_os[number<5>{}].value), \
336  [sld_os_6]"n"(sld_os[number<6>{}].value), \
337  [sld_os_7]"n"(sld_os[number<7>{}].value), \
338  [s_tile_os_a]"s"(tile_offset_a_bytes), \
339  [s_tile_os_b]"s"(tile_offset_b_bytes)
340 
341 #define _EXPAND_ASM_ARGS_CLOBBER \
342  "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9", \
343  "a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19", \
344  "a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29", \
345  "a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39", \
346  "a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49", \
347  "a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59", \
348  "a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69", \
349  "a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79", \
350  "a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89", \
351  "a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99", \
352  "a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107", \
353  "a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115", \
354  "a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123", \
355  "a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131", \
356  "a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139", \
357  "a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147", \
358  "a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155", \
359  "a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163", \
360  "a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171", \
361  "a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179", \
362  "a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187", \
363  "a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195", \
364  "a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203", \
365  "a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211", \
366  "a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219", \
367  "a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227", \
368  "a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235", \
369  "a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243", \
370  "a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251", \
371  "a252", "a253", "a254", "a255", \
372  "s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23", \
373  "s86", \
374  "v64", "v65", "v66", "v67", "v68", "v69", \
375  "v70", "v71", "v72", "v73", "v74", "v75", "v76", "v77", "v78", "v79", \
376  "v80", "v81", "v82", "v83", "v84", "v85", "v86", "v87", "v88", "v89", \
377  "v90", "v91", "v92", "v93", "v94", "v95", "v96", "v97", "v98", "v99", \
378  "v100", "v101", "v102", "v103", "v104", "v105", "v106", "v107", \
379  "v108", "v109", "v110", "v111", "v112", "v113", "v114", "v115", \
380  "v116", "v117", "v118", "v119", "v120", "v121", "v122", "v123", \
381  "v124", "v125", "v126", "v127"
382 // clang-format on
383 
385 {
386  using ADataType = bf16_t;
387  using BDataType = bf16_t;
388 
389  // TODO: need paired with tile_window_linear!
390  // TODO: need call init_raw() before call this function!
391  // Is2B: originally for B matrix we have 2 prefetch buffers. If set this to true
392  // we can support A matric serve 2 B matrix, B0/B1, each B0/B1 still have same tile size
393  template <typename ARes, typename ACoords, typename BRes, typename BCoords, bool Is2B = false>
394  CK_TILE_DEVICE auto
395  operator()(const ARes& res_a,
396  const ACoords& cached_coords_a,
397  const BRes& res_b,
398  const BCoords& cached_coords_b,
399  CK_TILE_LDS_ADDR void* smem,
400  index_t k,
401  index_t tile_offset_a, // for each tile, the offset to move for each unroll
402  index_t tile_offset_b,
403  bool_constant<Is2B> = {}) // for each tile, the offset to move for each unroll
404  {
405  static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 2 /*2x per dword*/); // 8
406  static_assert(BCoords::size() == Repeat_N);
407 
408  auto a_sst = make_tile_window(
409  make_tensor_view<address_space_enum::lds>(
410  reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsStoreDesc_A()),
411  MakeLdsStoreDesc_A().get_lengths(),
412  {0, 0, 0});
413 
414  auto a_sld = [&]() {
415  constexpr auto a_warp_enc_ = GetGemm_AWarpEnc();
416  constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
417  sequence<WarpPerBlock_N>,
418  tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_K>>,
419  tuple<sequence<1, 0>>,
420  tuple<sequence<1, 0>>,
421  sequence<1, 2>,
422  sequence<0, 0>>{};
423  constexpr auto a_block_dstr_encode =
424  detail::make_embed_tile_distribution_encoding(a_outer_dstr_enc, a_warp_enc_);
426  make_tensor_view<address_space_enum::lds>(
427  reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsLoadDesc_A()),
428  MakeLdsLoadDesc_A().get_lengths(),
429  {0, 0},
430  make_static_tile_distribution(a_block_dstr_encode));
431  }();
432 
433  const index_t tile_offset_a_bytes = tile_offset_a * sizeof(ADataType);
434  const index_t tile_offset_b_bytes = tile_offset_b * sizeof(BDataType);
435 
436  const auto [m0_init_value, size_per_issue] = get_async_store_smem_info(a_sst);
437  constexpr auto smem_buf_size =
438  MakeLdsLoadDesc_A().get_element_space_size() * sizeof(ADataType);
439  static_assert(a_sld.get_num_of_access() == 8);
440  constexpr auto sld_os = generate_tuple(
441  [&](auto i_access) {
442  return number<a_sld.get_bottom_linear_offset(i_access) * sizeof(ADataType)>{};
443  },
444  number<a_sld.get_num_of_access()>{});
445 
446  index_t loop_cnt = k / Block_K;
447 
448  if constexpr(Is2B)
449  {
450  // this is the acc thread buffer
451  fp32x4_t v_acc[32]{.0f};
452 
453  // B nr->kr
454 #pragma clang diagnostic push
455 #pragma clang diagnostic ignored "-Winline-asm"
456  // clang-format off
457  asm volatile(
458 #define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
459 #define CK_TILE_FLATMM_UK_2B 1
463  [s_res_b4]"s"(res_b[4]),
464  [s_res_b5]"s"(res_b[5]),
465  [s_res_b6]"s"(res_b[6]),
466  [s_res_b7]"s"(res_b[7])
467  : _EXPAND_ASM_ARGS_CLOBBER, "s24", "s25", "s26", "s27"
468  );
469  // clang-format on
470 #pragma clang diagnostic pop
471 
472  // return local scratch
474  for(auto i = 0; i < 16; i++)
475  {
476  c.at(number<0>{}).get_thread_buffer()[4 * i + 0] = v_acc[i].x;
477  c.at(number<0>{}).get_thread_buffer()[4 * i + 1] = v_acc[i].y;
478  c.at(number<0>{}).get_thread_buffer()[4 * i + 2] = v_acc[i].z;
479  c.at(number<0>{}).get_thread_buffer()[4 * i + 3] = v_acc[i].w;
480  }
481  for(auto i = 0; i < 16; i++)
482  {
483  c.at(number<1>{}).get_thread_buffer()[4 * i + 0] = v_acc[16 + i].x;
484  c.at(number<1>{}).get_thread_buffer()[4 * i + 1] = v_acc[16 + i].y;
485  c.at(number<1>{}).get_thread_buffer()[4 * i + 2] = v_acc[16 + i].z;
486  c.at(number<1>{}).get_thread_buffer()[4 * i + 3] = v_acc[16 + i].w;
487  }
488  return c;
489  }
490  else
491  {
492  // this is the acc thread buffer
493  fp32x4_t v_acc[16]{.0f};
494 
495  // B nr->kr
496 #pragma clang diagnostic push
497 #pragma clang diagnostic ignored "-Winline-asm"
498  // clang-format off
499  asm volatile(
500 #define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
505  );
506  // clang-format on
507 #pragma clang diagnostic pop
508 
509  // return local scratch
510  auto c = MakeCBlockTile();
511  for(auto i = 0; i < 16; i++)
512  {
513  c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
514  c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
515  c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
516  c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
517  }
518  return c;
519  }
520  }
521 };
522 
524 {
525  using ADataType = fp16_t;
526  using BDataType = fp16_t;
527 
528  // TODO: need paired with tile_window_linear!
529  // TODO: need call init_raw() before call this function!
530  template <typename ARes, typename ACoords, typename BRes, typename BCoords, bool Is2B = false>
531  CK_TILE_DEVICE auto
532  operator()(const ARes& res_a,
533  const ACoords& cached_coords_a,
534  const BRes& res_b,
535  const BCoords& cached_coords_b,
536  CK_TILE_LDS_ADDR void* smem,
537  index_t k,
538  index_t tile_offset_a, // for each tile, the offset to move for each unroll
539  index_t tile_offset_b, // for each tile, the offset to move for each unroll
540  bool_constant<Is2B> = {})
541  {
542  static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 2 /*2x per dword*/); // 8
543  static_assert(BCoords::size() == Repeat_N);
544 
545  auto a_sst = make_tile_window(
546  make_tensor_view<address_space_enum::lds>(
547  reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsStoreDesc_A()),
548  MakeLdsStoreDesc_A().get_lengths(),
549  {0, 0, 0});
550 
551  auto a_sld = [&]() {
552  constexpr auto a_warp_enc_ = GetGemm_AWarpEnc();
553  constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
554  sequence<WarpPerBlock_N>,
555  tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_K>>,
556  tuple<sequence<1, 0>>,
557  tuple<sequence<1, 0>>,
558  sequence<1, 2>,
559  sequence<0, 0>>{};
560  constexpr auto a_block_dstr_encode =
561  detail::make_embed_tile_distribution_encoding(a_outer_dstr_enc, a_warp_enc_);
563  make_tensor_view<address_space_enum::lds>(
564  reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsLoadDesc_A()),
565  MakeLdsLoadDesc_A().get_lengths(),
566  {0, 0},
567  make_static_tile_distribution(a_block_dstr_encode));
568  }();
569 
570  const index_t tile_offset_a_bytes = tile_offset_a * sizeof(ADataType);
571  const index_t tile_offset_b_bytes = tile_offset_b * sizeof(BDataType);
572 
573  const auto [m0_init_value, size_per_issue] = get_async_store_smem_info(a_sst);
574  constexpr auto smem_buf_size =
575  MakeLdsLoadDesc_A().get_element_space_size() * sizeof(ADataType);
576  static_assert(a_sld.get_num_of_access() == 8);
577  constexpr auto sld_os = generate_tuple(
578  [&](auto i_access) {
579  return number<a_sld.get_bottom_linear_offset(i_access) * sizeof(ADataType)>{};
580  },
581  number<a_sld.get_num_of_access()>{});
582 
583  index_t loop_cnt = k / Block_K;
584 
585  if constexpr(Is2B)
586  {
587  // this is the acc thread buffer
588  fp32x4_t v_acc[32]{.0f};
589 
590  // B nr->kr
591 #pragma clang diagnostic push
592 #pragma clang diagnostic ignored "-Winline-asm"
593  // clang-format off
594  asm volatile(
595 #define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
596 #define CK_TILE_FLATMM_UK_2B 1
600  [s_res_b4]"s"(res_b[4]),
601  [s_res_b5]"s"(res_b[5]),
602  [s_res_b6]"s"(res_b[6]),
603  [s_res_b7]"s"(res_b[7])
604  : _EXPAND_ASM_ARGS_CLOBBER, "s24", "s25", "s26", "s27"
605  );
606  // clang-format on
607 #pragma clang diagnostic pop
608 
609  // return local scratch
611  for(auto i = 0; i < 16; i++)
612  {
613  c.at(number<0>{}).get_thread_buffer()[4 * i + 0] = v_acc[i].x;
614  c.at(number<0>{}).get_thread_buffer()[4 * i + 1] = v_acc[i].y;
615  c.at(number<0>{}).get_thread_buffer()[4 * i + 2] = v_acc[i].z;
616  c.at(number<0>{}).get_thread_buffer()[4 * i + 3] = v_acc[i].w;
617  }
618  for(auto i = 0; i < 16; i++)
619  {
620  c.at(number<1>{}).get_thread_buffer()[4 * i + 0] = v_acc[16 + i].x;
621  c.at(number<1>{}).get_thread_buffer()[4 * i + 1] = v_acc[16 + i].y;
622  c.at(number<1>{}).get_thread_buffer()[4 * i + 2] = v_acc[16 + i].z;
623  c.at(number<1>{}).get_thread_buffer()[4 * i + 3] = v_acc[16 + i].w;
624  }
625  return c;
626  }
627  else
628  {
629  // this is the acc thread buffer
630  fp32x4_t v_acc[16]{.0f};
631 
632  // B nr->kr
633 #pragma clang diagnostic push
634 #pragma clang diagnostic ignored "-Winline-asm"
635  // clang-format off
636  asm volatile(
637 #define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
642  );
643  // clang-format on
644 #pragma clang diagnostic pop
645 
646  // return local scratch
647  auto c = MakeCBlockTile();
648  for(auto i = 0; i < 16; i++)
649  {
650  c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
651  c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
652  c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
653  c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
654  }
655  return c;
656  }
657  }
658 };
659 #undef _EXPAND_ASM_ARGS_OUT_ONE_ACC
660 #undef _EXPAND_ASM_ARGS_OUT_TWO_ACC
661 #undef _EXPAND_ASM_ARGS_IN
662 #undef _EXPAND_ASM_ARGS_CLOBBER
663 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_LDS_ADDR
Definition: config.hpp:56
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
#define _EXPAND_ASM_ARGS_OUT_TWO_ACC
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:263
#define _EXPAND_ASM_ARGS_CLOBBER
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:341
#define _EXPAND_ASM_ARGS_IN
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:299
#define _EXPAND_ASM_ARGS_OUT_ONE_ACC
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:243
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:420
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE index_t get_warp_size()
Definition: arch.hpp:51
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:255
_Float16 fp16_t
Definition: half.hpp:110
bfloat16_t bf16_t
Definition: bfloat16.hpp:106
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1672
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1615
constant< v > number
Definition: integral_constant.hpp:33
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:184
CK_TILE_DEVICE auto get_async_store_smem_info(LdsTileWindow_ &&lds_tile)
Definition: tile_window_utils.hpp:24
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
constexpr CK_TILE_DEVICE auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition: tile_window_linear.hpp:1124
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:400
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
WarpGemmImpl< WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution< WarpGemmAttributeMfmaImplF16F16F32M16N16K16< WGAttrCtlEnum::Default_ >, 2 > > WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
Definition: warp_gemm.hpp:52
float fp32x4_t
Definition: vector_type.hpp:87
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:498
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:385
bf16_t ADataType
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:386
bf16_t BDataType
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:387
CK_TILE_DEVICE auto operator()(const ARes &res_a, const ACoords &cached_coords_a, const BRes &res_b, const BCoords &cached_coords_b, CK_TILE_LDS_ADDR void *smem, index_t k, index_t tile_offset_a, index_t tile_offset_b, bool_constant< Is2B >={})
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:395
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:38
static constexpr index_t NumWarps
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:47
static constexpr index_t WarpPerBlock_N
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:44
static constexpr CK_TILE_DEVICE auto MakeCBlockTile()
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:84
static constexpr index_t SubKPacks
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:55
static constexpr index_t WarpPerBlock_K
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:45
static constexpr index_t Block_K
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:41
static constexpr index_t Block_N
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:40
static constexpr CK_TILE_HOST_DEVICE auto MakeLdsLoadDesc_A()
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:181
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:235
static constexpr index_t Block_W
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:58
static constexpr auto GetGemm_AWarpEnc()
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:218
static constexpr index_t Warp_M
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:49
static constexpr CK_TILE_HOST_DEVICE auto MakeLdsStoreDesc_A()
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:92
static constexpr index_t Block_M
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:39
static constexpr index_t Block_Kr
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:60
static constexpr index_t Repeat_K
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:64
static constexpr CK_TILE_DEVICE auto MakeCBlockDist()
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:66
static constexpr index_t Warp_N
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:50
static constexpr index_t Block_Nr
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:59
static constexpr index_t Warp_K
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:51
static constexpr index_t Repeat_M
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:62
static constexpr index_t BlockSize
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:53
static constexpr index_t WarpPerBlock_M
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:43
static constexpr index_t Repeat_N
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:63
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:524
CK_TILE_DEVICE auto operator()(const ARes &res_a, const ACoords &cached_coords_a, const BRes &res_b, const BCoords &cached_coords_b, CK_TILE_LDS_ADDR void *smem, index_t k, index_t tile_offset_a, index_t tile_offset_b, bool_constant< Is2B >={})
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:532
fp16_t ADataType
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:525
fp16_t BDataType
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:526
Definition: integral_constant.hpp:13
Definition: sequence.hpp:52
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192