/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp Source File
gemm_group_quant_utils.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck_tile {
9 
10 template <typename Problem, typename DataType, index_t YPerTile, index_t XPerTile>
11 CK_TILE_HOST_DEVICE static constexpr auto GetABQGlobalVectorLoadSize()
12 {
13  using I1 = number<1>;
14  constexpr index_t NWarps = Problem::BlockGemmShape::BlockWarps::at(I1{});
15 
16  constexpr index_t BlockSize = Problem::kBlockSize;
17 
18  // Data is replicated across warps along NWarps, so we divide BlockSize by NWarps
19  constexpr index_t elements_per_thread = (YPerTile * XPerTile) / (BlockSize / NWarps);
20  constexpr index_t PackedSize = ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
21 
22  // Define vector load candidates in descending order of priority
23  constexpr std::array<index_t, 5> candidates{
24  PackedSize * 32 / sizeof(DataType),
25  PackedSize * 16 / sizeof(DataType),
26  PackedSize * 8 / sizeof(DataType),
27  PackedSize * 4 / sizeof(DataType),
28  PackedSize * 2 / sizeof(DataType),
29  };
30 
31  for(const auto vec_size : candidates)
32  {
33  if(vec_size <= 0 || XPerTile % vec_size != 0 || elements_per_thread % vec_size != 0)
34  continue;
35  bool is_valid = (vec_size > 0) && (XPerTile % vec_size == 0) &&
36  (elements_per_thread % vec_size == 0) && vec_size != candidates[4];
37  if(is_valid)
38  {
39  return vec_size;
40  }
41  }
42  return PackedSize; // Absolute fallback
43 }
44 
45 // AQ holds groupquant scale data for A. Data is loaded from DRAM and partitioned across
46 // threads. Post mfma scales are shuffled across threads in the warp and applied to
47 // accum registers.
48 template <typename BlockGemmShape,
49  typename WarpGemm,
50  index_t BlockSize,
51  index_t YPerTile,
52  index_t XPerTile,
53  index_t KPerBlockAQ,
54  index_t VecSize,
55  bool PreshuffleQuant>
57 {
58  static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
59  static constexpr index_t warp_size = get_warp_size();
60  static constexpr index_t num_warps = BlockSize / get_warp_size();
61 
62  static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(number<0>{});
63  static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(number<1>{});
64  static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(number<2>{});
65 
66  static constexpr index_t MIterPerWarp = BlockGemmShape::kM / (MWarps * WarpGemm::kM);
67 
68  static_assert(num_warps == MWarps * NWarps * KWarps);
69 
70  // KWarps > 1 isn't supported
71  static_assert(KWarps == 1);
72 
74  {
75  if constexpr(PreshuffleQuant)
76  {
77  // # of elements per thread
78  static_assert(XPerTile >= warp_size && XPerTile % warp_size == 0);
79  constexpr index_t X1 = warp_size;
80  constexpr index_t X0 = XPerTile / warp_size;
81 
82  constexpr index_t Y1 = MWarps;
83  constexpr index_t Y0 = YPerTile / Y1;
90  sequence<0, 0>>{});
91  }
92  else
93  {
94  // # of elements per thread
95  constexpr index_t X = XPerTile;
96 
97  constexpr index_t Y0 = 1;
98  constexpr index_t Y1 = MIterPerWarp ? MIterPerWarp : 1;
99  constexpr index_t Y2 = MWarps;
100  constexpr index_t Y3 = WarpGemm::kM;
101  static_assert(Y3 >= WarpGemm::kM,
102  "Scales for all rows must be available within the warp.");
103  static_assert(Y0 * Y1 * Y2 * Y3 == YPerTile,
104  "Y0, Y1, Y2, Y3 must cover the blocktile along Y.");
111  sequence<1, 0>>{});
112  }
113  }
114 };
115 
116 template <typename BlockGemmShape,
117  typename WarpGemm,
118  index_t BlockSize,
119  index_t YPerTile,
120  index_t XPerTile,
121  index_t VecSize>
124 {
125  // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
126  static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
127  static constexpr index_t warp_size = get_warp_size();
128  static constexpr index_t num_warps = BlockSize / get_warp_size();
129 
130  static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(number<0>{});
131  static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(number<1>{});
132  static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(number<2>{});
133 
134  static constexpr index_t MIterPerWarp = BlockGemmShape::kM / (MWarps * WarpGemm::kM);
135 
136  static_assert(num_warps == MWarps * NWarps * KWarps);
137 
138  // KWarps > 1 isn't supported
139  static_assert(KWarps == 1);
140 
141  // # of elements per thread
142  static constexpr index_t X = XPerTile;
143  static constexpr index_t XR = 2;
144 
145  // Number of iters per warp
146  // MIters are indexed using (Y0, Y1)
147  static constexpr index_t Y0 = MIterPerWarp;
148 
149  // # of warps in Y dim
150  static constexpr index_t Y1 = MWarps;
151 
152  static constexpr index_t Y2 = WarpGemm::kM;
153 
154  static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y.");
155 
157  {
164  sequence<0, 0>>{});
165  }
166 };
167 
168 // TODO:: might need to update
169 template <typename BlockGemmShape,
170  typename WarpGemm,
171  index_t BlockSize,
172  index_t YPerTile,
173  index_t XPerTile,
174  index_t VecSize>
176 {
177  // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
178  static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
179  static constexpr index_t warp_size = get_warp_size();
180  static constexpr index_t num_warps = BlockSize / get_warp_size();
181 
182  static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(number<0>{});
183  static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(number<1>{});
184  static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(number<2>{});
185 
186  static constexpr index_t NIterPerWarp = BlockGemmShape::kN / (NWarps * WarpGemm::kN);
187 
188  static_assert(num_warps == MWarps * NWarps * KWarps);
189 
190  // KWarps > 1 isn't supported
191  static_assert(KWarps == 1);
192 
193  // # of elements per thread
194  static constexpr index_t X = XPerTile;
195  static constexpr index_t XR = 2;
196 
197  // Number of iters per warp
198  // MIters are indexed using (Y0, Y1)
199  static constexpr index_t Y0 = NIterPerWarp;
200 
201  // # of warps in Y dim
202  static constexpr index_t Y1 = NWarps;
203 
204  static constexpr index_t Y2 = WarpGemm::kN;
205 
206  static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y.");
207 
209  {
216  sequence<0, 0>>{});
217  }
218 };
219 
220 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
Definition: integral_constant.hpp:13
Definition: numeric.hpp:81
Definition: sequence.hpp:49
Definition: gemm_group_quant_utils.hpp:124
static constexpr index_t NWarps
Definition: gemm_group_quant_utils.hpp:131
static constexpr CK_TILE_HOST_DEVICE auto make_2d_static_tile_distribution()
Definition: gemm_group_quant_utils.hpp:156
static constexpr index_t MWarps
Definition: gemm_group_quant_utils.hpp:130
static constexpr index_t MIterPerWarp
Definition: gemm_group_quant_utils.hpp:134
static constexpr index_t X
Definition: gemm_group_quant_utils.hpp:142
static constexpr index_t KWarps
Definition: gemm_group_quant_utils.hpp:132
static constexpr index_t Y0
Definition: gemm_group_quant_utils.hpp:147
static constexpr index_t num_warps
Definition: gemm_group_quant_utils.hpp:128
static constexpr index_t Y2
Definition: gemm_group_quant_utils.hpp:152
static constexpr index_t Y1
Definition: gemm_group_quant_utils.hpp:150
static constexpr index_t XR
Definition: gemm_group_quant_utils.hpp:143
static constexpr index_t warp_size
Definition: gemm_group_quant_utils.hpp:127
Definition: gemm_group_quant_utils.hpp:57
static constexpr index_t KWarps
Definition: gemm_group_quant_utils.hpp:64
static constexpr index_t MWarps
Definition: gemm_group_quant_utils.hpp:62
static constexpr CK_TILE_HOST_DEVICE auto make_2d_static_tile_distribution()
Definition: gemm_group_quant_utils.hpp:73
static constexpr index_t warp_size
Definition: gemm_group_quant_utils.hpp:59
static constexpr index_t NWarps
Definition: gemm_group_quant_utils.hpp:63
static constexpr index_t num_warps
Definition: gemm_group_quant_utils.hpp:60
static constexpr index_t MIterPerWarp
Definition: gemm_group_quant_utils.hpp:66
Definition: gemm_group_quant_utils.hpp:176
static constexpr index_t warp_size
Definition: gemm_group_quant_utils.hpp:179
static constexpr CK_TILE_HOST_DEVICE auto make_2d_static_tile_distribution()
Definition: gemm_group_quant_utils.hpp:208
static constexpr index_t X
Definition: gemm_group_quant_utils.hpp:194
static constexpr index_t NWarps
Definition: gemm_group_quant_utils.hpp:183
static constexpr index_t Y2
Definition: gemm_group_quant_utils.hpp:204
static constexpr index_t num_warps
Definition: gemm_group_quant_utils.hpp:180
static constexpr index_t Y0
Definition: gemm_group_quant_utils.hpp:199
static constexpr index_t KWarps
Definition: gemm_group_quant_utils.hpp:184
static constexpr index_t NIterPerWarp
Definition: gemm_group_quant_utils.hpp:186
static constexpr index_t MWarps
Definition: gemm_group_quant_utils.hpp:182
static constexpr index_t XR
Definition: gemm_group_quant_utils.hpp:195
static constexpr index_t Y1
Definition: gemm_group_quant_utils.hpp:202
Definition: static_encoding_pattern.hpp:108
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192