/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp Source File
gemm_multi_abd_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 #include "ck_tile/host/concat.hpp"
17 
18 namespace ck_tile {
19 
31 template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
33 {
34  CK_TILE_HOST GemmMultiABDHostArgs(const std::array<const void*, NumATensor>& as_ptr_,
35  const std::array<const void*, NumBTensor>& bs_ptr_,
36  const std::array<const void*, NumDTensor>& ds_ptr_,
37  void* e_ptr_,
38  index_t k_batch_,
39  index_t M_,
40  index_t N_,
41  index_t K_,
42  const std::array<index_t, NumATensor>& stride_As_,
43  const std::array<index_t, NumBTensor>& stride_Bs_,
44  const std::array<index_t, NumDTensor>& stride_Ds_,
45  index_t stride_E_)
46  : as_ptr(as_ptr_),
47  bs_ptr(bs_ptr_),
48  ds_ptr(ds_ptr_),
49  e_ptr(e_ptr_),
50  M(M_),
51  N(N_),
52  K(K_),
53  stride_As(stride_As_),
54  stride_Bs(stride_Bs_),
55  stride_Ds(stride_Ds_),
56  stride_E(stride_E_),
57  k_batch(k_batch_)
58  {
59  }
60 
61  const std::array<const void*, NumATensor> as_ptr;
62  const std::array<const void*, NumBTensor> bs_ptr;
63  const std::array<const void*, NumDTensor> ds_ptr;
64  union
65  {
66  void* e_ptr;
67  void* c_ptr;
68  };
72  const std::array<index_t, NumATensor> stride_As;
73  const std::array<index_t, NumBTensor> stride_Bs;
74  const std::array<index_t, NumDTensor> stride_Ds;
75  union
76  {
79  };
80 
82 };
83 
84 template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
86 {
92 
96 
102 
108 
112  "ALayout and ADataType must be a tuple.");
113 
117  "BLayout and BDataType must be a tuple.");
118 
120  static_assert(!is_detected<is_tuple, CLayout>::value &&
122  "CLayout and EDataType must be a scalar.");
123 
127  DsLayout::size() == DsDataType::size() && DsLayout::size() > 0,
128  "DsLayout and DsDataType must be tuples and must have the same size.");
129 
131  static constexpr index_t NumATensor = AsDataType::size();
132  static constexpr index_t NumBTensor = BsDataType::size();
133  static constexpr index_t NumDTensor = DsDataType::size();
134 
138 
139  CK_TILE_HOST static auto GetName() -> const std::string
140  {
142  }
143 
144  CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3
145  {
146  return UniversalGemmKernel::GridSize(M, N, KBatch);
147  }
148 
149  CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
150  {
152  }
153 
154  CK_TILE_HOST static constexpr auto BlockSize() -> dim3
155  {
157  }
158 
159  CK_TILE_HOST static constexpr auto
162  {
167  hostArgs.bs_ptr,
168  hostArgs.ds_ptr,
169  hostArgs.e_ptr,
170  hostArgs.k_batch,
171  hostArgs.M,
172  hostArgs.N,
173  hostArgs.K,
174  hostArgs.stride_As,
175  hostArgs.stride_Bs,
176  hostArgs.stride_Ds,
177  hostArgs.stride_E));
178  }
179 
180  CK_TILE_HOST static auto
182  {
183  // Currently MultiABD kernel doesn't support k_batch > 1
184  if(kargs.k_batch > 1)
185  {
186  return false;
187  }
188  // Currently MultiABD kernel doesn't support F8 data type
189  if(ck_tile::get_device_name() == "gfx950" &&
193  {
194  return false;
195  }
196 
198  }
199 
200  CK_TILE_DEVICE auto operator()(typename UniversalGemmKernel::KernelArgs kargs) const -> void
201  {
202  UniversalGemmKernel{}.template operator()(kargs);
203  }
204 };
205 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
std::string get_device_name()
Definition: device_prop.hpp:19
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Definition: gemm_multi_abd_kernel.hpp:86
CK_TILE_DEVICE auto operator()(typename UniversalGemmKernel::KernelArgs kargs) const -> void
Definition: gemm_multi_abd_kernel.hpp:200
remove_cvref_t< typename GemmPipeline::AsLayout > AsLayout
Specify the layout configurations for A, B, E and D.
Definition: gemm_multi_abd_kernel.hpp:98
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: gemm_multi_abd_kernel.hpp:94
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: gemm_multi_abd_kernel.hpp:106
remove_cvref_t< std::tuple_element_t< 0, AsDataType > > ADataType
Definition: gemm_multi_abd_kernel.hpp:135
remove_cvref_t< std::tuple_element_t< 0, DsDataType > > DDataType
Definition: gemm_multi_abd_kernel.hpp:137
remove_cvref_t< std::tuple_element_t< 0, BsDataType > > BDataType
Definition: gemm_multi_abd_kernel.hpp:136
UniversalGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > UniversalGemmKernel
Inject the UniversalGemmKernel base class to support execution of all necessary functions.
Definition: gemm_multi_abd_kernel.hpp:90
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: gemm_multi_abd_kernel.hpp:107
static constexpr CK_TILE_HOST auto BlockSize() -> dim3
Definition: gemm_multi_abd_kernel.hpp:154
static constexpr index_t kBlockSize
Definition: gemm_multi_abd_kernel.hpp:91
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition: gemm_multi_abd_kernel.hpp:101
static constexpr CK_TILE_HOST auto MakeKernelArgs(const GemmMultiABDHostArgs< NumATensor, NumBTensor, NumDTensor > &hostArgs) -> typename UniversalGemmKernel::KernelArgs
Definition: gemm_multi_abd_kernel.hpp:160
static constexpr index_t NumDTensor
Definition: gemm_multi_abd_kernel.hpp:133
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition: gemm_multi_abd_kernel.hpp:100
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3
Definition: gemm_multi_abd_kernel.hpp:144
static constexpr index_t NumATensor
ALayout and ADataType are expected to be a tuple, not a scalar.
Definition: gemm_multi_abd_kernel.hpp:131
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: gemm_multi_abd_kernel.hpp:93
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: gemm_multi_abd_kernel.hpp:95
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Definition: gemm_multi_abd_kernel.hpp:149
static CK_TILE_HOST auto IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs &kargs) -> bool
Definition: gemm_multi_abd_kernel.hpp:181
remove_cvref_t< typename GemmPipeline::AsDataType > AsDataType
Specify the data type configurations for A, B, E and D.
Definition: gemm_multi_abd_kernel.hpp:104
static constexpr index_t NumBTensor
Definition: gemm_multi_abd_kernel.hpp:132
remove_cvref_t< typename GemmPipeline::BsLayout > BsLayout
Definition: gemm_multi_abd_kernel.hpp:99
static CK_TILE_HOST auto GetName() -> const std::string
Definition: gemm_multi_abd_kernel.hpp:139
remove_cvref_t< typename GemmPipeline::BsDataType > BsDataType
Definition: gemm_multi_abd_kernel.hpp:105
The MultiABD GEMM kernel host arguments.
Definition: gemm_multi_abd_kernel.hpp:33
const std::array< index_t, NumDTensor > stride_Ds
Definition: gemm_multi_abd_kernel.hpp:74
const std::array< const void *, NumATensor > as_ptr
Definition: gemm_multi_abd_kernel.hpp:61
index_t stride_C
Definition: gemm_multi_abd_kernel.hpp:78
index_t stride_E
Definition: gemm_multi_abd_kernel.hpp:77
void * e_ptr
Definition: gemm_multi_abd_kernel.hpp:66
index_t M
Definition: gemm_multi_abd_kernel.hpp:69
const std::array< index_t, NumATensor > stride_As
Definition: gemm_multi_abd_kernel.hpp:72
const std::array< const void *, NumDTensor > ds_ptr
Definition: gemm_multi_abd_kernel.hpp:63
index_t k_batch
Definition: gemm_multi_abd_kernel.hpp:81
const std::array< const void *, NumBTensor > bs_ptr
Definition: gemm_multi_abd_kernel.hpp:62
CK_TILE_HOST GemmMultiABDHostArgs(const std::array< const void *, NumATensor > &as_ptr_, const std::array< const void *, NumBTensor > &bs_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, const std::array< index_t, NumATensor > &stride_As_, const std::array< index_t, NumBTensor > &stride_Bs_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_E_)
Definition: gemm_multi_abd_kernel.hpp:34
index_t K
Definition: gemm_multi_abd_kernel.hpp:71
index_t N
Definition: gemm_multi_abd_kernel.hpp:70
const std::array< index_t, NumBTensor > stride_Bs
Definition: gemm_multi_abd_kernel.hpp:73
void * c_ptr
Definition: gemm_multi_abd_kernel.hpp:67
The Universal GEMM kernel host arguments.
Definition: universal_gemm_kernel.hpp:32
The GEMM kernel device arguments.
Definition: universal_gemm_kernel.hpp:86
static CK_TILE_HOST const std::string GetName()
Definition: universal_gemm_kernel.hpp:260
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: universal_gemm_kernel.hpp:267
static CK_TILE_HOST auto BlockSize()
Definition: universal_gemm_kernel.hpp:290
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: universal_gemm_kernel.hpp:278
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: universal_gemm_kernel.hpp:373
static constexpr CK_TILE_HOST KernelArgs MakeKernelArgs(const UniversalGemmHostArgs< NumATensor, NumBTensor, NumDTensor > &hostArgs)
Definition: universal_gemm_kernel.hpp:303
static constexpr index_t kBlockSize
Definition: universal_gemm_kernel.hpp:202
Definition: stream_config.hpp:30