/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_gemm.hpp Source File
reference_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <cstdlib>
7 #include <thread>
8 
9 #include "ck_tile/core.hpp"
11 
12 namespace ck_tile {
13 
14 template <typename ADataType,
15  typename QDataType,
16  typename BDataType,
17  typename AccDataType,
18  typename CDataType,
19  uint32_t QuantGroupSize,
20  bool aquant,
21  typename AElementOp = ck_tile::identity,
22  typename BElementOp = ck_tile::identity,
23  typename ACCElementOp = ck_tile::identity>
25  const HostTensor<QDataType>& q,
26  const HostTensor<BDataType>& b_k_n,
27  HostTensor<CDataType>& c_m_n,
28  const AElementOp& a_element_op = {},
29  const BElementOp& b_element_op = {},
30  const ACCElementOp& acc_element_op = {})
31 {
32  const std::size_t M = a_m_k.get_length(0);
33  const std::size_t N = b_k_n.get_length(1);
34  const std::size_t K = a_m_k.get_length(1);
35 
36  auto f_mn = [&](auto m, auto n) {
37  AccDataType v_acc = 0, v_block_acc = 0;
38 
39  static_assert(std::is_same_v<ADataType, pk_int4_t> || std::is_same_v<ADataType, fp8_t> ||
40  std::is_same_v<ADataType, bf8_t>);
41  static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
42  std::is_same_v<BDataType, pk_int4_t>);
43  static_assert(std::is_same_v<AccDataType, float>);
44  static_assert(std::is_same_v<CDataType, float> ||
45  std::is_same_v<CDataType, ck_tile::half_t>);
46  for(std::size_t k = 0; k < K; ++k)
47  {
48  AccDataType v_a;
49  AccDataType v_b;
50  if constexpr(std::is_same_v<ADataType, pk_int4_t>)
51  {
52  const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
53  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
54  if(k % 2 == 1)
55  v_a = fp32_val.hi;
56  else
57  v_a = fp32_val.lo;
58  }
59  else
60  {
61  v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
62  }
63  if constexpr(std::is_same_v<BDataType, pk_int4_t>)
64  {
65  const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
66  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
67  if(k % 2 == 1)
68  v_b = fp32_val.hi;
69  else
70  v_b = fp32_val.lo;
71  }
72  else if constexpr(std::is_same_v<BDataType, fp8_t>)
73  {
74  v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n)));
75  }
76  else
77  {
78  v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
79  }
80  v_block_acc += v_a * v_b;
81 
82  // Apply group dequant scale
83  if((k + 1) % QuantGroupSize == 0)
84  {
85  float scale = 0.f;
86  index_t outer_dim = (aquant) ? m : k / QuantGroupSize;
87  index_t inner_dim = (aquant) ? k / QuantGroupSize : n;
88 
89  if constexpr(std::is_same_v<QDataType, float>)
90  {
91  scale = q(outer_dim, inner_dim);
92  }
93  else if constexpr(std::is_same_v<QDataType, ck_tile::fp8_t>)
94  {
95  scale = fp8_to_float_raw(q(outer_dim, inner_dim));
96  }
97  else if constexpr(std::is_same_v<QDataType, ck_tile::bf8_t>)
98  {
99  scale = bf8_to_float_raw(q(outer_dim, inner_dim));
100  }
101  else
102  {
103  static_assert(false, "Unexpected Q datatype.");
104  }
105  v_block_acc *= scale;
106  v_acc += v_block_acc;
107  v_block_acc = 0;
108  }
109  }
110 
111  c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
112  };
113 
114  make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
115  std::cout << std::endl;
116 }
117 
118 template <typename ADataType,
119  typename AQDataType,
120  typename BDataType,
121  typename BQDataType,
122  typename AccDataType,
123  typename CDataType,
124  typename AElementOp = ck_tile::identity,
125  typename BElementOp = ck_tile::identity,
126  typename ACCElementOp = ck_tile::identity>
128  const HostTensor<AQDataType>& aq_m_1,
129  const HostTensor<BDataType>& b_k_n,
130  const HostTensor<BQDataType>& bq_1_n,
131  HostTensor<CDataType>& c_m_n,
132  const AElementOp& a_element_op = {},
133  const BElementOp& b_element_op = {},
134  const ACCElementOp& acc_element_op = {})
135 {
136  static_assert(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t>);
137  static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t>);
138  static_assert(std::is_same_v<AccDataType, float>);
139  static_assert(std::is_same_v<CDataType, float> || std::is_same_v<CDataType, ck_tile::half_t>);
140  static_assert(std::is_same_v<AQDataType, float> && std::is_same_v<BQDataType, float>);
141  const std::size_t M = a_m_k.get_length(0);
142  const std::size_t N = b_k_n.get_length(1);
143  const std::size_t K = a_m_k.get_length(1);
144 
145  auto f_mn = [&](auto m, auto n) {
146  // Init accumulator
147  AccDataType v_acc = 0;
148  // Get row scale for A and column scale for B
149  float a_scale = aq_m_1(m, 0);
150  float b_scale = bq_1_n(0, n);
151 
152  // Compute the dot product
153  for(std::size_t k = 0; k < K; ++k)
154  {
155  AccDataType v_a;
156  AccDataType v_b;
157 
158  // Process A data
159  if constexpr(std::is_same_v<ADataType, pk_int4_t>)
160  {
161  const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
162  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
163  if(k % 2 == 1)
164  v_a = fp32_val.hi;
165  else
166  v_a = fp32_val.lo;
167  }
168  else
169  {
170  v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
171  }
172 
173  // Process B data
174  if constexpr(std::is_same_v<BDataType, pk_int4_t>)
175  {
176  const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
177  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
178  if(k % 2 == 1)
179  v_b = fp32_val.hi;
180  else
181  v_b = fp32_val.lo;
182  }
183  else
184  {
185  v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
186  }
187 
188  v_acc += v_a * v_b;
189  }
190 
191  v_acc = v_acc * a_scale * b_scale;
192 
193  c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
194  };
195 
196  make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
197 }
198 
199 template <typename ADataType,
200  typename AQDataType,
201  typename BDataType,
202  typename BQDataType,
203  typename AccDataType,
204  typename CDataType,
205  typename AElementOp = ck_tile::identity,
206  typename BElementOp = ck_tile::identity,
207  typename ACCElementOp = ck_tile::identity>
209  const HostTensor<AQDataType>& aq_1_1,
210  const HostTensor<BDataType>& b_k_n,
211  const HostTensor<BQDataType>& bq_1_1,
212  HostTensor<CDataType>& c_m_n,
213  const AElementOp& a_element_op = {},
214  const BElementOp& b_element_op = {},
215  const ACCElementOp& acc_element_op = {})
216 {
217  static_assert(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t>);
218  static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t>);
219  static_assert(std::is_same_v<AccDataType, float>);
220  static_assert(std::is_same_v<CDataType, float> || std::is_same_v<CDataType, ck_tile::half_t>);
221  static_assert(std::is_same_v<AQDataType, float> && std::is_same_v<BQDataType, float>);
222  const std::size_t M = a_m_k.get_length(0);
223  const std::size_t N = b_k_n.get_length(1);
224  const std::size_t K = a_m_k.get_length(1);
225 
226  auto f_mn = [&](auto m, auto n) {
227  // Init accumulator
228  AccDataType v_acc = 0;
229  // Get scale for A and scale for B
230  const AccDataType a_scale = ck_tile::type_convert<AccDataType>(aq_1_1(0, 0));
231  const AccDataType b_scale = ck_tile::type_convert<AccDataType>(bq_1_1(0, 0));
232 
233  // Compute the dot product
234  for(std::size_t k = 0; k < K; ++k)
235  {
236  AccDataType v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
237  AccDataType v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
238 
239  v_acc += v_a * v_b;
240  }
241 
242  v_acc = v_acc * a_scale * b_scale;
243 
244  c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
245  };
246 
247  make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
248 }
249 
250 template <typename ADataType,
251  typename BDataType,
252  typename AccDataType,
253  typename CDataType,
254  typename AElementOp = ck_tile::identity,
255  typename BElementOp = ck_tile::identity,
256  typename ACCElementOp = ck_tile::identity>
258  const HostTensor<BDataType>& b_k_n,
259  HostTensor<CDataType>& c_m_n,
260  const AElementOp& a_element_op = {},
261  const BElementOp& b_element_op = {},
262  const ACCElementOp& acc_element_op = {})
263 {
264  const std::size_t M = a_m_k.get_length(0);
265  const std::size_t N = b_k_n.get_length(1);
266  const std::size_t K = a_m_k.get_length(1);
267 
268  auto f_mn = [&](auto m, auto n) {
269  AccDataType v_acc = 0;
270 
271  for(std::size_t k = 0; k < K; ++k)
272  {
273  AccDataType v_a;
274  AccDataType v_b;
275  if constexpr(std::is_same_v<ADataType, pk_int4_t>)
276  {
277  const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
278  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
279  if(k % 2 == 1)
280  v_a = fp32_val.hi;
281  else
282  v_a = fp32_val.lo;
283  }
284  else
285  {
286  v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
287  }
288  if constexpr(std::is_same_v<BDataType, pk_int4_t>)
289  {
290  const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
291  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
292  if(k % 2 == 1)
293  v_b = fp32_val.hi;
294  else
295  v_b = fp32_val.lo;
296  }
297  else
298  {
299  v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
300  }
301  v_acc += v_a * v_b;
302  }
303 
304  c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
305  };
306 
307  make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
308 }
309 
310 template <typename AsDataType,
311  typename BsDataType,
312  typename DsDataType,
313  typename AccDataType,
314  typename CDataType,
315  typename AElementOp,
316  typename BElementOp,
317  typename CDElementOp,
318  typename ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>,
319  typename BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>,
320  typename DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>>
321 CK_TILE_HOST void
322 reference_gemm_multiple_abd(const std::array<HostTensor<ADataType>, AsDataType::size()>& as_m_k,
323  const std::array<HostTensor<BDataType>, BsDataType::size()>& bs_k_n,
324  const std::array<HostTensor<DDataType>, DsDataType::size()>& ds_m_n,
325  HostTensor<ADataType>& a_m_k,
326  HostTensor<BDataType>& b_k_n,
327  HostTensor<CDataType>& c_m_n,
328  const AElementOp& a_element_op = {},
329  const BElementOp& b_element_op = {},
330  const CDElementOp& acc_element_op = {})
331 {
332  const std::size_t M = a_m_k.get_length(0);
333  const std::size_t N = b_k_n.get_length(1);
334  const std::size_t K = a_m_k.get_length(1);
335 
336  auto as_m_k_tuple =
337  generate_tie([&](auto idx) -> auto& { return as_m_k[idx]; }, number<AsDataType::size()>{});
338 
339  auto bs_k_n_tuple =
340  generate_tie([&](auto idx) -> auto& { return bs_k_n[idx]; }, number<BsDataType::size()>{});
341 
342  auto ds_m_n_tuple =
343  generate_tie([&](auto idx) -> auto& { return ds_m_n[idx]; }, number<DsDataType::size()>{});
344 
345  // Apply elementwise function to A
346  auto a_elementwise_fn = [&](auto i, auto j) {
347  ck_tile::apply([&](auto&&... t) { a_element_op(a_m_k(i, j), t(i, j)...); }, as_m_k_tuple);
348  };
349 
350  make_ParallelTensorFunctor(a_elementwise_fn, M, K)(std::thread::hardware_concurrency());
351 
352  // Apply elementwise function to B
353  auto b_elementwise_fn = [&](auto i, auto j) {
354  ck_tile::apply([&](auto&&... t) { b_element_op(b_k_n(i, j), t(i, j)...); }, bs_k_n_tuple);
355  };
356 
357  make_ParallelTensorFunctor(b_elementwise_fn, K, N)(std::thread::hardware_concurrency());
358 
359  auto f_mk_kn_mn = [&](auto m, auto n) {
360  AccDataType v_acc = 0;
361  for(std::size_t k = 0; k < K; ++k)
362  {
363  ADataType v_a = a_m_k(m, k);
364  BDataType v_b = b_k_n(k, n);
365  v_acc +=
366  ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
367  }
368 
369  CDataType v_c = 0;
370 
372  [&](auto&&... t) {
373  acc_element_op(v_c,
374  ck_tile::type_convert<float>(v_acc),
375  ck_tile::type_convert<float>(t(m, n))...);
376  },
377  ds_m_n_tuple);
378 
379  c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_c);
380  };
381 
382  make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency());
383 }
384 
385 template <typename ADataType,
386  typename BDataType,
387  typename DsDataType,
388  typename AccDataType,
389  typename CDataType,
390  typename ACCElementOp,
391  typename DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>>
392 CK_TILE_HOST void
394  const HostTensor<BDataType>& b_k_n,
395  const std::array<HostTensor<DDataType>, DsDataType::size()>& ds_m_n,
396  HostTensor<CDataType>& c_m_n,
397  const ACCElementOp& acc_element_op = {})
398 {
399  const std::size_t M = a_m_k.get_length(0);
400  const std::size_t N = b_k_n.get_length(1);
401  const std::size_t K = a_m_k.get_length(1);
402 
403  auto f_mk_kn_mn = [&](auto m, auto n) {
404  AccDataType v_acc = 0;
405  for(std::size_t k = 0; k < K; ++k)
406  {
407  ADataType v_a = a_m_k(m, k);
408  BDataType v_b = b_k_n(k, n);
409  v_acc +=
410  ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
411  }
412 
413  CDataType v_c = 0;
414  if constexpr(DsDataType::size() == 0)
415  {
416  acc_element_op(v_c, ck_tile::type_convert<float>(v_acc));
417  }
418  else if constexpr(DsDataType::size() == 1)
419  {
420  acc_element_op(v_c,
421  ck_tile::type_convert<float>(v_acc),
422  ck_tile::type_convert<float>(ds_m_n[0](m, n)));
423  }
424  else if constexpr(DsDataType::size() == 2)
425  {
426  acc_element_op(v_c,
427  ck_tile::type_convert<float>(v_acc),
428  ck_tile::type_convert<float>(ds_m_n[0](m, n)),
429  ck_tile::type_convert<float>(ds_m_n[1](m, n)));
430  }
431  c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_c);
432  };
433 
434  make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency());
435 }
436 
437 template <typename ADataType,
438  typename BDataType,
439  typename AccDataType,
440  typename CDataType,
441  typename LayoutA,
442  typename LayoutB,
443  typename LayoutC>
444 __global__ void naive_gemm_kernel(ADataType* A,
445  BDataType* B,
446  CDataType* C,
450  ck_tile::index_t strideA,
451  ck_tile::index_t strideB,
452  ck_tile::index_t strideC)
453 {
454  int idx = blockIdx.x * blockDim.x + threadIdx.x;
455  int row = idx / N; // Compute row index
456  int col = idx % N; // Compute column index
457 
458  if(row < M && col < N)
459  {
460  AccDataType acc = 0.0;
461  for(int k = 0; k < K; ++k)
462  {
465  // Adjust indexing based on matrix layout
466  int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
467  ? row * strideA + k
468  : k * strideA + row;
469  int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
470  ? col * strideB + k
471  : k * strideB + col;
472 
473  AccDataType v_a;
474  AccDataType v_b;
475  if constexpr(std::is_same_v<ADataType, pk_int4_t>)
476  {
477  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
478  if(k % 2 == 1)
479  v_a = fp32_val.hi;
480  else
481  v_a = fp32_val.lo;
482  }
483  else
484  {
485  v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
486  }
487  if constexpr(std::is_same_v<BDataType, pk_int4_t>)
488  {
489  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
490  if(k % 2 == 1)
491  v_b = fp32_val.hi;
492  else
493  v_b = fp32_val.lo;
494  }
495  else
496  {
497  v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
498  }
499  acc += v_a * v_b;
500  }
501 
502  int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
503  ? row * strideC + col
504  : col * strideC + row;
505  C[c_index] = ck_tile::type_convert<CDataType>(acc);
506  }
507 }
508 
509 template <typename ADataType,
510  typename BDataType,
511  typename AccDataType,
512  typename CDataType,
513  typename LayoutA,
514  typename LayoutB,
515  typename LayoutC>
516 void reference_gemm_gpu(ADataType* a_ptr,
517  BDataType* b_ptr,
518  CDataType* c_ptr,
519  index_t M,
520  index_t N,
521  index_t K,
522  index_t stride_a,
523  index_t stride_b,
524  index_t stride_c)
525 {
526  int totalElements = M * N;
527  int numThreadsPerBlock = 256; // Common choice for threads per block
528  int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
529 
530  naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
531  <<<numBlocks, numThreadsPerBlock>>>(
532  a_ptr, b_ptr, c_ptr, M, N, K, stride_a, stride_b, stride_c);
533 
534  return;
535 }
536 
537 template <typename ADataType,
538  typename BDataType,
539  typename AccDataType,
540  typename CDataType,
541  typename LayoutA,
542  typename LayoutB,
543  typename LayoutC>
544 void reference_batched_gemm_gpu(ADataType* a_ptr,
545  BDataType* b_ptr,
546  CDataType* c_ptr,
547  index_t M,
548  index_t N,
549  index_t K,
550  index_t stride_a,
551  index_t stride_b,
552  index_t stride_c,
553  index_t batch_stride_A,
554  index_t batch_stride_B,
555  index_t batch_stride_C,
556  index_t batch_count)
557 {
558  int totalElements = M * N;
559  int numThreadsPerBlock = 256; // Common choice for threads per block
560  int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
561 
562  for(index_t batch_id = 0; batch_id < batch_count; ++batch_id)
563  {
564  ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A;
565  BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B;
566  CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
567  naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
568  <<<numBlocks, numThreadsPerBlock>>>(
569  d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
570  }
571 
572  return;
573 }
574 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
void reference_batched_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c, index_t batch_stride_A, index_t batch_stride_B, index_t batch_stride_C, index_t batch_count)
Definition: reference_gemm.hpp:544
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
constexpr decltype(auto) apply(F &&f, Tuple &&t)
Definition: tuple.hpp:526
__global__ void naive_gemm_kernel(ADataType *A, BDataType *B, CDataType *C, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t strideA, ck_tile::index_t strideB, ck_tile::index_t strideC)
Definition: reference_gemm.hpp:444
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t &x)
Definition: pk_int4.hpp:105
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t)
Definition: float8.hpp:751
CK_TILE_HOST void reference_gemm_quant(const HostTensor< ADataType > &a_m_k, const HostTensor< QDataType > &q, const HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:24
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t)
Definition: float8.hpp:764
CK_TILE_HOST void reference_gemm_multiple_abd(const std::array< HostTensor< ADataType >, AsDataType::size()> &as_m_k, const std::array< HostTensor< BDataType >, BsDataType::size()> &bs_k_n, const std::array< HostTensor< DDataType >, DsDataType::size()> &ds_m_n, HostTensor< ADataType > &a_m_k, HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const CDElementOp &acc_element_op={})
Definition: reference_gemm.hpp:322
float fp32x2_t
Definition: pk_fp4.hpp:22
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST void reference_gemm_rowcol_quant(const HostTensor< ADataType > &a_m_k, const HostTensor< AQDataType > &aq_m_1, const HostTensor< BDataType > &b_k_n, const HostTensor< BQDataType > &bq_1_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:127
constexpr CK_TILE_HOST_DEVICE auto generate_tie(F &&f, number< N >)
Definition: tuple.hpp:435
constant< v > number
Definition: integral_constant.hpp:37
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t_signed_conversion(const pk_int4_t &x)
Definition: pk_int4.hpp:120
void reference_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c)
Definition: reference_gemm.hpp:516
CK_TILE_HOST void reference_gemm_multiple_d(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, const std::array< HostTensor< DDataType >, DsDataType::size()> &ds_m_n, HostTensor< CDataType > &c_m_n, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:393
CK_TILE_HOST void reference_gemm(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:257
CK_TILE_HOST void reference_gemm_tensor_quant(const HostTensor< ADataType > &a_m_k, const HostTensor< AQDataType > &aq_1_1, const HostTensor< BDataType > &b_k_n, const HostTensor< BQDataType > &bq_1_1, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:208
unsigned int uint32_t
Definition: stdint.h:126
Definition: host_tensor.hpp:336
std::size_t get_length(std::size_t dim) const
Definition: host_tensor.hpp:388
Definition: functional.hpp:86
Definition: numeric.hpp:81