include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp Source File

include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp Source File#

Composable Kernel: include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp Source File
batched_transpose_pipeline.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 #include <string>
9 #include <type_traits>
10 
11 namespace ck_tile {
12 
13 template <typename Problem_, typename Policy_ = BatchedTransposePolicy>
15 {
16  // TODO: this kernel only support warp per row
20  static constexpr ck_tile::index_t kMPerBlock = Problem::kMPerBlock;
21  static constexpr ck_tile::index_t kNPerBlock = Problem::kNPerBlock;
22  static constexpr index_t AlignmentM = Problem::AlignmentM;
23  static constexpr index_t AlignmentN = Problem::AlignmentN;
24  static constexpr bool kPadM = Problem::kPadM;
25  static constexpr bool kPadN = Problem::kPadN;
26 
27  template <typename InputWindow, typename OutputWindow>
28  CK_TILE_DEVICE auto operator()(const InputWindow& input_window, OutputWindow& out_window)
29  {
30  auto inp_win =
31  make_tile_window(input_window, Policy::template MakeInputDistribution<Problem>());
32  auto out_win =
33  make_tile_window(out_window, Policy::template MakeOutputDistribution<Problem>());
34 
35  auto x = load_tile(inp_win); // x->thread input_win->block
36 
37  auto y = make_static_distributed_tensor<InputType>(
38  Policy::template MakeOutputDistribution<Problem>());
39 
40  constexpr auto span_2d_x = decltype(x)::get_distributed_spans();
41 
42  sweep_tile_span(span_2d_x[number<0>{}], [&](auto idx0) {
43  sweep_tile_span(span_2d_x[number<1>{}], [&](auto idx1) {
44  constexpr auto i_j_idx = make_tuple(idx1, idx0);
45  y(i_j_idx) = x(i_j_idx);
46  });
47  });
48 
49  store_tile(out_win, y);
50  }
51 };
52 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:27
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
Definition: batched_transpose_pipeline.hpp:15
remove_cvref_t< Problem_ > Problem
Definition: batched_transpose_pipeline.hpp:17
static constexpr ck_tile::index_t kMPerBlock
Definition: batched_transpose_pipeline.hpp:20
static constexpr index_t AlignmentN
Definition: batched_transpose_pipeline.hpp:23
static constexpr ck_tile::index_t kNPerBlock
Definition: batched_transpose_pipeline.hpp:21
CK_TILE_DEVICE auto operator()(const InputWindow &input_window, OutputWindow &out_window)
Definition: batched_transpose_pipeline.hpp:28
static constexpr index_t AlignmentM
Definition: batched_transpose_pipeline.hpp:22
static constexpr bool kPadM
Definition: batched_transpose_pipeline.hpp:24
ck_tile::remove_cvref_t< typename Problem::InputType > InputType
Definition: batched_transpose_pipeline.hpp:19
remove_cvref_t< Policy_ > Policy
Definition: batched_transpose_pipeline.hpp:18
static constexpr bool kPadN
Definition: batched_transpose_pipeline.hpp:25
Definition: integral_constant.hpp:13