/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp Source File
warp_gemm_attribute_smfmac_impl.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck_tile {
9 
10 // fp16 2:4 structured sparsity
11 
12 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
14 {
15  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
16  using ADataType = fp16_t;
17  using BDataType = fp16_t;
19  using CDataType = float;
20 
24 
25  static constexpr index_t kM = 32;
26  static constexpr index_t kN = 32;
27  static constexpr index_t kK = 16;
28 
29  static constexpr index_t kAMBlock = 1;
30  static constexpr index_t kBNBlock = 1;
31 
32  static constexpr index_t kAMLane = 32;
33  static constexpr index_t kBNLane = 32;
34  static constexpr index_t kABKLane = 2;
35  static constexpr index_t kABKPerLane = 8;
36 
37  static constexpr index_t kCMLane = 2;
38  static constexpr index_t kCNLane = 32;
39  static constexpr index_t kCM0PerLane = 4;
40  static constexpr index_t kCM1PerLane = 4;
41 
42  static constexpr index_t CompressionRatio = 2;
43 
44  // c_vec += a_vec * b_vec[idx]
45  template <bool post_nop_ = false>
47  const AVecType& a_vec,
48  const BVecType& b_vec,
49  const int32_t& idx,
50  bool_constant<post_nop_> = {}) const
51  {
52 #if defined(__gfx94_) or defined(__gfx95_)
53  c_vec = __builtin_amdgcn_smfmac_f32_32x32x16_f16(a_vec, b_vec, c_vec, idx, 0, 0);
54 #else
55  ck_tile::ignore = c_vec;
56  ck_tile::ignore = a_vec;
57  ck_tile::ignore = b_vec;
58  ck_tile::ignore = idx;
59 #endif
60  }
61 };
62 
63 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
65 {
66  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
67  using ADataType = fp16_t;
68  using BDataType = fp16_t;
70  using CDataType = float;
71 
75 
76  static constexpr index_t kM = 16;
77  static constexpr index_t kN = 16;
78  static constexpr index_t kK = 32;
79 
80  static constexpr index_t kAMBlock = 1;
81  static constexpr index_t kBNBlock = 1;
82 
83  static constexpr index_t kAMLane = 16;
84  static constexpr index_t kBNLane = 16;
85  static constexpr index_t kABKLane = 4;
86  static constexpr index_t kABKPerLane = 8;
87 
88  static constexpr index_t kCMLane = 4;
89  static constexpr index_t kCNLane = 16;
90  static constexpr index_t kCM0PerLane = 1;
91  static constexpr index_t kCM1PerLane = 4;
92 
93  static constexpr index_t CompressionRatio = 2;
94 
95  // c_vec += a_vec * b_vec[idx]
96  template <bool post_nop_ = false>
98  const AVecType& a_vec,
99  const BVecType& b_vec,
100  const int32_t& idx,
101  bool_constant<post_nop_> = {}) const
102  {
103 #if defined(__gfx94_) or defined(__gfx95_)
104  c_vec = __builtin_amdgcn_smfmac_f32_16x16x32_f16(a_vec, b_vec, c_vec, idx, 0, 0);
105 #else
106  ck_tile::ignore = c_vec;
107  ck_tile::ignore = a_vec;
108  ck_tile::ignore = b_vec;
109  ck_tile::ignore = idx;
110 #endif
111  }
112 };
113 
114 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
WGAttrCtlEnum
Definition: warp_gemm_attribute_mfma_impl.hpp:15
_Float16 fp16_t
Definition: half.hpp:110
int32_t index_t
Definition: integer.hpp:9
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
typename impl::ext_vector< T, N >::type ext_vector_t
Definition: vector_type.hpp:83
int32_t int32_t
Definition: integer.hpp:10
Definition: warp_gemm_attribute_smfmac_impl.hpp:65
ext_vector_t< fp16_t, 8 > BVecType
Definition: warp_gemm_attribute_smfmac_impl.hpp:73
int32_t IdxDataType
Definition: warp_gemm_attribute_smfmac_impl.hpp:69
static constexpr index_t kM
Definition: warp_gemm_attribute_smfmac_impl.hpp:76
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_smfmac_impl.hpp:83
float CDataType
Definition: warp_gemm_attribute_smfmac_impl.hpp:70
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_smfmac_impl.hpp:86
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_smfmac_impl.hpp:90
static constexpr index_t CompressionRatio
Definition: warp_gemm_attribute_smfmac_impl.hpp:93
fp16_t ADataType
Definition: warp_gemm_attribute_smfmac_impl.hpp:67
ext_vector_t< fp16_t, 4 > AVecType
Definition: warp_gemm_attribute_smfmac_impl.hpp:72
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, const int32_t &idx, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_smfmac_impl.hpp:97
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_smfmac_impl.hpp:91
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_smfmac_impl.hpp:66
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_smfmac_impl.hpp:80
static constexpr index_t kN
Definition: warp_gemm_attribute_smfmac_impl.hpp:77
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_smfmac_impl.hpp:88
fp16_t BDataType
Definition: warp_gemm_attribute_smfmac_impl.hpp:68
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_smfmac_impl.hpp:89
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_smfmac_impl.hpp:84
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_smfmac_impl.hpp:74
static constexpr index_t kK
Definition: warp_gemm_attribute_smfmac_impl.hpp:78
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_smfmac_impl.hpp:85
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_smfmac_impl.hpp:81
Definition: warp_gemm_attribute_smfmac_impl.hpp:14
static constexpr index_t kN
Definition: warp_gemm_attribute_smfmac_impl.hpp:26
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_smfmac_impl.hpp:35
static constexpr index_t kK
Definition: warp_gemm_attribute_smfmac_impl.hpp:27
ext_vector_t< fp16_t, 4 > AVecType
Definition: warp_gemm_attribute_smfmac_impl.hpp:21
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_smfmac_impl.hpp:40
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, const int32_t &idx, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_smfmac_impl.hpp:46
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_smfmac_impl.hpp:34
static constexpr index_t kM
Definition: warp_gemm_attribute_smfmac_impl.hpp:25
ext_vector_t< float, 16 > CVecType
Definition: warp_gemm_attribute_smfmac_impl.hpp:23
float CDataType
Definition: warp_gemm_attribute_smfmac_impl.hpp:19
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_smfmac_impl.hpp:39
fp16_t BDataType
Definition: warp_gemm_attribute_smfmac_impl.hpp:17
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_smfmac_impl.hpp:32
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_smfmac_impl.hpp:38
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_smfmac_impl.hpp:37
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_smfmac_impl.hpp:33
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_smfmac_impl.hpp:30
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_smfmac_impl.hpp:29
int32_t IdxDataType
Definition: warp_gemm_attribute_smfmac_impl.hpp:18
ext_vector_t< fp16_t, 8 > BVecType
Definition: warp_gemm_attribute_smfmac_impl.hpp:22
fp16_t ADataType
Definition: warp_gemm_attribute_smfmac_impl.hpp:16
static constexpr index_t CompressionRatio
Definition: warp_gemm_attribute_smfmac_impl.hpp:42
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_smfmac_impl.hpp:15
Definition: integral_constant.hpp:13