/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/amd_gemm_dpp.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/amd_gemm_dpp.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/amd_gemm_dpp.hpp Source File
amd_gemm_dpp.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/math.hpp"
9 
10 namespace ck {
11 
12 namespace dpp8 {
13 
14 template <class ABDataType>
16 
17 template <>
19 {
20  // Dot product of `half2_t` and `half2_t` to get `float`. Reducing 2 elements from K in a
21  // single instruction.
22  using a_dtype = half_t;
23  using b_dtype = half_t;
24  using c_dtype = float;
25  static constexpr index_t k_per_instr = 2;
26 };
27 
28 template <index_t MPerThread,
29  index_t NPerThread,
30  index_t KPerThread,
31  class BaseInputType,
32  class AVecDataType,
33  class BVecDataType,
34  class CVecDataType,
35  bool ShareA>
37 {
39  using ADataType = typename datatypes_conf::a_dtype;
40  using BDataType = typename datatypes_conf::b_dtype;
41  using CDataType = typename datatypes_conf::c_dtype;
42 
43  __device__ void Run(const AVecDataType& a_vec, const BVecDataType& b_vec, CVecDataType& c_vec)
44  {
45  constexpr index_t num_c_elems_per_thread = ShareA ? MPerThread : NPerThread;
46 
47  const vector_type<ADataType, KPerThread> a_vector{a_vec};
48  const vector_type<BDataType, KPerThread> b_vector{b_vec};
49 
51  float c = c_vec.template AsType<CDataType>()(c_idx);
52  // Next `c_idx` implies that we need to pull data from the next lane.
53  constexpr index_t source_lane = c_idx;
54  static_for<0, KPerThread / datatypes_conf::k_per_instr, 1>{}([&](auto k_chunk) {
55  const auto a_k_vec = a_vector.template AsType<AVecDataType>()[k_chunk];
56  const auto b_k_vec = b_vector.template AsType<BVecDataType>()[k_chunk];
57  ck::dpp8::
58  inner_product_dpp<AVecDataType, BVecDataType, CDataType, source_lane, ShareA>(
59  a_k_vec, b_k_vec, c);
60  });
61  c_vec.template AsType<CDataType>()(c_idx) = c;
62  });
63  }
64 };
65 
66 } // namespace dpp8
67 
68 } // namespace ck
Definition: ck.hpp:267
_Float16 half_t
Definition: data_type.hpp:30
int32_t index_t
Definition: ck.hpp:298
Definition: amd_gemm_dpp.hpp:37
typename datatypes_conf::c_dtype CDataType
Definition: amd_gemm_dpp.hpp:41
typename datatypes_conf::a_dtype ADataType
Definition: amd_gemm_dpp.hpp:39
__device__ void Run(const AVecDataType &a_vec, const BVecDataType &b_vec, CVecDataType &c_vec)
Definition: amd_gemm_dpp.hpp:43
typename datatypes_conf::b_dtype BDataType
Definition: amd_gemm_dpp.hpp:40
half_t a_dtype
Definition: amd_gemm_dpp.hpp:22
half_t b_dtype
Definition: amd_gemm_dpp.hpp:23
float c_dtype
Definition: amd_gemm_dpp.hpp:24
Definition: amd_gemm_dpp.hpp:15
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10