/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp Source File
gemm_multi_d_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 #include "ck_tile/host/concat.hpp"
17 
18 namespace ck_tile {
19 
28 template <index_t NumDTensor = 1>
30 {
32  CK_TILE_HOST GemmMultiDHostArgs(const void* a_ptr_,
33  const void* b_ptr_,
34  const std::array<const void*, NumDTensor>& ds_ptr_,
35  void* e_ptr_,
36  index_t k_batch_,
37  index_t M_,
38  index_t N_,
39  index_t K_,
40  index_t stride_A_,
41  index_t stride_B_,
42  const std::array<index_t, NumDTensor>& stride_Ds_,
43  index_t stride_E_)
44  : a_ptr(a_ptr_),
45  b_ptr(b_ptr_),
46  ds_ptr(ds_ptr_),
47  e_ptr(e_ptr_),
48  M(M_),
49  N(N_),
50  K(K_),
51  stride_A(stride_A_),
52  stride_B(stride_B_),
53  stride_Ds(stride_Ds_),
54  stride_E(stride_E_),
55  k_batch(k_batch_)
56  {
57  }
58 
59  const void* a_ptr;
60  const void* b_ptr;
61  const std::array<const void*, NumDTensor> ds_ptr;
62  union
63  {
64  void* e_ptr;
65  void* c_ptr;
66  };
72  const std::array<index_t, NumDTensor> stride_Ds;
73  union
74  {
77  };
78 
80 };
81 
82 template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
84 {
90 
94 
100 
106 
108  static_assert(!is_detected<is_tuple, ALayout>::value &&
110  "ALayout and ADataType must be scalars.");
111 
113  static_assert(!is_detected<is_tuple, BLayout>::value &&
115  "BLayout and BDataType must be scalars.");
116 
118  static_assert(!is_detected<is_tuple, ELayout>::value &&
120  "ELayout and EDataType must be scalars.");
121 
125  DsLayout::size() == DsDataType::size() && DsLayout::size() > 0,
126  "DsLayout and DsDataType must be tuples and must have the same size.");
127 
130  static constexpr index_t NumATensor = 1;
131  static constexpr index_t NumBTensor = 1;
132  static constexpr index_t NumDTensor = DsDataType::size();
133 
134  CK_TILE_HOST static auto GetName() -> const std::string
135  {
137  }
138 
139  CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3
140  {
141  return UniversalGemmKernel::GridSize(M, N, KBatch);
142  }
143 
144  CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
145  {
147  }
148 
149  CK_TILE_HOST static constexpr auto BlockSize() -> dim3
150  {
152  }
153 
154  CK_TILE_HOST static constexpr auto
157  {
162  {hostArgs.b_ptr},
163  hostArgs.ds_ptr,
164  hostArgs.e_ptr,
165  hostArgs.k_batch,
166  hostArgs.M,
167  hostArgs.N,
168  hostArgs.K,
169  {hostArgs.stride_A},
170  {hostArgs.stride_B},
171  hostArgs.stride_Ds,
172  hostArgs.stride_E));
173  }
174 
175  CK_TILE_HOST static auto
177  {
178  // Currently MultiD kernel doesn't support k_batch > 1
179  if(kargs.k_batch > 1)
180  {
181  return false;
182  }
183 
185  }
186 
187  CK_TILE_DEVICE auto operator()(typename UniversalGemmKernel::KernelArgs kargs) const -> void
188  {
189  UniversalGemmKernel{}.template operator()(kargs);
190  }
191 };
192 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
Definition: gemm_multi_d_kernel.hpp:84
static constexpr index_t NumDTensor
Definition: gemm_multi_d_kernel.hpp:132
static constexpr index_t NumATensor
ALayout and ADataType are expected to be scalars, not a tuple.
Definition: gemm_multi_d_kernel.hpp:130
static CK_TILE_HOST auto IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs &kargs) -> bool
Definition: gemm_multi_d_kernel.hpp:176
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Definition: gemm_multi_d_kernel.hpp:144
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition: gemm_multi_d_kernel.hpp:99
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: gemm_multi_d_kernel.hpp:105
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: gemm_multi_d_kernel.hpp:104
static constexpr CK_TILE_HOST auto MakeKernelArgs(const GemmMultiDHostArgs< NumDTensor > &hostArgs) -> typename UniversalGemmKernel::KernelArgs
Definition: gemm_multi_d_kernel.hpp:155
UniversalGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > UniversalGemmKernel
Inject the UniversalGemmKernel base class to support execution of all necessary functions.
Definition: gemm_multi_d_kernel.hpp:88
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: gemm_multi_d_kernel.hpp:91
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3
Definition: gemm_multi_d_kernel.hpp:139
static constexpr index_t NumBTensor
Definition: gemm_multi_d_kernel.hpp:131
static CK_TILE_HOST auto GetName() -> const std::string
Definition: gemm_multi_d_kernel.hpp:134
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: gemm_multi_d_kernel.hpp:103
static constexpr index_t kBlockSize
Definition: gemm_multi_d_kernel.hpp:89
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: gemm_multi_d_kernel.hpp:92
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: gemm_multi_d_kernel.hpp:97
CK_TILE_DEVICE auto operator()(typename UniversalGemmKernel::KernelArgs kargs) const -> void
Definition: gemm_multi_d_kernel.hpp:187
remove_cvref_t< typename GemmPipeline::CLayout > ELayout
Definition: gemm_multi_d_kernel.hpp:98
static constexpr CK_TILE_HOST auto BlockSize() -> dim3
Definition: gemm_multi_d_kernel.hpp:149
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: gemm_multi_d_kernel.hpp:93
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Specify the layout configurations for A, B, E and D.
Definition: gemm_multi_d_kernel.hpp:96
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Specify the data type configurations for A, B, E and D.
Definition: gemm_multi_d_kernel.hpp:102
The MultiD GEMM kernel host arguments.
Definition: gemm_multi_d_kernel.hpp:30
void * c_ptr
Definition: gemm_multi_d_kernel.hpp:65
index_t stride_B
Definition: gemm_multi_d_kernel.hpp:71
CK_TILE_HOST GemmMultiDHostArgs(const void *a_ptr_, const void *b_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_E_)
Definition: gemm_multi_d_kernel.hpp:32
index_t k_batch
Definition: gemm_multi_d_kernel.hpp:79
const void * a_ptr
Definition: gemm_multi_d_kernel.hpp:59
index_t stride_A
Definition: gemm_multi_d_kernel.hpp:70
const std::array< const void *, NumDTensor > ds_ptr
Definition: gemm_multi_d_kernel.hpp:61
CK_TILE_HOST GemmMultiDHostArgs()=default
index_t N
Definition: gemm_multi_d_kernel.hpp:68
void * e_ptr
Definition: gemm_multi_d_kernel.hpp:64
index_t stride_C
Definition: gemm_multi_d_kernel.hpp:76
const void * b_ptr
Definition: gemm_multi_d_kernel.hpp:60
const std::array< index_t, NumDTensor > stride_Ds
Definition: gemm_multi_d_kernel.hpp:72
index_t M
Definition: gemm_multi_d_kernel.hpp:67
index_t stride_E
Definition: gemm_multi_d_kernel.hpp:75
index_t K
Definition: gemm_multi_d_kernel.hpp:69
The Universal GEMM kernel host arguments.
Definition: universal_gemm_kernel.hpp:32
The GEMM kernel device arguments.
Definition: universal_gemm_kernel.hpp:86
static CK_TILE_HOST const std::string GetName()
Definition: universal_gemm_kernel.hpp:257
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: universal_gemm_kernel.hpp:264
static CK_TILE_HOST auto BlockSize()
Definition: universal_gemm_kernel.hpp:287
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: universal_gemm_kernel.hpp:275
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: universal_gemm_kernel.hpp:370
static constexpr CK_TILE_HOST KernelArgs MakeKernelArgs(const UniversalGemmHostArgs< NumATensor, NumBTensor, NumDTensor > &hostArgs)
Definition: universal_gemm_kernel.hpp:300
static constexpr index_t kBlockSize
Definition: universal_gemm_kernel.hpp:199
Definition: stream_config.hpp:30