/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_policy.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_policy.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_policy.hpp Source File
batched_transpose_lds_policy.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 
9 namespace ck_tile {
10 
12 {
13  template <typename Problem>
14  CK_TILE_DEVICE static constexpr index_t GetSmemSize()
15  {
17  sizeof(typename Problem::DataType) *
18  MakeLdsStoreBlockDescriptor<Problem>().get_element_space_size(),
19  16);
20  }
21 
22  template <typename Problem>
23  CK_TILE_DEVICE static constexpr auto MakeOutputDistribution()
24  {
25  constexpr auto input_dstr = MakeLdsLoadTileDistribution<Problem>();
26 
27  using OutTileDstrEncode =
28  typename OutputTileDistributionTraits<typename decltype(input_dstr)::DstrEncode,
29  typename Problem::DataType>::TransposedDstrEncode;
30  constexpr auto block_dstr = make_static_tile_distribution(OutTileDstrEncode{});
31 
32  return block_dstr;
33  }
34 
35  template <typename Problem>
37  {
38  constexpr index_t kLeadDimPerBlock = Problem::kLeadSizePerBlock;
39  constexpr index_t kSecondDimPerBlock = Problem::kSecondSizePerBlock;
40  constexpr index_t kVectorSize = Problem::LDSVectorSize;
41 
42  constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
44  number<kLeadDimPerBlock / kVectorSize>{},
48  number<1>{});
49 
50  constexpr auto lds_block_desc = transform_tensor_descriptor(
51  lds_block_desc_0,
57 
58  return lds_block_desc;
59  }
60 
61  template <typename Problem>
63  {
64  constexpr index_t kLeadDimPerBlock = Problem::kLeadSizePerBlock;
65  constexpr index_t kSecondDimPerBlock = Problem::kSecondSizePerBlock;
66  constexpr index_t kVectorSize = Problem::LDSVectorSize;
67 
68  constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
70  number<kLeadDimPerBlock / kVectorSize>{},
74  number<1>{});
75 
76  constexpr auto lds_block_desc = transform_tensor_descriptor(
77  lds_block_desc_0,
83 
84  return lds_block_desc;
85  }
86 
87  template <typename Problem>
89  {
90  using DataType = typename Problem::DataType;
91 
92  // Calculate block-level dimensions
93  constexpr index_t kLeadIterPerWarp = 1;
94  constexpr index_t kSecondIterPerWarp = 1;
95  constexpr index_t kLeadNumWarps = Problem::kLeadNumWarps;
96  constexpr index_t kSecondNumWarps = Problem::kSecondNumWarps;
97 
98  // Calculate repetitions of base pattern
99  constexpr index_t kLeadRepetitions = Problem::kQuadNumPerLeadDim;
100  constexpr index_t kSecondRepetitions = Problem::kQuadNumPerSecondDim;
101  constexpr index_t kSecondDimIterations = Problem::kIterationsInSecondDim;
102  constexpr index_t kSecondDimStrSub = kSecondRepetitions / kSecondDimIterations;
103 
104  constexpr index_t kLaneGroupSize = 16;
105  constexpr auto xdllevel_dstr_encoding = make_transposed_distr_encode<DataType,
106  kLaneGroupSize,
107  kSecondDimStrSub,
108  kSecondDimIterations,
109  kLeadRepetitions,
110  1>();
111 
112  constexpr auto input_tile_encode =
113  InputTileDistributionEncoding<decltype(xdllevel_dstr_encoding),
114  kLeadIterPerWarp,
115  kSecondIterPerWarp,
116  kLeadNumWarps,
117  kSecondNumWarps>();
118  constexpr auto block_dstr = make_static_tile_distribution(input_tile_encode);
119  return block_dstr;
120  }
121 };
122 
123 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:268
constexpr CK_TILE_HOST_DEVICE auto integer_least_multiple(X x, Y y)
Definition: math.hpp:155
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1615
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:197
constexpr CK_TILE_HOST_DEVICE auto InputTileDistributionEncoding()
Definition: load_tile_transpose.hpp:351
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_DEVICE auto make_transposed_distr_encode()
Definition: amd_transpose_load_encoding.hpp:82
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
Definition: batched_transpose_common_policy.hpp:11
Definition: batched_transpose_lds_policy.hpp:12
static constexpr CK_TILE_DEVICE auto MakeLdsStoreBlockDescriptor()
Definition: batched_transpose_lds_policy.hpp:36
static constexpr CK_TILE_DEVICE auto MakeLdsLoadTileDistribution()
Definition: batched_transpose_lds_policy.hpp:88
static constexpr CK_TILE_DEVICE auto MakeOutputDistribution()
Definition: batched_transpose_lds_policy.hpp:23
static constexpr CK_TILE_DEVICE index_t GetSmemSize()
Definition: batched_transpose_lds_policy.hpp:14
static constexpr CK_TILE_DEVICE auto MakeLdsLoadBlockDescriptor()
Definition: batched_transpose_lds_policy.hpp:62
Definition: load_tile_transpose.hpp:207
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49