/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_transpose.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_transpose.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_transpose.hpp Source File
reference_transpose.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 #include <thread>
9 
10 namespace ck_tile {
11 
12 template <typename ADataType, typename BDataType>
14 {
15  ck_tile::index_t M = static_cast<ck_tile::index_t>(a.mDesc.get_lengths()[0]);
16  ck_tile::index_t N = static_cast<ck_tile::index_t>(a.mDesc.get_lengths()[1]);
17 
18  // Ensure the b tensor is sized correctly for N x M
19  if(static_cast<ck_tile::index_t>(b.mDesc.get_lengths()[0]) != N ||
20  static_cast<ck_tile::index_t>(b.mDesc.get_lengths()[1]) != M)
21  {
22  throw std::runtime_error("Output tensor b has incorrect dimensions for transpose.");
23  }
24 
25  auto f = [&](auto i, auto j) {
26  auto v_a = a(i, j);
27  b(j, i) = ck_tile::type_convert<BDataType>(v_a);
28  };
29 
30  make_ParallelTensorFunctor(f, M, N)(std::thread::hardware_concurrency());
31 }
32 
33 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
int32_t index_t
Definition: integer.hpp:9
void reference_transpose_elementwise(const HostTensor< ADataType > &a, HostTensor< BDataType > &b)
Definition: reference_transpose.hpp:13
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1249
const std::vector< std::size_t > & get_lengths() const
Definition: host_tensor.hpp:198
Definition: host_tensor.hpp:336
Descriptor mDesc
Definition: host_tensor.hpp:800