/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_batched_contraction.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_batched_contraction.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_batched_contraction.hpp Source File
reference_batched_contraction.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <cstdlib>
7 #include <functional>
8 #include <numeric>
9 #include <thread>
10 
11 #include "ck_tile/core.hpp"
13 
14 namespace ck_tile {
15 
16 template <typename ADataType,
17  typename BDataType,
18  typename DDataType,
19  typename EDataType,
20  typename AccDataType,
21  typename CDEElementWise>
22 
24  const ck_tile::HostTensor<ADataType>& a_full_dims,
25  const ck_tile::HostTensor<BDataType>& b_full_dims,
26  const std::vector<ck_tile::HostTensor<DDataType>>& ds_full_dims_host,
27  ck_tile::HostTensor<EDataType>& e_full_dims_host_ref,
28  ck_tile::index_t G_total,
29  ck_tile::index_t M_total,
30  ck_tile::index_t N_total,
31  ck_tile::index_t K_total,
32  const CDEElementWise& cde_elementwise)
33 {
34  std::cout << "Calculating reference using optimized flat indexing with parallel processing..."
35  << std::endl;
36 
37  // Parallel computation over G and M dimensions using pattern from reference_batched_gemm.hpp
38  auto f_gm = [&](auto g_flat, auto m_flat) {
39  for(ck_tile::index_t n_flat = 0; n_flat < N_total; ++n_flat)
40  {
41  AccDataType sum = 0;
42 
43  // Compute dot product over K dimension
44  for(ck_tile::index_t k_flat = 0; k_flat < K_total; ++k_flat)
45  {
46  auto a_val =
47  a_full_dims.mData[g_flat * M_total * K_total + m_flat * K_total + k_flat];
48  auto b_val =
49  b_full_dims.mData[g_flat * N_total * K_total + n_flat * K_total + k_flat];
50  sum += static_cast<AccDataType>(a_val) * static_cast<AccDataType>(b_val);
51  }
52 
53  // Apply elementwise operation with D tensors
54  EDataType result = static_cast<EDataType>(sum);
55  if(ds_full_dims_host.size() == 0)
56  {
57  ;
58  }
59  else if(ds_full_dims_host.size() == 1)
60  {
61  cde_elementwise(result,
62  ck_tile::type_convert<float>(sum),
63  ck_tile::type_convert<float>(
64  ds_full_dims_host[0].mData[g_flat * M_total * N_total +
65  m_flat * N_total + n_flat]));
66  }
67  else if(ds_full_dims_host.size() == 2)
68  {
69  cde_elementwise(
70  result,
71  ck_tile::type_convert<float>(sum),
72  ck_tile::type_convert<float>(
73  ds_full_dims_host[0]
74  .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
75  ck_tile::type_convert<float>(
76  ds_full_dims_host[1]
77  .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
78  }
79  else if(ds_full_dims_host.size() == 3)
80  {
81  cde_elementwise(
82  result,
83  ck_tile::type_convert<float>(sum),
84  ck_tile::type_convert<float>(
85  ds_full_dims_host[0]
86  .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
87  ck_tile::type_convert<float>(
88  ds_full_dims_host[1]
89  .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
90  ck_tile::type_convert<float>(
91  ds_full_dims_host[2]
92  .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
93  }
94  else if(ds_full_dims_host.size() == 4)
95  {
96  cde_elementwise(
97  result,
98  ck_tile::type_convert<float>(sum),
99  ck_tile::type_convert<float>(
100  ds_full_dims_host[0]
101  .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
102  ck_tile::type_convert<float>(
103  ds_full_dims_host[1]
104  .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
105  ck_tile::type_convert<float>(
106  ds_full_dims_host[2]
107  .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
108  ck_tile::type_convert<float>(
109  ds_full_dims_host[3]
110  .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
111  }
112  else
113  {
114  throw std::runtime_error("Unsupported NumDTensor for reference calculation");
115  }
116 
117  // Store result
118  e_full_dims_host_ref.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat] =
119  static_cast<EDataType>(result);
120  }
121  };
122 
123  // Execute parallel computation using hardware concurrency
124  // Parallelize over G_total and M_total dimensions for optimal CPU utilization
125  make_ParallelTensorFunctor(f_gm, G_total, M_total)(std::thread::hardware_concurrency());
126 }
127 
128 template <typename ADataType,
129  typename BDataType,
130  typename DDataType,
131  typename EDataType,
132  typename AccDataType,
133  typename CDEElementWise>
135  const HostTensor<ADataType>& a_full_dims,
136  const HostTensor<BDataType>& b_full_dims,
137  const std::vector<HostTensor<DDataType>>& ds_full_dims_host,
138  HostTensor<EDataType>& e_full_dims_host_ref,
139  const std::vector<index_t>& G_dims,
140  const std::vector<index_t>& M_dims,
141  const std::vector<index_t>& N_dims,
142  const std::vector<index_t>& K_dims,
143  const std::vector<index_t>& A_dims,
144  const std::vector<index_t>& B_dims,
145  const std::vector<index_t>& E_dims,
146  const CDEElementWise& cde_elementwise)
147 {
148  std::cout << "Calculating reference using multi-dimensional indexing..." << std::endl;
149 
150  std::vector<std::size_t> g_idx(G_dims.size());
151  std::vector<std::size_t> m_idx(M_dims.size());
152  std::vector<std::size_t> n_idx(N_dims.size());
153  std::vector<std::size_t> k_idx(K_dims.size());
154  std::vector<std::size_t> a_idx, b_idx, e_idx;
155 
156  a_idx.reserve(A_dims.size());
157  b_idx.reserve(B_dims.size());
158  e_idx.reserve(E_dims.size());
159 
160  auto calculate_total_elements = [](const std::vector<ck_tile::index_t>& dims) {
161  return std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<ck_tile::index_t>());
162  };
163 
164  for(ck_tile::index_t g_flat = 0; g_flat < calculate_total_elements(G_dims); ++g_flat)
165  {
166  ck_tile::index_t temp = g_flat;
167  for(int i = G_dims.size() - 1; i >= 0; --i)
168  {
169  g_idx[i] = temp % G_dims[i];
170  temp /= G_dims[i];
171  }
172 
173  for(ck_tile::index_t m_flat = 0; m_flat < calculate_total_elements(M_dims); ++m_flat)
174  {
175  temp = m_flat;
176  for(int i = M_dims.size() - 1; i >= 0; --i)
177  {
178  m_idx[i] = temp % M_dims[i];
179  temp /= M_dims[i];
180  }
181 
182  for(ck_tile::index_t n_flat = 0; n_flat < calculate_total_elements(N_dims); ++n_flat)
183  {
184  temp = n_flat;
185  for(int i = N_dims.size() - 1; i >= 0; --i)
186  {
187  n_idx[i] = temp % N_dims[i];
188  temp /= N_dims[i];
189  }
190 
191  AccDataType sum = 0;
192 
193  for(ck_tile::index_t k_flat = 0; k_flat < calculate_total_elements(K_dims);
194  ++k_flat)
195  {
196  temp = k_flat;
197  for(int i = K_dims.size() - 1; i >= 0; --i)
198  {
199  k_idx[i] = temp % K_dims[i];
200  temp /= K_dims[i];
201  }
202 
203  a_idx.clear();
204  b_idx.clear();
205 
206  a_idx.insert(a_idx.end(), g_idx.begin(), g_idx.end());
207  a_idx.insert(a_idx.end(), m_idx.begin(), m_idx.end());
208  a_idx.insert(a_idx.end(), k_idx.begin(), k_idx.end());
209 
210  b_idx.insert(b_idx.end(), g_idx.begin(), g_idx.end());
211  b_idx.insert(b_idx.end(), n_idx.begin(), n_idx.end());
212  b_idx.insert(b_idx.end(), k_idx.begin(), k_idx.end());
213 
214  auto a_val = a_full_dims(a_idx);
215  auto b_val = b_full_dims(b_idx);
216 
217  sum += static_cast<AccDataType>(a_val) * static_cast<AccDataType>(b_val);
218  }
219 
220  e_idx.clear();
221  e_idx.insert(e_idx.end(), g_idx.begin(), g_idx.end());
222  e_idx.insert(e_idx.end(), m_idx.begin(), m_idx.end());
223  e_idx.insert(e_idx.end(), n_idx.begin(), n_idx.end());
224 
225  EDataType result = static_cast<EDataType>(sum);
226  if(ds_full_dims_host.size() == 0)
227  {
228  ;
229  }
230  else if(ds_full_dims_host.size() == 1)
231  {
232  cde_elementwise(result,
233  ck_tile::type_convert<float>(sum),
234  ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)));
235  }
236  else if(ds_full_dims_host.size() == 2)
237  {
238  cde_elementwise(result,
239  ck_tile::type_convert<float>(sum),
240  ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
241  ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)));
242  }
243  else if(ds_full_dims_host.size() == 3)
244  {
245  cde_elementwise(result,
246  ck_tile::type_convert<float>(sum),
247  ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
248  ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)),
249  ck_tile::type_convert<float>(ds_full_dims_host[2](e_idx)));
250  }
251  else if(ds_full_dims_host.size() == 4)
252  {
253  cde_elementwise(result,
254  ck_tile::type_convert<float>(sum),
255  ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
256  ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)),
257  ck_tile::type_convert<float>(ds_full_dims_host[2](e_idx)),
258  ck_tile::type_convert<float>(ds_full_dims_host[3](e_idx)));
259  }
260  else
261  {
262  throw std::runtime_error("Unsupported NumDTensor for reference calculation");
263  }
264 
265  e_full_dims_host_ref(e_idx) = static_cast<EDataType>(result);
266  }
267  }
268  }
269 }
270 
271 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
int32_t index_t
Definition: integer.hpp:9
void calculate_reference_flat_indexing(const ck_tile::HostTensor< ADataType > &a_full_dims, const ck_tile::HostTensor< BDataType > &b_full_dims, const std::vector< ck_tile::HostTensor< DDataType >> &ds_full_dims_host, ck_tile::HostTensor< EDataType > &e_full_dims_host_ref, ck_tile::index_t G_total, ck_tile::index_t M_total, ck_tile::index_t N_total, ck_tile::index_t K_total, const CDEElementWise &cde_elementwise)
Definition: reference_batched_contraction.hpp:23
void calculate_reference_multi_dimensional(const HostTensor< ADataType > &a_full_dims, const HostTensor< BDataType > &b_full_dims, const std::vector< HostTensor< DDataType >> &ds_full_dims_host, HostTensor< EDataType > &e_full_dims_host_ref, const std::vector< index_t > &G_dims, const std::vector< index_t > &M_dims, const std::vector< index_t > &N_dims, const std::vector< index_t > &K_dims, const std::vector< index_t > &A_dims, const std::vector< index_t > &B_dims, const std::vector< index_t > &E_dims, const CDEElementWise &cde_elementwise)
Definition: reference_batched_contraction.hpp:134
Definition: host_tensor.hpp:336
Data mData
Definition: host_tensor.hpp:801