include/ck_tile/host/reference/reference_gemm.hpp Source File

include/ck_tile/host/reference/reference_gemm.hpp Source File#

Composable Kernel: include/ck_tile/host/reference/reference_gemm.hpp Source File
reference_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <cstdlib>
7 #include <thread>
8 
9 #include "ck_tile/core.hpp"
11 
12 namespace ck_tile {
13 
14 template <typename ADataType,
15  typename BDataType,
16  typename AccDataType,
17  typename CDataType,
18  typename AElementOp = ck_tile::identity,
19  typename BElementOp = ck_tile::identity,
20  typename ACCElementOp = ck_tile::identity>
22  const HostTensor<BDataType>& b_k_n,
23  HostTensor<CDataType>& c_m_n,
24  const AElementOp& a_element_op = {},
25  const BElementOp& b_element_op = {},
26  const ACCElementOp& acc_element_op = {})
27 {
28  const std::size_t M = a_m_k.get_length(0);
29  const std::size_t N = b_k_n.get_length(1);
30  const std::size_t K = a_m_k.get_length(1);
31 
32  auto f_mn = [&](auto m, auto n) {
33  AccDataType v_acc = 0;
34 
35  for(std::size_t k = 0; k < K; ++k)
36  {
37  ADataType v_a = a_element_op(a_m_k(m, k));
38  BDataType v_b = b_element_op(b_k_n(k, n));
39 
40  v_acc +=
41  ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
42  }
43 
44  c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
45  };
46 
47  make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
48 }
49 
50 template <typename ADataType,
51  typename BDataType,
52  typename AccDataType,
53  typename CDataType,
54  typename LayoutA,
55  typename LayoutB,
56  typename LayoutC>
57 __global__ void naive_gemm_kernel(ADataType* A,
58  BDataType* B,
59  CDataType* C,
63  ck_tile::index_t strideA,
64  ck_tile::index_t strideB,
65  ck_tile::index_t strideC)
66 {
67  int idx = blockIdx.x * blockDim.x + threadIdx.x;
68  int row = idx / N; // Compute row index
69  int col = idx % N; // Compute column index
70 
71  if(row < M && col < N)
72  {
73  AccDataType acc = 0.0;
74  for(int k = 0; k < K; ++k)
75  {
76  // Adjust indexing based on matrix layout
77  int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
78  ? row * strideA + k
79  : k * strideA + row;
80  int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
81  ? col * strideB + k
82  : k * strideB + col;
83  acc += ck_tile::type_convert<AccDataType>(A[a_index]) *
84  ck_tile::type_convert<AccDataType>(B[b_index]);
85  }
86 
87  int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
88  ? row * strideC + col
89  : col * strideC + row;
90  C[c_index] = ck_tile::type_convert<CDataType>(acc);
91  }
92 }
93 
94 template <typename ADataType,
95  typename BDataType,
96  typename AccDataType,
97  typename CDataType,
98  typename LayoutA,
99  typename LayoutB,
100  typename LayoutC>
101 void reference_gemm_gpu(ADataType* a_ptr,
102  BDataType* b_ptr,
103  CDataType* c_ptr,
104  index_t M,
105  index_t N,
106  index_t K,
107  index_t stride_a,
108  index_t stride_b,
109  index_t stride_c)
110 {
111  int totalElements = M * N;
112  int numThreadsPerBlock = 256; // Common choice for threads per block
113  int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
114 
115  naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
116  <<<numBlocks, numThreadsPerBlock>>>(
117  a_ptr, b_ptr, c_ptr, M, N, K, stride_a, stride_b, stride_c);
118 
119  return;
120 }
121 
122 template <typename ADataType,
123  typename BDataType,
124  typename AccDataType,
125  typename CDataType,
126  typename LayoutA,
127  typename LayoutB,
128  typename LayoutC>
129 void reference_batched_gemm_gpu(ADataType* a_ptr,
130  BDataType* b_ptr,
131  CDataType* c_ptr,
132  index_t M,
133  index_t N,
134  index_t K,
135  index_t stride_a,
136  index_t stride_b,
137  index_t stride_c,
138  index_t batch_stride_A,
139  index_t batch_stride_B,
140  index_t batch_stride_C,
141  index_t batch_count)
142 {
143  int totalElements = M * N;
144  int numThreadsPerBlock = 256; // Common choice for threads per block
145  int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
146 
147  for(index_t batch_id = 0; batch_id < batch_count; ++batch_id)
148  {
149  ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A;
150  BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B;
151  CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
152  naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
153  <<<numBlocks, numThreadsPerBlock>>>(
154  d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
155  }
156 
157  return;
158 }
159 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:39
Definition: cluster_descriptor.hpp:13
void reference_batched_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c, index_t batch_stride_A, index_t batch_stride_B, index_t batch_stride_C, index_t batch_count)
Definition: reference_gemm.hpp:129
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:272
__global__ void naive_gemm_kernel(ADataType *A, BDataType *B, CDataType *C, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t strideA, ck_tile::index_t strideB, ck_tile::index_t strideC)
Definition: reference_gemm.hpp:57
int32_t index_t
Definition: integer.hpp:9
void reference_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c)
Definition: reference_gemm.hpp:101
CK_TILE_HOST void reference_gemm(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:21
Definition: host_tensor.hpp:279
std::size_t get_length(std::size_t dim) const
Definition: host_tensor.hpp:331
Definition: functional.hpp:62