include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp Source File

include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp Source File#

Composable Kernel: include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp Source File
gemm_pipeline_problem.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 
9 namespace ck_tile {
10 
11 template <typename ADataType_,
12  typename BDataType_,
13  typename CDataType_,
14  typename BlockGemmShape_,
15  typename Traits_>
17 {
19 
23 
25 
29 
30  static constexpr bool TransposeC = Traits::TransposeC;
31 
32  static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
33 
34  static constexpr bool kPadM = Traits::kPadM;
35  static constexpr bool kPadN = Traits::kPadN;
36  static constexpr bool kPadK = Traits::kPadK;
37 
38  static constexpr auto Scheduler = GemmPipelineScheduler::Default;
39 
40  static constexpr index_t VectorLoadSize = Traits::_VectorSize;
41  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
42  {
43  if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
44  {
45  constexpr index_t pixels_per_thread =
46  BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
47  return pixels_per_thread < VectorLoadSize / sizeof(ADataType)
48  ? pixels_per_thread
49  : VectorLoadSize / sizeof(ADataType);
50  }
51  else
52  {
53  return VectorLoadSize / sizeof(ADataType);
54  }
55  }
56 
57  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB()
58  {
59  if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
60  {
61  constexpr index_t pixels_per_thread =
62  BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;
63  return pixels_per_thread < VectorLoadSize / sizeof(BDataType)
64  ? pixels_per_thread
65  : VectorLoadSize / sizeof(BDataType);
66  }
67  else
68  {
69  return VectorLoadSize / sizeof(BDataType);
70  }
71  }
72 
73  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentC()
74  {
75  if constexpr(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
76  {
77  constexpr index_t N1 = kBlockSize / get_warp_size();
78  constexpr index_t N2 = std::min(BlockGemmShape::kN / N1, get_warp_size());
79  constexpr index_t M0 = get_warp_size() / N2;
80  constexpr index_t M1 = BlockGemmShape::kM / M0;
81 
82  return std::min(M1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
83  }
84  else
85  {
86  constexpr index_t M1 = kBlockSize / get_warp_size();
87  constexpr index_t M2 = std::min(BlockGemmShape::kM / M1, get_warp_size());
88  constexpr index_t N0 = get_warp_size() / M2;
89  constexpr index_t N1 = BlockGemmShape::kN / N0;
90 
91  return std::min(N1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
92  }
93  }
94 
95  static constexpr index_t VectorSizeA = []() {
96  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
97  {
98  return kPadK ? 1 : GetAlignmentA();
99  }
100  else
101  {
102  return kPadM ? 1 : GetAlignmentA();
103  }
104  }();
105 
106  static constexpr index_t VectorSizeB = []() {
107  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
108  {
109  return kPadN ? 1 : GetAlignmentB();
110  }
111  else
112  {
113  return kPadK ? 1 : GetAlignmentB();
114  }
115  }();
116  static constexpr index_t VectorSizeC = []() {
117  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
118  {
119  return kPadN ? 1 : GetAlignmentC();
120  }
121  else
122  {
123  return kPadM ? 1 : GetAlignmentC();
124  }
125  }();
126 };
127 
128 // Alias for GemmPipelineProblem
129 template <typename ADataType_,
130  typename BDataType_,
131  typename CDataType_,
132  typename BlockGemmShape_,
133  typename Traits_>
136 
137 template <typename ADataType_,
138  typename BDataType_,
139  typename CDataType_,
140  typename BlockGemmShape_,
141  typename Traits_,
143  bool HasHotLoop_ = true,
144  TailNumber TailNum_ = TailNumber::Full>
146 {
148 
152 
154 
158 
159  static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
160 
161  static constexpr bool kPadM = Traits::kPadM;
162  static constexpr bool kPadN = Traits::kPadN;
163  static constexpr bool kPadK = Traits::kPadK;
164 
165  static constexpr auto Scheduler = Scheduler_;
166  static constexpr auto HasHotLoop = HasHotLoop_;
167  static constexpr auto TailNum = TailNum_;
168 
169  static constexpr bool TransposeC = Traits::TransposeC;
170 };
171 
172 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE index_t get_warp_size()
Definition: arch.hpp:51
TailNumber
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:20
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
GemmPipelineScheduler
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:13
Definition: gemm_pipeline_problem.hpp:17
remove_cvref_t< ADataType_ > ADataType
Definition: gemm_pipeline_problem.hpp:20
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentC()
Definition: gemm_pipeline_problem.hpp:73
static constexpr auto Scheduler
Definition: gemm_pipeline_problem.hpp:38
remove_cvref_t< typename Traits::CLayout > CLayout
Definition: gemm_pipeline_problem.hpp:28
remove_cvref_t< CDataType_ > CDataType
Definition: gemm_pipeline_problem.hpp:22
remove_cvref_t< typename Traits::ALayout > ALayout
Definition: gemm_pipeline_problem.hpp:26
remove_cvref_t< typename Traits::BLayout > BLayout
Definition: gemm_pipeline_problem.hpp:27
remove_cvref_t< BlockGemmShape_ > BlockGemmShape
Definition: gemm_pipeline_problem.hpp:24
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentB()
Definition: gemm_pipeline_problem.hpp:57
static constexpr index_t VectorLoadSize
Definition: gemm_pipeline_problem.hpp:40
static constexpr index_t VectorSizeB
Definition: gemm_pipeline_problem.hpp:106
static constexpr index_t VectorSizeA
Definition: gemm_pipeline_problem.hpp:95
remove_cvref_t< Traits_ > Traits
Definition: gemm_pipeline_problem.hpp:18
remove_cvref_t< BDataType_ > BDataType
Definition: gemm_pipeline_problem.hpp:21
static constexpr bool kPadM
Definition: gemm_pipeline_problem.hpp:34
static constexpr index_t kBlockSize
Definition: gemm_pipeline_problem.hpp:32
static constexpr bool TransposeC
Definition: gemm_pipeline_problem.hpp:30
static constexpr bool kPadK
Definition: gemm_pipeline_problem.hpp:36
static constexpr index_t VectorSizeC
Definition: gemm_pipeline_problem.hpp:116
static constexpr bool kPadN
Definition: gemm_pipeline_problem.hpp:35
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentA()
Definition: gemm_pipeline_problem.hpp:41
Definition: gemm_pipeline_problem.hpp:146
remove_cvref_t< Traits_ > Traits
Definition: gemm_pipeline_problem.hpp:147
remove_cvref_t< typename Traits::BLayout > BLayout
Definition: gemm_pipeline_problem.hpp:156
static constexpr bool kPadN
Definition: gemm_pipeline_problem.hpp:162
static constexpr auto HasHotLoop
Definition: gemm_pipeline_problem.hpp:166
static constexpr auto Scheduler
Definition: gemm_pipeline_problem.hpp:165
static constexpr index_t kBlockSize
Definition: gemm_pipeline_problem.hpp:159
static constexpr auto TailNum
Definition: gemm_pipeline_problem.hpp:167
static constexpr bool kPadK
Definition: gemm_pipeline_problem.hpp:163
remove_cvref_t< typename Traits::ALayout > ALayout
Definition: gemm_pipeline_problem.hpp:155
remove_cvref_t< CDataType_ > CDataType
Definition: gemm_pipeline_problem.hpp:151
static constexpr bool kPadM
Definition: gemm_pipeline_problem.hpp:161
remove_cvref_t< typename Traits::CLayout > CLayout
Definition: gemm_pipeline_problem.hpp:157
remove_cvref_t< ADataType_ > ADataType
Definition: gemm_pipeline_problem.hpp:149
remove_cvref_t< BlockGemmShape_ > BlockGemmShape
Definition: gemm_pipeline_problem.hpp:153
remove_cvref_t< BDataType_ > BDataType
Definition: gemm_pipeline_problem.hpp:150
static constexpr bool TransposeC
Definition: gemm_pipeline_problem.hpp:169