/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp Source File
warp_gemm_impl.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 namespace ck_tile {
8 
9 template <typename WarpGemmAttribute_>
11 {
13 
14  static constexpr index_t kM = WarpGemmAttribute::kM;
15  static constexpr index_t kN = WarpGemmAttribute::kN;
16  static constexpr index_t kK = WarpGemmAttribute::kK;
17  static constexpr index_t kCMLane = WarpGemmAttribute::kCMLane;
22  static constexpr index_t kKPerThread = WarpGemmAttribute::kKPerThread;
23 
24  using ADataType = typename WarpGemmAttribute::ADataType;
25  using BDataType = typename WarpGemmAttribute::BDataType;
26  using CDataType = typename WarpGemmAttribute::CDataType;
27 
28  using AWarpDstrEncoding = typename WarpGemmAttribute::AWarpDstrEncoding;
29  using BWarpDstrEncoding = typename WarpGemmAttribute::BWarpDstrEncoding;
30  using CWarpDstrEncoding = typename WarpGemmAttribute::CWarpDstrEncoding;
31 
35 
39 
40  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access()
41  {
42  return WarpGemmAttribute_::get_num_of_access();
43  }
44 
45  template <typename CTensor, typename ATensor, typename BTensor, bool post_nop_ = false>
46  CK_TILE_DEVICE void
47  operator()(CTensor& c, const ATensor& a, const BTensor& b, bool_constant<post_nop_> = {}) const
48  {
49  static_assert(detail::is_similiar_distributed_tensor_v<CTensor, CWarpTensor> &&
50  detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
51  detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
52  using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
53  using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
54  using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
55 
56  constexpr auto I0 = number<0>{};
57 
58  const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
59  const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
60  auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
61 
62  // c_vec += a_vec * b_vec
63  WarpGemmAttribute{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
64 
65  c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
66  }
67 
68  template <typename CTensor,
69  typename ATensor,
70  typename BTensor,
71  index_t i_subk,
72  bool post_nop_ = false>
73  CK_TILE_DEVICE void operator()(CTensor& c,
74  const ATensor& a,
75  const BTensor& b,
77  bool_constant<post_nop_> = {}) const
78  {
79  using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
80  using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
81  using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
82 
83  constexpr auto I0 = number<0>{};
84 
85  const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
86  const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
87  auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
88 
89  // c_vec += a_vec * b_vec
90  WarpGemmAttribute{}(c_vec, a_vec, b_vec, number<i_subk>{}, bool_constant<post_nop_>{});
91 
92  c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
93  }
94 
95  template <typename ATensor, typename BTensor>
96  CK_TILE_DEVICE auto operator()(const ATensor& a, const BTensor& b) const
97  {
98  using CTensor = CWarpTensor;
99  static_assert(detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
100  detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
101  CTensor c;
102 
103  using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
104  using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
105  using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
106 
107  constexpr auto I0 = number<0>{};
108 
109  const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
110  const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
111 
112  // c_vec = a_vec * b_vec
113  auto c_vec = WarpGemmAttribute{}(a_vec, b_vec);
114 
115  c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
116 
117  return c;
118  }
119 };
120 
121 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename impl::ext_vector< T, N >::type ext_vector_t
Definition: vector_type.hpp:83
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1249
Definition: warp_gemm_impl.hpp:11
CK_TILE_DEVICE void operator()(CTensor &c, const ATensor &a, const BTensor &b, number< i_subk >, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_impl.hpp:73
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_impl.hpp:40
typename WarpGemmAttribute::CWarpDstrEncoding CWarpDstrEncoding
Definition: warp_gemm_impl.hpp:30
CK_TILE_DEVICE auto operator()(const ATensor &a, const BTensor &b) const
Definition: warp_gemm_impl.hpp:96
typename WarpGemmAttribute::BWarpDstrEncoding BWarpDstrEncoding
Definition: warp_gemm_impl.hpp:29
remove_cvref_t< decltype(make_static_tile_distribution(BWarpDstrEncoding{}))> BWarpDstr
Definition: warp_gemm_impl.hpp:33
static constexpr index_t kKPerThread
The number of elements in K dimension processed by single thread in wavefront.
Definition: warp_gemm_impl.hpp:22
typename WarpGemmAttribute::AWarpDstrEncoding AWarpDstrEncoding
Definition: warp_gemm_impl.hpp:28
CK_TILE_DEVICE void operator()(CTensor &c, const ATensor &a, const BTensor &b, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_impl.hpp:47
static constexpr index_t kM
Definition: warp_gemm_impl.hpp:14
static constexpr index_t kK
Definition: warp_gemm_impl.hpp:16
static constexpr index_t kCMLane
Definition: warp_gemm_impl.hpp:17
typename WarpGemmAttribute::CDataType CDataType
Definition: warp_gemm_impl.hpp:26
typename WarpGemmAttribute::ADataType ADataType
Definition: warp_gemm_impl.hpp:24
remove_cvref_t< decltype(make_static_tile_distribution(CWarpDstrEncoding{}))> CWarpDstr
Definition: warp_gemm_impl.hpp:34
static constexpr index_t kN
Definition: warp_gemm_impl.hpp:15
remove_cvref_t< WarpGemmAttribute_ > WarpGemmAttribute
Definition: warp_gemm_impl.hpp:12
remove_cvref_t< decltype(make_static_tile_distribution(AWarpDstrEncoding{}))> AWarpDstr
Definition: warp_gemm_impl.hpp:32
static_distributed_tensor< CDataType, CWarpDstr > CWarpTensor
Definition: warp_gemm_impl.hpp:38
typename WarpGemmAttribute::BDataType BDataType
Definition: warp_gemm_impl.hpp:25
Definition: integral_constant.hpp:13
Definition: static_distributed_tensor.hpp:21