11 template <
typename CDataType,
12 typename WarpGemmType,
30 c_block_outer_dstr_encoding,
typename WarpGemmType::CWarpDstrEncoding{});
32 auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
34 return c_block_tensor;
40 template <
typename QDataType,
typename T>
43 float scale_reg_f = 0.f;
44 if constexpr(std::is_same_v<QDataType, ck_tile::fp8_t>)
46 scale_reg_f = __builtin_amdgcn_cvt_f32_fp8(
static_cast<uint32_t>(scale), 0);
48 else if constexpr(std::is_same_v<QDataType, ck_tile::bf8_t>)
50 scale_reg_f = __builtin_amdgcn_cvt_f32_bf8(
static_cast<uint32_t>(scale), 0);
52 else if constexpr(std::is_same_v<QDataType, float>)
54 scale_reg_f = ck_tile::bit_cast<float>(scale);
58 static_assert(!std::is_same_v<QDataType, QDataType>,
59 "QDataType must be float, fp8_t or bf8_t.");
65 template <
typename AQBlockTensor,
typename GemmTraits_,
int32_t mIter,
int32_t kQScale>
79 if constexpr(std::is_same_v<AQDataType, float>)
81 scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
85 scale_reg_dword =
static_cast<uint32_t>(scale_reg);
88 int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
89 pull_from_lane << 2, __builtin_bit_cast(
int, scale_reg_dword));
90 return Base::cvt_scale_to_fp32<typename Traits::AQDataType>(gathered_scale_reg);
96 if constexpr(Traits::TransposeC)
99 Traits::PreshuffleQuant ? mIter : mIter * Traits::AQPerBlock + kQScale;
101 if constexpr(Traits::PreshuffleQuant)
103 auto pull_from_lane =
104 (__lane_id() & (Traits::WarpGemm::kN - 1)) * Traits::AQPerBlock + kQScale;
110 scale_reg_f = Base::cvt_scale_to_fp32<typename Traits::AQDataType>(scale_reg);
114 template <u
int32_t c_row = 0>
117 if constexpr(Traits::TransposeC)
124 if constexpr(Traits::PreshuffleQuant)
153 decltype(threadIdx.x) pull_from_lane = 0;
154 if constexpr(WarpGemm::kM == 16)
157 (__lane_id() / Traits::WarpGemm::kN * kTileRowsOfCPerThread + c_row) *
158 Traits::QScalesPerBlockRow +
161 else if constexpr(WarpGemm::kM == 32)
163 pull_from_lane = (__lane_id() / Traits::WarpGemm::kN * kTileRowsOfCPerThread +
164 ((c_row >> 2) << 3) + (c_row & 0b11)) *
165 Traits::QScalesPerBlockRow +
170 static_assert(
false,
"WarpGemm::kM is not 16 nor 32.");
196 constexpr
index_t reg_block_offset = mIter * Traits::AQPerBlock;
197 constexpr
index_t src_reg_offset = reg_block_offset + kQScale;
204 constexpr
index_t m_base_offset_of_c_row =
205 (c_row / WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane) *
206 (WarpGemm::kCMLane * WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane);
209 index_t m_base_offset_of_lane =
210 (get_lane_id() / WarpGemm::kN * WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane);
213 constexpr
index_t m_offset_of_c_row =
214 c_row & (WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane - 1);
217 m_base_offset_of_c_row + m_base_offset_of_lane + m_offset_of_c_row;
#define CK_TILE_DEVICE
Definition: config.hpp:45
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:457
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:495
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
unsigned int uint32_t
Definition: stdint.h:126
Definition: block_gemm_quant_common.hpp:67
static CK_TILE_DEVICE float exchange_quant_value_across_lanes(float scale_reg, index_t pull_from_lane)
Definition: block_gemm_quant_common.hpp:73
CK_TILE_DEVICE float pick()
Definition: block_gemm_quant_common.hpp:115
CK_TILE_DEVICE AQPickerCommon(AQBlockTensor &aq_block_tensor_)
Definition: block_gemm_quant_common.hpp:94
AQBlockTensor & aq_block_tensor
Definition: block_gemm_quant_common.hpp:223
remove_cvref_t< GemmTraits_ > Traits
Definition: block_gemm_quant_common.hpp:69
remove_cvref_t< typename Traits::AQDataType > AQDataType
Definition: block_gemm_quant_common.hpp:71
float scale_reg_f
Definition: block_gemm_quant_common.hpp:224
remove_cvref_t< typename Traits::WarpGemm > WarpGemm
Definition: block_gemm_quant_common.hpp:70
Definition: block_gemm_quant_common.hpp:39
static CK_TILE_DEVICE float cvt_scale_to_fp32(T scale)
Definition: block_gemm_quant_common.hpp:41
Definition: block_gemm_quant_common.hpp:18
static constexpr CK_TILE_DEVICE auto MakeCBlockTile()
Definition: block_gemm_quant_common.hpp:19
Definition: sequence.hpp:49
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192