/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_gemm.hpp Source File
reference_gemm.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include <cstdlib>
7 #include <thread>
8 
9 #include "ck_tile/core.hpp"
11 
12 namespace ck_tile {
13 
14 template <typename ADataType,
15  typename QDataType,
16  typename BDataType,
17  typename AccDataType,
18  typename CDataType,
19  typename QuantGroupSize,
20  bool aquant,
21  typename AElementOp = ck_tile::identity,
22  typename BElementOp = ck_tile::identity,
23  typename ACCElementOp = ck_tile::identity>
25  const HostTensor<QDataType>& q,
26  const HostTensor<BDataType>& b_k_n,
27  HostTensor<CDataType>& c_m_n,
28  const AElementOp& a_element_op = {},
29  const BElementOp& b_element_op = {},
30  const ACCElementOp& acc_element_op = {})
31 {
32  const std::size_t M = a_m_k.get_length(0);
33  const std::size_t N = b_k_n.get_length(1);
34  const std::size_t K = a_m_k.get_length(1);
35 
36  auto f_mn = [&](auto m, auto n) {
37  AccDataType v_acc = 0;
38 
39  constexpr std::size_t kGroupK = QuantGroupSize::kK;
40 
41  // ---- A loader: dequant A(m,k) into AccDataType ----
42  auto load_a = [&](std::size_t k) -> AccDataType {
43  if constexpr(std::is_same_v<ADataType, pk_int4_t>)
44  {
45  const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
46  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
47  return (k & 1) ? fp32_val.hi : fp32_val.lo;
48  }
49  else
50  {
51  return ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
52  }
53  };
54 
55  // ---- B loader: dequant B(k,n) into AccDataType ----
56  auto load_b = [&](std::size_t k) -> AccDataType {
57  if constexpr(std::is_same_v<BDataType, pk_int4_t>)
58  {
59  const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
60  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
61  return (k & 1) ? fp32_val.hi : fp32_val.lo;
62  }
63  else if constexpr(std::is_same_v<BDataType, fp8_t>)
64  {
65  return fp8_to_float_raw(b_element_op(b_k_n(k, n)));
66  }
67  else
68  {
69  return ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
70  }
71  };
72 
73  // ---- scale loader for a given K-group index ----
74  auto load_scale = [&](ck_tile::index_t k_group) -> float {
75  const ck_tile::index_t outer_dim = aquant ? (m / QuantGroupSize::kM) : k_group;
76  const ck_tile::index_t inner_dim = aquant ? k_group : (n / QuantGroupSize::kN);
77 
78  if constexpr(std::is_same_v<QDataType, float>)
79  {
80  return q(outer_dim, inner_dim);
81  }
82  else if constexpr(std::is_same_v<QDataType, ck_tile::fp8_t>)
83  {
84  return fp8_to_float_raw(q(outer_dim, inner_dim));
85  }
86  else // QDataType == bf8_t by static_assert above
87  {
88  return bf8_to_float_raw(q(outer_dim, inner_dim));
89  }
90  };
91 
92  // ---- Loop over K by groups (full and tail) ----
93  for(std::size_t k_begin = 0; k_begin < K; k_begin += kGroupK)
94  {
95  const std::size_t k_end = std::min<std::size_t>(k_begin + kGroupK, K);
96 
97  AccDataType v_block_acc = 0;
98 
99  // unscaled accumulation within this K-group
100  for(std::size_t k = k_begin; k < k_end; ++k)
101  {
102  const AccDataType v_a = load_a(k);
103  const AccDataType v_b = load_b(k);
104  v_block_acc += v_a * v_b;
105  }
106 
107  const ck_tile::index_t k_group = static_cast<ck_tile::index_t>(k_begin / kGroupK);
108  const float scale = load_scale(k_group);
109 
110  v_acc += v_block_acc * scale;
111  }
112 
113  c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
114  };
115 
116  make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
117  std::cout << std::endl;
118 }
119 
120 template <typename ADataType,
121  typename AQDataType,
122  typename BDataType,
123  typename BQDataType,
124  typename AccDataType,
125  typename CDataType,
126  typename AQuantGroupSize,
127  typename BQuantGroupSize,
128  typename AElementOp = ck_tile::identity,
129  typename BElementOp = ck_tile::identity,
130  typename ACCElementOp = ck_tile::identity>
132  const HostTensor<AQDataType>& a_q,
133  const HostTensor<BDataType>& b_k_n,
134  const HostTensor<BQDataType>& b_q,
135  HostTensor<CDataType>& c_m_n,
136  const AElementOp& a_element_op = {},
137  const BElementOp& b_element_op = {},
138  const ACCElementOp& acc_element_op = {})
139 {
140  const std::size_t M = a_m_k.get_length(0);
141  const std::size_t N = b_k_n.get_length(1);
142  const std::size_t K = a_m_k.get_length(1);
143 
144  auto f_mn = [&](auto m, auto n) {
145  AccDataType v_acc = 0;
146 
147  constexpr std::size_t kGroupK = BQuantGroupSize::kK;
148 
149  // ---- A loader: dequant A(m,k) into AccDataType ----
150  auto load_a = [&](std::size_t k) -> AccDataType {
151  if constexpr(std::is_same_v<ADataType, pk_int4_t>)
152  {
153  const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
154  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
155  return (k & 1) ? fp32_val.hi : fp32_val.lo;
156  }
157  else
158  {
159  return ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
160  }
161  };
162 
163  // ---- B loader: dequant B(k,n) into AccDataType ----
164  auto load_b = [&](std::size_t k) -> AccDataType {
165  if constexpr(std::is_same_v<BDataType, pk_int4_t>)
166  {
167  const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
168  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
169  return (k & 1) ? fp32_val.hi : fp32_val.lo;
170  }
171  else if constexpr(std::is_same_v<BDataType, fp8_t>)
172  {
173  return fp8_to_float_raw(b_element_op(b_k_n(k, n)));
174  }
175  else
176  {
177  return ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
178  }
179  };
180 
181  // ---- a scale loader for a given K-group index ----
182  auto load_scale_a = [&](ck_tile::index_t k_group) -> float {
183  const ck_tile::index_t outer_dim = m / AQuantGroupSize::kM;
184  const ck_tile::index_t inner_dim = k_group;
185 
186  if constexpr(std::is_same_v<AQDataType, float>)
187  {
188  return a_q(outer_dim, inner_dim);
189  }
190  else if constexpr(std::is_same_v<AQDataType, ck_tile::fp8_t>)
191  {
192  return fp8_to_float_raw(a_q(outer_dim, inner_dim));
193  }
194  else // QDataType == bf8_t by static_assert above
195  {
196  return bf8_to_float_raw(a_q(outer_dim, inner_dim));
197  }
198  };
199  // ---- b scale loader for a given K-group index ----
200  auto load_scale_b = [&](ck_tile::index_t k_group) -> float {
201  const ck_tile::index_t outer_dim = k_group;
202  const ck_tile::index_t inner_dim = n / BQuantGroupSize::kN;
203 
204  if constexpr(std::is_same_v<BQDataType, float>)
205  {
206  return b_q(outer_dim, inner_dim);
207  }
208  else if constexpr(std::is_same_v<BQDataType, ck_tile::fp8_t>)
209  {
210  return fp8_to_float_raw(b_q(outer_dim, inner_dim));
211  }
212  else // QDataType == bf8_t by static_assert above
213  {
214  return bf8_to_float_raw(b_q(outer_dim, inner_dim));
215  }
216  };
217  // ---- Loop over K by groups (full and tail) ----
218  for(std::size_t k_begin = 0; k_begin < K; k_begin += kGroupK)
219  {
220  const std::size_t k_end = std::min<std::size_t>(k_begin + kGroupK, K);
221 
222  AccDataType v_block_acc = 0;
223 
224  // unscaled accumulation within this K-group
225  for(std::size_t k = k_begin; k < k_end; ++k)
226  {
227  const AccDataType v_a = load_a(k);
228  const AccDataType v_b = load_b(k);
229  v_block_acc += v_a * v_b;
230  }
231 
232  const ck_tile::index_t k_group = static_cast<ck_tile::index_t>(k_begin / kGroupK);
233  const float scale_a = load_scale_a(k_group);
234  const float scale_b = load_scale_b(k_group);
235 
236  v_acc += v_block_acc * scale_a * scale_b;
237  }
238 
239  c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
240  };
241 
242  make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
243 }
244 
245 template <typename ADataType,
246  typename AQDataType,
247  typename BDataType,
248  typename BQDataType,
249  typename AccDataType,
250  typename CDataType,
251  typename AElementOp = ck_tile::identity,
252  typename BElementOp = ck_tile::identity,
253  typename ACCElementOp = ck_tile::identity>
255  const HostTensor<AQDataType>& aq_m_1,
256  const HostTensor<BDataType>& b_k_n,
257  const HostTensor<BQDataType>& bq_1_n,
258  HostTensor<CDataType>& c_m_n,
259  const AElementOp& a_element_op = {},
260  const BElementOp& b_element_op = {},
261  const ACCElementOp& acc_element_op = {})
262 {
263  static_assert(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t>);
264  static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t>);
265  static_assert(std::is_same_v<AccDataType, float>);
266  static_assert(std::is_same_v<CDataType, float> || std::is_same_v<CDataType, ck_tile::half_t>);
267  static_assert(std::is_same_v<AQDataType, float> && std::is_same_v<BQDataType, float>);
268  const std::size_t M = a_m_k.get_length(0);
269  const std::size_t N = b_k_n.get_length(1);
270  const std::size_t K = a_m_k.get_length(1);
271 
272  auto f_mn = [&](auto m, auto n) {
273  // Init accumulator
274  AccDataType v_acc = 0;
275  // Get row scale for A and column scale for B
276  float a_scale = aq_m_1(m, 0);
277  float b_scale = bq_1_n(0, n);
278 
279  // Compute the dot product
280  for(std::size_t k = 0; k < K; ++k)
281  {
282  AccDataType v_a;
283  AccDataType v_b;
284 
285  // Process A data
286  if constexpr(std::is_same_v<ADataType, pk_int4_t>)
287  {
288  const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
289  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
290  if(k % 2 == 1)
291  v_a = fp32_val.hi;
292  else
293  v_a = fp32_val.lo;
294  }
295  else
296  {
297  v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
298  }
299 
300  // Process B data
301  if constexpr(std::is_same_v<BDataType, pk_int4_t>)
302  {
303  const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
304  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
305  if(k % 2 == 1)
306  v_b = fp32_val.hi;
307  else
308  v_b = fp32_val.lo;
309  }
310  else
311  {
312  v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
313  }
314 
315  v_acc += v_a * v_b;
316  }
317 
318  v_acc = v_acc * a_scale * b_scale;
319 
320  c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
321  };
322 
323  make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
324 }
325 
326 template <typename ADataType,
327  typename AQDataType,
328  typename BDataType,
329  typename BQDataType,
330  typename AccDataType,
331  typename CDataType,
332  typename AElementOp = ck_tile::identity,
333  typename BElementOp = ck_tile::identity,
334  typename ACCElementOp = ck_tile::identity>
336  const HostTensor<AQDataType>& aq_1_1,
337  const HostTensor<BDataType>& b_k_n,
338  const HostTensor<BQDataType>& bq_1_1,
339  HostTensor<CDataType>& c_m_n,
340  const AElementOp& a_element_op = {},
341  const BElementOp& b_element_op = {},
342  const ACCElementOp& acc_element_op = {})
343 {
344  static_assert(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t>);
345  static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t>);
346  static_assert(std::is_same_v<AccDataType, float>);
347  static_assert(std::is_same_v<CDataType, float> || std::is_same_v<CDataType, ck_tile::half_t>);
348  static_assert(std::is_same_v<AQDataType, float> && std::is_same_v<BQDataType, float>);
349  const std::size_t M = a_m_k.get_length(0);
350  const std::size_t N = b_k_n.get_length(1);
351  const std::size_t K = a_m_k.get_length(1);
352 
353  auto f_mn = [&](auto m, auto n) {
354  // Init accumulator
355  AccDataType v_acc = 0;
356  // Get scale for A and scale for B
357  const AccDataType a_scale = ck_tile::type_convert<AccDataType>(aq_1_1(0, 0));
358  const AccDataType b_scale = ck_tile::type_convert<AccDataType>(bq_1_1(0, 0));
359 
360  // Compute the dot product
361  for(std::size_t k = 0; k < K; ++k)
362  {
363  AccDataType v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
364  AccDataType v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
365 
366  v_acc += v_a * v_b;
367  }
368 
369  v_acc = v_acc * a_scale * b_scale;
370 
371  c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
372  };
373 
374  make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
375 }
376 
377 template <typename ADataType,
378  typename QDataType,
379  typename BDataType,
380  typename AccDataType,
381  typename CDataType,
382  typename QuantGroupSize,
383  bool aquant,
384  typename AElementOp = ck_tile::identity,
385  typename BElementOp = ck_tile::identity,
386  typename ACCElementOp = ck_tile::identity>
388  const HostTensor<QDataType>& q,
389  const HostTensor<BDataType>& b_k_n,
390  HostTensor<CDataType>& c_m_n,
391  const AElementOp& a_element_op = {},
392  const BElementOp& b_element_op = {},
393  const ACCElementOp& acc_element_op = {})
394 {
395  const std::size_t M = a_m_k.get_length(0);
396  const std::size_t N = b_k_n.get_length(1);
397  const std::size_t K = a_m_k.get_length(1);
398 
399  auto f_mn = [&](auto m, auto n) {
400  AccDataType v_acc = 0;
401  AccDataType pasual = 0;
402  for(std::size_t k = 0; k < (K / 2); k++)
403  {
404  using ComputeType = float;
405  auto b_scale = type_convert<int32_t>(q((2 * k) / QuantGroupSize::kK, n)) - 127;
406  ComputeType v_a_0, v_a_1;
407  ComputeType v_b_0, v_b_1;
408 
409  v_a_0 = ck_tile::type_convert<ComputeType>((a_element_op(a_m_k(m, 2 * k))));
410  v_a_1 = ck_tile::type_convert<ComputeType>((a_element_op(a_m_k(m, 2 * k + 1))));
411 
412  if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
413  {
414  auto b_pack = type_convert<pk_fp4_t>(b_element_op(b_k_n(k, n)));
415  auto b_scale_fp4 = type_convert<float>(std::pow(2.0f, b_scale));
416 
417  auto b_f4_lo = type_convert<pk_fp4_t>(b_pack.unpack(number<0>{}));
418  auto b_f4_hi = type_convert<pk_fp4_t>(b_pack.unpack(number<1>{}));
419 
420  v_b_0 = type_convert<ComputeType>(b_f4_lo) * b_scale_fp4;
421  v_b_1 = type_convert<ComputeType>(b_f4_hi) * b_scale_fp4;
422  }
423 
424  pasual = v_a_0 * v_b_0 + v_a_1 * v_b_1;
425  v_acc += pasual;
426  }
427  c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
428  };
429 
430  make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
431  std::cout << std::endl;
432 }
433 
434 template <typename ADataType,
435  typename BDataType,
436  typename AccDataType,
437  typename CDataType,
438  typename AElementOp = ck_tile::identity,
439  typename BElementOp = ck_tile::identity,
440  typename ACCElementOp = ck_tile::identity>
442  const HostTensor<BDataType>& b_k_n,
443  HostTensor<CDataType>& c_m_n,
444  const AElementOp& a_element_op = {},
445  const BElementOp& b_element_op = {},
446  const ACCElementOp& acc_element_op = {})
447 {
448  const std::size_t M = a_m_k.get_length(0);
449  const std::size_t N = b_k_n.get_length(1);
450  const std::size_t K = a_m_k.get_length(1);
451 
452  auto f_mn = [&](auto m, auto n) {
453  AccDataType v_acc = 0;
454 
455  for(std::size_t k = 0; k < K; ++k)
456  {
457  AccDataType v_a;
458  AccDataType v_b;
459  if constexpr(std::is_same_v<ADataType, pk_int4_t>)
460  {
461  const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
462  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
463  if(k % 2 == 1)
464  v_a = fp32_val.hi;
465  else
466  v_a = fp32_val.lo;
467  }
468  else
469  {
470  v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
471  }
472  if constexpr(std::is_same_v<BDataType, pk_int4_t>)
473  {
474  const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
475  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
476  if(k % 2 == 1)
477  v_b = fp32_val.hi;
478  else
479  v_b = fp32_val.lo;
480  }
481  else
482  {
483  v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
484  }
485  v_acc += v_a * v_b;
486  }
487 
488  c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
489  };
490 
491  make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
492 }
493 
494 template <typename AsDataType,
495  typename BsDataType,
496  typename DsDataType,
497  typename AccDataType,
498  typename CDataType,
499  typename AElementOp,
500  typename BElementOp,
501  typename CDElementOp,
502  typename ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>,
503  typename BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>,
504  typename DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>>
505 CK_TILE_HOST void
506 reference_gemm_multiple_abd(const std::array<HostTensor<ADataType>, AsDataType::size()>& as_m_k,
507  const std::array<HostTensor<BDataType>, BsDataType::size()>& bs_k_n,
508  const std::array<HostTensor<DDataType>, DsDataType::size()>& ds_m_n,
509  HostTensor<ADataType>& a_m_k,
510  HostTensor<BDataType>& b_k_n,
511  HostTensor<CDataType>& c_m_n,
512  const AElementOp& a_element_op = {},
513  const BElementOp& b_element_op = {},
514  const CDElementOp& acc_element_op = {})
515 {
516  const std::size_t M = a_m_k.get_length(0);
517  const std::size_t N = b_k_n.get_length(1);
518  const std::size_t K = a_m_k.get_length(1);
519 
520  auto as_m_k_tuple =
521  generate_tie([&](auto idx) -> auto& { return as_m_k[idx]; }, number<AsDataType::size()>{});
522 
523  auto bs_k_n_tuple =
524  generate_tie([&](auto idx) -> auto& { return bs_k_n[idx]; }, number<BsDataType::size()>{});
525 
526  auto ds_m_n_tuple =
527  generate_tie([&](auto idx) -> auto& { return ds_m_n[idx]; }, number<DsDataType::size()>{});
528 
529  // Apply elementwise function to A
530  auto a_elementwise_fn = [&](auto i, auto j) {
531  ck_tile::apply([&](auto&&... t) { a_element_op(a_m_k(i, j), t(i, j)...); }, as_m_k_tuple);
532  };
533 
534  make_ParallelTensorFunctor(a_elementwise_fn, M, K)(std::thread::hardware_concurrency());
535 
536  // Apply elementwise function to B
537  auto b_elementwise_fn = [&](auto i, auto j) {
538  ck_tile::apply([&](auto&&... t) { b_element_op(b_k_n(i, j), t(i, j)...); }, bs_k_n_tuple);
539  };
540 
541  make_ParallelTensorFunctor(b_elementwise_fn, K, N)(std::thread::hardware_concurrency());
542 
543  auto f_mk_kn_mn = [&](auto m, auto n) {
544  AccDataType v_acc = 0;
545  for(std::size_t k = 0; k < K; ++k)
546  {
547  ADataType v_a = a_m_k(m, k);
548  BDataType v_b = b_k_n(k, n);
549  v_acc +=
550  ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
551  }
552 
553  CDataType v_c = 0;
554 
556  [&](auto&&... t) {
557  acc_element_op(v_c,
558  ck_tile::type_convert<float>(v_acc),
559  ck_tile::type_convert<float>(t(m, n))...);
560  },
561  ds_m_n_tuple);
562 
563  c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_c);
564  };
565 
566  make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency());
567 }
568 
569 template <typename ADataType,
570  typename BDataType,
571  typename ScaleDataType,
572  typename AccDataType,
573  typename CDataType,
574  typename AElementOp = ck_tile::identity,
575  typename BElementOp = ck_tile::identity,
576  typename ACCElementOp = ck_tile::identity>
578  const HostTensor<BDataType>& b_k_n,
579  HostTensor<CDataType>& c_m_n,
580  const HostTensor<ScaleDataType>& scale_a,
581  const HostTensor<ScaleDataType>& scale_b,
582  const AElementOp& = {},
583  const BElementOp& = {},
584  const ACCElementOp& = {})
585 {
586  static_assert(std::is_same_v<AElementOp, ck_tile::identity>);
587  static_assert(std::is_same_v<BElementOp, ck_tile::identity>);
588  static_assert(std::is_same_v<ACCElementOp, ck_tile::identity>);
589 
590  const std::size_t M = a_m_k.get_length(0);
591  const std::size_t N = b_k_n.get_length(1);
592  const std::size_t K = a_m_k.get_length(1);
593 
594  const std::size_t ScaleBlockSize = K / scale_a.get_length(1);
595 
596  HostTensor<AccDataType> a_m_k_scaled({std::size_t(M), std::size_t(K)},
597  {std::size_t(K), std::size_t(1)});
598  HostTensor<AccDataType> b_k_n_scaled({std::size_t(K), std::size_t(N)},
599  {std::size_t(1), std::size_t(K)});
600 
601  for(std::size_t m = 0; m < M; ++m)
602  {
603  for(std::size_t k = 0; k < K; ++k)
604  {
605  if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
606  {
607  if(k % 2 == 1)
608  continue; // skip odd k
609 
610  auto a_f4x2 = a_m_k(m, k);
611  auto a_scale = ck_tile::type_convert<AccDataType>(scale_a(m, k / ScaleBlockSize));
612  auto a_f4_lo =
613  ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(number<0>{}));
614  auto a_f4_hi =
615  ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(number<1>{}));
616 
617  a_m_k_scaled(m, k) = a_f4_lo * a_scale;
618  a_m_k_scaled(m, k + 1) = a_f4_hi * a_scale;
619  }
620  else
621  {
622  a_m_k_scaled(m, k) =
623  ck_tile::type_convert<AccDataType>((a_m_k(m, k))) *
624  ck_tile::type_convert<AccDataType>(scale_a(m, k / ScaleBlockSize));
625  }
626  }
627  }
628 
629  for(std::size_t n = 0; n < N; n++)
630  {
631  for(std::size_t k = 0; k < K; k++)
632  {
633  if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
634  {
635  if(k % 2 == 1)
636  continue; // skip odd k
637 
638  auto b_f4x2 = b_k_n(k, n);
639  auto b_scale = ck_tile::type_convert<AccDataType>(scale_b(k / ScaleBlockSize, n));
640  auto b_f4_lo =
641  ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(number<0>{}));
642  auto b_f4_hi =
643  ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(number<1>{}));
644 
645  b_k_n_scaled(k, n) = b_f4_lo * b_scale;
646  b_k_n_scaled(k + 1, n) = b_f4_hi * b_scale;
647  }
648  else
649  {
650  b_k_n_scaled(k, n) =
651  ck_tile::type_convert<AccDataType>((b_k_n(k, n))) *
652  ck_tile::type_convert<AccDataType>(scale_b(k / ScaleBlockSize, n));
653  }
654  }
655  }
656 
657  // call reference gemm
658  reference_gemm<AccDataType, AccDataType, AccDataType, CDataType>(
659  a_m_k_scaled, b_k_n_scaled, c_m_n);
660 }
661 
662 template <typename ADataType,
663  typename BDataType,
664  typename DsDataType,
665  typename AccDataType,
666  typename CDataType,
667  typename ACCElementOp,
668  typename DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>>
669 CK_TILE_HOST void
671  const HostTensor<BDataType>& b_k_n,
672  const std::array<HostTensor<DDataType>, DsDataType::size()>& ds_m_n,
673  HostTensor<CDataType>& c_m_n,
674  const ACCElementOp& acc_element_op = {})
675 {
676  const std::size_t M = a_m_k.get_length(0);
677  const std::size_t N = b_k_n.get_length(1);
678  const std::size_t K = a_m_k.get_length(1);
679 
680  auto f_mk_kn_mn = [&](auto m, auto n) {
681  AccDataType v_acc = 0;
682  for(std::size_t k = 0; k < K; ++k)
683  {
684  ADataType v_a = a_m_k(m, k);
685  BDataType v_b = b_k_n(k, n);
686  v_acc +=
687  ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
688  }
689 
690  CDataType v_c = 0;
691  if constexpr(DsDataType::size() == 0)
692  {
693  acc_element_op(v_c, ck_tile::type_convert<float>(v_acc));
694  }
695  else if constexpr(DsDataType::size() == 1)
696  {
697  acc_element_op(v_c,
698  ck_tile::type_convert<float>(v_acc),
699  ck_tile::type_convert<float>(ds_m_n[0](m, n)));
700  }
701  else if constexpr(DsDataType::size() == 2)
702  {
703  acc_element_op(v_c,
704  ck_tile::type_convert<float>(v_acc),
705  ck_tile::type_convert<float>(ds_m_n[0](m, n)),
706  ck_tile::type_convert<float>(ds_m_n[1](m, n)));
707  }
708  c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_c);
709  };
710 
711  make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency());
712 }
713 
714 template <typename ADataType,
715  typename BDataType,
716  typename AccDataType,
717  typename CDataType,
718  typename LayoutA,
719  typename LayoutB,
720  typename LayoutC>
721 __global__ void naive_gemm_kernel(ADataType* A,
722  BDataType* B,
723  CDataType* C,
727  ck_tile::index_t strideA,
728  ck_tile::index_t strideB,
729  ck_tile::index_t strideC)
730 {
731  int idx = blockIdx.x * blockDim.x + threadIdx.x;
732  int row = idx / N; // Compute row index
733  int col = idx % N; // Compute column index
734 
735  if(row < M && col < N)
736  {
737  AccDataType acc = 0.0;
738  for(int k = 0; k < K; ++k)
739  {
742  // Adjust indexing based on matrix layout
743  int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
744  ? row * strideA + k
745  : k * strideA + row;
746  int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
747  ? col * strideB + k
748  : k * strideB + col;
749 
750  AccDataType v_a;
751  AccDataType v_b;
752  if constexpr(std::is_same_v<ADataType, pk_int4_t>)
753  {
754  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
755  if(k % 2 == 1)
756  v_a = fp32_val.hi;
757  else
758  v_a = fp32_val.lo;
759  }
760  else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
761  {
762  const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
763  if(k % 2 == 1)
764  v_a = fp32_val.hi;
765  else
766  v_a = fp32_val.lo;
767  }
768  else
769  {
770  v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
771  }
772  if constexpr(std::is_same_v<BDataType, pk_int4_t>)
773  {
774  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
775  if(k % 2 == 1)
776  v_b = fp32_val.hi;
777  else
778  v_b = fp32_val.lo;
779  }
780  else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
781  {
782  const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b]);
783  if(k % 2 == 1)
784  v_b = fp32_val.hi;
785  else
786  v_b = fp32_val.lo;
787  }
788  else
789  {
790  v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
791  }
792  acc += v_a * v_b;
793  }
794 
795  int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
796  ? row * strideC + col
797  : col * strideC + row;
798  C[c_index] = ck_tile::type_convert<CDataType>(acc);
799  }
800 }
801 
802 template <typename ADataType,
803  typename BDataType,
804  typename AccDataType,
805  typename CDataType,
806  typename LayoutA,
807  typename LayoutB,
808  typename LayoutC>
809 __global__ void blockwise_gemm_kernel(ADataType* A,
810  BDataType* B,
811  CDataType* C,
815  ck_tile::index_t strideA,
816  ck_tile::index_t strideB,
817  ck_tile::index_t strideC,
818  ck_tile::index_t scale_granularity_m,
819  ck_tile::index_t scale_granularity_n,
820  ck_tile::index_t scale_granularity_k,
821  float* scale_A_ptr,
822  float* scale_B_ptr)
823 {
824  int idx = blockIdx.x * blockDim.x + threadIdx.x;
825  int row = idx / N; // Compute row index
826  int col = idx % N; // Compute column index
827 
828  if(row < M && col < N)
829  {
830  AccDataType acc = 0.0, acc_temp = 0.0;
831 
832  index_t scale_A_stride = (M + scale_granularity_m - 1) / scale_granularity_m;
833  index_t scale_B_stride = (N + scale_granularity_n - 1) / scale_granularity_n;
834 
835  float scale_A = 0;
836  float scale_B = 0;
837 
838  for(int k = 0; k < K; ++k)
839  {
840  if(k % scale_granularity_k == 0)
841  {
842  // update acc
843  acc += acc_temp * scale_A * scale_B;
844  acc_temp = 0.0;
845  // update scale factors
846  scale_A = scale_A_ptr[(row / scale_granularity_m) +
847  (k / scale_granularity_k) * scale_A_stride];
848  scale_B = scale_B_ptr[(col / scale_granularity_n) +
849  (k / scale_granularity_k) * scale_B_stride];
850  }
851 
854  // Adjust indexing based on matrix layout
855  int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
856  ? row * strideA + k
857  : k * strideA + row;
858  int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
859  ? col * strideB + k
860  : k * strideB + col;
861 
862  AccDataType v_a;
863  AccDataType v_b;
864  if constexpr(std::is_same_v<ADataType, pk_int4_t>)
865  {
866  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
867  if(k % 2 == 1)
868  v_a = fp32_val.hi;
869  else
870  v_a = fp32_val.lo;
871  }
872  else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
873  {
874  const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
875  if(k % 2 == 1)
876  v_a = fp32_val.hi;
877  else
878  v_a = fp32_val.lo;
879  }
880  else
881  {
882  v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
883  }
884 
885  if constexpr(std::is_same_v<BDataType, pk_int4_t>)
886  {
887  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
888  if(k % 2 == 1)
889  v_b = fp32_val.hi;
890  else
891  v_b = fp32_val.lo;
892  }
893  else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
894  {
895  const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b], 1.0f);
896  if(k % 2 == 1)
897  v_b = fp32_val.hi;
898  else
899  v_b = fp32_val.lo;
900  }
901  else
902  {
903  v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
904  }
905  acc_temp += v_a * v_b;
906  }
907  // final accumulation
908  acc += acc_temp * scale_A * scale_B;
909 
910  int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
911  ? row * strideC + col
912  : col * strideC + row;
913  C[c_index] = ck_tile::type_convert<CDataType>(acc);
914  }
915 }
916 
917 template <typename ADataType,
918  typename BDataType,
919  typename AccDataType,
920  typename CDataType,
921  typename LayoutA,
922  typename LayoutB,
923  typename LayoutC>
924 void reference_gemm_gpu(ADataType* a_ptr,
925  BDataType* b_ptr,
926  CDataType* c_ptr,
927  index_t M,
928  index_t N,
929  index_t K,
930  index_t stride_a,
931  index_t stride_b,
932  index_t stride_c)
933 {
934  int totalElements = M * N;
935  int numThreadsPerBlock = 256; // Common choice for threads per block
936  int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
937 
938  naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
939  <<<numBlocks, numThreadsPerBlock>>>(
940  a_ptr, b_ptr, c_ptr, M, N, K, stride_a, stride_b, stride_c);
941 
942  return;
943 }
944 
945 template <typename ADataType,
946  typename BDataType,
947  typename AccDataType,
948  typename CDataType,
949  typename LayoutA,
950  typename LayoutB,
951  typename LayoutC>
952 void reference_blockwise_gemm_gpu(ADataType* a_ptr,
953  BDataType* b_ptr,
954  CDataType* c_ptr,
955  index_t M,
956  index_t N,
957  index_t K,
958  index_t stride_a,
959  index_t stride_b,
960  index_t stride_c,
961  index_t scale_granularity_m,
962  index_t scale_granularity_n,
963  index_t scale_granularity_k,
964  float* scale_A_ptr,
965  float* scale_B_ptr)
966 {
967  int totalElements = M * N;
968  int numThreadsPerBlock = 256; // Common choice for threads per block
969  int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
970 
971  blockwise_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
972  <<<numBlocks, numThreadsPerBlock>>>(a_ptr,
973  b_ptr,
974  c_ptr,
975  M,
976  N,
977  K,
978  stride_a,
979  stride_b,
980  stride_c,
981  scale_granularity_m,
982  scale_granularity_n,
983  scale_granularity_k,
984  scale_A_ptr,
985  scale_B_ptr);
986 
987  return;
988 }
989 
990 template <typename ADataType,
991  typename BDataType,
992  typename AccDataType,
993  typename CDataType,
994  typename LayoutA,
995  typename LayoutB,
996  typename LayoutC>
997 void reference_batched_gemm_gpu(ADataType* a_ptr,
998  BDataType* b_ptr,
999  CDataType* c_ptr,
1000  index_t M,
1001  index_t N,
1002  index_t K,
1003  index_t stride_a,
1004  index_t stride_b,
1005  index_t stride_c,
1006  index_t batch_stride_A,
1007  index_t batch_stride_B,
1008  index_t batch_stride_C,
1009  index_t batch_count)
1010 {
1011  int totalElements = M * N;
1012  int numThreadsPerBlock = 256; // Common choice for threads per block
1013  int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
1014 
1015  for(index_t batch_id = 0; batch_id < batch_count; ++batch_id)
1016  {
1017  ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A;
1018  BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B;
1019  CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
1020  naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
1021  <<<numBlocks, numThreadsPerBlock>>>(
1022  d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
1023  }
1024 
1025  return;
1026 }
1027 
1028 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:44
__host__ T pow(T x, T gamma)
Definition: math_v2.hpp:427
Definition: cluster_descriptor.hpp:13
void reference_batched_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c, index_t batch_stride_A, index_t batch_stride_B, index_t batch_stride_C, index_t batch_count)
Definition: reference_gemm.hpp:997
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
constexpr decltype(auto) apply(F &&f, Tuple &&t)
Definition: tuple.hpp:526
__global__ void naive_gemm_kernel(ADataType *A, BDataType *B, CDataType *C, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t strideA, ck_tile::index_t strideB, ck_tile::index_t strideC)
Definition: reference_gemm.hpp:721
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t &x)
Definition: pk_int4.hpp:105
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t)
Definition: float8.hpp:751
CK_TILE_HOST void reference_gemm_quant(const HostTensor< ADataType > &a_m_k, const HostTensor< QDataType > &q, const HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:24
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t)
Definition: float8.hpp:764
CK_TILE_HOST void reference_gemm_multiple_abd(const std::array< HostTensor< ADataType >, AsDataType::size()> &as_m_k, const std::array< HostTensor< BDataType >, BsDataType::size()> &bs_k_n, const std::array< HostTensor< DDataType >, DsDataType::size()> &ds_m_n, HostTensor< ADataType > &a_m_k, HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const CDElementOp &acc_element_op={})
Definition: reference_gemm.hpp:506
float fp32x2_t
Definition: bfloat16.hpp:434
void reference_blockwise_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c, index_t scale_granularity_m, index_t scale_granularity_n, index_t scale_granularity_k, float *scale_A_ptr, float *scale_B_ptr)
Definition: reference_gemm.hpp:952
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST void reference_gemm_rowcol_quant(const HostTensor< ADataType > &a_m_k, const HostTensor< AQDataType > &aq_m_1, const HostTensor< BDataType > &b_k_n, const HostTensor< BQDataType > &bq_1_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:254
constexpr CK_TILE_HOST_DEVICE auto generate_tie(F &&f, number< N >)
Definition: tuple.hpp:435
constant< v > number
Definition: integral_constant.hpp:37
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t_signed_conversion(const pk_int4_t &x)
Definition: pk_int4.hpp:120
CK_TILE_HOST void reference_gemm_abquant(const HostTensor< ADataType > &a_m_k, const HostTensor< AQDataType > &a_q, const HostTensor< BDataType > &b_k_n, const HostTensor< BQDataType > &b_q, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:131
__global__ void blockwise_gemm_kernel(ADataType *A, BDataType *B, CDataType *C, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t strideA, ck_tile::index_t strideB, ck_tile::index_t strideC, ck_tile::index_t scale_granularity_m, ck_tile::index_t scale_granularity_n, ck_tile::index_t scale_granularity_k, float *scale_A_ptr, float *scale_B_ptr)
Definition: reference_gemm.hpp:809
constexpr CK_TILE_HOST_DEVICE fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t &x, float scale)
Definition: pk_fp4.hpp:350
void reference_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c)
Definition: reference_gemm.hpp:924
CK_TILE_HOST void reference_gemm_multiple_d(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, const std::array< HostTensor< DDataType >, DsDataType::size()> &ds_m_n, HostTensor< CDataType > &c_m_n, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:670
CK_TILE_HOST void reference_mxfp4gemm_quant(const HostTensor< ADataType > &a_m_k, const HostTensor< QDataType > &q, const HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:387
CK_TILE_HOST void reference_gemm(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:441
CK_TILE_HOST void reference_mx_gemm(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const HostTensor< ScaleDataType > &scale_a, const HostTensor< ScaleDataType > &scale_b, const AElementOp &={}, const BElementOp &={}, const ACCElementOp &={})
Definition: reference_gemm.hpp:577
CK_TILE_HOST void reference_gemm_tensor_quant(const HostTensor< ADataType > &a_m_k, const HostTensor< AQDataType > &aq_1_1, const HostTensor< BDataType > &b_k_n, const HostTensor< BQDataType > &bq_1_1, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:335
Definition: host_tensor.hpp:336
std::size_t get_length(std::size_t dim) const
Definition: host_tensor.hpp:388
Definition: functional.hpp:114
Definition: numeric.hpp:81