/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_gemm.hpp Source File
reference_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <cstdlib>
7 #include <thread>
8 
9 #include "ck_tile/core.hpp"
11 
12 namespace ck_tile {
13 
14 template <typename ADataType,
15  typename QDataType,
16  typename BDataType,
17  typename AccDataType,
18  typename CDataType,
19  uint32_t QuantGroupSize,
20  bool aquant,
21  typename AElementOp = ck_tile::identity,
22  typename BElementOp = ck_tile::identity,
23  typename ACCElementOp = ck_tile::identity>
25  const HostTensor<QDataType>& q,
26  const HostTensor<BDataType>& b_k_n,
27  HostTensor<CDataType>& c_m_n,
28  const AElementOp& a_element_op = {},
29  const BElementOp& b_element_op = {},
30  const ACCElementOp& acc_element_op = {})
31 {
32  const std::size_t M = a_m_k.get_length(0);
33  const std::size_t N = b_k_n.get_length(1);
34  const std::size_t K = a_m_k.get_length(1);
35 
36  auto f_mn = [&](auto m, auto n) {
37  AccDataType v_acc = 0, v_block_acc = 0;
38 
39  static_assert(std::is_same_v<ADataType, pk_int4_t> || std::is_same_v<ADataType, fp8_t> ||
40  std::is_same_v<ADataType, bf8_t>);
41  static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
42  std::is_same_v<BDataType, pk_int4_t>);
43  static_assert(std::is_same_v<AccDataType, float>);
44  static_assert(std::is_same_v<CDataType, float> ||
45  std::is_same_v<CDataType, ck_tile::half_t>);
46  for(std::size_t k = 0; k < K; ++k)
47  {
48  AccDataType v_a;
49  AccDataType v_b;
50  if constexpr(std::is_same_v<ADataType, pk_int4_t>)
51  {
52  const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
53  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
54  if(k % 2 == 1)
55  v_a = fp32_val.hi;
56  else
57  v_a = fp32_val.lo;
58  }
59  else
60  {
61  v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
62  }
63  if constexpr(std::is_same_v<BDataType, pk_int4_t>)
64  {
65  const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
66  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
67  if(k % 2 == 1)
68  v_b = fp32_val.hi;
69  else
70  v_b = fp32_val.lo;
71  }
72  else if constexpr(std::is_same_v<BDataType, fp8_t>)
73  {
74  v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n)));
75  }
76  else
77  {
78  v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
79  }
80  v_block_acc += v_a * v_b;
81 
82  // Apply group dequant scale
83  if((k + 1) % QuantGroupSize == 0)
84  {
85  float scale = 0.f;
86  index_t outer_dim = (aquant) ? m : k / QuantGroupSize;
87  index_t inner_dim = (aquant) ? k / QuantGroupSize : n;
88 
89  if constexpr(std::is_same_v<QDataType, float>)
90  {
91  scale = q(outer_dim, inner_dim);
92  }
93  else if constexpr(std::is_same_v<QDataType, ck_tile::fp8_t>)
94  {
95  scale = fp8_to_float_raw(q(outer_dim, inner_dim));
96  }
97  else if constexpr(std::is_same_v<QDataType, ck_tile::bf8_t>)
98  {
99  scale = bf8_to_float_raw(q(outer_dim, inner_dim));
100  }
101  else
102  {
103  static_assert(false, "Unexpected Q datatype.");
104  }
105  v_block_acc *= scale;
106  v_acc += v_block_acc;
107  v_block_acc = 0;
108  }
109  }
110 
111  c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
112  };
113 
114  make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
115  std::cout << std::endl;
116 }
117 
118 template <typename ADataType,
119  typename AQDataType,
120  typename BDataType,
121  typename BQDataType,
122  typename AccDataType,
123  typename CDataType,
124  typename AElementOp = ck_tile::identity,
125  typename BElementOp = ck_tile::identity,
126  typename ACCElementOp = ck_tile::identity>
128  const HostTensor<AQDataType>& aq_m_1,
129  const HostTensor<BDataType>& b_k_n,
130  const HostTensor<BQDataType>& bq_1_n,
131  HostTensor<CDataType>& c_m_n,
132  const AElementOp& a_element_op = {},
133  const BElementOp& b_element_op = {},
134  const ACCElementOp& acc_element_op = {})
135 {
136  static_assert(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t>);
137  static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t>);
138  static_assert(std::is_same_v<AccDataType, float>);
139  static_assert(std::is_same_v<CDataType, float> || std::is_same_v<CDataType, ck_tile::half_t>);
140  static_assert(std::is_same_v<AQDataType, float> && std::is_same_v<BQDataType, float>);
141  const std::size_t M = a_m_k.get_length(0);
142  const std::size_t N = b_k_n.get_length(1);
143  const std::size_t K = a_m_k.get_length(1);
144 
145  auto f_mn = [&](auto m, auto n) {
146  // Init accumulator
147  AccDataType v_acc = 0;
148  // Get row scale for A and column scale for B
149  float a_scale = aq_m_1(m, 0);
150  float b_scale = bq_1_n(0, n);
151 
152  // Compute the dot product
153  for(std::size_t k = 0; k < K; ++k)
154  {
155  AccDataType v_a;
156  AccDataType v_b;
157 
158  // Process A data
159  if constexpr(std::is_same_v<ADataType, pk_int4_t>)
160  {
161  const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
162  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
163  if(k % 2 == 1)
164  v_a = fp32_val.hi;
165  else
166  v_a = fp32_val.lo;
167  }
168  else
169  {
170  v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
171  }
172 
173  // Process B data
174  if constexpr(std::is_same_v<BDataType, pk_int4_t>)
175  {
176  const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
177  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
178  if(k % 2 == 1)
179  v_b = fp32_val.hi;
180  else
181  v_b = fp32_val.lo;
182  }
183  else if constexpr(std::is_same_v<BDataType, fp8_t>)
184  {
185  v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n)));
186  }
187  else
188  {
189  v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
190  }
191 
192  v_acc += v_a * v_b;
193  }
194 
195  v_acc = v_acc * a_scale * b_scale;
196 
197  c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
198  };
199 
200  make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
201  std::cout << std::endl;
202 }
203 
204 template <typename ADataType,
205  typename BDataType,
206  typename AccDataType,
207  typename CDataType,
208  typename AElementOp = ck_tile::identity,
209  typename BElementOp = ck_tile::identity,
210  typename ACCElementOp = ck_tile::identity>
212  const HostTensor<BDataType>& b_k_n,
213  HostTensor<CDataType>& c_m_n,
214  const AElementOp& a_element_op = {},
215  const BElementOp& b_element_op = {},
216  const ACCElementOp& acc_element_op = {})
217 {
218  const std::size_t M = a_m_k.get_length(0);
219  const std::size_t N = b_k_n.get_length(1);
220  const std::size_t K = a_m_k.get_length(1);
221 
222  auto f_mn = [&](auto m, auto n) {
223  AccDataType v_acc = 0;
224 
225  for(std::size_t k = 0; k < K; ++k)
226  {
227  AccDataType v_a;
228  AccDataType v_b;
229  if constexpr(std::is_same_v<ADataType, pk_int4_t>)
230  {
231  const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
232  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
233  if(k % 2 == 1)
234  v_a = fp32_val.hi;
235  else
236  v_a = fp32_val.lo;
237  }
238  else
239  {
240  v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
241  }
242  if constexpr(std::is_same_v<BDataType, pk_int4_t>)
243  {
244  const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
245  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
246  if(k % 2 == 1)
247  v_b = fp32_val.hi;
248  else
249  v_b = fp32_val.lo;
250  }
251  else
252  {
253  v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
254  }
255  v_acc += v_a * v_b;
256  }
257 
258  c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
259  };
260 
261  make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
262 }
263 
264 template <typename ADataType,
265  typename BDataType,
266  typename DsDataType,
267  typename AccDataType,
268  typename CDataType,
269  typename ACCElementOp,
270  typename DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>>
271 CK_TILE_HOST void
273  const HostTensor<BDataType>& b_k_n,
274  const std::array<HostTensor<DDataType>, DsDataType::size()>& ds_m_n,
275  HostTensor<CDataType>& c_m_n,
276  const ACCElementOp& acc_element_op = {})
277 {
278  const std::size_t M = a_m_k.get_length(0);
279  const std::size_t N = b_k_n.get_length(1);
280  const std::size_t K = a_m_k.get_length(1);
281 
282  auto f_mk_kn_mn = [&](auto m, auto n) {
283  AccDataType v_acc = 0;
284  for(std::size_t k = 0; k < K; ++k)
285  {
286  ADataType v_a = a_m_k(m, k);
287  BDataType v_b = b_k_n(k, n);
288  v_acc +=
289  ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
290  }
291 
292  CDataType v_c = 0;
293  if constexpr(DsDataType::size() == 0)
294  {
295  acc_element_op(v_c, ck_tile::type_convert<float>(v_acc));
296  }
297  else if constexpr(DsDataType::size() == 1)
298  {
299  acc_element_op(v_c,
300  ck_tile::type_convert<float>(v_acc),
301  ck_tile::type_convert<float>(ds_m_n[0](m, n)));
302  }
303  else if constexpr(DsDataType::size() == 2)
304  {
305  acc_element_op(v_c,
306  ck_tile::type_convert<float>(v_acc),
307  ck_tile::type_convert<float>(ds_m_n[0](m, n)),
308  ck_tile::type_convert<float>(ds_m_n[1](m, n)));
309  }
310  c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_c);
311  };
312 
313  make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency());
314 }
315 
316 template <typename ADataType,
317  typename BDataType,
318  typename AccDataType,
319  typename CDataType,
320  typename LayoutA,
321  typename LayoutB,
322  typename LayoutC>
323 __global__ void naive_gemm_kernel(ADataType* A,
324  BDataType* B,
325  CDataType* C,
329  ck_tile::index_t strideA,
330  ck_tile::index_t strideB,
331  ck_tile::index_t strideC)
332 {
333  int idx = blockIdx.x * blockDim.x + threadIdx.x;
334  int row = idx / N; // Compute row index
335  int col = idx % N; // Compute column index
336 
337  if(row < M && col < N)
338  {
339  AccDataType acc = 0.0;
340  for(int k = 0; k < K; ++k)
341  {
344  // Adjust indexing based on matrix layout
345  int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
346  ? row * strideA + k
347  : k * strideA + row;
348  int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
349  ? col * strideB + k
350  : k * strideB + col;
351 
352  AccDataType v_a;
353  AccDataType v_b;
354  if constexpr(std::is_same_v<ADataType, pk_int4_t>)
355  {
356  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
357  if(k % 2 == 1)
358  v_a = fp32_val.hi;
359  else
360  v_a = fp32_val.lo;
361  }
362  else
363  {
364  v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
365  }
366  if constexpr(std::is_same_v<BDataType, pk_int4_t>)
367  {
368  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
369  if(k % 2 == 1)
370  v_b = fp32_val.hi;
371  else
372  v_b = fp32_val.lo;
373  }
374  else
375  {
376  v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
377  }
378  acc += v_a * v_b;
379  }
380 
381  int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
382  ? row * strideC + col
383  : col * strideC + row;
384  C[c_index] = ck_tile::type_convert<CDataType>(acc);
385  }
386 }
387 
388 template <typename ADataType,
389  typename BDataType,
390  typename AccDataType,
391  typename CDataType,
392  typename LayoutA,
393  typename LayoutB,
394  typename LayoutC>
395 void reference_gemm_gpu(ADataType* a_ptr,
396  BDataType* b_ptr,
397  CDataType* c_ptr,
398  index_t M,
399  index_t N,
400  index_t K,
401  index_t stride_a,
402  index_t stride_b,
403  index_t stride_c)
404 {
405  int totalElements = M * N;
406  int numThreadsPerBlock = 256; // Common choice for threads per block
407  int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
408 
409  naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
410  <<<numBlocks, numThreadsPerBlock>>>(
411  a_ptr, b_ptr, c_ptr, M, N, K, stride_a, stride_b, stride_c);
412 
413  return;
414 }
415 
416 template <typename ADataType,
417  typename BDataType,
418  typename AccDataType,
419  typename CDataType,
420  typename LayoutA,
421  typename LayoutB,
422  typename LayoutC>
423 void reference_batched_gemm_gpu(ADataType* a_ptr,
424  BDataType* b_ptr,
425  CDataType* c_ptr,
426  index_t M,
427  index_t N,
428  index_t K,
429  index_t stride_a,
430  index_t stride_b,
431  index_t stride_c,
432  index_t batch_stride_A,
433  index_t batch_stride_B,
434  index_t batch_stride_C,
435  index_t batch_count)
436 {
437  int totalElements = M * N;
438  int numThreadsPerBlock = 256; // Common choice for threads per block
439  int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
440 
441  for(index_t batch_id = 0; batch_id < batch_count; ++batch_id)
442  {
443  ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A;
444  BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B;
445  CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
446  naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
447  <<<numBlocks, numThreadsPerBlock>>>(
448  d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
449  }
450 
451  return;
452 }
453 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
void reference_batched_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c, index_t batch_stride_A, index_t batch_stride_B, index_t batch_stride_C, index_t batch_count)
Definition: reference_gemm.hpp:423
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
__global__ void naive_gemm_kernel(ADataType *A, BDataType *B, CDataType *C, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t strideA, ck_tile::index_t strideB, ck_tile::index_t strideC)
Definition: reference_gemm.hpp:323
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t &x)
Definition: pk_int4.hpp:105
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t)
Definition: float8.hpp:751
CK_TILE_HOST void reference_gemm_quant(const HostTensor< ADataType > &a_m_k, const HostTensor< QDataType > &q, const HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:24
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t)
Definition: float8.hpp:764
float fp32x2_t
Definition: pk_fp4.hpp:22
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST void reference_gemm_rowcol_quant(const HostTensor< ADataType > &a_m_k, const HostTensor< AQDataType > &aq_m_1, const HostTensor< BDataType > &b_k_n, const HostTensor< BQDataType > &bq_1_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:127
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t_signed_conversion(const pk_int4_t &x)
Definition: pk_int4.hpp:120
void reference_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c)
Definition: reference_gemm.hpp:395
CK_TILE_HOST void reference_gemm_multiple_d(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, const std::array< HostTensor< DDataType >, DsDataType::size()> &ds_m_n, HostTensor< CDataType > &c_m_n, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:272
CK_TILE_HOST void reference_gemm(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:211
unsigned int uint32_t
Definition: stdint.h:126
Definition: host_tensor.hpp:336
std::size_t get_length(std::size_t dim) const
Definition: host_tensor.hpp:388
Definition: functional.hpp:86
Definition: numeric.hpp:81