/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp Source File
block_gemm_quant_common.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 
8 namespace ck_tile {
9 
10 // Common utilities for quantized GEMM block operations
11 template <typename CDataType,
12  typename WarpGemmType,
13  index_t MIterPerWarp,
14  index_t MWarp,
15  index_t NIterPerWarp,
16  index_t NWarp>
18 {
19  CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
20  {
21  constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
22  sequence<>,
27  sequence<0, 0>>{};
28 
29  constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
30  c_block_outer_dstr_encoding, typename WarpGemmType::CWarpDstrEncoding{});
31  constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
32  auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
33 
34  return c_block_tensor;
35  }
36 };
37 
39 {
40  template <typename QDataType, typename T>
41  CK_TILE_DEVICE static float cvt_scale_to_fp32(T scale)
42  {
43  float scale_reg_f = 0.f;
44  if constexpr(std::is_same_v<QDataType, ck_tile::fp8_t>)
45  {
46  scale_reg_f = __builtin_amdgcn_cvt_f32_fp8(static_cast<uint32_t>(scale), 0);
47  }
48  else if constexpr(std::is_same_v<QDataType, ck_tile::bf8_t>)
49  {
50  scale_reg_f = __builtin_amdgcn_cvt_f32_bf8(static_cast<uint32_t>(scale), 0);
51  }
52  else if constexpr(std::is_same_v<QDataType, float>)
53  {
54  scale_reg_f = ck_tile::bit_cast<float>(scale);
55  }
56  else
57  {
58  static_assert(!std::is_same_v<QDataType, QDataType>,
59  "QDataType must be float, fp8_t or bf8_t.");
60  }
61  return scale_reg_f;
62  }
63 };
64 
65 template <typename AQBlockTensor, typename GemmTraits_, int32_t mIter, int32_t kQScale>
67 {
72 
73  CK_TILE_DEVICE static float exchange_quant_value_across_lanes(float scale_reg,
74  index_t pull_from_lane)
75  {
76  // cross lane ops
77  uint32_t scale_reg_dword;
78 
79  if constexpr(std::is_same_v<AQDataType, float>)
80  {
81  scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
82  }
83  else
84  {
85  scale_reg_dword = static_cast<uint32_t>(scale_reg);
86  }
87 
88  int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
89  pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword));
90  return Base::cvt_scale_to_fp32<typename Traits::AQDataType>(gathered_scale_reg);
91  }
92 
94  AQPickerCommon(AQBlockTensor& aq_block_tensor_) : aq_block_tensor(aq_block_tensor_)
95  {
96  if constexpr(Traits::TransposeC) // transposed C
97  {
98  index_t reg_offset =
99  Traits::PreshuffleQuant ? mIter : mIter * Traits::AQPerBlock + kQScale;
100  auto scale_reg = aq_block_tensor.get_thread_buffer()[reg_offset];
101  if constexpr(Traits::PreshuffleQuant)
102  {
103  auto pull_from_lane =
104  (__lane_id() & (Traits::WarpGemm::kN - 1)) * Traits::AQPerBlock + kQScale;
105 
106  scale_reg_f = exchange_quant_value_across_lanes(scale_reg, pull_from_lane);
107  }
108  else
109  {
110  scale_reg_f = Base::cvt_scale_to_fp32<typename Traits::AQDataType>(scale_reg);
111  }
112  }
113  }
114  template <uint32_t c_row = 0>
116  {
117  if constexpr(Traits::TransposeC)
118  {
119  // pre-computed scale_reg_f is shared by entire column when TransposeC is true
120  return scale_reg_f;
121  }
122  else
123  {
124  if constexpr(Traits::PreshuffleQuant)
125  {
126  // A view is created on top of the preshuffled AQ, where each row of
127  // the view is composed of a row from a warp tile within an AQ block
128  // tile. Multiple warp tile rows that belong to the same block tile
129  // are laid out as consecutive rows.
130  //
131  // When we need to multiply a C warp tile with an AQ warp tile,
132  // thread 0 in the warp will load AQ_warp_tile[0], thread 1 will
133  // load AQ_warp_tile[1], and so on, up to thread 63, which will load
134  // AQ_warp_tile[63]. The VGPR file in the warp acts similarly to LDS
135  // in this context, but we use cross-lane operations to access the
136  // data. (Cross-lane operations are faster than using LDS.)
137  //
138  // Note that when the size of the AQ warp tile is smaller than the
139  // warp size, you need to pad the rows in the view to ensure that
140  // each thread can read one element.
141 
142  // For a warp tile of [16x16x32], take thread 0 as an
143  // example. Its VGPR[0] stores the value from C_tile[0,0],
144  // VGPR[1] stores C_tile[1,0], VGPR[2] stores C_tile[2,0],
145  // and VGPR[3] stores C_tile[3,0]. This means VGPR[0] should
146  // be multiplied by AQ_tile[0, 0], VGPR[1] by AQ_tile[1, 0],
147  // VGPR[2] by AQ_tile[2, 0], and VGPR[3] by AQ_tile[3, 0].
148 
149  // Thread 0 can read AQ_tile[0, 0] from itself, AQ_tile[1,
150  // 0] from thread 1, ..., and AQ_tile[3, 0] from thread 3.
151 
152  constexpr uint32_t kTileRowsOfCPerThread = (get_warp_size() == 64) ? 4 : 8;
153  decltype(threadIdx.x) pull_from_lane = 0;
154  if constexpr(WarpGemm::kM == 16)
155  {
156  pull_from_lane =
157  (__lane_id() / Traits::WarpGemm::kN * kTileRowsOfCPerThread + c_row) *
158  Traits::QScalesPerBlockRow +
159  kQScale;
160  }
161  else if constexpr(WarpGemm::kM == 32)
162  {
163  pull_from_lane = (__lane_id() / Traits::WarpGemm::kN * kTileRowsOfCPerThread +
164  ((c_row >> 2) << 3) + (c_row & 0b11)) *
165  Traits::QScalesPerBlockRow +
166  kQScale;
167  }
168  else
169  {
170  static_assert(false, "WarpGemm::kM is not 16 nor 32.");
171  }
172  auto& scale_reg = aq_block_tensor.get_thread_buffer()[mIter];
173  return exchange_quant_value_across_lanes(scale_reg, pull_from_lane);
174  }
175  else
176  {
177  // Need to multiply aquant with accumulated C
178  //
179  // The accumulated C tile has the standard distribution. For example, a
180  // 32x32 C lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0],
181  // [10,0], [11,0], [16,0], [17,0], [18,0], [19,0], [24,0], [25,0],
182  // [26,0], [27,0].
183  //
184  // These elements are in different rows, need to get the scale value
185  // for the corresponding row.
186  // Based on aquant's tile distribution, it can be inferred which
187  // lane holds the relevant scale. For example, the scales
188  // corresponding to the 16 elements held by lane 0 are held by lanes
189  // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27
190  // respectively.
191  //
192  // These scales can be obtained using __builtin_amdgcn_ds_bpermute.
193 
194  // Reg block offset based on mIter
195  // Each thread stores AQPerBlock scale values per M iteration.
196  constexpr index_t reg_block_offset = mIter * Traits::AQPerBlock;
197  constexpr index_t src_reg_offset = reg_block_offset + kQScale;
198  auto& scale_reg = aq_block_tensor.get_thread_buffer()[src_reg_offset];
199 
200  // Divide M dimension of C Warp tile into groups of
201  // (WarpGemm::kCMLane * WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane)
202  // m_base_offset_of_c_row indicates which group the current c_row belongs
203  // to.
204  constexpr index_t m_base_offset_of_c_row =
205  (c_row / WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane) *
206  (WarpGemm::kCMLane * WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane);
207 
208  // M offset of each thread within its group (see comment above)
209  index_t m_base_offset_of_lane =
210  (get_lane_id() / WarpGemm::kN * WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane);
211 
212  // M offset wrt. c_row in the subgroup of kCM1PerLane
213  constexpr index_t m_offset_of_c_row =
214  c_row & (WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane - 1);
215 
216  uint32_t src_lane_idx =
217  m_base_offset_of_c_row + m_base_offset_of_lane + m_offset_of_c_row;
218 
219  return exchange_quant_value_across_lanes(scale_reg, src_lane_idx);
220  }
221  }
222  }
223  AQBlockTensor& aq_block_tensor;
224  float scale_reg_f = 0.0f;
225 };
226 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:457
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:495
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
unsigned int uint32_t
Definition: stdint.h:126
Definition: block_gemm_quant_common.hpp:67
static CK_TILE_DEVICE float exchange_quant_value_across_lanes(float scale_reg, index_t pull_from_lane)
Definition: block_gemm_quant_common.hpp:73
CK_TILE_DEVICE float pick()
Definition: block_gemm_quant_common.hpp:115
CK_TILE_DEVICE AQPickerCommon(AQBlockTensor &aq_block_tensor_)
Definition: block_gemm_quant_common.hpp:94
AQBlockTensor & aq_block_tensor
Definition: block_gemm_quant_common.hpp:223
remove_cvref_t< GemmTraits_ > Traits
Definition: block_gemm_quant_common.hpp:69
remove_cvref_t< typename Traits::AQDataType > AQDataType
Definition: block_gemm_quant_common.hpp:71
float scale_reg_f
Definition: block_gemm_quant_common.hpp:224
remove_cvref_t< typename Traits::WarpGemm > WarpGemm
Definition: block_gemm_quant_common.hpp:70
Definition: block_gemm_quant_common.hpp:39
static CK_TILE_DEVICE float cvt_scale_to_fp32(T scale)
Definition: block_gemm_quant_common.hpp:41
Definition: block_gemm_quant_common.hpp:18
static constexpr CK_TILE_DEVICE auto MakeCBlockTile()
Definition: block_gemm_quant_common.hpp:19
Definition: sequence.hpp:49
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192