/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/tensor_shuffle_utils.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/tensor_shuffle_utils.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/tensor_shuffle_utils.hpp Source File
tensor_shuffle_utils.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 #include <stdexcept>
6 
7 namespace ck_tile {
8 template <typename T>
9 auto shuffle_aq(const ck_tile::HostTensor<T>* t, int block_aq_k)
10 {
11  if(t->get_lengths().size() != 2)
12  {
13  throw std::runtime_error("Host tensor is not rank 2 tensor.");
14  }
15  int m_ = t->get_lengths()[0];
16  int aqk_ = t->get_lengths()[1];
17 
18  if(aqk_ % block_aq_k != 0)
19  {
20  throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k.");
21  }
22  ck_tile::HostTensor<T> t_view({m_, aqk_ / block_aq_k, block_aq_k});
23  std::copy(t->begin(), t->end(), t_view.begin());
24  return ck_tile::reference_permute(t_view, {1, 0, 2});
25 }
26 
27 template <typename T>
28 auto shuffle_bq(const ck_tile::HostTensor<T>* t, int block_bq_k)
29 {
30  const auto& lengths = t->get_lengths();
31  const size_t rank = lengths.size();
32 
33  // Validate block_bq_k divisibility based on rank
34  int bqk_dim = (rank == 5) ? lengths[4] : (rank == 2) ? lengths[0] : -1;
35 
36  if(bqk_dim < 0)
37  {
38  throw std::runtime_error("shuffle_bq expects either rank-2 or rank-5 tensor, got rank " +
39  std::to_string(rank));
40  }
41 
42  if(bqk_dim % block_bq_k != 0)
43  {
44  throw std::runtime_error("shuffle_bq needs bqk dimension to be a multiple of block_bq_k.");
45  }
46 
47  // For TilePermuteN
48  if(rank == 5)
49  {
50  // Handle 5D tensor: [n, nrepeat, nwarp, n_warp_tile, bqk]
51  ck_tile::HostTensor<T> t_view({static_cast<int>(lengths[0]),
52  static_cast<int>(lengths[1]),
53  static_cast<int>(lengths[2]),
54  static_cast<int>(lengths[3]),
55  bqk_dim / block_bq_k,
56  block_bq_k});
57  std::copy(t->begin(), t->end(), t_view.begin());
58  return ck_tile::reference_permute(t_view, {4, 0, 1, 2, 3, 5});
59  }
60  else // rank == 2
61  {
62  // Handle 2D tensor: [bqk, n]
63  int n_ = lengths[1];
64  ck_tile::HostTensor<T> t_view({n_, bqk_dim / block_bq_k, block_bq_k});
65  std::copy(t->begin(), t->end(), t_view.begin());
66  return ck_tile::reference_permute(t_view, {1, 0, 2});
67  }
68 }
69 
70 template <typename GemmConfig, typename T>
72 {
73  assert(t.get_lengths().size() == 2);
74  int n_ = t.get_lengths()[1];
75  int k_ = t.get_lengths()[0];
76 
78  {
79  constexpr int divisor = 2;
80  constexpr int kABK1PerLane = 8;
81  constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
82  ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
83  GemmConfig::N_Warp_Tile,
84  k_ / GemmConfig::K_Warp_Tile,
85  kABK0PerLane,
86  divisor,
87  kABK1PerLane});
88  std::copy(t.begin(), t.end(), t_view.begin());
89  return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
90  }
91  else
92  {
93  int divisor = 1;
95  {
96  divisor = 1;
97  }
98  else
99  {
100  assert(is_wave32() == false);
101  divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
102  }
103  ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
104  GemmConfig::N_Warp_Tile,
105  k_ / GemmConfig::K_Warp_Tile,
106  divisor,
107  GemmConfig::K_Warp_Tile / divisor});
108  std::copy(t.begin(), t.end(), t_view.begin());
109  return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
110  }
111 }
112 
113 template <typename GemmConfig, typename T>
115 {
116  assert(t.get_lengths().size() == 2);
117 
118  int n_ = t.get_lengths()[1];
119  int bqk_ = t.get_lengths()[0];
120  constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
121 
122  ck_tile::HostTensor<T> t_view({n_ / (GemmConfig::N_Tile / group_n),
123  GemmConfig::N_Warp,
124  GemmConfig::N_Warp_Tile / group_n,
125  NRepeat,
126  bqk_});
127  std::copy(t.begin(), t.end(), t_view.begin());
128  return ck_tile::reference_permute(t_view, {0, 3, 1, 2, 4});
129 }
130 
131 template <typename GemmConfig, typename T>
133 {
134  assert(t.get_lengths().size() == 2);
135  int n_ = t.get_lengths()[1];
136  int k_ = t.get_lengths()[0];
137  constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
139  {
140  constexpr int divisor = 2;
141  constexpr int kABK1PerLane = 8;
142  constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
143  ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
144  GemmConfig::N_Warp,
145  GemmConfig::N_Warp_Tile,
146  NRepeat,
147  k_ / GemmConfig::K_Warp_Tile,
148  kABK0PerLane,
149  divisor,
150  kABK1PerLane});
151  std::copy(t.begin(), t.end(), t_view.begin());
152  return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 6, 5, 2, 7});
153  }
154  else
155  {
156  int divisor = 1;
158  {
159  divisor = 1;
160  }
161  else
162  {
163  assert(is_wave32() == false);
164  divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
165  }
166  ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
167  GemmConfig::N_Warp,
168  GemmConfig::N_Warp_Tile,
169  NRepeat,
170  k_ / GemmConfig::K_Warp_Tile,
171  divisor,
172  GemmConfig::K_Warp_Tile / divisor});
173  std::copy(t.begin(), t.end(), t_view.begin());
174  return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
175  }
176 }
177 } // namespace ck_tile
__host__ constexpr __device__ auto rank([[maybe_unused]] const Layout< Shape, UnrolledDescriptorType > &layout)
Get layout rank (num elements in shape).
Definition: layout_utils.hpp:310
auto copy(InputRange &&range, OutputIterator iter) -> decltype(std::copy(std::begin(std::forward< InputRange >(range)), std::end(std::forward< InputRange >(range)), iter))
Definition: algorithm.hpp:14
Definition: cluster_descriptor.hpp:13
auto shuffle_bq(const ck_tile::HostTensor< T > *t, int block_bq_k)
Definition: tensor_shuffle_utils.hpp:28
bool is_gfx12_supported()
Definition: device_prop.hpp:63
auto shuffle_b(const ck_tile::HostTensor< T > &t)
Definition: tensor_shuffle_utils.hpp:71
auto shuffle_b_permuteN(const ck_tile::HostTensor< T > &t)
Definition: tensor_shuffle_utils.hpp:132
int32_t index_t
Definition: integer.hpp:9
auto shuffle_aq(const ck_tile::HostTensor< T > *t, int block_aq_k)
Definition: tensor_shuffle_utils.hpp:9
bool is_gfx11_supported()
Definition: device_prop.hpp:55
auto bq_permuteN(const ck_tile::HostTensor< T > &t, index_t group_n)
Definition: tensor_shuffle_utils.hpp:114
CK_TILE_HOST void reference_permute(const HostTensor< DataType > &x, HostTensor< DataType > &y, std::vector< index_t > perm)
Definition: reference_permute.hpp:19
Definition: host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
Data::iterator end()
Definition: host_tensor.hpp:588
Data::iterator begin()
Definition: host_tensor.hpp:586