13 throw std::runtime_error(
"Host tensor is not rank 2 tensor.");
18 if(aqk_ % block_aq_k != 0)
20 throw std::runtime_error(
"shuffle_aq needs a aqk of multiple times of block_aq_k.");
31 const size_t rank = lengths.size();
34 int bqk_dim = (
rank == 5) ? lengths[4] : (
rank == 2) ? lengths[0] : -1;
38 throw std::runtime_error(
"shuffle_bq expects either rank-2 or rank-5 tensor, got rank " +
39 std::to_string(
rank));
42 if(bqk_dim % block_bq_k != 0)
44 throw std::runtime_error(
"shuffle_bq needs bqk dimension to be a multiple of block_bq_k.");
52 static_cast<int>(lengths[1]),
53 static_cast<int>(lengths[2]),
54 static_cast<int>(lengths[3]),
70 template <
typename GemmConfig,
typename T>
79 constexpr
int divisor = 2;
80 constexpr
int kABK1PerLane = 8;
81 constexpr
int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
83 GemmConfig::N_Warp_Tile,
84 k_ / GemmConfig::K_Warp_Tile,
100 assert(is_wave32() ==
false);
101 divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
104 GemmConfig::N_Warp_Tile,
105 k_ / GemmConfig::K_Warp_Tile,
107 GemmConfig::K_Warp_Tile / divisor});
113 template <
typename GemmConfig,
typename T>
120 constexpr
int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
124 GemmConfig::N_Warp_Tile / group_n,
131 template <
typename GemmConfig,
typename T>
137 constexpr
int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
140 constexpr
int divisor = 2;
141 constexpr
int kABK1PerLane = 8;
142 constexpr
int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
145 GemmConfig::N_Warp_Tile,
147 k_ / GemmConfig::K_Warp_Tile,
163 assert(is_wave32() ==
false);
164 divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
168 GemmConfig::N_Warp_Tile,
170 k_ / GemmConfig::K_Warp_Tile,
172 GemmConfig::K_Warp_Tile / divisor});
__host__ constexpr __device__ auto rank([[maybe_unused]] const Layout< Shape, UnrolledDescriptorType > &layout)
Get layout rank (num elements in shape).
Definition: layout_utils.hpp:310
auto copy(InputRange &&range, OutputIterator iter) -> decltype(std::copy(std::begin(std::forward< InputRange >(range)), std::end(std::forward< InputRange >(range)), iter))
Definition: algorithm.hpp:14
Definition: cluster_descriptor.hpp:13
auto shuffle_bq(const ck_tile::HostTensor< T > *t, int block_bq_k)
Definition: tensor_shuffle_utils.hpp:28
bool is_gfx12_supported()
Definition: device_prop.hpp:63
auto shuffle_b(const ck_tile::HostTensor< T > &t)
Definition: tensor_shuffle_utils.hpp:71
auto shuffle_b_permuteN(const ck_tile::HostTensor< T > &t)
Definition: tensor_shuffle_utils.hpp:132
int32_t index_t
Definition: integer.hpp:9
auto shuffle_aq(const ck_tile::HostTensor< T > *t, int block_aq_k)
Definition: tensor_shuffle_utils.hpp:9
bool is_gfx11_supported()
Definition: device_prop.hpp:55
auto bq_permuteN(const ck_tile::HostTensor< T > &t, index_t group_n)
Definition: tensor_shuffle_utils.hpp:114
CK_TILE_HOST void reference_permute(const HostTensor< DataType > &x, HostTensor< DataType > &y, std::vector< index_t > perm)
Definition: reference_permute.hpp:19
Definition: host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
Data::iterator end()
Definition: host_tensor.hpp:588
Data::iterator begin()
Definition: host_tensor.hpp:586