include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp Source File

include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp Source File#

Composable Kernel: include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp Source File
batched_transpose_problem.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include <string>
8 #include <type_traits>
9 
10 #define VectorLoadSize 16
11 
12 namespace ck_tile {
13 
14 template <typename InputType_,
15  typename BlockTile, // Sequence<...
16  typename WarpTile, // Sequence<...
17  typename ThreadTile, // Sequence<...
18  bool kPadM_ = true,
19  bool kPadN_ = true>
21 {
23 
24  static constexpr index_t kMPerThread = ThreadTile::at(number<0>{});
25  static constexpr index_t kNPerThread = ThreadTile::at(number<1>{});
26 
27  static constexpr index_t kMPerWarp = WarpTile::at(number<0>{});
28  static constexpr index_t kNPerWarp = WarpTile::at(number<1>{});
29 
32 
33  static constexpr index_t kMPerBlock = BlockTile::at(number<0>{});
34  static constexpr index_t kNPerBlock = BlockTile::at(number<1>{});
35 
38 
39  static constexpr index_t kBlockSize =
41 
42  static constexpr bool kPadM = kPadM_;
43  static constexpr bool kPadN = kPadN_;
44 
45  static constexpr index_t AlignmentM = kPadM ? VectorLoadSize / sizeof(InputType) : 1; // TODO
46  static constexpr index_t AlignmentN = kPadN ? VectorLoadSize / sizeof(InputType) : 1;
47 };
48 } // namespace ck_tile
#define VectorLoadSize
Definition: batched_transpose_problem.hpp:10
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
Definition: batched_transpose_problem.hpp:21
remove_cvref_t< InputType_ > InputType
Definition: batched_transpose_problem.hpp:22
static constexpr index_t kNPerBlock
Definition: batched_transpose_problem.hpp:34
static constexpr index_t kMThreadPerWarp
Definition: batched_transpose_problem.hpp:30
static constexpr index_t kNPerThread
Definition: batched_transpose_problem.hpp:25
static constexpr index_t kMPerWarp
Definition: batched_transpose_problem.hpp:27
static constexpr index_t kNPerWarp
Definition: batched_transpose_problem.hpp:28
static constexpr index_t kMPerBlock
Definition: batched_transpose_problem.hpp:33
static constexpr index_t kNThreadPerWarp
Definition: batched_transpose_problem.hpp:31
static constexpr index_t kMPerThread
Definition: batched_transpose_problem.hpp:24
static constexpr bool kPadM
Definition: batched_transpose_problem.hpp:42
static constexpr index_t AlignmentN
Definition: batched_transpose_problem.hpp:46
static constexpr index_t kNWarpPerBlock
Definition: batched_transpose_problem.hpp:37
static constexpr index_t AlignmentM
Definition: batched_transpose_problem.hpp:45
static constexpr bool kPadN
Definition: batched_transpose_problem.hpp:43
static constexpr index_t kMWarpPerBlock
Definition: batched_transpose_problem.hpp:36
static constexpr index_t kBlockSize
Definition: batched_transpose_problem.hpp:39
Definition: integral_constant.hpp:13