/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/permute/kernel/generic_permute_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/permute/kernel/generic_permute_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/permute/kernel/generic_permute_kernel.hpp Source File
generic_permute_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
8 // #include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp"
9 
10 namespace ck_tile {
11 
12 /* independent host side argument, no template
13  */
15 {
16  static constexpr index_t kMaxRanks = 8; // TODO: hardcoded
17 
18  const void* p_src;
19  void* p_dst;
21  index_t shape[kMaxRanks]; // input shape
22  index_t perm[kMaxRanks]; // permute index
23 };
24 
25 /*
26 simulate torch.permute:
27 x_ = x_.view(x.shape[0],
28  x.shape[1]//16, 16,
29  x.shape[2]//32, 4, 8)
30 x_ = x_.permute(0,1,3,4,2,5)
31 x_ = x_.contiguous()
32 x_ = x_.view(x.shape[0], x.shape[1], x.shape[2]);//
33 
34 this kernel is supposed not to be performant(just OK), with functional support up to kMaxRanks
35 dim of permutation, with a single kernel
36 
37 */
38 template <typename Problem_>
40 {
42 
44  static constexpr index_t kBlockSize = Problem::kBlockSize;
45  static constexpr index_t kMaxRanks = Problem::kMaxRanks;
46  static constexpr bool KeepLastDim = Problem::KeepLastDim;
47 
48  struct __attribute__((packed)) Kargs
49  {
50  const void* p_src;
51  void* p_dst;
52  // index_t rank;
54  index_t perm_length[kMaxRanks]; // tensor length after permutation
55  index_t perm_stride[kMaxRanks]; // tensor stride after permutation
56  };
57 
59  {
60  index_t n = 1;
61  for(auto i = 0; i < h.rank; i++)
62  {
63  n *= h.shape[i];
64  }
65  return n;
66  }
67 
69  {
70  Kargs a;
71  a.p_src = h.p_src;
72  a.p_dst = h.p_dst;
73 
74  // assert rank <= kMaxRanks
75  index_t i = 0;
76 
77  index_t perm[kMaxRanks];
78  index_t x_shape[kMaxRanks];
79  index_t x_stride[kMaxRanks];
80  // index_t perm_length[kMaxRanks];
81 
82  for(; i < h.rank; i++)
83  {
84  x_shape[i] = h.shape[i];
85  perm[i] = h.perm[i];
86  }
87  for(; i < kMaxRanks; i++)
88  {
89  x_shape[i] = 1;
90  perm[i] = i; // will index to len = 1
91  }
92 
93  index_t stride = 1;
94  for(index_t j = kMaxRanks - 1; j >= 0; j--)
95  {
96  x_stride[j] = stride;
97  stride *= x_shape[j];
98  }
99 
100  for(index_t j = 0; j < kMaxRanks; j++)
101  {
102  a.perm_length[j] = x_shape[perm[j]];
103  a.perm_stride[j] = x_stride[perm[j]];
104  }
105 
106  a.num_elements = TotalElements(h);
107  return a;
108  }
109 
111  {
112  auto total = TotalElements(h);
113  auto grids = dim3((total + BlockSize() - 1) / BlockSize());
114  // printf("### total:%d, grids:%dx%dx%d\n", total, );
115  return grids;
116  }
117 
118  CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::kBlockSize; }
119 
120  CK_TILE_DEVICE void operator()(Kargs kargs) const
121  {
122  index_t id = blockIdx.x * BlockSize() + threadIdx.x;
123 
124  if(id >= kargs.num_elements)
125  return;
126 
127  const auto perm_length =
128  generate_tuple([&](auto I) { return kargs.perm_length[I]; }, number<kMaxRanks>{});
129  const auto perm_stride =
130  generate_tuple([&](auto I) { return kargs.perm_stride[I]; }, number<kMaxRanks>{});
131 
132  const DataType* p_src = reinterpret_cast<const DataType*>(kargs.p_src);
133  DataType* p_dst = reinterpret_cast<DataType*>(kargs.p_dst);
134 
135  const auto src_view_0 = make_naive_tensor_view<address_space_enum::global>(
136  p_src, perm_length, perm_stride, number<1>{}, number<1>{});
137 
138  const auto src_view = transform_tensor_view(
139  src_view_0,
140  make_tuple(make_merge_transform(perm_length)),
143 
144  auto dst_view_0 = make_naive_tensor_view_packed<address_space_enum::global>(
145  p_dst, perm_length, number<1>{});
146 
147  auto dst_view = transform_tensor_view(
148  dst_view_0,
149  make_tuple(make_merge_transform(perm_length)),
152 
153  // TODO: hard code to vector 1
154  using vector_t = thread_buffer<DataType, 1>;
155 
156  const auto src_coord =
157  make_tensor_coordinate(src_view.get_tensor_descriptor(), array<index_t, 1>{id});
158  const auto dst_coord =
159  make_tensor_coordinate(dst_view.get_tensor_descriptor(), array<index_t, 1>{id});
160 
161  // printf("src id:%d, os:%d\n", id, src_coord.get_offset());
162  // printf("dst id:%d, os:%d\n", id, dst_coord.get_offset());
163 
164  const vector_t x = src_view.template get_vectorized_elements<vector_t>(src_coord, 0);
165  dst_view.template set_vectorized_elements<vector_t>(dst_coord, 0, x);
166  }
167 };
168 
169 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:511
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1615
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_tensor_coordinate(const TensorDesc &tensor_desc, const TopIndex &idx_top)
Definition: tensor_coordinate.hpp:60
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1249
Definition: generic_permute_kernel.hpp:49
const void * p_src
Definition: generic_permute_kernel.hpp:50
index_t perm_stride[kMaxRanks]
Definition: generic_permute_kernel.hpp:55
index_t perm_length[kMaxRanks]
Definition: generic_permute_kernel.hpp:54
index_t num_elements
Definition: generic_permute_kernel.hpp:53
void * p_dst
Definition: generic_permute_kernel.hpp:51
Definition: generic_permute_kernel.hpp:15
const void * p_src
Definition: generic_permute_kernel.hpp:18
index_t perm[kMaxRanks]
Definition: generic_permute_kernel.hpp:22
void * p_dst
Definition: generic_permute_kernel.hpp:19
index_t shape[kMaxRanks]
Definition: generic_permute_kernel.hpp:21
static constexpr index_t kMaxRanks
Definition: generic_permute_kernel.hpp:16
index_t rank
Definition: generic_permute_kernel.hpp:20
Definition: generic_permute_kernel.hpp:40
static constexpr CK_TILE_HOST auto GridSize(GenericPermuteHostArgs h)
Definition: generic_permute_kernel.hpp:110
static constexpr index_t kBlockSize
Definition: generic_permute_kernel.hpp:44
static constexpr index_t kMaxRanks
Definition: generic_permute_kernel.hpp:45
remove_cvref_t< typename Problem::DataType > DataType
Definition: generic_permute_kernel.hpp:43
static constexpr CK_TILE_HOST index_t TotalElements(const GenericPermuteHostArgs &h)
Definition: generic_permute_kernel.hpp:58
static constexpr CK_TILE_HOST_DEVICE auto BlockSize()
Definition: generic_permute_kernel.hpp:118
static constexpr CK_TILE_HOST Kargs MakeKargs(const GenericPermuteHostArgs &h)
Definition: generic_permute_kernel.hpp:68
static constexpr bool KeepLastDim
Definition: generic_permute_kernel.hpp:46
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: generic_permute_kernel.hpp:120
ck_tile::remove_cvref_t< Problem_ > Problem
Definition: generic_permute_kernel.hpp:41
typename std::conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:299
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49
Definition: debug.hpp:67