/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp Source File
gemm_group_quant_utils.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
7 
8 namespace ck_tile {
9 
10 template <typename Problem, typename DataType, index_t YPerTile, index_t XPerTile>
11 CK_TILE_HOST_DEVICE static constexpr auto GetABQGlobalVectorLoadSize()
12 {
13  using I1 = number<1>;
14  constexpr index_t NWarps = Problem::BlockGemmShape::BlockWarps::at(I1{});
15 
16  constexpr index_t BlockSize = Problem::kBlockSize;
17 
18  // Data is replicated across warps along NWarps, so we divide BlockSize by NWarps
19  constexpr index_t elements_per_thread = (YPerTile * XPerTile) / (BlockSize / NWarps);
20  constexpr index_t PackedSize = ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
21 
22  // Define vector load candidates in descending order of priority
23  constexpr std::array<index_t, 5> candidates{
24  PackedSize * 32 / sizeof(DataType),
25  PackedSize * 16 / sizeof(DataType),
26  PackedSize * 8 / sizeof(DataType),
27  PackedSize * 4 / sizeof(DataType),
28  PackedSize * 2 / sizeof(DataType),
29  };
30 
31  for(const auto vec_size : candidates)
32  {
33  if(vec_size <= 0 || XPerTile % vec_size != 0 || elements_per_thread % vec_size != 0)
34  continue;
35  bool is_valid = (vec_size > 0) && (XPerTile % vec_size == 0) &&
36  (elements_per_thread % vec_size == 0) && vec_size != candidates[4];
37  if(is_valid)
38  {
39  return vec_size;
40  }
41  }
42  return PackedSize; // Absolute fallback
43 }
44 
45 // AQ holds groupquant scale data for A. Data is loaded from DRAM and partitioned across
46 // threads. Post mfma scales are shuffled across threads in the warp and applied to
47 // accum registers.
48 template <typename BlockGemmShape,
49  typename WarpGemm,
50  index_t BlockSize,
51  index_t YPerTile,
52  index_t XPerTile,
53  index_t KPerBlockAQ,
54  index_t VecSize,
55  bool PreshuffleQuant>
57 {
58  static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
59  static constexpr index_t warp_size = get_warp_size();
60  static constexpr index_t num_warps = BlockSize / get_warp_size();
61 
62  static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(number<0>{});
63  static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(number<1>{});
64  static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(number<2>{});
65 
66  static constexpr index_t MIterPerWarp = BlockGemmShape::kM / (MWarps * WarpGemm::kM);
67 
68  static_assert(num_warps == MWarps * NWarps * KWarps);
69 
70  // KWarps > 1 isn't supported
71  static_assert(KWarps == 1);
72 
74  {
75  if constexpr(PreshuffleQuant)
76  {
77  // # of elements per thread
78  static_assert(XPerTile >= warp_size && XPerTile % warp_size == 0);
79  constexpr index_t X1 = warp_size;
80  constexpr index_t X0 = XPerTile / warp_size;
81 
82  constexpr index_t Y1 = MWarps;
83  constexpr index_t Y0 = YPerTile / Y1;
90  sequence<0, 0>>{});
91  }
92  else
93  {
94  // # of elements per thread
95  constexpr index_t X = XPerTile;
96 
97  constexpr index_t YR = 1;
98  constexpr index_t Y0 = MIterPerWarp ? MIterPerWarp : 1;
99  constexpr index_t Y1 = MWarps;
100  constexpr index_t Y2 = WarpGemm::kM;
101  static_assert(Y2 >= WarpGemm::kM,
102  "Scales for all rows must be available within the warp.");
103  static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y.");
110  sequence<0, 0>>{});
111  }
112  }
114  {
115 
116  constexpr index_t Y0 = YPerTile;
117  constexpr index_t X0 = 1;
118  constexpr index_t X1 = MIterPerWarp ? MIterPerWarp : 1;
119  constexpr index_t X2 = MWarps;
120  constexpr index_t X3 = WarpGemm::kM;
121 
122  static_assert(X3 >= WarpGemm::kM, "Scales for all rows must be available within the warp.");
123  static_assert(X0 * X1 * X2 * X3 == XPerTile,
124  "X0, X1, X2, X3 must cover the blocktile along X.");
125 
132  sequence<1, 0>>{});
133  }
134 };
135 
136 template <typename BlockGemmShape,
137  typename WarpGemm,
138  index_t BlockSize,
139  index_t YPerTile,
140  index_t XPerTile,
141  index_t VecSize>
144 {
145  // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
146  static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
147  static constexpr index_t warp_size = get_warp_size();
148  static constexpr index_t num_warps = BlockSize / get_warp_size();
149 
150  static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(number<0>{});
151  static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(number<1>{});
152  static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(number<2>{});
153 
154  static constexpr index_t MIterPerWarp = BlockGemmShape::kM / (MWarps * WarpGemm::kM);
155 
156  static_assert(num_warps == MWarps * NWarps * KWarps);
157 
158  // KWarps > 1 isn't supported
159  static_assert(KWarps == 1);
160 
161  // # of elements per thread
162  static constexpr index_t X = XPerTile;
163  static constexpr index_t XR = 2;
164 
165  // Number of iters per warp
166  // MIters are indexed using (Y0, Y1)
167  static constexpr index_t Y0 = MIterPerWarp;
168 
169  // # of warps in Y dim
170  static constexpr index_t Y1 = MWarps;
171 
172  static constexpr index_t Y2 = WarpGemm::kM;
173 
174  static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y.");
175 
177  {
184  sequence<0, 0>>{});
185  }
186 };
187 
188 // TODO:: might need to update
189 template <typename BlockGemmShape,
190  typename WarpGemm,
191  index_t BlockSize,
192  index_t KPerTile,
193  index_t NPerTile,
194  index_t NPerQ,
195  typename BQLayout = tensor_layout::gemm::ColumnMajor,
196  bool PreshuffleQuant = false>
198 {
199  static constexpr index_t warp_size = get_warp_size();
200  static constexpr index_t num_warps = BlockSize / get_warp_size();
201 
202  static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(number<0>{});
203  static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(number<1>{});
204  static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(number<2>{});
205 
206  static constexpr index_t NIterPerWarp = BlockGemmShape::kN / (NWarps * WarpGemm::kN);
207 
208  static_assert(num_warps == MWarps * NWarps * KWarps);
209  static_assert(KWarps == 1);
210 
237  {
238  // Preshuffle only supported for ColumnMajor currently
239  static_assert(!(PreshuffleQuant && std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>),
240  "PreshuffleQuant only supported for ColumnMajor BQLayout");
241 
242  if constexpr(PreshuffleQuant)
243  {
244  // ColumnMajor only for preshuffle
245  constexpr index_t X1 = warp_size;
246  constexpr index_t X0 = NPerTile / warp_size;
247  constexpr index_t Y1 = NWarps;
248  constexpr index_t Y0 = KPerTile / Y1;
249 
256  sequence<0, 0>>{});
257  }
258  else
259  {
260  if constexpr(NPerQ < WarpGemm::kN)
261  {
262  // Case 1: Fine-grained - multiple quantization scales within a single warp
263  // N dimension needs to be partitioned the same way regardless of layout
264  constexpr index_t NR = 1; // No N replication needed
265  constexpr index_t N0 = NIterPerWarp; // Iterations per warp in N-dim
266  constexpr index_t N1 = NWarps; // Number of warps in N-dim
267  constexpr index_t N2 = WarpGemm::kN / NPerQ; // Number of scales per warp
268 
269  static_assert(N0 * N1 * N2 == NPerTile,
270  "N0, N1, N2 must cover the blocktile along N dimension.");
271 
272  if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>)
273  {
274  // ColumnMajor: [(N0, N1, N2), K] - N on Y-axis, partition Y
281  sequence<0, 0>>{});
282  }
283  else
284  {
285  // RowMajor: [K, (N0, N1, N2)] - N on X-axis, partition X
292  sequence<0, 0>>{});
293  }
294  }
295  else if constexpr(NPerQ <= WarpGemm::kN * NWarps)
296  {
297  // Case 2: Medium-grained - one quantization scale per warp
298  constexpr auto NR = NPerQ / WarpGemm::kN; // Scale replication factor
299  constexpr auto N1 = NWarps / NR; // Warps per unique scale
300  constexpr auto N0 = NPerTile / N1; // Iterations to cover N dimension
301 
302  if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>)
303  {
304  // ColumnMajor: [(N0, N1), K] - N on Y-axis
311  sequence<0, 0>>{});
312  }
313  else
314  {
315  // RowMajor: [K, (N0, N1)] - N on X-axis
322  sequence<0, 0>>{});
323  }
324  }
325  else // NPerQ > WarpGemm::kN * NWarps
326  {
327  // Case 3: Coarse-grained - quantization group spans all warps
328  // All warps in N-dimension share the same quantization scale
329  if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>)
330  {
331  // ColumnMajor: [N, K]
338  sequence<0, 0>>{});
339  }
340  else
341  {
342  // RowMajor: [K, N]
349  sequence<0, 0>>{});
350  }
351  }
352  }
353  }
354 };
355 
356 template <typename GroupSizes>
358 {
359  static constexpr index_t kM = GroupSizes::at(number<0>{});
360  static constexpr index_t kN = GroupSizes::at(number<1>{});
361  static constexpr index_t kK = GroupSizes::at(number<2>{});
362 
363  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
364  {
365  return concat('_', "quant_group_shape", concat('x', kM, kN, kK));
366  }
367 };
368 
369 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:495
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
Definition: gemm_group_quant_utils.hpp:358
static constexpr index_t kM
Definition: gemm_group_quant_utils.hpp:359
static constexpr index_t kK
Definition: gemm_group_quant_utils.hpp:361
static constexpr index_t kN
Definition: gemm_group_quant_utils.hpp:360
static CK_TILE_HOST const std::string GetName()
Definition: gemm_group_quant_utils.hpp:363
Definition: integral_constant.hpp:13
Definition: numeric.hpp:81
Definition: sequence.hpp:49
Definition: gemm_group_quant_utils.hpp:144
static constexpr index_t NWarps
Definition: gemm_group_quant_utils.hpp:151
static constexpr CK_TILE_HOST_DEVICE auto make_2d_static_tile_distribution()
Definition: gemm_group_quant_utils.hpp:176
static constexpr index_t MWarps
Definition: gemm_group_quant_utils.hpp:150
static constexpr index_t MIterPerWarp
Definition: gemm_group_quant_utils.hpp:154
static constexpr index_t X
Definition: gemm_group_quant_utils.hpp:162
static constexpr index_t KWarps
Definition: gemm_group_quant_utils.hpp:152
static constexpr index_t Y0
Definition: gemm_group_quant_utils.hpp:167
static constexpr index_t num_warps
Definition: gemm_group_quant_utils.hpp:148
static constexpr index_t Y2
Definition: gemm_group_quant_utils.hpp:172
static constexpr index_t Y1
Definition: gemm_group_quant_utils.hpp:170
static constexpr index_t XR
Definition: gemm_group_quant_utils.hpp:163
static constexpr index_t warp_size
Definition: gemm_group_quant_utils.hpp:147
Definition: gemm_group_quant_utils.hpp:57
static constexpr index_t KWarps
Definition: gemm_group_quant_utils.hpp:64
static constexpr index_t MWarps
Definition: gemm_group_quant_utils.hpp:62
static constexpr CK_TILE_HOST_DEVICE auto make_2d_static_tile_distribution()
Definition: gemm_group_quant_utils.hpp:73
static constexpr index_t warp_size
Definition: gemm_group_quant_utils.hpp:59
static constexpr index_t NWarps
Definition: gemm_group_quant_utils.hpp:63
static constexpr CK_TILE_HOST_DEVICE auto make_2d_static_tile_distribution_transposed()
Definition: gemm_group_quant_utils.hpp:113
static constexpr index_t num_warps
Definition: gemm_group_quant_utils.hpp:60
static constexpr index_t MIterPerWarp
Definition: gemm_group_quant_utils.hpp:66
Definition: gemm_group_quant_utils.hpp:198
static constexpr CK_TILE_HOST_DEVICE auto make_2d_static_tile_distribution()
Creates a 2D tile distribution for BQ (B-matrix quantization scales)
Definition: gemm_group_quant_utils.hpp:236
static constexpr index_t NWarps
Definition: gemm_group_quant_utils.hpp:203
static constexpr index_t warp_size
Definition: gemm_group_quant_utils.hpp:199
static constexpr index_t NIterPerWarp
Definition: gemm_group_quant_utils.hpp:206
static constexpr index_t num_warps
Definition: gemm_group_quant_utils.hpp:200
static constexpr index_t MWarps
Definition: gemm_group_quant_utils.hpp:202
static constexpr index_t KWarps
Definition: gemm_group_quant_utils.hpp:204
Definition: static_encoding_pattern.hpp:108
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192