16 template <
typename ADataType,
 
   21           typename CDEElementWise>
 
   32     const CDEElementWise& cde_elementwise)
 
   34     std::cout << 
"Calculating reference using optimized flat indexing with parallel processing..." 
   38     auto f_gm = [&](
auto g_flat, 
auto m_flat) {
 
   47                     a_full_dims.
mData[g_flat * M_total * K_total + m_flat * K_total + k_flat];
 
   49                     b_full_dims.
mData[g_flat * N_total * K_total + n_flat * K_total + k_flat];
 
   50                 sum += 
static_cast<AccDataType
>(a_val) * 
static_cast<AccDataType
>(b_val);
 
   54             EDataType result = 
static_cast<EDataType
>(sum);
 
   55             if(ds_full_dims_host.size() == 0)
 
   59             else if(ds_full_dims_host.size() == 1)
 
   61                 cde_elementwise(result,
 
   62                                 ck_tile::type_convert<float>(sum),
 
   63                                 ck_tile::type_convert<float>(
 
   64                                     ds_full_dims_host[0].mData[g_flat * M_total * N_total +
 
   65                                                                m_flat * N_total + n_flat]));
 
   67             else if(ds_full_dims_host.size() == 2)
 
   71                     ck_tile::type_convert<float>(sum),
 
   72                     ck_tile::type_convert<float>(
 
   74                             .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
 
   75                     ck_tile::type_convert<float>(
 
   77                             .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
 
   79             else if(ds_full_dims_host.size() == 3)
 
   83                     ck_tile::type_convert<float>(sum),
 
   84                     ck_tile::type_convert<float>(
 
   86                             .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
 
   87                     ck_tile::type_convert<float>(
 
   89                             .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
 
   90                     ck_tile::type_convert<float>(
 
   92                             .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
 
   94             else if(ds_full_dims_host.size() == 4)
 
   98                     ck_tile::type_convert<float>(sum),
 
   99                     ck_tile::type_convert<float>(
 
  101                             .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
 
  102                     ck_tile::type_convert<float>(
 
  104                             .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
 
  105                     ck_tile::type_convert<float>(
 
  107                             .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
 
  108                     ck_tile::type_convert<float>(
 
  110                             .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
 
  114                 throw std::runtime_error(
"Unsupported NumDTensor for reference calculation");
 
  118             e_full_dims_host_ref.
mData[g_flat * M_total * N_total + m_flat * N_total + n_flat] =
 
  119                 static_cast<EDataType
>(result);
 
  128 template <
typename ADataType,
 
  132           typename AccDataType,
 
  133           typename CDEElementWise>
 
  139     const std::vector<index_t>& G_dims,
 
  140     const std::vector<index_t>& M_dims,
 
  141     const std::vector<index_t>& N_dims,
 
  142     const std::vector<index_t>& K_dims,
 
  143     const std::vector<index_t>& A_dims,
 
  144     const std::vector<index_t>& B_dims,
 
  145     const std::vector<index_t>& E_dims,
 
  146     const CDEElementWise& cde_elementwise)
 
  148     std::cout << 
"Calculating reference using multi-dimensional indexing..." << std::endl;
 
  150     std::vector<std::size_t> g_idx(G_dims.size());
 
  151     std::vector<std::size_t> m_idx(M_dims.size());
 
  152     std::vector<std::size_t> n_idx(N_dims.size());
 
  153     std::vector<std::size_t> k_idx(K_dims.size());
 
  154     std::vector<std::size_t> a_idx, b_idx, e_idx;
 
  156     a_idx.reserve(A_dims.size());
 
  157     b_idx.reserve(B_dims.size());
 
  158     e_idx.reserve(E_dims.size());
 
  160     auto calculate_total_elements = [](
const std::vector<ck_tile::index_t>& dims) {
 
  161         return std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<ck_tile::index_t>());
 
  164     for(
ck_tile::index_t g_flat = 0; g_flat < calculate_total_elements(G_dims); ++g_flat)
 
  167         for(
int i = G_dims.size() - 1; i >= 0; --i)
 
  169             g_idx[i] = temp % G_dims[i];
 
  173         for(
ck_tile::index_t m_flat = 0; m_flat < calculate_total_elements(M_dims); ++m_flat)
 
  176             for(
int i = M_dims.size() - 1; i >= 0; --i)
 
  178                 m_idx[i] = temp % M_dims[i];
 
  182             for(
ck_tile::index_t n_flat = 0; n_flat < calculate_total_elements(N_dims); ++n_flat)
 
  185                 for(
int i = N_dims.size() - 1; i >= 0; --i)
 
  187                     n_idx[i] = temp % N_dims[i];
 
  197                     for(
int i = K_dims.size() - 1; i >= 0; --i)
 
  199                         k_idx[i] = temp % K_dims[i];
 
  206                     a_idx.insert(a_idx.end(), g_idx.begin(), g_idx.end());
 
  207                     a_idx.insert(a_idx.end(), m_idx.begin(), m_idx.end());
 
  208                     a_idx.insert(a_idx.end(), k_idx.begin(), k_idx.end());
 
  210                     b_idx.insert(b_idx.end(), g_idx.begin(), g_idx.end());
 
  211                     b_idx.insert(b_idx.end(), n_idx.begin(), n_idx.end());
 
  212                     b_idx.insert(b_idx.end(), k_idx.begin(), k_idx.end());
 
  214                     auto a_val = a_full_dims(a_idx);
 
  215                     auto b_val = b_full_dims(b_idx);
 
  217                     sum += 
static_cast<AccDataType
>(a_val) * 
static_cast<AccDataType
>(b_val);
 
  221                 e_idx.insert(e_idx.end(), g_idx.begin(), g_idx.end());
 
  222                 e_idx.insert(e_idx.end(), m_idx.begin(), m_idx.end());
 
  223                 e_idx.insert(e_idx.end(), n_idx.begin(), n_idx.end());
 
  225                 EDataType result = 
static_cast<EDataType
>(sum);
 
  226                 if(ds_full_dims_host.size() == 0)
 
  230                 else if(ds_full_dims_host.size() == 1)
 
  232                     cde_elementwise(result,
 
  233                                     ck_tile::type_convert<float>(sum),
 
  234                                     ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)));
 
  236                 else if(ds_full_dims_host.size() == 2)
 
  238                     cde_elementwise(result,
 
  239                                     ck_tile::type_convert<float>(sum),
 
  240                                     ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
 
  241                                     ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)));
 
  243                 else if(ds_full_dims_host.size() == 3)
 
  245                     cde_elementwise(result,
 
  246                                     ck_tile::type_convert<float>(sum),
 
  247                                     ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
 
  248                                     ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)),
 
  249                                     ck_tile::type_convert<float>(ds_full_dims_host[2](e_idx)));
 
  251                 else if(ds_full_dims_host.size() == 4)
 
  253                     cde_elementwise(result,
 
  254                                     ck_tile::type_convert<float>(sum),
 
  255                                     ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
 
  256                                     ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)),
 
  257                                     ck_tile::type_convert<float>(ds_full_dims_host[2](e_idx)),
 
  258                                     ck_tile::type_convert<float>(ds_full_dims_host[3](e_idx)));
 
  262                     throw std::runtime_error(
"Unsupported NumDTensor for reference calculation");
 
  265                 e_full_dims_host_ref(e_idx) = 
static_cast<EDataType
>(result);
 
Definition: cluster_descriptor.hpp:13
 
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
 
int32_t index_t
Definition: integer.hpp:9
 
void calculate_reference_flat_indexing(const ck_tile::HostTensor< ADataType > &a_full_dims, const ck_tile::HostTensor< BDataType > &b_full_dims, const std::vector< ck_tile::HostTensor< DDataType >> &ds_full_dims_host, ck_tile::HostTensor< EDataType > &e_full_dims_host_ref, ck_tile::index_t G_total, ck_tile::index_t M_total, ck_tile::index_t N_total, ck_tile::index_t K_total, const CDEElementWise &cde_elementwise)
Definition: reference_batched_contraction.hpp:23
 
void calculate_reference_multi_dimensional(const HostTensor< ADataType > &a_full_dims, const HostTensor< BDataType > &b_full_dims, const std::vector< HostTensor< DDataType >> &ds_full_dims_host, HostTensor< EDataType > &e_full_dims_host_ref, const std::vector< index_t > &G_dims, const std::vector< index_t > &M_dims, const std::vector< index_t > &N_dims, const std::vector< index_t > &K_dims, const std::vector< index_t > &A_dims, const std::vector< index_t > &B_dims, const std::vector< index_t > &E_dims, const CDEElementWise &cde_elementwise)
Definition: reference_batched_contraction.hpp:134
 
Definition: host_tensor.hpp:336
 
Data mData
Definition: host_tensor.hpp:801