/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp Source File
gemm_group_quant_utils.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck_tile {
9 
10 template <typename Problem, typename DataType, index_t YPerTile, index_t XPerTile>
11 CK_TILE_HOST_DEVICE static constexpr auto GetABQGlobalVectorLoadSize()
12 {
13  using I1 = number<1>;
14  constexpr index_t NWarps = Problem::BlockGemmShape::BlockWarps::at(I1{});
15 
16  constexpr index_t BlockSize = Problem::kBlockSize;
17 
18  // Data is replicated across warps along NWarps, so we divide BlockSize by NWarps
19  constexpr index_t elements_per_thread = (YPerTile * XPerTile) / (BlockSize / NWarps);
20  constexpr index_t PackedSize = ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
21 
22  // Define vector load candidates in descending order of priority
23  constexpr std::array<index_t, 5> candidates{
24  PackedSize * 32 / sizeof(DataType),
25  PackedSize * 16 / sizeof(DataType),
26  PackedSize * 8 / sizeof(DataType),
27  PackedSize * 4 / sizeof(DataType),
28  PackedSize * 2 / sizeof(DataType),
29  };
30 
31  for(const auto vec_size : candidates)
32  {
33  if(vec_size <= 0 || XPerTile % vec_size != 0 || elements_per_thread % vec_size != 0)
34  continue;
35  bool is_valid = (vec_size > 0) && (XPerTile % vec_size == 0) &&
36  (elements_per_thread % vec_size == 0) && vec_size != candidates[4];
37  if(is_valid)
38  {
39  return vec_size;
40  }
41  }
42  return PackedSize; // Absolute fallback
43 }
44 
45 // AQ holds groupquant scale data for A. Data is loaded from DRAM and partitioned across
46 // threads. Post mfma scales are shuffled across threads in the warp and applied to
47 // accum registers.
48 template <typename BlockGemmShape,
49  typename WarpGemm,
50  index_t BlockSize,
51  index_t YPerTile,
52  index_t XPerTile,
53  index_t KPerBlockAQ,
54  index_t VecSize,
55  bool PreshuffleQuant>
57 {
58  static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
59  static constexpr index_t warp_size = get_warp_size();
60  static constexpr index_t num_warps = BlockSize / get_warp_size();
61 
62  static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(number<0>{});
63  static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(number<1>{});
64  static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(number<2>{});
65 
66  static constexpr index_t MIterPerWarp = BlockGemmShape::kM / (MWarps * WarpGemm::kM);
67 
68  static_assert(num_warps == MWarps * NWarps * KWarps);
69 
70  // KWarps > 1 isn't supported
71  static_assert(KWarps == 1);
72 
74  {
75  if constexpr(PreshuffleQuant)
76  {
77  // # of elements per thread
78  static_assert(XPerTile >= warp_size && XPerTile % warp_size == 0);
79  constexpr index_t X1 = warp_size;
80  constexpr index_t X0 = XPerTile / warp_size;
81 
82  constexpr index_t Y1 = MWarps;
83  constexpr index_t Y0 = YPerTile / Y1;
90  sequence<0, 0>>{});
91  }
92  else
93  {
94  // # of elements per thread
95  constexpr index_t X = XPerTile;
96 
97  constexpr index_t Y0 = 1;
98  constexpr index_t Y1 = MIterPerWarp ? MIterPerWarp : 1;
99  constexpr index_t Y2 = MWarps;
100  constexpr index_t Y3 = WarpGemm::kM;
101  static_assert(Y3 >= WarpGemm::kM,
102  "Scales for all rows must be available within the warp.");
103  static_assert(Y0 * Y1 * Y2 * Y3 == YPerTile,
104  "Y0, Y1, Y2, Y3 must cover the blocktile along Y.");
111  sequence<1, 0>>{});
112  }
113  }
114 };
115 
116 template <typename BlockGemmShape,
117  typename WarpGemm,
118  index_t BlockSize,
119  index_t YPerTile,
120  index_t XPerTile,
121  index_t VecSize>
124 {
125  // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
126  static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
127  static constexpr index_t warp_size = get_warp_size();
128  static constexpr index_t num_warps = BlockSize / get_warp_size();
129 
130  static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(number<0>{});
131  static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(number<1>{});
132  static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(number<2>{});
133 
134  static constexpr index_t MIterPerWarp = BlockGemmShape::kM / (MWarps * WarpGemm::kM);
135 
136  static_assert(num_warps == MWarps * NWarps * KWarps);
137 
138  // KWarps > 1 isn't supported
139  static_assert(KWarps == 1);
140 
141  // # of elements per thread
142  static constexpr index_t X = XPerTile;
143  static constexpr index_t XR = 2;
144 
145  // Number of iters per warp
146  // MIters are indexed using (Y0, Y1)
147  static constexpr index_t Y0 = MIterPerWarp;
148 
149  // # of warps in Y dim
150  static constexpr index_t Y1 = MWarps;
151 
152  static constexpr index_t Y2 = WarpGemm::kM;
153 
154  static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y.");
155 
157  {
164  sequence<0, 0>>{});
165  }
166 };
167 
168 // TODO:: might need to update
169 template <typename BlockGemmShape,
170  typename WarpGemm,
171  index_t BlockSize,
172  index_t YPerTile,
173  index_t XPerTile,
174  index_t XPerQ>
176 {
177  static constexpr index_t warp_size = get_warp_size();
178  static constexpr index_t num_warps = BlockSize / get_warp_size();
179 
180  static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(number<0>{});
181  static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(number<1>{});
182  static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(number<2>{});
183 
184  static constexpr index_t NIterPerWarp = BlockGemmShape::kN / (NWarps * WarpGemm::kN);
185 
186  static_assert(num_warps == MWarps * NWarps * KWarps);
187  static_assert(KWarps == 1);
188 
215  {
216  if constexpr(XPerQ < WarpGemm::kN)
217  {
218  // Case 1: Fine-grained - multiple quantization scales within a single warp
219  constexpr index_t Y = YPerTile; // Full Y dimension of tile
220  constexpr index_t YR = 1; // No Y replication needed
221  constexpr index_t X0 = NIterPerWarp; // Iterations per warp in N-dim
222  constexpr index_t X1 = NWarps; // Number of warps in N-dim
223  constexpr index_t X2 = WarpGemm::kN / XPerQ; // Number of scales per warp
224  constexpr index_t XR = XPerQ; // Elements per quantization group
225 
226  static_assert(X0 * X1 * X2 == XPerTile, "X0, X1, X2 must cover the blocktile along X.");
227 
234  sequence<0, 0>>{});
235  }
236  else if constexpr(XPerQ <= WarpGemm::kN * NWarps)
237  {
238  // Case 2: Medium-grained - one quantization scale per warp
239  constexpr auto XR = XPerQ / WarpGemm::kN; // Scale replication factor
240  constexpr auto X1 = NWarps / XR; // Warps per unique scale
241  constexpr auto X0 = XPerTile / X1; // Iterations to cover X dimension
248  sequence<0, 0>>{});
249  }
250  else // XPerQ > WarpGemm::kN * NWarps
251  {
252  // Case 3: Coarse-grained - quantization group spans all warps
253  // All warps in N-dimension share the same quantization scale
260  sequence<0, 0>>{});
261  }
262  }
263 };
264 
265 template <typename GroupSizes>
267 {
268  static constexpr index_t kM = GroupSizes::at(number<0>{});
269  static constexpr index_t kN = GroupSizes::at(number<1>{});
270  static constexpr index_t kK = GroupSizes::at(number<2>{});
271 
272  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
273  {
274  return concat('_', "quant_group_shape", concat('x', kM, kN, kK));
275  }
276 };
277 
278 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
Definition: gemm_group_quant_utils.hpp:267
static constexpr index_t kM
Definition: gemm_group_quant_utils.hpp:268
static constexpr index_t kK
Definition: gemm_group_quant_utils.hpp:270
static constexpr index_t kN
Definition: gemm_group_quant_utils.hpp:269
static CK_TILE_HOST const std::string GetName()
Definition: gemm_group_quant_utils.hpp:272
Definition: integral_constant.hpp:13
Definition: numeric.hpp:81
Definition: sequence.hpp:49
Definition: gemm_group_quant_utils.hpp:124
static constexpr index_t NWarps
Definition: gemm_group_quant_utils.hpp:131
static constexpr CK_TILE_HOST_DEVICE auto make_2d_static_tile_distribution()
Definition: gemm_group_quant_utils.hpp:156
static constexpr index_t MWarps
Definition: gemm_group_quant_utils.hpp:130
static constexpr index_t MIterPerWarp
Definition: gemm_group_quant_utils.hpp:134
static constexpr index_t X
Definition: gemm_group_quant_utils.hpp:142
static constexpr index_t KWarps
Definition: gemm_group_quant_utils.hpp:132
static constexpr index_t Y0
Definition: gemm_group_quant_utils.hpp:147
static constexpr index_t num_warps
Definition: gemm_group_quant_utils.hpp:128
static constexpr index_t Y2
Definition: gemm_group_quant_utils.hpp:152
static constexpr index_t Y1
Definition: gemm_group_quant_utils.hpp:150
static constexpr index_t XR
Definition: gemm_group_quant_utils.hpp:143
static constexpr index_t warp_size
Definition: gemm_group_quant_utils.hpp:127
Definition: gemm_group_quant_utils.hpp:57
static constexpr index_t KWarps
Definition: gemm_group_quant_utils.hpp:64
static constexpr index_t MWarps
Definition: gemm_group_quant_utils.hpp:62
static constexpr CK_TILE_HOST_DEVICE auto make_2d_static_tile_distribution()
Definition: gemm_group_quant_utils.hpp:73
static constexpr index_t warp_size
Definition: gemm_group_quant_utils.hpp:59
static constexpr index_t NWarps
Definition: gemm_group_quant_utils.hpp:63
static constexpr index_t num_warps
Definition: gemm_group_quant_utils.hpp:60
static constexpr index_t MIterPerWarp
Definition: gemm_group_quant_utils.hpp:66
Definition: gemm_group_quant_utils.hpp:176
static constexpr CK_TILE_HOST_DEVICE auto make_2d_static_tile_distribution()
Creates a 2D tile distribution for BQ (B-matrix quantization scales)
Definition: gemm_group_quant_utils.hpp:214
static constexpr index_t num_warps
Definition: gemm_group_quant_utils.hpp:178
static constexpr index_t NWarps
Definition: gemm_group_quant_utils.hpp:181
static constexpr index_t warp_size
Definition: gemm_group_quant_utils.hpp:177
static constexpr index_t MWarps
Definition: gemm_group_quant_utils.hpp:180
static constexpr index_t KWarps
Definition: gemm_group_quant_utils.hpp:182
static constexpr index_t NIterPerWarp
Definition: gemm_group_quant_utils.hpp:184
Definition: static_encoding_pattern.hpp:108
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192