/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp Source File
tile_gemm_traits.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 
8 namespace ck_tile {
9 
10 template <bool kPadM_,
11  bool kPadN_,
12  bool kPadK_,
13  typename AsLayout_,
14  typename BsLayout_,
15  typename CLayout_,
16  index_t NumWaveGroups_ = 1>
18 {
19  static constexpr bool kPadM = kPadM_;
20  static constexpr bool kPadN = kPadN_;
21  static constexpr bool kPadK = kPadK_;
22 
23  // TODO this can't be hardcoded here! Should be in policy!
24  static constexpr int _VectorSize = 16;
25 
26  using AsLayout = AsLayout_;
27  using BsLayout = BsLayout_;
28  using CLayout = CLayout_;
29 
30  static constexpr bool TransposeC = false;
31  static constexpr bool UseStructuredSparsity = false;
32  static constexpr index_t NumWaveGroups = NumWaveGroups_;
33 };
34 
35 template <bool kPadM_,
36  bool kPadN_,
37  bool kPadK_,
38  bool DoubleSmemBuffer_,
39  typename AsLayout_,
40  typename BsLayout_,
41  typename CLayout_,
42  bool TransposeC_ = false,
43  bool UseStructuredSparsity_ = false,
44  bool UsePersistentKernel_ = false,
45  index_t NumWaveGroups_ = 1,
46  bool Preshuffle_ = false,
47  int VectorSize_ = 16>
49 {
50  static constexpr bool kPadM = kPadM_;
51  static constexpr bool kPadN = kPadN_;
52  static constexpr bool kPadK = kPadK_;
53  static constexpr int _VectorSize = VectorSize_;
54  static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
55 
56  using AsLayout = AsLayout_;
57  using BsLayout = BsLayout_;
58  using CLayout = CLayout_;
59  static constexpr bool TransposeC = TransposeC_;
60 
61  static constexpr bool UseStructuredSparsity = UseStructuredSparsity_;
62  static constexpr bool UsePersistentKernel = UsePersistentKernel_;
63  static constexpr index_t NumWaveGroups = NumWaveGroups_;
64  static constexpr bool Preshuffle = Preshuffle_;
65 };
66 
67 template <bool kPadM_,
68  bool kPadN_,
69  bool kPadK_,
70  bool DoubleSmemBuffer_,
71  typename AsLayout_,
72  typename BsLayout_,
73  typename CLayout_,
74  bool TransposeC_ = false,
75  bool UseStructuredSparsity_ = false>
77  kPadN_,
78  kPadK_,
79  DoubleSmemBuffer_,
80  AsLayout_,
81  BsLayout_,
82  CLayout_,
83  TransposeC_,
84  UseStructuredSparsity_,
85  true>;
86 
87 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
Definition: tile_gemm_traits.hpp:18
static constexpr index_t NumWaveGroups
Definition: tile_gemm_traits.hpp:32
CLayout_ CLayout
Definition: tile_gemm_traits.hpp:28
static constexpr bool TransposeC
Definition: tile_gemm_traits.hpp:30
AsLayout_ AsLayout
Definition: tile_gemm_traits.hpp:26
static constexpr bool kPadM
Definition: tile_gemm_traits.hpp:19
static constexpr int _VectorSize
Definition: tile_gemm_traits.hpp:24
static constexpr bool kPadN
Definition: tile_gemm_traits.hpp:20
static constexpr bool UseStructuredSparsity
Definition: tile_gemm_traits.hpp:31
static constexpr bool kPadK
Definition: tile_gemm_traits.hpp:21
BsLayout_ BsLayout
Definition: tile_gemm_traits.hpp:27
Definition: tile_gemm_traits.hpp:49
static constexpr bool Preshuffle
Definition: tile_gemm_traits.hpp:64
BsLayout_ BsLayout
Definition: tile_gemm_traits.hpp:57
static constexpr index_t NumWaveGroups
Definition: tile_gemm_traits.hpp:63
static constexpr bool UsePersistentKernel
Definition: tile_gemm_traits.hpp:62
static constexpr bool UseStructuredSparsity
Definition: tile_gemm_traits.hpp:61
AsLayout_ AsLayout
Definition: tile_gemm_traits.hpp:56
static constexpr int _VectorSize
Definition: tile_gemm_traits.hpp:53
static constexpr bool DoubleSmemBuffer
Definition: tile_gemm_traits.hpp:54
CLayout_ CLayout
Definition: tile_gemm_traits.hpp:58
static constexpr bool kPadK
Definition: tile_gemm_traits.hpp:52
static constexpr bool kPadN
Definition: tile_gemm_traits.hpp:51
static constexpr bool TransposeC
Definition: tile_gemm_traits.hpp:59
static constexpr bool kPadM
Definition: tile_gemm_traits.hpp:50