/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/tensor_shuffle_utils.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/tensor_shuffle_utils.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/tensor_shuffle_utils.hpp Source File
tensor_shuffle_utils.hpp
Go to the documentation of this file.
1 #pragma once
2 #include <stdexcept>
3 
4 namespace ck_tile {
5 template <typename T>
6 auto shuffle_aq(const ck_tile::HostTensor<T>* t, int block_aq_k)
7 {
8  if(t->get_lengths().size() != 2)
9  {
10  throw std::runtime_error("Host tensor is not rank 2 tensor.");
11  }
12  int m_ = t->get_lengths()[0];
13  int aqk_ = t->get_lengths()[1];
14  if(aqk_ % block_aq_k != 0)
15  {
16  throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k.");
17  }
18  ck_tile::HostTensor<T> t_view({m_, aqk_ / block_aq_k, block_aq_k});
19  std::copy(t->begin(), t->end(), t_view.begin());
20  return ck_tile::reference_permute(t_view, {1, 0, 2});
21 }
22 
23 template <typename GemmConfig, typename T>
25 {
26  assert(t.get_lengths().size() == 2);
27  int n_ = t.get_lengths()[1];
28  int k_ = t.get_lengths()[0];
29  constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
30  ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
31  GemmConfig::N_Warp_Tile,
32  k_ / GemmConfig::K_Warp_Tile,
33  divisor,
34  GemmConfig::K_Warp_Tile / divisor});
35  std::copy(t.begin(), t.end(), t_view.begin());
36  return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
37 }
38 
39 template <typename GemmConfig, typename T>
41 {
42  assert(t.get_lengths().size() == 2);
43 
44  int n_ = t.get_lengths()[1];
45  int bqk_ = t.get_lengths()[0];
46  constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
47 
49  {n_ / GemmConfig::N_Tile, GemmConfig::N_Warp, GemmConfig::N_Warp_Tile, NRepeat, bqk_});
50  std::copy(t.begin(), t.end(), t_view.begin());
51  return ck_tile::reference_permute(t_view, {0, 3, 1, 2, 4});
52 }
53 
54 template <typename GemmConfig, typename T>
56 {
57  assert(t.get_lengths().size() == 2);
58 
59  int n_ = t.get_lengths()[1];
60  int k_ = t.get_lengths()[0];
61  constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
62  constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
63 
64  ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
65  GemmConfig::N_Warp,
66  GemmConfig::N_Warp_Tile,
67  NRepeat,
68  k_ / GemmConfig::K_Warp_Tile,
69  divisor,
70  GemmConfig::K_Warp_Tile / divisor});
71 
72  std::copy(t.begin(), t.end(), t_view.begin());
73  return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
74 }
75 } // namespace ck_tile
auto copy(InputRange &&range, OutputIterator iter) -> decltype(std::copy(std::begin(std::forward< InputRange >(range)), std::end(std::forward< InputRange >(range)), iter))
Definition: algorithm.hpp:14
Definition: cluster_descriptor.hpp:13
auto shuffle_b(const ck_tile::HostTensor< T > &t)
Definition: tensor_shuffle_utils.hpp:24
auto shuffle_bq_permuteN(const ck_tile::HostTensor< T > &t)
Definition: tensor_shuffle_utils.hpp:40
auto shuffle_b_permuteN(const ck_tile::HostTensor< T > &t)
Definition: tensor_shuffle_utils.hpp:55
auto shuffle_aq(const ck_tile::HostTensor< T > *t, int block_aq_k)
Definition: tensor_shuffle_utils.hpp:6
CK_TILE_HOST void reference_permute(const HostTensor< DataType > &x, HostTensor< DataType > &y, std::vector< index_t > perm)
Definition: reference_permute.hpp:19
Definition: host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
Data::iterator end()
Definition: host_tensor.hpp:589
Data::iterator begin()
Definition: host_tensor.hpp:587