/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp Source File
tile_gemm_quant_traits.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include <cstdint>
8 
9 namespace ck_tile {
10 
11 enum struct QuantType : std::uint16_t
12 {
13  AQuantGrouped = 0,
14  BQuantGrouped = 1,
15  RowColQuant = 2
16 };
17 
18 template <bool kPadM_,
19  bool kPadN_,
20  bool kPadK_,
21  bool PreshuffleQuant_,
22  typename ALayout_,
23  typename BLayout_,
24  typename CLayout_,
25  QuantType QuantType_,
26  typename AQLayout_ = ALayout_,
27  typename BQLayout_ = BLayout_>
29 {
30  static constexpr bool kPadM = kPadM_;
31  static constexpr bool kPadN = kPadN_;
32  static constexpr bool kPadK = kPadK_;
33 
34  static constexpr QuantType kQuantType = QuantType_;
35 
36  static constexpr int _VectorSize = 16;
37 
38  using ALayout = ALayout_;
39  using BLayout = BLayout_;
40  using CLayout = CLayout_;
41  using AQLayout = AQLayout_;
42  using BQLayout = BQLayout_;
43 
44  static constexpr bool TransposeC = false;
45  static constexpr bool UseStructuredSparsity = false;
46  static constexpr index_t NumWaveGroups = 1;
47 
48  static constexpr bool PreshuffleQuant = PreshuffleQuant_;
49 };
50 
51 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
QuantType
Definition: tile_gemm_quant_traits.hpp:12
unsigned short uint16_t
Definition: stdint.h:125
Definition: tile_gemm_quant_traits.hpp:29
static constexpr index_t NumWaveGroups
Definition: tile_gemm_quant_traits.hpp:46
BQLayout_ BQLayout
Definition: tile_gemm_quant_traits.hpp:42
ALayout_ ALayout
Definition: tile_gemm_quant_traits.hpp:38
AQLayout_ AQLayout
Definition: tile_gemm_quant_traits.hpp:41
static constexpr int _VectorSize
Definition: tile_gemm_quant_traits.hpp:36
static constexpr bool kPadN
Definition: tile_gemm_quant_traits.hpp:31
static constexpr bool kPadM
Definition: tile_gemm_quant_traits.hpp:30
static constexpr QuantType kQuantType
Definition: tile_gemm_quant_traits.hpp:34
static constexpr bool TransposeC
Definition: tile_gemm_quant_traits.hpp:44
static constexpr bool UseStructuredSparsity
Definition: tile_gemm_quant_traits.hpp:45
static constexpr bool kPadK
Definition: tile_gemm_quant_traits.hpp:32
static constexpr bool PreshuffleQuant
Definition: tile_gemm_quant_traits.hpp:48
BLayout_ BLayout
Definition: tile_gemm_quant_traits.hpp:39
CLayout_ CLayout
Definition: tile_gemm_quant_traits.hpp:40