/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_pipeline.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_pipeline.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_pipeline.hpp Source File
batched_transpose_lds_pipeline.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 namespace ck_tile {
7 
8 template <typename Problem_, typename Policy_>
10 {
13 
15 
16  static constexpr index_t kBlockSize = Problem::kBlockSize;
17  static constexpr index_t kLeadSizePerBlock = Problem::kLeadSizePerBlock;
18  static constexpr index_t kSecondSizePerBlock = Problem::kSecondSizePerBlock;
19 
20  static constexpr index_t GetVectorSize() { return Policy::template GetVectorSize<Problem>(); }
21 
22  CK_TILE_DEVICE static constexpr index_t GetSmemSize()
23  {
24  return Policy::template GetSmemSize<Problem>();
25  }
26 
27  template <typename InputTileWindow, typename OutputTileWindow>
28  CK_TILE_DEVICE void operator()(const InputTileWindow& input_window,
29  OutputTileWindow& output_window)
30  {
31  __shared__ char smem[GetSmemSize()];
32  auto input_tile_window =
33  make_tile_window(input_window, Policy::template MakeInputDistribution<Problem>());
34  auto output_tile_window =
35  make_tile_window(output_window, Policy::template MakeOutputDistribution<Problem>());
36 
37  DataType* p_lds_ptr = reinterpret_cast<DataType*>(smem);
38  constexpr auto in_lds_block_desc = Policy::template MakeLdsStoreBlockDescriptor<Problem>();
39  auto input_lds_block =
40  make_tensor_view<address_space_enum::lds>(p_lds_ptr, in_lds_block_desc);
41 
42  constexpr auto out_lds_block_desc = Policy::template MakeLdsLoadBlockDescriptor<Problem>();
43  auto output_lds_block =
44  make_tensor_view<address_space_enum::lds>(p_lds_ptr, out_lds_block_desc);
45 
46  auto copy_to_lds_window =
47  make_tile_window(input_lds_block,
49  {0, 0});
50  auto load_from_lds_window =
51  make_tile_window(output_lds_block,
53  {0, 0},
54  Policy::template MakeLdsLoadTileDistribution<Problem>());
55 
56  auto x = load_tile(input_tile_window);
57 
58  store_tile(copy_to_lds_window, x);
60 
61  auto y = load_tile_transpose(load_from_lds_window);
62 
63  store_tile(output_tile_window, y);
64  }
65 };
66 
67 } // namespace ck_tile
CK_TILE_DEVICE void block_sync_lds()
Definition: arch.hpp:190
#define CK_TILE_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
CK_TILE_DEVICE auto load_tile_transpose(const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window)
transpose loads tile from a tensor and returns the resulting tensor with a new (transposed) tile dist...
Definition: load_tile_transpose.hpp:403
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
Definition: batched_transpose_lds_pipeline.hpp:10
CK_TILE_DEVICE void operator()(const InputTileWindow &input_window, OutputTileWindow &output_window)
Definition: batched_transpose_lds_pipeline.hpp:28
static constexpr index_t kLeadSizePerBlock
Definition: batched_transpose_lds_pipeline.hpp:17
remove_cvref_t< typename Problem::DataType > DataType
Definition: batched_transpose_lds_pipeline.hpp:14
static constexpr index_t GetVectorSize()
Definition: batched_transpose_lds_pipeline.hpp:20
remove_cvref_t< Problem_ > Problem
Definition: batched_transpose_lds_pipeline.hpp:11
static constexpr index_t kBlockSize
Definition: batched_transpose_lds_pipeline.hpp:16
remove_cvref_t< Policy_ > Policy
Definition: batched_transpose_lds_pipeline.hpp:12
static constexpr CK_TILE_DEVICE index_t GetSmemSize()
Definition: batched_transpose_lds_pipeline.hpp:22
static constexpr index_t kSecondSizePerBlock
Definition: batched_transpose_lds_pipeline.hpp:18
Definition: integral_constant.hpp:13