/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp Source File
warp_gemm_smfmac_impl.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 namespace ck_tile {
8 
9 template <typename WarpGemmAttribute_>
11 {
13 
14  static constexpr index_t kM = WarpGemmAttribute::kM;
15  static constexpr index_t kN = WarpGemmAttribute::kN;
16  static constexpr index_t kK = WarpGemmAttribute::kK;
21  static constexpr index_t kKPerThread = WarpGemmAttribute::kKPerThread;
22 
23  using ADataType = typename WarpGemmAttribute::ADataType;
24  using BDataType = typename WarpGemmAttribute::BDataType;
25  using CDataType = typename WarpGemmAttribute::CDataType;
26 
27  using AWarpDstrEncoding = typename WarpGemmAttribute::AWarpDstrEncoding;
28  using BWarpDstrEncoding = typename WarpGemmAttribute::BWarpDstrEncoding;
29  using CWarpDstrEncoding = typename WarpGemmAttribute::CWarpDstrEncoding;
30 
34 
38 
39  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access()
40  {
41  return WarpGemmAttribute_::get_num_of_access();
42  }
43 
44  //----------------------------------------------------------------------------------------------
52  template <typename AVec>
53  CK_TILE_DEVICE int32_t compress_a(AVec& a_vec) const
54  {
55  int32_t idx = 0b11101110;
56 
57  static_for<0, 2, 1>{}([&](auto i) {
58  ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]};
59  int32_t non_zero_pos = 0;
60 
61  static_for<0, 3, 1>{}([&](auto j) {
62  if(a_vec[i * 4 + j] != 0.0f)
63  {
64  nonzero_elems[non_zero_pos] = a_vec[i * 4 + j];
65  idx &= ~(0b11 << 2 * (i * 2 + non_zero_pos));
66  idx |= j << 2 * (i * 2 + non_zero_pos);
67  ++non_zero_pos;
68  }
69  });
70  a_vec[i * 2] = nonzero_elems[0];
71  a_vec[i * 2 + 1] = nonzero_elems[1];
72  });
73 
74  return idx;
75  }
76 
77  template <typename CTensor, typename ATensor, typename BTensor, bool post_nop_ = false>
78  CK_TILE_DEVICE void
79  operator()(CTensor& c, const ATensor& a, const BTensor& b, bool_constant<post_nop_> = {}) const
80  {
81  static_assert(detail::is_similiar_distributed_tensor_v<CTensor, CWarpTensor> &&
82  detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
83  detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
84  constexpr auto CompressionRatio = WarpGemmAttribute::kCompressionRatio;
85 
86  using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
87  using AVecCompressed =
88  ext_vector_t<ADataType, ATensor::get_thread_buffer_size() / CompressionRatio>;
89  using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
90  using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
91 
92  constexpr auto I0 = number<0>{};
93 
94  auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
95  const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
96  auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
97 
98  const int32_t idx = compress_a(a_vec);
99 
100  // @TODO can we simply set a_vec_pruned to a_vec[0:3]?
101  const AVecCompressed a_vec_pruned = {a_vec[0], a_vec[1], a_vec[2], a_vec[3]};
102 
103  // c_vec += a_vec * b_vec[idx]
104  WarpGemmAttribute{}(c_vec, a_vec_pruned, b_vec, idx, bool_constant<post_nop_>{});
105 
106  c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
107  }
108 };
109 
110 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename impl::ext_vector< T, N >::type ext_vector_t
Definition: vector_type.hpp:83
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1249
signed int int32_t
Definition: stdint.h:123
Definition: warp_gemm_smfmac_impl.hpp:11
CK_TILE_DEVICE void operator()(CTensor &c, const ATensor &a, const BTensor &b, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_smfmac_impl.hpp:79
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_smfmac_impl.hpp:39
static constexpr index_t kK
Definition: warp_gemm_smfmac_impl.hpp:16
typename WarpGemmAttribute::BDataType BDataType
Definition: warp_gemm_smfmac_impl.hpp:24
typename WarpGemmAttribute::BWarpDstrEncoding BWarpDstrEncoding
Definition: warp_gemm_smfmac_impl.hpp:28
typename WarpGemmAttribute::ADataType ADataType
Definition: warp_gemm_smfmac_impl.hpp:23
remove_cvref_t< decltype(make_static_tile_distribution(BWarpDstrEncoding{}))> BWarpDstr
Definition: warp_gemm_smfmac_impl.hpp:32
static constexpr index_t kM
Definition: warp_gemm_smfmac_impl.hpp:14
remove_cvref_t< decltype(make_static_tile_distribution(CWarpDstrEncoding{}))> CWarpDstr
Definition: warp_gemm_smfmac_impl.hpp:33
static constexpr index_t kKPerThread
The number of elements in K dimension processed by single thread in wavefront.
Definition: warp_gemm_smfmac_impl.hpp:21
typename WarpGemmAttribute::CWarpDstrEncoding CWarpDstrEncoding
Definition: warp_gemm_smfmac_impl.hpp:29
CK_TILE_DEVICE int32_t compress_a(AVec &a_vec) const
Compress A vector for 2:4 structured sparsity instruction by moving all non-zero elements into lower ...
Definition: warp_gemm_smfmac_impl.hpp:53
typename WarpGemmAttribute::CDataType CDataType
Definition: warp_gemm_smfmac_impl.hpp:25
remove_cvref_t< decltype(make_static_tile_distribution(AWarpDstrEncoding{}))> AWarpDstr
Definition: warp_gemm_smfmac_impl.hpp:31
typename WarpGemmAttribute::AWarpDstrEncoding AWarpDstrEncoding
Definition: warp_gemm_smfmac_impl.hpp:27
static constexpr index_t kN
Definition: warp_gemm_smfmac_impl.hpp:15
remove_cvref_t< WarpGemmAttribute_ > WarpGemmAttribute
Definition: warp_gemm_smfmac_impl.hpp:12
Definition: integral_constant.hpp:13
Definition: static_distributed_tensor.hpp:21
Definition: functional.hpp:43