/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp Source File
gemm_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 #include "ck_tile/host/concat.hpp"
17 
18 namespace ck_tile {
19 
29 {
31  CK_TILE_HOST GemmHostArgs(const void* a_ptr_,
32  const void* b_ptr_,
33  void* e_ptr_,
34  index_t k_batch_,
35  index_t M_,
36  index_t N_,
37  index_t K_,
38  index_t stride_A_,
39  index_t stride_B_,
40  index_t stride_E_)
41  : a_ptr(a_ptr_),
42  b_ptr(b_ptr_),
43  e_ptr(e_ptr_),
44  M(M_),
45  N(N_),
46  K(K_),
47  stride_A(stride_A_),
48  stride_B(stride_B_),
49  stride_E(stride_E_),
50  k_batch(k_batch_)
51  {
52  }
53 
54  const void* a_ptr;
55  const void* b_ptr;
56  union
57  {
58  void* e_ptr;
59  void* c_ptr;
60  };
61 
67 
68  union
69  {
72  };
73 
75 };
76 
77 template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
78 struct GemmKernel
79 {
84 
88 
93 
98 
100  static_assert(
102  "ALayout and ADataType must be scalars. Multiple parameters are not currently supported.");
103 
105  static_assert(
107  "BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
108 
110  static_assert(!is_detected<is_tuple, ELayout>::value &&
112  "C/ELayout and C/EDataType must be scalars.");
113 
114  static constexpr index_t NumATensor = 1;
115  static constexpr index_t NumBTensor = 1;
117 
118  CK_TILE_HOST static auto GetName() -> const std::string
119  {
121  }
122 
123  CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3
124  {
125  return UniversalGemmKernel::GridSize(M, N, KBatch);
126  }
127 
128  CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
129  {
131  }
132 
133  CK_TILE_HOST static constexpr auto BlockSize() -> dim3
134  {
136  }
137 
138  CK_TILE_HOST static constexpr auto MakeKernelArgs(const GemmHostArgs& hostArgs) ->
140  {
145  {hostArgs.a_ptr},
146  {hostArgs.b_ptr},
147  {/*hostArgs.ds_ptr*/},
148  hostArgs.e_ptr,
149  hostArgs.k_batch,
150  hostArgs.M,
151  hostArgs.N,
152  hostArgs.K,
153  {hostArgs.stride_A},
154  {hostArgs.stride_B},
155  {/*hostArgs.stride_Ds*/},
156  hostArgs.stride_E));
157  }
158 
159  CK_TILE_HOST static auto
161  {
163  }
164 
165  CK_TILE_DEVICE auto operator()(typename UniversalGemmKernel::KernelArgs kargs) const -> void
166  {
167  UniversalGemmKernel{}.template operator()(kargs);
168  }
169 };
170 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
The GEMM kernel host arguments.
Definition: gemm_kernel.hpp:29
CK_TILE_HOST GemmHostArgs()=default
void * c_ptr
Definition: gemm_kernel.hpp:59
index_t stride_E
Definition: gemm_kernel.hpp:70
index_t stride_B
Definition: gemm_kernel.hpp:66
index_t stride_C
Definition: gemm_kernel.hpp:71
void * e_ptr
Definition: gemm_kernel.hpp:58
index_t K
Definition: gemm_kernel.hpp:64
index_t M
Definition: gemm_kernel.hpp:62
CK_TILE_HOST GemmHostArgs(const void *a_ptr_, const void *b_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_E_)
Definition: gemm_kernel.hpp:31
index_t stride_A
Definition: gemm_kernel.hpp:65
const void * a_ptr
Definition: gemm_kernel.hpp:54
const void * b_ptr
Definition: gemm_kernel.hpp:55
index_t N
Definition: gemm_kernel.hpp:63
index_t k_batch
Definition: gemm_kernel.hpp:74
Definition: gemm_kernel.hpp:79
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: gemm_kernel.hpp:97
static constexpr CK_TILE_HOST auto BlockSize() -> dim3
Definition: gemm_kernel.hpp:133
UniversalGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > UniversalGemmKernel
Inject the UniversalGemmKernel base class to support execution of all necessary functions.
Definition: gemm_kernel.hpp:83
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Specify the data type configurations for A, B, E and D.
Definition: gemm_kernel.hpp:95
remove_cvref_t< typename GemmPipeline::CLayout > ELayout
Definition: gemm_kernel.hpp:92
static constexpr CK_TILE_HOST auto MakeKernelArgs(const GemmHostArgs &hostArgs) -> typename UniversalGemmKernel::KernelArgs
Definition: gemm_kernel.hpp:138
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3
Definition: gemm_kernel.hpp:123
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Specify the layout configurations for A, B, E and D.
Definition: gemm_kernel.hpp:90
static constexpr index_t NumBTensor
Definition: gemm_kernel.hpp:115
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: gemm_kernel.hpp:96
static constexpr index_t kBlockSize
Definition: gemm_kernel.hpp:116
CK_TILE_DEVICE auto operator()(typename UniversalGemmKernel::KernelArgs kargs) const -> void
Definition: gemm_kernel.hpp:165
static constexpr index_t NumATensor
ALayout and ADataType are expected to be scalars, not a tuple.
Definition: gemm_kernel.hpp:114
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Definition: gemm_kernel.hpp:128
static CK_TILE_HOST auto IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs &kargs) -> bool
Definition: gemm_kernel.hpp:160
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: gemm_kernel.hpp:86
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: gemm_kernel.hpp:91
static CK_TILE_HOST auto GetName() -> const std::string
Definition: gemm_kernel.hpp:118
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: gemm_kernel.hpp:85
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: gemm_kernel.hpp:87
The Universal GEMM kernel host arguments.
Definition: universal_gemm_kernel.hpp:32
The GEMM kernel device arguments.
Definition: universal_gemm_kernel.hpp:86
static CK_TILE_HOST const std::string GetName()
Definition: universal_gemm_kernel.hpp:257
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: universal_gemm_kernel.hpp:264
static CK_TILE_HOST auto BlockSize()
Definition: universal_gemm_kernel.hpp:287
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: universal_gemm_kernel.hpp:275
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: universal_gemm_kernel.hpp:370
static constexpr CK_TILE_HOST KernelArgs MakeKernelArgs(const UniversalGemmHostArgs< NumATensor, NumBTensor, NumDTensor > &hostArgs)
Definition: universal_gemm_kernel.hpp:300
static constexpr index_t kBlockSize
Definition: universal_gemm_kernel.hpp:199
Definition: stream_config.hpp:30