/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_gemm.hpp Source File
reference_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <cstdlib>
7 #include <thread>
8 
9 #include "ck_tile/core.hpp"
11 
12 namespace ck_tile {
13 
14 template <typename ADataType,
15  typename QDataType,
16  typename BDataType,
17  typename AccDataType,
18  typename CDataType,
19  typename QuantGroupSize,
20  bool aquant,
21  typename AElementOp = ck_tile::identity,
22  typename BElementOp = ck_tile::identity,
23  typename ACCElementOp = ck_tile::identity>
25  const HostTensor<QDataType>& q,
26  const HostTensor<BDataType>& b_k_n,
27  HostTensor<CDataType>& c_m_n,
28  const AElementOp& a_element_op = {},
29  const BElementOp& b_element_op = {},
30  const ACCElementOp& acc_element_op = {})
31 {
32  const std::size_t M = a_m_k.get_length(0);
33  const std::size_t N = b_k_n.get_length(1);
34  const std::size_t K = a_m_k.get_length(1);
35 
36  auto f_mn = [&](auto m, auto n) {
37  AccDataType v_acc = 0, v_block_acc = 0;
38 
39  static_assert(std::is_same_v<ADataType, pk_int4_t> || std::is_same_v<ADataType, fp8_t> ||
40  std::is_same_v<ADataType, bf8_t>);
41  static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
42  std::is_same_v<BDataType, pk_int4_t>);
43  static_assert(std::is_same_v<AccDataType, float>);
44  static_assert(std::is_same_v<CDataType, float> ||
45  std::is_same_v<CDataType, ck_tile::half_t>);
46  for(std::size_t k = 0; k < K; ++k)
47  {
48  AccDataType v_a;
49  AccDataType v_b;
50  if constexpr(std::is_same_v<ADataType, pk_int4_t>)
51  {
52  const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
53  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
54  if(k % 2 == 1)
55  v_a = fp32_val.hi;
56  else
57  v_a = fp32_val.lo;
58  }
59  else
60  {
61  v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
62  }
63  if constexpr(std::is_same_v<BDataType, pk_int4_t>)
64  {
65  const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
66  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
67  if(k % 2 == 1)
68  v_b = fp32_val.hi;
69  else
70  v_b = fp32_val.lo;
71  }
72  else if constexpr(std::is_same_v<BDataType, fp8_t>)
73  {
74  v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n)));
75  }
76  else
77  {
78  v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
79  }
80  v_block_acc += v_a * v_b;
81 
82  // Apply group dequant scale
83  if((k + 1) % QuantGroupSize::kK == 0)
84  {
85  float scale = 0.f;
86  index_t outer_dim = (aquant) ? (m / QuantGroupSize::kM) : (k / QuantGroupSize::kK);
87  index_t inner_dim = (aquant) ? (k / QuantGroupSize::kK) : (n / QuantGroupSize::kN);
88  if constexpr(std::is_same_v<QDataType, float>)
89  {
90  scale = q(outer_dim, inner_dim);
91  }
92  else if constexpr(std::is_same_v<QDataType, ck_tile::fp8_t>)
93  {
94  scale = fp8_to_float_raw(q(outer_dim, inner_dim));
95  }
96  else if constexpr(std::is_same_v<QDataType, ck_tile::bf8_t>)
97  {
98  scale = bf8_to_float_raw(q(outer_dim, inner_dim));
99  }
100  else
101  {
102  static_assert(false, "Unexpected Q datatype.");
103  }
104  v_block_acc *= scale;
105  v_acc += v_block_acc;
106  v_block_acc = 0;
107  }
108  }
109 
110  c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
111  };
112 
113  make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
114  std::cout << std::endl;
115 }
116 
117 template <typename ADataType,
118  typename AQDataType,
119  typename BDataType,
120  typename BQDataType,
121  typename AccDataType,
122  typename CDataType,
123  typename AElementOp = ck_tile::identity,
124  typename BElementOp = ck_tile::identity,
125  typename ACCElementOp = ck_tile::identity>
127  const HostTensor<AQDataType>& aq_m_1,
128  const HostTensor<BDataType>& b_k_n,
129  const HostTensor<BQDataType>& bq_1_n,
130  HostTensor<CDataType>& c_m_n,
131  const AElementOp& a_element_op = {},
132  const BElementOp& b_element_op = {},
133  const ACCElementOp& acc_element_op = {})
134 {
135  static_assert(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t>);
136  static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t>);
137  static_assert(std::is_same_v<AccDataType, float>);
138  static_assert(std::is_same_v<CDataType, float> || std::is_same_v<CDataType, ck_tile::half_t>);
139  static_assert(std::is_same_v<AQDataType, float> && std::is_same_v<BQDataType, float>);
140  const std::size_t M = a_m_k.get_length(0);
141  const std::size_t N = b_k_n.get_length(1);
142  const std::size_t K = a_m_k.get_length(1);
143 
144  auto f_mn = [&](auto m, auto n) {
145  // Init accumulator
146  AccDataType v_acc = 0;
147  // Get row scale for A and column scale for B
148  float a_scale = aq_m_1(m, 0);
149  float b_scale = bq_1_n(0, n);
150 
151  // Compute the dot product
152  for(std::size_t k = 0; k < K; ++k)
153  {
154  AccDataType v_a;
155  AccDataType v_b;
156 
157  // Process A data
158  if constexpr(std::is_same_v<ADataType, pk_int4_t>)
159  {
160  const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
161  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
162  if(k % 2 == 1)
163  v_a = fp32_val.hi;
164  else
165  v_a = fp32_val.lo;
166  }
167  else
168  {
169  v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
170  }
171 
172  // Process B data
173  if constexpr(std::is_same_v<BDataType, pk_int4_t>)
174  {
175  const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
176  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
177  if(k % 2 == 1)
178  v_b = fp32_val.hi;
179  else
180  v_b = fp32_val.lo;
181  }
182  else
183  {
184  v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
185  }
186 
187  v_acc += v_a * v_b;
188  }
189 
190  v_acc = v_acc * a_scale * b_scale;
191 
192  c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
193  };
194 
195  make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
196 }
197 
198 template <typename ADataType,
199  typename AQDataType,
200  typename BDataType,
201  typename BQDataType,
202  typename AccDataType,
203  typename CDataType,
204  typename AElementOp = ck_tile::identity,
205  typename BElementOp = ck_tile::identity,
206  typename ACCElementOp = ck_tile::identity>
208  const HostTensor<AQDataType>& aq_1_1,
209  const HostTensor<BDataType>& b_k_n,
210  const HostTensor<BQDataType>& bq_1_1,
211  HostTensor<CDataType>& c_m_n,
212  const AElementOp& a_element_op = {},
213  const BElementOp& b_element_op = {},
214  const ACCElementOp& acc_element_op = {})
215 {
216  static_assert(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t>);
217  static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t>);
218  static_assert(std::is_same_v<AccDataType, float>);
219  static_assert(std::is_same_v<CDataType, float> || std::is_same_v<CDataType, ck_tile::half_t>);
220  static_assert(std::is_same_v<AQDataType, float> && std::is_same_v<BQDataType, float>);
221  const std::size_t M = a_m_k.get_length(0);
222  const std::size_t N = b_k_n.get_length(1);
223  const std::size_t K = a_m_k.get_length(1);
224 
225  auto f_mn = [&](auto m, auto n) {
226  // Init accumulator
227  AccDataType v_acc = 0;
228  // Get scale for A and scale for B
229  const AccDataType a_scale = ck_tile::type_convert<AccDataType>(aq_1_1(0, 0));
230  const AccDataType b_scale = ck_tile::type_convert<AccDataType>(bq_1_1(0, 0));
231 
232  // Compute the dot product
233  for(std::size_t k = 0; k < K; ++k)
234  {
235  AccDataType v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
236  AccDataType v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
237 
238  v_acc += v_a * v_b;
239  }
240 
241  v_acc = v_acc * a_scale * b_scale;
242 
243  c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
244  };
245 
246  make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
247 }
248 
249 template <typename ADataType,
250  typename BDataType,
251  typename AccDataType,
252  typename CDataType,
253  typename AElementOp = ck_tile::identity,
254  typename BElementOp = ck_tile::identity,
255  typename ACCElementOp = ck_tile::identity>
257  const HostTensor<BDataType>& b_k_n,
258  HostTensor<CDataType>& c_m_n,
259  const AElementOp& a_element_op = {},
260  const BElementOp& b_element_op = {},
261  const ACCElementOp& acc_element_op = {})
262 {
263  const std::size_t M = a_m_k.get_length(0);
264  const std::size_t N = b_k_n.get_length(1);
265  const std::size_t K = a_m_k.get_length(1);
266 
267  auto f_mn = [&](auto m, auto n) {
268  AccDataType v_acc = 0;
269 
270  for(std::size_t k = 0; k < K; ++k)
271  {
272  AccDataType v_a;
273  AccDataType v_b;
274  if constexpr(std::is_same_v<ADataType, pk_int4_t>)
275  {
276  const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
277  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
278  if(k % 2 == 1)
279  v_a = fp32_val.hi;
280  else
281  v_a = fp32_val.lo;
282  }
283  else
284  {
285  v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
286  }
287  if constexpr(std::is_same_v<BDataType, pk_int4_t>)
288  {
289  const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
290  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
291  if(k % 2 == 1)
292  v_b = fp32_val.hi;
293  else
294  v_b = fp32_val.lo;
295  }
296  else
297  {
298  v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
299  }
300  v_acc += v_a * v_b;
301  }
302 
303  c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
304  };
305 
306  make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
307 }
308 
309 template <typename AsDataType,
310  typename BsDataType,
311  typename DsDataType,
312  typename AccDataType,
313  typename CDataType,
314  typename AElementOp,
315  typename BElementOp,
316  typename CDElementOp,
317  typename ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>,
318  typename BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>,
319  typename DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>>
320 CK_TILE_HOST void
321 reference_gemm_multiple_abd(const std::array<HostTensor<ADataType>, AsDataType::size()>& as_m_k,
322  const std::array<HostTensor<BDataType>, BsDataType::size()>& bs_k_n,
323  const std::array<HostTensor<DDataType>, DsDataType::size()>& ds_m_n,
324  HostTensor<ADataType>& a_m_k,
325  HostTensor<BDataType>& b_k_n,
326  HostTensor<CDataType>& c_m_n,
327  const AElementOp& a_element_op = {},
328  const BElementOp& b_element_op = {},
329  const CDElementOp& acc_element_op = {})
330 {
331  const std::size_t M = a_m_k.get_length(0);
332  const std::size_t N = b_k_n.get_length(1);
333  const std::size_t K = a_m_k.get_length(1);
334 
335  auto as_m_k_tuple =
336  generate_tie([&](auto idx) -> auto& { return as_m_k[idx]; }, number<AsDataType::size()>{});
337 
338  auto bs_k_n_tuple =
339  generate_tie([&](auto idx) -> auto& { return bs_k_n[idx]; }, number<BsDataType::size()>{});
340 
341  auto ds_m_n_tuple =
342  generate_tie([&](auto idx) -> auto& { return ds_m_n[idx]; }, number<DsDataType::size()>{});
343 
344  // Apply elementwise function to A
345  auto a_elementwise_fn = [&](auto i, auto j) {
346  ck_tile::apply([&](auto&&... t) { a_element_op(a_m_k(i, j), t(i, j)...); }, as_m_k_tuple);
347  };
348 
349  make_ParallelTensorFunctor(a_elementwise_fn, M, K)(std::thread::hardware_concurrency());
350 
351  // Apply elementwise function to B
352  auto b_elementwise_fn = [&](auto i, auto j) {
353  ck_tile::apply([&](auto&&... t) { b_element_op(b_k_n(i, j), t(i, j)...); }, bs_k_n_tuple);
354  };
355 
356  make_ParallelTensorFunctor(b_elementwise_fn, K, N)(std::thread::hardware_concurrency());
357 
358  auto f_mk_kn_mn = [&](auto m, auto n) {
359  AccDataType v_acc = 0;
360  for(std::size_t k = 0; k < K; ++k)
361  {
362  ADataType v_a = a_m_k(m, k);
363  BDataType v_b = b_k_n(k, n);
364  v_acc +=
365  ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
366  }
367 
368  CDataType v_c = 0;
369 
371  [&](auto&&... t) {
372  acc_element_op(v_c,
373  ck_tile::type_convert<float>(v_acc),
374  ck_tile::type_convert<float>(t(m, n))...);
375  },
376  ds_m_n_tuple);
377 
378  c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_c);
379  };
380 
381  make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency());
382 }
383 
384 template <typename ADataType,
385  typename BDataType,
386  typename ScaleDataType,
387  typename AccDataType,
388  typename CDataType,
389  typename AElementOp = ck_tile::identity,
390  typename BElementOp = ck_tile::identity,
391  typename ACCElementOp = ck_tile::identity>
393  const HostTensor<BDataType>& b_k_n,
394  HostTensor<CDataType>& c_m_n,
395  const HostTensor<ScaleDataType>& scale_a,
396  const HostTensor<ScaleDataType>& scale_b,
397  const AElementOp& = {},
398  const BElementOp& = {},
399  const ACCElementOp& = {})
400 {
401  static_assert(std::is_same_v<AElementOp, ck_tile::identity>);
402  static_assert(std::is_same_v<BElementOp, ck_tile::identity>);
403  static_assert(std::is_same_v<ACCElementOp, ck_tile::identity>);
404 
405  const std::size_t M = a_m_k.get_length(0);
406  const std::size_t N = b_k_n.get_length(1);
407  const std::size_t K = a_m_k.get_length(1);
408 
409  const std::size_t ScaleBlockSize = K / scale_a.get_length(1);
410 
411  HostTensor<AccDataType> a_m_k_scaled({std::size_t(M), std::size_t(K)},
412  {std::size_t(K), std::size_t(1)});
413  HostTensor<AccDataType> b_k_n_scaled({std::size_t(K), std::size_t(N)},
414  {std::size_t(1), std::size_t(K)});
415 
416  for(std::size_t m = 0; m < M; ++m)
417  {
418  for(std::size_t k = 0; k < K; ++k)
419  {
420  if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
421  {
422  if(k % 2 == 1)
423  continue; // skip odd k
424 
425  auto a_f4x2 = a_m_k(m, k);
426  auto a_scale = ck_tile::type_convert<AccDataType>(scale_a(m, k / ScaleBlockSize));
427  auto a_f4_lo =
428  ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(number<0>{}));
429  auto a_f4_hi =
430  ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(number<1>{}));
431 
432  a_m_k_scaled(m, k) = a_f4_lo * a_scale;
433  a_m_k_scaled(m, k + 1) = a_f4_hi * a_scale;
434  }
435  }
436  }
437 
438  for(std::size_t n = 0; n < N; n++)
439  {
440  for(std::size_t k = 0; k < K; k++)
441  {
442  if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
443  {
444  if(k % 2 == 1)
445  continue; // skip odd k
446 
447  auto b_f4x2 = b_k_n(k, n);
448  auto b_scale = ck_tile::type_convert<AccDataType>(scale_b(k / ScaleBlockSize, n));
449  auto b_f4_lo =
450  ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(number<0>{}));
451  auto b_f4_hi =
452  ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(number<1>{}));
453 
454  b_k_n_scaled(k, n) = b_f4_lo * b_scale;
455  b_k_n_scaled(k + 1, n) = b_f4_hi * b_scale;
456  }
457  else
458  {
459  b_k_n_scaled(k, n) =
460  ck_tile::type_convert<AccDataType>((b_k_n(k, n))) *
461  ck_tile::type_convert<AccDataType>(scale_b(k / ScaleBlockSize, n));
462  }
463  }
464  }
465 
466  // call reference gemm
467  reference_gemm<AccDataType, AccDataType, AccDataType, CDataType>(
468  a_m_k_scaled, b_k_n_scaled, c_m_n);
469 }
470 
471 template <typename ADataType,
472  typename BDataType,
473  typename DsDataType,
474  typename AccDataType,
475  typename CDataType,
476  typename ACCElementOp,
477  typename DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>>
478 CK_TILE_HOST void
480  const HostTensor<BDataType>& b_k_n,
481  const std::array<HostTensor<DDataType>, DsDataType::size()>& ds_m_n,
482  HostTensor<CDataType>& c_m_n,
483  const ACCElementOp& acc_element_op = {})
484 {
485  const std::size_t M = a_m_k.get_length(0);
486  const std::size_t N = b_k_n.get_length(1);
487  const std::size_t K = a_m_k.get_length(1);
488 
489  auto f_mk_kn_mn = [&](auto m, auto n) {
490  AccDataType v_acc = 0;
491  for(std::size_t k = 0; k < K; ++k)
492  {
493  ADataType v_a = a_m_k(m, k);
494  BDataType v_b = b_k_n(k, n);
495  v_acc +=
496  ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
497  }
498 
499  CDataType v_c = 0;
500  if constexpr(DsDataType::size() == 0)
501  {
502  acc_element_op(v_c, ck_tile::type_convert<float>(v_acc));
503  }
504  else if constexpr(DsDataType::size() == 1)
505  {
506  acc_element_op(v_c,
507  ck_tile::type_convert<float>(v_acc),
508  ck_tile::type_convert<float>(ds_m_n[0](m, n)));
509  }
510  else if constexpr(DsDataType::size() == 2)
511  {
512  acc_element_op(v_c,
513  ck_tile::type_convert<float>(v_acc),
514  ck_tile::type_convert<float>(ds_m_n[0](m, n)),
515  ck_tile::type_convert<float>(ds_m_n[1](m, n)));
516  }
517  c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_c);
518  };
519 
520  make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency());
521 }
522 
523 template <typename ADataType,
524  typename BDataType,
525  typename AccDataType,
526  typename CDataType,
527  typename LayoutA,
528  typename LayoutB,
529  typename LayoutC>
530 __global__ void naive_gemm_kernel(ADataType* A,
531  BDataType* B,
532  CDataType* C,
536  ck_tile::index_t strideA,
537  ck_tile::index_t strideB,
538  ck_tile::index_t strideC)
539 {
540  int idx = blockIdx.x * blockDim.x + threadIdx.x;
541  int row = idx / N; // Compute row index
542  int col = idx % N; // Compute column index
543 
544  if(row < M && col < N)
545  {
546  AccDataType acc = 0.0;
547  for(int k = 0; k < K; ++k)
548  {
551  // Adjust indexing based on matrix layout
552  int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
553  ? row * strideA + k
554  : k * strideA + row;
555  int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
556  ? col * strideB + k
557  : k * strideB + col;
558 
559  AccDataType v_a;
560  AccDataType v_b;
561  if constexpr(std::is_same_v<ADataType, pk_int4_t>)
562  {
563  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
564  if(k % 2 == 1)
565  v_a = fp32_val.hi;
566  else
567  v_a = fp32_val.lo;
568  }
569  else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
570  {
571  const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
572  if(k % 2 == 1)
573  v_a = fp32_val.hi;
574  else
575  v_a = fp32_val.lo;
576  }
577  else
578  {
579  v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
580  }
581  if constexpr(std::is_same_v<BDataType, pk_int4_t>)
582  {
583  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
584  if(k % 2 == 1)
585  v_b = fp32_val.hi;
586  else
587  v_b = fp32_val.lo;
588  }
589  else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
590  {
591  const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b]);
592  if(k % 2 == 1)
593  v_b = fp32_val.hi;
594  else
595  v_b = fp32_val.lo;
596  }
597  else
598  {
599  v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
600  }
601  acc += v_a * v_b;
602  }
603 
604  int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
605  ? row * strideC + col
606  : col * strideC + row;
607  C[c_index] = ck_tile::type_convert<CDataType>(acc);
608  }
609 }
610 
611 template <typename ADataType,
612  typename BDataType,
613  typename AccDataType,
614  typename CDataType,
615  typename LayoutA,
616  typename LayoutB,
617  typename LayoutC>
618 __global__ void blockwise_gemm_kernel(ADataType* A,
619  BDataType* B,
620  CDataType* C,
624  ck_tile::index_t strideA,
625  ck_tile::index_t strideB,
626  ck_tile::index_t strideC,
627  ck_tile::index_t scale_granularity_m,
628  ck_tile::index_t scale_granularity_n,
629  ck_tile::index_t scale_granularity_k,
630  float* scale_A_ptr,
631  float* scale_B_ptr)
632 {
633  int idx = blockIdx.x * blockDim.x + threadIdx.x;
634  int row = idx / N; // Compute row index
635  int col = idx % N; // Compute column index
636 
637  if(row < M && col < N)
638  {
639  AccDataType acc = 0.0, acc_temp = 0.0;
640 
641  index_t scale_A_stride = (M + scale_granularity_m - 1) / scale_granularity_m;
642  index_t scale_B_stride = (N + scale_granularity_n - 1) / scale_granularity_n;
643 
644  float scale_A = 0;
645  float scale_B = 0;
646 
647  for(int k = 0; k < K; ++k)
648  {
649  if(k % scale_granularity_k == 0)
650  {
651  // update acc
652  acc += acc_temp * scale_A * scale_B;
653  acc_temp = 0.0;
654  // update scale factors
655  scale_A = scale_A_ptr[(row / scale_granularity_m) +
656  (k / scale_granularity_k) * scale_A_stride];
657  scale_B = scale_B_ptr[(col / scale_granularity_n) +
658  (k / scale_granularity_k) * scale_B_stride];
659  }
660 
663  // Adjust indexing based on matrix layout
664  int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
665  ? row * strideA + k
666  : k * strideA + row;
667  int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
668  ? col * strideB + k
669  : k * strideB + col;
670 
671  AccDataType v_a;
672  AccDataType v_b;
673  if constexpr(std::is_same_v<ADataType, pk_int4_t>)
674  {
675  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
676  if(k % 2 == 1)
677  v_a = fp32_val.hi;
678  else
679  v_a = fp32_val.lo;
680  }
681  else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
682  {
683  const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
684  if(k % 2 == 1)
685  v_a = fp32_val.hi;
686  else
687  v_a = fp32_val.lo;
688  }
689  else
690  {
691  v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
692  }
693 
694  if constexpr(std::is_same_v<BDataType, pk_int4_t>)
695  {
696  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
697  if(k % 2 == 1)
698  v_b = fp32_val.hi;
699  else
700  v_b = fp32_val.lo;
701  }
702  else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
703  {
704  const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b], 1.0f);
705  if(k % 2 == 1)
706  v_b = fp32_val.hi;
707  else
708  v_b = fp32_val.lo;
709  }
710  else
711  {
712  v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
713  }
714  acc_temp += v_a * v_b;
715  }
716  // final accumulation
717  acc += acc_temp * scale_A * scale_B;
718 
719  int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
720  ? row * strideC + col
721  : col * strideC + row;
722  C[c_index] = ck_tile::type_convert<CDataType>(acc);
723  }
724 }
725 
726 template <typename ADataType,
727  typename BDataType,
728  typename AccDataType,
729  typename CDataType,
730  typename LayoutA,
731  typename LayoutB,
732  typename LayoutC>
733 void reference_gemm_gpu(ADataType* a_ptr,
734  BDataType* b_ptr,
735  CDataType* c_ptr,
736  index_t M,
737  index_t N,
738  index_t K,
739  index_t stride_a,
740  index_t stride_b,
741  index_t stride_c)
742 {
743  int totalElements = M * N;
744  int numThreadsPerBlock = 256; // Common choice for threads per block
745  int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
746 
747  naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
748  <<<numBlocks, numThreadsPerBlock>>>(
749  a_ptr, b_ptr, c_ptr, M, N, K, stride_a, stride_b, stride_c);
750 
751  return;
752 }
753 
754 template <typename ADataType,
755  typename BDataType,
756  typename AccDataType,
757  typename CDataType,
758  typename LayoutA,
759  typename LayoutB,
760  typename LayoutC>
761 void reference_blockwise_gemm_gpu(ADataType* a_ptr,
762  BDataType* b_ptr,
763  CDataType* c_ptr,
764  index_t M,
765  index_t N,
766  index_t K,
767  index_t stride_a,
768  index_t stride_b,
769  index_t stride_c,
770  index_t scale_granularity_m,
771  index_t scale_granularity_n,
772  index_t scale_granularity_k,
773  float* scale_A_ptr,
774  float* scale_B_ptr)
775 {
776  int totalElements = M * N;
777  int numThreadsPerBlock = 256; // Common choice for threads per block
778  int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
779 
780  blockwise_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
781  <<<numBlocks, numThreadsPerBlock>>>(a_ptr,
782  b_ptr,
783  c_ptr,
784  M,
785  N,
786  K,
787  stride_a,
788  stride_b,
789  stride_c,
790  scale_granularity_m,
791  scale_granularity_n,
792  scale_granularity_k,
793  scale_A_ptr,
794  scale_B_ptr);
795 
796  return;
797 }
798 
799 template <typename ADataType,
800  typename BDataType,
801  typename AccDataType,
802  typename CDataType,
803  typename LayoutA,
804  typename LayoutB,
805  typename LayoutC>
806 void reference_batched_gemm_gpu(ADataType* a_ptr,
807  BDataType* b_ptr,
808  CDataType* c_ptr,
809  index_t M,
810  index_t N,
811  index_t K,
812  index_t stride_a,
813  index_t stride_b,
814  index_t stride_c,
815  index_t batch_stride_A,
816  index_t batch_stride_B,
817  index_t batch_stride_C,
818  index_t batch_count)
819 {
820  int totalElements = M * N;
821  int numThreadsPerBlock = 256; // Common choice for threads per block
822  int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
823 
824  for(index_t batch_id = 0; batch_id < batch_count; ++batch_id)
825  {
826  ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A;
827  BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B;
828  CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
829  naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
830  <<<numBlocks, numThreadsPerBlock>>>(
831  d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
832  }
833 
834  return;
835 }
836 
837 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
void reference_batched_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c, index_t batch_stride_A, index_t batch_stride_B, index_t batch_stride_C, index_t batch_count)
Definition: reference_gemm.hpp:806
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
constexpr decltype(auto) apply(F &&f, Tuple &&t)
Definition: tuple.hpp:526
__global__ void naive_gemm_kernel(ADataType *A, BDataType *B, CDataType *C, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t strideA, ck_tile::index_t strideB, ck_tile::index_t strideC)
Definition: reference_gemm.hpp:530
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t &x)
Definition: pk_int4.hpp:105
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t)
Definition: float8.hpp:751
CK_TILE_HOST void reference_gemm_quant(const HostTensor< ADataType > &a_m_k, const HostTensor< QDataType > &q, const HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:24
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t)
Definition: float8.hpp:764
CK_TILE_HOST void reference_gemm_multiple_abd(const std::array< HostTensor< ADataType >, AsDataType::size()> &as_m_k, const std::array< HostTensor< BDataType >, BsDataType::size()> &bs_k_n, const std::array< HostTensor< DDataType >, DsDataType::size()> &ds_m_n, HostTensor< ADataType > &a_m_k, HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const CDElementOp &acc_element_op={})
Definition: reference_gemm.hpp:321
float fp32x2_t
Definition: pk_fp4.hpp:22
void reference_blockwise_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c, index_t scale_granularity_m, index_t scale_granularity_n, index_t scale_granularity_k, float *scale_A_ptr, float *scale_B_ptr)
Definition: reference_gemm.hpp:761
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST void reference_gemm_rowcol_quant(const HostTensor< ADataType > &a_m_k, const HostTensor< AQDataType > &aq_m_1, const HostTensor< BDataType > &b_k_n, const HostTensor< BQDataType > &bq_1_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:126
constexpr CK_TILE_HOST_DEVICE auto generate_tie(F &&f, number< N >)
Definition: tuple.hpp:435
constant< v > number
Definition: integral_constant.hpp:37
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t_signed_conversion(const pk_int4_t &x)
Definition: pk_int4.hpp:120
__global__ void blockwise_gemm_kernel(ADataType *A, BDataType *B, CDataType *C, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t strideA, ck_tile::index_t strideB, ck_tile::index_t strideC, ck_tile::index_t scale_granularity_m, ck_tile::index_t scale_granularity_n, ck_tile::index_t scale_granularity_k, float *scale_A_ptr, float *scale_B_ptr)
Definition: reference_gemm.hpp:618
constexpr CK_TILE_HOST_DEVICE fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t &x, float scale)
Definition: pk_fp4.hpp:350
void reference_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c)
Definition: reference_gemm.hpp:733
CK_TILE_HOST void reference_gemm_multiple_d(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, const std::array< HostTensor< DDataType >, DsDataType::size()> &ds_m_n, HostTensor< CDataType > &c_m_n, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:479
CK_TILE_HOST void reference_gemm(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:256
CK_TILE_HOST void reference_mx_gemm(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const HostTensor< ScaleDataType > &scale_a, const HostTensor< ScaleDataType > &scale_b, const AElementOp &={}, const BElementOp &={}, const ACCElementOp &={})
Definition: reference_gemm.hpp:392
CK_TILE_HOST void reference_gemm_tensor_quant(const HostTensor< ADataType > &a_m_k, const HostTensor< AQDataType > &aq_1_1, const HostTensor< BDataType > &b_k_n, const HostTensor< BQDataType > &bq_1_1, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:207
Definition: host_tensor.hpp:336
std::size_t get_length(std::size_t dim) const
Definition: host_tensor.hpp:388
Definition: functional.hpp:86
Definition: numeric.hpp:81