/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp Source File
batched_transpose_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
10 #include <string>
11 #include <type_traits>
12 
13 namespace ck_tile {
14 
16 {
17  const void* p_input;
18  void* p_output;
25 };
26 
27 template <typename Pipeline_>
29 {
30 
34 
35  using Type = typename Problem::DataType;
36 
37  static constexpr index_t kBlockSize = Problem::kBlockSize;
38 
40  {
41  const void* p_input;
42  void* p_output;
47  };
48 
51 
52  CK_TILE_HOST static constexpr auto GridSize(const Hargs& host_args)
53  {
54  const size_t grid_size_x =
55  ck_tile::integer_divide_ceil(host_args.height, host_args.dim_block_h);
56  const size_t grid_size_y =
57  ck_tile::integer_divide_ceil(host_args.width, host_args.dim_block_w);
58  const size_t grid_size_z = host_args.batch;
59  return dim3(grid_size_x, grid_size_y, grid_size_z);
60  }
61 
62  CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
63  {
64  Kargs k;
65  k.p_input = h.p_input;
66  k.p_output = h.p_output;
67  k.batch = h.batch;
68  k.height = h.height;
69  k.width = h.width;
70  k.dim_stride = h.dim_stride;
71  return k;
72  }
73 
74  CK_TILE_HOST static constexpr auto BlockSize() { return Problem::kBlockSize; }
75 
76  CK_TILE_DEVICE void operator()(Kargs kargs) const
77  {
78  static constexpr ck_tile::index_t kMPerBlock = Problem::kMPerBlock;
79  static constexpr ck_tile::index_t kNPerBlock = Problem::kNPerBlock;
80  static constexpr bool kPadM = Problem::kPadM;
81  static constexpr bool kPadN = Problem::kPadN;
82  static constexpr ck_tile::index_t VectorSizeInput = Problem::VectorSizeInput;
83  static constexpr ck_tile::index_t VectorStrideInput = 1;
84  static constexpr ck_tile::index_t VectorSizeOutput = Problem::VectorSizeOutput;
85  static constexpr ck_tile::index_t VectorStrideOutput = 1;
86 
87  const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock);
88  const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kNPerBlock);
89  const auto offset = __builtin_amdgcn_readfirstlane(blockIdx.z * kargs.height * kargs.width);
90 
91  const auto x_m_n = [&]() {
92  const auto x_dram_naive = make_naive_tensor_view<address_space_enum::global>(
93  static_cast<const Type*>(kargs.p_input) + offset,
94  make_tuple(kargs.height, kargs.width),
95  make_tuple(kargs.width, 1),
98 
99  return pad_tensor_view(x_dram_naive,
102  }();
103 
104  const auto y_n_m = [&]() {
105  const auto y_dram_naive = make_naive_tensor_view<address_space_enum::global>(
106  static_cast<Type*>(kargs.p_output) + offset,
107  make_tuple(kargs.width, kargs.height),
108  make_tuple(kargs.height, 1),
111 
112  return pad_tensor_view(y_dram_naive,
115  }();
116 
117  auto x_block_window = make_tile_window(
118  x_m_n,
120  {static_cast<ck_tile::index_t>(iM), static_cast<ck_tile::index_t>(iN)});
121 
122  auto y_block_window = make_tile_window(
123  y_n_m,
125  {static_cast<ck_tile::index_t>(iN), static_cast<ck_tile::index_t>(iM)});
126 
127  Pipeline{}(x_block_window, y_block_window);
128  }
129 };
130 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
Definition: batched_transpose_kernel.hpp:16
index_t height
Definition: batched_transpose_kernel.hpp:20
index_t batch
Definition: batched_transpose_kernel.hpp:19
void * p_output
Definition: batched_transpose_kernel.hpp:18
index_t dim_block_w
Definition: batched_transpose_kernel.hpp:24
index_t dim_stride
Definition: batched_transpose_kernel.hpp:22
const void * p_input
Definition: batched_transpose_kernel.hpp:17
index_t width
Definition: batched_transpose_kernel.hpp:21
index_t dim_block_h
Definition: batched_transpose_kernel.hpp:23
Definition: batched_transpose_kernel.hpp:40
index_t width
Definition: batched_transpose_kernel.hpp:45
index_t height
Definition: batched_transpose_kernel.hpp:44
index_t dim_stride
Definition: batched_transpose_kernel.hpp:46
index_t batch
Definition: batched_transpose_kernel.hpp:43
const void * p_input
Definition: batched_transpose_kernel.hpp:41
void * p_output
Definition: batched_transpose_kernel.hpp:42
Definition: batched_transpose_kernel.hpp:29
static constexpr CK_TILE_HOST auto GridSize(const Hargs &host_args)
Definition: batched_transpose_kernel.hpp:52
remove_cvref_t< typename Pipeline::Problem > Problem
Definition: batched_transpose_kernel.hpp:33
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: batched_transpose_kernel.hpp:76
static constexpr CK_TILE_HOST auto BlockSize()
Definition: batched_transpose_kernel.hpp:74
static constexpr CK_TILE_HOST auto MakeKargs(const Hargs &h)
Definition: batched_transpose_kernel.hpp:62
typename Problem::DataType Type
Definition: batched_transpose_kernel.hpp:35
static CK_TILE_DEVICE index_t counter
Definition: batched_transpose_kernel.hpp:31
remove_cvref_t< Pipeline_ > Pipeline
Definition: batched_transpose_kernel.hpp:32
static constexpr index_t kBlockSize
Definition: batched_transpose_kernel.hpp:37
Definition: integral_constant.hpp:13
Definition: coordinate_transform.hpp:1392
Definition: sequence.hpp:49