/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp Source File
flatmm_32x512x128_1x4x1_16x16x32.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
12 // A async load to LDS, B direct to AGPR
13 // B matrix preshuffled in br*kr*w
14 // require 4 wave, occupancy=1c
15 // agpr useage:256
16 // vgpr usage:64(A local) + 64(acc) + 8(os_a) + 8(os_b) = 144 (rem:112)
17 //
18 // for this gemm, 4 16x16x16 transposed layout
19 // input A vpgpr layout
20 // v0-v15: [ 0:15](gemm_m)x128(gemm_k)
21 // v16-v31: [16:31](gemm_m)x128(gemm_k)
22 
23 // input B vpgpr layout
24 // v0-v15: [ 0: 15](gemm_n)x128(gemm_k)
25 // v16-v31: [ 64: 79](gemm_n)x128(gemm_k)
26 // ......................
27 // v111-v127: [448:463](gemm_n)x128(gemm_k)
28 
29 // output C vpgpr layout
30 // v0-v3 : [ 0:15](gemm_m)x[ 0: 15](gemm_n)
31 // v4-v7 : [16:31](gemm_m)x[ 0: 15](gemm_n)
32 // v8-v11: [ 0:15](gemm_m)x[64: 79](gemm_n)
33 // v12-v15: [16:31](gemm_m)x[64: 79](gemm_n)
34 // ......................
35 // v56-v59: [ 0:15](gemm_m)x[448:463](gemm_n)
36 // v60-v63: [16:31](gemm_m)x[448:463](gemm_n)
38 {
39  static constexpr index_t Block_M = 32;
40  static constexpr index_t Block_N = 512;
41  static constexpr index_t Block_K = 128;
42 
43  static constexpr index_t WarpPerBlock_M = 1;
44  static constexpr index_t WarpPerBlock_N = 4;
45  static constexpr index_t WarpPerBlock_K = 1;
46 
47  static constexpr index_t NumWarps = 4;
48 
49  static constexpr index_t Warp_M = 16;
50  static constexpr index_t Warp_N = 16;
51  static constexpr index_t Warp_K = 32; // 16 * SubKPacks
52 
53  static constexpr index_t BlockSize = 256;
54 
55  static constexpr index_t SubKPacks = 2; // this is used to gurantee every threads can do dwordx4
56 
57  // TODO: note Nr/Kr/W need consider SubKPacks
58  static constexpr index_t Block_W = Warp_N * Warp_K; // 512 element
59  static constexpr index_t Block_Nr = Block_N / Warp_N; // 32 element, 4 per wave
60  static constexpr index_t Block_Kr = Block_K / Warp_K; // 4
61 
62  static constexpr index_t Repeat_M = Block_M / (Warp_M * WarpPerBlock_M); // 2
63  static constexpr index_t Repeat_N = Block_N / (Warp_N * WarpPerBlock_N); // 8
64  static constexpr index_t Repeat_K = Block_K / (Warp_K * WarpPerBlock_K); // 8/2=4
65 
66  private:
67  template <index_t LanesPerK, index_t WarpSize, typename = void>
68  struct LdsStoreDescSelector;
69 
70  template <index_t LanesPerK, index_t WarpSize>
71  struct LdsStoreDescSelector<LanesPerK, WarpSize, std::enable_if_t<(LanesPerK >= WarpSize)>>
72  {
73  template <index_t NumWarps, index_t Block_M, index_t Block_K, index_t KVector, index_t KPad>
74  static CK_TILE_HOST_DEVICE constexpr auto MakeDesc()
75  {
76  // need multiple waves to load K
77  static_assert(LanesPerK % WarpSize == 0);
78  constexpr index_t wavesPerK = LanesPerK / WarpSize;
79  if constexpr(wavesPerK > NumWarps)
80  {
81  // TODO: need multiple issues along K to load all data
82  }
83  else
84  {
85  constexpr index_t wavesPerM = NumWarps / wavesPerK;
86  constexpr index_t NumIssues = Block_M / wavesPerM;
87  constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
89  number<wavesPerM>{}, // m1
90  number<wavesPerK>{}, // k0
91  number<WarpSize>{}, // k1
92  number<KVector>{}), // k2
93  make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
94  number<wavesPerK*(WarpSize * KVector + KPad)>{}, // m1
95  number<WarpSize * KVector + KPad>{}, // k0
96  number<KVector>{}, // k1
97  number<1>{}), // k2
98  number<KVector>{}, // lds store vector(actually no explicit store)
99  number<1>{});
100 
101  constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
102  lds_block_desc_0,
103  make_tuple(
104  make_pass_through_transform(number<NumIssues>{}),
105  make_merge_transform(make_tuple(number<wavesPerM>{}, number<wavesPerK>{})),
106  make_merge_transform(make_tuple(number<WarpSize>{}, number<KVector>{}))),
107  make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}),
108  make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
109 
110  return lds_block_desc_issues_warps_lanes;
111  }
112  }
113  };
114 
115  template <index_t LanesPerK, index_t WarpSize>
116  struct LdsStoreDescSelector<LanesPerK, WarpSize, std::enable_if_t<(LanesPerK < WarpSize)>>
117  {
118  template <index_t NumWarps, index_t Block_M, index_t Block_K, index_t KVector, index_t KPad>
119  static CK_TILE_HOST_DEVICE constexpr auto MakeDesc()
120  {
121  // lanes within a wave load different M but same K
122  static_assert(WarpSize % LanesPerK == 0);
123  constexpr index_t LaneGroups = WarpSize / LanesPerK; // along m
124  constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
125 
126  constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
127  make_tuple(number<NumIssues>{}, // m0
128  number<LaneGroups>{}, // m1
129  number<NumWarps>{}, // m2
130  number<LanesPerK>{}, // k0
131  number<KVector>{}), // k1
132  make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
133  number<Block_K>{}, // m1
134  number<WarpSize * KVector + KPad>{}, // m2
135  number<KVector>{}, // k0
136  number<1>{}), // k1
137  number<KVector>{}, // lds store vector(actually no explicit store)
138  number<1>{});
139 
140  constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
141  lds_block_desc_0,
142  make_tuple(make_pass_through_transform(number<NumIssues>{}),
143  make_pass_through_transform(number<NumWarps>{}),
145  number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
146  make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
147  make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
148 
149  return lds_block_desc_issues_warps_lanes;
150  }
151  };
152 
153  public:
154  static CK_TILE_DEVICE constexpr auto MakeCBlockDist()
155  {
156  constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
157  sequence<>,
161  sequence<2, 1>, // !! note here is different
162  sequence<0, 0>>{};
163 
165 
166  constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
167  c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
168  constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
169  return c_block_dstr;
170  }
171 
172  static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
173  {
174  using CDataType = float;
175  constexpr auto c_block_dstr = MakeCBlockDist();
176  auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
177  return c_block_tensor;
178  }
179 
180  CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A()
181  {
182  // A async->LDS
183  constexpr index_t WarpSize = ck_tile::get_warp_size();
184 
185  constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
186  constexpr index_t KVector = 2; // GetAlignment_A<Problem>(); // async copy 1 dword
187  constexpr index_t KPad = KPack_; // pad between warps
188 
189  static_assert(Block_K % KVector == 0);
190  constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
191 
192  return LdsStoreDescSelector<LanesPerK, WarpSize>::
193  template MakeDesc<NumWarps, Block_M, Block_K, KVector, KPad>();
194  }
195 
196  // template <typename Problem>
197  CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadDesc_A()
198  {
199  // load from LDS to register, every wave has same layout
200  constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
201  constexpr index_t KPad = KPack_; // pad between warps
202 
203  constexpr index_t kAMLane = 16;
204  constexpr index_t kABKLane = 4;
205  constexpr index_t kABKPerLane = 4;
206  constexpr index_t kKIter = 2;
207  static_assert(KPack_ == (kABKPerLane * kKIter));
208 
209  constexpr auto lds_block_desc_0 =
211  number<kAMLane>{}, // m1 p
212  number<Repeat_K>{}, // k0 y
213  number<kABKLane>{}, // k1 p
214  number<KPack_>{}), // k2 y-vector
215  make_tuple(number<kAMLane*(Block_K + KPad)>{}, // m0
216  number<Block_K + KPad>{}, // m1
218  number<KPack_>{}, // k1
219  number<1>{}), // k2
220  number<KPack_>{}, // lds load vector
221  number<1>{});
222 
223  constexpr auto lds_desc_m_k = transform_tensor_descriptor(
224  lds_block_desc_0,
230 
231  return lds_desc_m_k;
232  }
233 
234  static constexpr auto GetGemm_AWarpEnc()
235  {
236  constexpr index_t kAMLane = 16;
237  constexpr index_t kABKLane = 4;
238  constexpr index_t kABKPerLane = 4;
239  constexpr index_t kKIter = 2;
240 
241  using enc_ = tile_distribution_encoding<
242  sequence<>,
246  sequence<2>,
247  sequence<1>>;
248  return enc_{};
249  }
250 
252  {
253  // return 32 * (128 + 8) * sizeof(bf16_t);
254  return MakeLdsLoadDesc_A().get_element_space_size() * sizeof(bf16_t) * 2; // 2 lds buffers
255  }
256 };
257 
258 // clang-format off
259 #define _EXPAND_ASM_ARGS_OUT_ONE_ACC \
260  [s_loop_cnt]"+s"(loop_cnt), \
261  [v_acc_0]"+v"(v_acc[0]), \
262  [v_acc_1]"+v"(v_acc[1]), \
263  [v_acc_2]"+v"(v_acc[2]), \
264  [v_acc_3]"+v"(v_acc[3]), \
265  [v_acc_4]"+v"(v_acc[4]), \
266  [v_acc_5]"+v"(v_acc[5]), \
267  [v_acc_6]"+v"(v_acc[6]), \
268  [v_acc_7]"+v"(v_acc[7]), \
269  [v_acc_8]"+v"(v_acc[8]), \
270  [v_acc_9]"+v"(v_acc[9]), \
271  [v_acc_10]"+v"(v_acc[10]), \
272  [v_acc_11]"+v"(v_acc[11]), \
273  [v_acc_12]"+v"(v_acc[12]), \
274  [v_acc_13]"+v"(v_acc[13]), \
275  [v_acc_14]"+v"(v_acc[14]), \
276  [v_acc_15]"+v"(v_acc[15]), \
277  [s_mem_]"+r"(smem)
278 
279 #define _EXPAND_ASM_ARGS_OUT_TWO_ACC \
280  [s_loop_cnt]"+s"(loop_cnt), \
281  [v_acc_0]"+v"(v_acc[0]), \
282  [v_acc_1]"+v"(v_acc[1]), \
283  [v_acc_2]"+v"(v_acc[2]), \
284  [v_acc_3]"+v"(v_acc[3]), \
285  [v_acc_4]"+v"(v_acc[4]), \
286  [v_acc_5]"+v"(v_acc[5]), \
287  [v_acc_6]"+v"(v_acc[6]), \
288  [v_acc_7]"+v"(v_acc[7]), \
289  [v_acc_8]"+v"(v_acc[8]), \
290  [v_acc_9]"+v"(v_acc[9]), \
291  [v_acc_10]"+v"(v_acc[10]), \
292  [v_acc_11]"+v"(v_acc[11]), \
293  [v_acc_12]"+v"(v_acc[12]), \
294  [v_acc_13]"+v"(v_acc[13]), \
295  [v_acc_14]"+v"(v_acc[14]), \
296  [v_acc_15]"+v"(v_acc[15]), \
297  [v_acc_16]"+v"(v_acc[16]), \
298  [v_acc_17]"+v"(v_acc[17]), \
299  [v_acc_18]"+v"(v_acc[18]), \
300  [v_acc_19]"+v"(v_acc[19]), \
301  [v_acc_20]"+v"(v_acc[20]), \
302  [v_acc_21]"+v"(v_acc[21]), \
303  [v_acc_22]"+v"(v_acc[22]), \
304  [v_acc_23]"+v"(v_acc[23]), \
305  [v_acc_24]"+v"(v_acc[24]), \
306  [v_acc_25]"+v"(v_acc[25]), \
307  [v_acc_26]"+v"(v_acc[26]), \
308  [v_acc_27]"+v"(v_acc[27]), \
309  [v_acc_28]"+v"(v_acc[28]), \
310  [v_acc_29]"+v"(v_acc[29]), \
311  [v_acc_30]"+v"(v_acc[30]), \
312  [v_acc_31]"+v"(v_acc[31]), \
313  [s_mem_]"+r"(smem)
314 
315 #define _EXPAND_ASM_ARGS_IN \
316  [s_res_a0]"s"(res_a[0]), \
317  [s_res_a1]"s"(res_a[1]), \
318  [s_res_a2]"s"(res_a[2]), \
319  [s_res_a3]"s"(res_a[3]), \
320  [s_res_b0]"s"(res_b[0]), \
321  [s_res_b1]"s"(res_b[1]), \
322  [s_res_b2]"s"(res_b[2]), \
323  [s_res_b3]"s"(res_b[3]), \
324  [v_os_a0]"v"(static_cast<index_t>(cached_coords_a[number<0>{}] * sizeof(ADataType))), \
325  [v_os_a1]"v"(static_cast<index_t>(cached_coords_a[number<1>{}] * sizeof(ADataType))), \
326  [v_os_a2]"v"(static_cast<index_t>(cached_coords_a[number<2>{}] * sizeof(ADataType))), \
327  [v_os_a3]"v"(static_cast<index_t>(cached_coords_a[number<3>{}] * sizeof(ADataType))), \
328  [v_os_a4]"v"(static_cast<index_t>(cached_coords_a[number<4>{}] * sizeof(ADataType))), \
329  [v_os_a5]"v"(static_cast<index_t>(cached_coords_a[number<5>{}] * sizeof(ADataType))), \
330  [v_os_a6]"v"(static_cast<index_t>(cached_coords_a[number<6>{}] * sizeof(ADataType))), \
331  [v_os_a7]"v"(static_cast<index_t>(cached_coords_a[number<7>{}] * sizeof(ADataType))), \
332  \
333  [v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))), \
334  [v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))), \
335  [v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))), \
336  [v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))), \
337  [v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))), \
338  [v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))), \
339  [v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))), \
340  [v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))), \
341  \
342  [v_os_slda]"v"(static_cast<index_t>(a_sld.cached_coords_[number<0>{}].get_offset() * sizeof(ADataType))),\
343  [s_m0_init]"s"(m0_init_value), \
344  [s_size_per_issue]"s"(size_per_issue), \
345  [smem_sz]"n"(smem_buf_size), \
346  [sld_os_0]"n"(sld_os[number<0>{}].value), \
347  [sld_os_1]"n"(sld_os[number<1>{}].value), \
348  [sld_os_2]"n"(sld_os[number<2>{}].value), \
349  [sld_os_3]"n"(sld_os[number<3>{}].value), \
350  [sld_os_4]"n"(sld_os[number<4>{}].value), \
351  [sld_os_5]"n"(sld_os[number<5>{}].value), \
352  [sld_os_6]"n"(sld_os[number<6>{}].value), \
353  [sld_os_7]"n"(sld_os[number<7>{}].value), \
354  [s_tile_os_a]"s"(tile_offset_a_bytes), \
355  [s_tile_os_b]"s"(tile_offset_b_bytes)
356 
357 #define _EXPAND_ASM_ARGS_CLOBBER \
358  "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9", \
359  "a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19", \
360  "a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29", \
361  "a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39", \
362  "a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49", \
363  "a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59", \
364  "a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69", \
365  "a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79", \
366  "a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89", \
367  "a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99", \
368  "a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107", \
369  "a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115", \
370  "a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123", \
371  "a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131", \
372  "a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139", \
373  "a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147", \
374  "a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155", \
375  "a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163", \
376  "a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171", \
377  "a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179", \
378  "a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187", \
379  "a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195", \
380  "a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203", \
381  "a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211", \
382  "a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219", \
383  "a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227", \
384  "a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235", \
385  "a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243", \
386  "a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251", \
387  "a252", "a253", "a254", "a255", \
388  "s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23", \
389  "s86", \
390  "v64", "v65", "v66", "v67", "v68", "v69", \
391  "v70", "v71", "v72", "v73", "v74", "v75", "v76", "v77", "v78", "v79", \
392  "v80", "v81", "v82", "v83", "v84", "v85", "v86", "v87", "v88", "v89", \
393  "v90", "v91", "v92", "v93", "v94", "v95", "v96", "v97", "v98", "v99", \
394  "v100", "v101", "v102", "v103", "v104", "v105", "v106", "v107", \
395  "v108", "v109", "v110", "v111", "v112", "v113", "v114", "v115", \
396  "v116", "v117", "v118", "v119", "v120", "v121", "v122", "v123", \
397  "v124", "v125", "v126", "v127"
398 // clang-format on
399 
401 {
402  using ADataType = bf16_t;
403  using BDataType = bf16_t;
404 
405  // TODO: need paired with tile_window_linear!
406  // TODO: need call init_raw() before call this function!
407  // Is2B: originally for B matrix we have 2 prefetch buffers. If set this to true
408  // we can support A matric serve 2 B matrix, B0/B1, each B0/B1 still have same tile size
409  template <typename ARes, typename ACoords, typename BRes, typename BCoords, bool Is2B = false>
410  CK_TILE_DEVICE auto
411  operator()(const ARes& res_a,
412  const ACoords& cached_coords_a,
413  const BRes& res_b,
414  const BCoords& cached_coords_b,
415  CK_TILE_LDS_ADDR void* smem,
416  index_t k,
417  index_t tile_offset_a, // for each tile, the offset to move for each unroll
418  index_t tile_offset_b,
419  bool_constant<Is2B> = {}) // for each tile, the offset to move for each unroll
420  {
421  static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 2 /*2x per dword*/); // 8
422  static_assert(BCoords::size() == Repeat_N);
423 
424  auto a_sst = make_tile_window(
425  make_tensor_view<address_space_enum::lds>(
426  reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsStoreDesc_A()),
427  MakeLdsStoreDesc_A().get_lengths(),
428  {0, 0, 0});
429 
430  auto a_sld = [&]() {
431  constexpr auto a_warp_enc_ = GetGemm_AWarpEnc();
432  constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
433  sequence<WarpPerBlock_N>,
434  tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_K>>,
435  tuple<sequence<1, 0>>,
436  tuple<sequence<1, 0>>,
437  sequence<1, 2>,
438  sequence<0, 0>>{};
439  constexpr auto a_block_dstr_encode =
440  detail::make_embed_tile_distribution_encoding(a_outer_dstr_enc, a_warp_enc_);
442  make_tensor_view<address_space_enum::lds>(
443  reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsLoadDesc_A()),
444  MakeLdsLoadDesc_A().get_lengths(),
445  {0, 0},
446  make_static_tile_distribution(a_block_dstr_encode));
447  }();
448 
449  const index_t tile_offset_a_bytes = tile_offset_a * sizeof(ADataType);
450  const index_t tile_offset_b_bytes = tile_offset_b * sizeof(BDataType);
451 
452  const auto [m0_init_value, size_per_issue] = get_async_store_smem_info(a_sst);
453  constexpr auto smem_buf_size =
454  MakeLdsLoadDesc_A().get_element_space_size() * sizeof(ADataType);
455  static_assert(a_sld.get_num_of_access() == 8);
456  constexpr auto sld_os = generate_tuple(
457  [&](auto i_access) {
458  return number<a_sld.get_bottom_linear_offset(i_access) * sizeof(ADataType)>{};
459  },
460  number<a_sld.get_num_of_access()>{});
461 
462  index_t loop_cnt = k / Block_K;
463 
464  if constexpr(Is2B)
465  {
466  // this is the acc thread buffer
467  fp32x4_t v_acc[32]{.0f};
468 
469  // B nr->kr
470 #pragma clang diagnostic push
471 #pragma clang diagnostic ignored "-Winline-asm"
472  // clang-format off
473  asm volatile(
474 #define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
475 #define CK_TILE_FLATMM_UK_2B 1
479  [s_res_b4]"s"(res_b[4]),
480  [s_res_b5]"s"(res_b[5]),
481  [s_res_b6]"s"(res_b[6]),
482  [s_res_b7]"s"(res_b[7])
483  : _EXPAND_ASM_ARGS_CLOBBER, "s24", "s25", "s26", "s27"
484  );
485  // clang-format on
486 #pragma clang diagnostic pop
487 
488  // return local scratch
490  for(auto i = 0; i < 16; i++)
491  {
492  c.at(number<0>{}).get_thread_buffer()[4 * i + 0] = v_acc[i].x;
493  c.at(number<0>{}).get_thread_buffer()[4 * i + 1] = v_acc[i].y;
494  c.at(number<0>{}).get_thread_buffer()[4 * i + 2] = v_acc[i].z;
495  c.at(number<0>{}).get_thread_buffer()[4 * i + 3] = v_acc[i].w;
496  }
497  for(auto i = 0; i < 16; i++)
498  {
499  c.at(number<1>{}).get_thread_buffer()[4 * i + 0] = v_acc[16 + i].x;
500  c.at(number<1>{}).get_thread_buffer()[4 * i + 1] = v_acc[16 + i].y;
501  c.at(number<1>{}).get_thread_buffer()[4 * i + 2] = v_acc[16 + i].z;
502  c.at(number<1>{}).get_thread_buffer()[4 * i + 3] = v_acc[16 + i].w;
503  }
504  return c;
505  }
506  else
507  {
508  // this is the acc thread buffer
509  fp32x4_t v_acc[16]{.0f};
510 
511  // B nr->kr
512 #pragma clang diagnostic push
513 #pragma clang diagnostic ignored "-Winline-asm"
514  // clang-format off
515  asm volatile(
516 #define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
521  );
522  // clang-format on
523 #pragma clang diagnostic pop
524 
525  // return local scratch
526  auto c = MakeCBlockTile();
527  for(auto i = 0; i < 16; i++)
528  {
529  c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
530  c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
531  c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
532  c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
533  }
534  return c;
535  }
536  }
537 };
538 
540 {
541  using ADataType = fp16_t;
542  using BDataType = fp16_t;
543 
544  // TODO: need paired with tile_window_linear!
545  // TODO: need call init_raw() before call this function!
546  template <typename ARes, typename ACoords, typename BRes, typename BCoords, bool Is2B = false>
547  CK_TILE_DEVICE auto
548  operator()(const ARes& res_a,
549  const ACoords& cached_coords_a,
550  const BRes& res_b,
551  const BCoords& cached_coords_b,
552  CK_TILE_LDS_ADDR void* smem,
553  index_t k,
554  index_t tile_offset_a, // for each tile, the offset to move for each unroll
555  index_t tile_offset_b, // for each tile, the offset to move for each unroll
556  bool_constant<Is2B> = {})
557  {
558  static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 2 /*2x per dword*/); // 8
559  static_assert(BCoords::size() == Repeat_N);
560 
561  auto a_sst = make_tile_window(
562  make_tensor_view<address_space_enum::lds>(
563  reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsStoreDesc_A()),
564  MakeLdsStoreDesc_A().get_lengths(),
565  {0, 0, 0});
566 
567  auto a_sld = [&]() {
568  constexpr auto a_warp_enc_ = GetGemm_AWarpEnc();
569  constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
570  sequence<WarpPerBlock_N>,
571  tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_K>>,
572  tuple<sequence<1, 0>>,
573  tuple<sequence<1, 0>>,
574  sequence<1, 2>,
575  sequence<0, 0>>{};
576  constexpr auto a_block_dstr_encode =
577  detail::make_embed_tile_distribution_encoding(a_outer_dstr_enc, a_warp_enc_);
579  make_tensor_view<address_space_enum::lds>(
580  reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsLoadDesc_A()),
581  MakeLdsLoadDesc_A().get_lengths(),
582  {0, 0},
583  make_static_tile_distribution(a_block_dstr_encode));
584  }();
585 
586  const index_t tile_offset_a_bytes = tile_offset_a * sizeof(ADataType);
587  const index_t tile_offset_b_bytes = tile_offset_b * sizeof(BDataType);
588 
589  const auto [m0_init_value, size_per_issue] = get_async_store_smem_info(a_sst);
590  constexpr auto smem_buf_size =
591  MakeLdsLoadDesc_A().get_element_space_size() * sizeof(ADataType);
592  static_assert(a_sld.get_num_of_access() == 8);
593  constexpr auto sld_os = generate_tuple(
594  [&](auto i_access) {
595  return number<a_sld.get_bottom_linear_offset(i_access) * sizeof(ADataType)>{};
596  },
597  number<a_sld.get_num_of_access()>{});
598 
599  index_t loop_cnt = k / Block_K;
600 
601  if constexpr(Is2B)
602  {
603  // this is the acc thread buffer
604  fp32x4_t v_acc[32]{.0f};
605 
606  // B nr->kr
607 #pragma clang diagnostic push
608 #pragma clang diagnostic ignored "-Winline-asm"
609  // clang-format off
610  asm volatile(
611 #define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
612 #define CK_TILE_FLATMM_UK_2B 1
616  [s_res_b4]"s"(res_b[4]),
617  [s_res_b5]"s"(res_b[5]),
618  [s_res_b6]"s"(res_b[6]),
619  [s_res_b7]"s"(res_b[7])
620  : _EXPAND_ASM_ARGS_CLOBBER, "s24", "s25", "s26", "s27"
621  );
622  // clang-format on
623 #pragma clang diagnostic pop
624 
625  // return local scratch
627  for(auto i = 0; i < 16; i++)
628  {
629  c.at(number<0>{}).get_thread_buffer()[4 * i + 0] = v_acc[i].x;
630  c.at(number<0>{}).get_thread_buffer()[4 * i + 1] = v_acc[i].y;
631  c.at(number<0>{}).get_thread_buffer()[4 * i + 2] = v_acc[i].z;
632  c.at(number<0>{}).get_thread_buffer()[4 * i + 3] = v_acc[i].w;
633  }
634  for(auto i = 0; i < 16; i++)
635  {
636  c.at(number<1>{}).get_thread_buffer()[4 * i + 0] = v_acc[16 + i].x;
637  c.at(number<1>{}).get_thread_buffer()[4 * i + 1] = v_acc[16 + i].y;
638  c.at(number<1>{}).get_thread_buffer()[4 * i + 2] = v_acc[16 + i].z;
639  c.at(number<1>{}).get_thread_buffer()[4 * i + 3] = v_acc[16 + i].w;
640  }
641  return c;
642  }
643  else
644  {
645  // this is the acc thread buffer
646  fp32x4_t v_acc[16]{.0f};
647 
648  // B nr->kr
649 #pragma clang diagnostic push
650 #pragma clang diagnostic ignored "-Winline-asm"
651  // clang-format off
652  asm volatile(
653 #define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
658  );
659  // clang-format on
660 #pragma clang diagnostic pop
661 
662  // return local scratch
663  auto c = MakeCBlockTile();
664  for(auto i = 0; i < 16; i++)
665  {
666  c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
667  c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
668  c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
669  c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
670  }
671  return c;
672  }
673  }
674 };
675 #undef _EXPAND_ASM_ARGS_OUT_ONE_ACC
676 #undef _EXPAND_ASM_ARGS_OUT_TWO_ACC
677 #undef _EXPAND_ASM_ARGS_IN
678 #undef _EXPAND_ASM_ARGS_CLOBBER
679 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_LDS_ADDR
Definition: config.hpp:58
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
#define _EXPAND_ASM_ARGS_OUT_TWO_ACC
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:279
#define _EXPAND_ASM_ARGS_CLOBBER
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:357
#define _EXPAND_ASM_ARGS_IN
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:315
#define _EXPAND_ASM_ARGS_OUT_ONE_ACC
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:259
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:457
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:268
_Float16 fp16_t
Definition: half.hpp:110
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1615
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
constant< v > number
Definition: integral_constant.hpp:37
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:197
CK_TILE_DEVICE auto get_async_store_smem_info(LdsTileWindow_ &&lds_tile)
Definition: tile_window_utils.hpp:31
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_DEVICE auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition: tile_window_linear.hpp:993
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
float fp32x4_t
Definition: vector_type.hpp:117
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:401
bf16_t ADataType
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:402
bf16_t BDataType
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:403
CK_TILE_DEVICE auto operator()(const ARes &res_a, const ACoords &cached_coords_a, const BRes &res_b, const BCoords &cached_coords_b, CK_TILE_LDS_ADDR void *smem, index_t k, index_t tile_offset_a, index_t tile_offset_b, bool_constant< Is2B >={})
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:411
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:38
static constexpr index_t NumWarps
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:47
static constexpr index_t WarpPerBlock_N
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:44
static constexpr CK_TILE_DEVICE auto MakeCBlockTile()
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:172
static constexpr index_t SubKPacks
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:55
static constexpr index_t WarpPerBlock_K
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:45
static constexpr index_t Block_K
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:41
static constexpr index_t Block_N
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:40
static constexpr CK_TILE_HOST_DEVICE auto MakeLdsLoadDesc_A()
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:197
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:251
static constexpr index_t Block_W
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:58
static constexpr auto GetGemm_AWarpEnc()
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:234
static constexpr index_t Warp_M
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:49
static constexpr CK_TILE_HOST_DEVICE auto MakeLdsStoreDesc_A()
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:180
static constexpr index_t Block_M
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:39
static constexpr index_t Block_Kr
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:60
static constexpr index_t Repeat_K
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:64
static constexpr CK_TILE_DEVICE auto MakeCBlockDist()
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:154
static constexpr index_t Warp_N
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:50
static constexpr index_t Block_Nr
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:59
static constexpr index_t Warp_K
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:51
static constexpr index_t Repeat_M
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:62
static constexpr index_t BlockSize
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:53
static constexpr index_t WarpPerBlock_M
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:43
static constexpr index_t Repeat_N
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:63
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:540
CK_TILE_DEVICE auto operator()(const ARes &res_a, const ACoords &cached_coords_a, const BRes &res_b, const BCoords &cached_coords_b, CK_TILE_LDS_ADDR void *smem, index_t k, index_t tile_offset_a, index_t tile_offset_b, bool_constant< Is2B >={})
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:548
fp16_t ADataType
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:541
fp16_t BDataType
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:542
Definition: warp_gemm_impl.hpp:11
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192