/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp Source File
tile_gemm_quant_traits.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include <cstdint>
8 
9 namespace ck_tile {
10 
11 enum struct QuantType : std::uint16_t
12 {
13  AQuantGrouped = 0,
14  BQuantGrouped = 1,
15  RowColQuant = 2,
16  TensorQuant = 3,
17  ABQuantGrouped = 4
18 };
19 
20 inline std::string quant_type_to_string(QuantType quant_type)
21 {
22  switch(quant_type)
23  {
24  case QuantType::AQuantGrouped: return "AQuantGrouped";
25  case QuantType::BQuantGrouped: return "BQuantGrouped";
26  case QuantType::RowColQuant: return "RowColQuant";
27  case QuantType::TensorQuant: return "TensorQuant";
28  case QuantType::ABQuantGrouped: return "ABQuantGrouped";
29  default: return "Unknown";
30  }
31 }
32 
33 template <bool kPadM_,
34  bool kPadN_,
35  bool kPadK_,
36  bool PreshuffleQuant_,
37  bool PreshuffleB_,
38  typename ALayout_,
39  typename BLayout_,
40  typename CLayout_,
41  QuantType QuantType_,
42  typename AQLayout_ = ALayout_,
43  typename BQLayout_ = BLayout_,
44  bool TransposeC_ = false,
45  bool DoubleSmemBuffer_ = false,
46  bool UsePersistentKernel_ = false,
47  int VectorSize_ = 16>
49 {
50  static constexpr bool kPadM = kPadM_;
51  static constexpr bool kPadN = kPadN_;
52  static constexpr bool kPadK = kPadK_;
53 
54  static constexpr QuantType kQuantType = QuantType_;
55 
56  static constexpr int _VectorSize = VectorSize_;
57  static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
58 
59  using ALayout = ALayout_;
60  using BLayout = BLayout_;
61  using CLayout = CLayout_;
62  using AQLayout = AQLayout_;
63  using BQLayout = BQLayout_;
64 
65  // TODO: It should be replaced to single value
66  using AsLayout = ALayout_;
67  using BsLayout = BLayout_;
68 
69  static constexpr bool TransposeC = TransposeC_;
70  static constexpr bool UseStructuredSparsity = false;
71  static constexpr index_t NumWaveGroups = 1;
72  static constexpr bool UsePersistentKernel = UsePersistentKernel_;
73 
74  static constexpr bool PreshuffleQuant = PreshuffleQuant_;
75  static constexpr bool PreshuffleB = PreshuffleB_;
76 };
77 
78 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
std::string quant_type_to_string(QuantType quant_type)
Definition: tile_gemm_quant_traits.hpp:20
int32_t index_t
Definition: integer.hpp:9
QuantType
Definition: tile_gemm_quant_traits.hpp:12
unsigned short uint16_t
Definition: stdint.h:125
Definition: tile_gemm_quant_traits.hpp:49
static constexpr bool kPadN
Definition: tile_gemm_quant_traits.hpp:51
static constexpr bool UsePersistentKernel
Definition: tile_gemm_quant_traits.hpp:72
AQLayout_ AQLayout
Definition: tile_gemm_quant_traits.hpp:62
CLayout_ CLayout
Definition: tile_gemm_quant_traits.hpp:61
BLayout_ BLayout
Definition: tile_gemm_quant_traits.hpp:60
ALayout_ ALayout
Definition: tile_gemm_quant_traits.hpp:59
static constexpr bool TransposeC
Definition: tile_gemm_quant_traits.hpp:69
static constexpr bool PreshuffleQuant
Definition: tile_gemm_quant_traits.hpp:74
BLayout_ BsLayout
Definition: tile_gemm_quant_traits.hpp:67
static constexpr index_t NumWaveGroups
Definition: tile_gemm_quant_traits.hpp:71
static constexpr bool kPadM
Definition: tile_gemm_quant_traits.hpp:50
static constexpr bool PreshuffleB
Definition: tile_gemm_quant_traits.hpp:75
BQLayout_ BQLayout
Definition: tile_gemm_quant_traits.hpp:63
static constexpr bool DoubleSmemBuffer
Definition: tile_gemm_quant_traits.hpp:57
static constexpr bool kPadK
Definition: tile_gemm_quant_traits.hpp:52
static constexpr QuantType kQuantType
Definition: tile_gemm_quant_traits.hpp:54
static constexpr int _VectorSize
Definition: tile_gemm_quant_traits.hpp:56
static constexpr bool UseStructuredSparsity
Definition: tile_gemm_quant_traits.hpp:70
ALayout_ AsLayout
Definition: tile_gemm_quant_traits.hpp:66