98 template <ck_tile::index_t NumDTensor = 0>
 
   99 struct BatchedContractionHostArgs
 
  117     BatchedContractionHostArgs(
 
  120         const std::array<const void*, NumDTensor>& ds_ptr_,
 
  123         const std::vector<ck_tile::index_t>& A_dims_, 
 
  124         const std::vector<ck_tile::index_t>& B_dims_, 
 
  125         const std::array<std::vector<ck_tile::index_t>, NumDTensor>&
 
  127         const std::vector<ck_tile::index_t>& E_dims_, 
 
  129         const std::vector<ck_tile::index_t>& A_strides_, 
 
  130         const std::vector<ck_tile::index_t>& B_strides_, 
 
  131         const std::array<std::vector<ck_tile::index_t>, NumDTensor>&
 
  133         const std::vector<ck_tile::index_t>&
 
  145           A_strides(A_strides_),
 
  146           B_strides(B_strides_),
 
  147           Ds_strides(Ds_strides_),
 
  148           E_strides(E_strides_)
 
  154     std::array<const void*, NumDTensor> ds_ptr; 
 
  157     const std::vector<ck_tile::index_t>
 
  159     const std::vector<ck_tile::index_t>
 
  161     const std::array<std::vector<ck_tile::index_t>, NumDTensor>
 
  163     const std::vector<ck_tile::index_t>
 
  165     const std::vector<ck_tile::index_t>
 
  167     const std::vector<ck_tile::index_t>
 
  169     const std::array<std::vector<ck_tile::index_t>, NumDTensor>
 
  171     const std::vector<ck_tile::index_t>
 
  188 struct BatchedContractionKernelArgs
 
  192     std::array<const void*, NumDTensor> ds_ptr; 
 
  206     std::array<ck_tile::index_t, NumDTensor> batch_stride_Ds; 
 
  215     std::array<ck_tile::index_t, NumDTensor>
 
  233 template <
typename Problem_,
 
  234           typename TilePartitioner_,
 
  235           typename GemmPipeline_,
 
  236           typename EpiloguePipeline_>
 
  237 struct BatchedContractionKernel
 
  263     using TilePartitioner =
 
  267     using EpiloguePipeline =
 
  271     using UniversalGemmKernel =
 
  278         BatchedContractionKernelArgs<NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor>; 
 
  284     CK_TILE_HOST static constexpr 
auto GetKernelName() { 
return "batched_contraction_kernel"; }
 
  290     CK_TILE_HOST static constexpr 
bool IsSupportedArguments(
const KernelArgs& kargs)
 
  323     CK_TILE_HOST static constexpr 
auto GridSize(
const KernelArgs& kargs)
 
  326             TilePartitioner::GridSize(kargs.M_total, kargs.N_total), kargs.G_total, kargs.k_batch);
 
  330     MakeKernelArgs(
const BatchedContractionHostArgs<NumDTensor>& host_args)
 
  332         const auto expected_A_dims = NumDimG + NumDimM + NumDimK;
 
  333         const auto expected_B_dims = NumDimG + NumDimN + NumDimK;
 
  334         const auto expected_E_dims = NumDimG + NumDimM + NumDimN;
 
  336         if(host_args.A_dims.size() != expected_A_dims ||
 
  337            host_args.A_strides.size() != expected_A_dims)
 
  339             throw std::invalid_argument(
"A dimension size mismatch");
 
  341         if(host_args.B_dims.size() != expected_B_dims ||
 
  342            host_args.B_strides.size() != expected_B_dims)
 
  344             throw std::invalid_argument(
"B dimension size mismatch");
 
  346         if(host_args.E_dims.size() != expected_E_dims ||
 
  347            host_args.E_strides.size() != expected_E_dims)
 
  349             throw std::invalid_argument(
"E dimension size mismatch");
 
  354             if(host_args.Ds_dims[d].size() != expected_E_dims ||
 
  355                host_args.Ds_strides[d].size() != expected_E_dims)
 
  357                 throw std::invalid_argument(
"D dimension size mismatch");
 
  362         kargs.a_ptr   = host_args.a_ptr;
 
  363         kargs.b_ptr   = host_args.b_ptr;
 
  364         kargs.ds_ptr  = host_args.ds_ptr;
 
  365         kargs.e_ptr   = host_args.e_ptr;
 
  366         kargs.k_batch = host_args.k_batch;
 
  372             if(host_args.A_dims[i] != host_args.B_dims[i] ||
 
  373                host_args.A_dims[i] != host_args.E_dims[i])
 
  375                 throw std::invalid_argument(
 
  376                     "All tensors must have identical G dimensions for valid contraction");
 
  380             kargs.G_dims[i] = host_args.A_dims[i];
 
  384         kargs.batch_stride_A = host_args.A_strides[NumDimG - 1];
 
  385         kargs.batch_stride_B = host_args.B_strides[NumDimG - 1];
 
  386         kargs.batch_stride_E = host_args.E_strides[NumDimG - 1];
 
  390             kargs.M_dims[i] = host_args.A_dims[NumDimG + i];
 
  391             if(kargs.M_dims[i] != host_args.E_dims[NumDimG + i])
 
  393                 throw std::invalid_argument(
"M dimension mismatch between A and E tensors");
 
  398             kargs.N_dims[i] = host_args.B_dims[NumDimG + i];
 
  399             if(kargs.N_dims[i] != host_args.E_dims[NumDimG + NumDimM + i])
 
  401                 throw std::invalid_argument(
"N dimension mismatch between B and E tensors");
 
  406             kargs.K_dims[i] = host_args.A_dims[NumDimG + NumDimM + i];
 
  407             if(kargs.K_dims[i] != host_args.B_dims[NumDimG + NumDimN + i])
 
  409                 throw std::invalid_argument(
"K dimension mismatch between A and B tensors");
 
  417             kargs.G_total *= kargs.G_dims[i];
 
  423             kargs.M_total *= kargs.M_dims[i];
 
  429             kargs.N_total *= kargs.N_dims[i];
 
  435             kargs.K_total *= kargs.K_dims[i];
 
  438         kargs.stride_A = kargs.K_total;
 
  439         kargs.stride_B = kargs.K_total;
 
  440         kargs.stride_E = kargs.N_total;
 
  447                 if(host_args.Ds_dims[d][i] != host_args.A_dims[i])
 
  449                     throw std::invalid_argument(
 
  450                         "D tensor G dimensions must match A/B/E tensor G dimensions");
 
  454             kargs.batch_stride_Ds[d] = host_args.Ds_strides[d][NumDimG - 1];
 
  455             kargs.stride_Ds[d]       = kargs.N_total; 
 
  464         const auto [iM, iN] =
 
  465             TilePartitioner{kargs.M_total, kargs.N_total}.GetOutputTileIndex(blockIdx.x);
 
  467             __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
 
  469             __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
 
  471         const auto i_batch_flat = __builtin_amdgcn_readfirstlane(blockIdx.y);
 
  472         const auto i_splitk     = __builtin_amdgcn_readfirstlane(blockIdx.z);
 
  475         const auto batch_offset_A = i_batch_flat * kargs.batch_stride_A;
 
  476         const auto batch_offset_B = i_batch_flat * kargs.batch_stride_B;
 
  477         const auto batch_offset_E = i_batch_flat * kargs.batch_stride_E;
 
  479         const ADataType* a_ptr = 
static_cast<const ADataType*
>(kargs.a_ptr) + batch_offset_A;
 
  480         const BDataType* b_ptr = 
static_cast<const BDataType*
>(kargs.b_ptr) + batch_offset_B;
 
  481         EDataType* e_ptr       = 
static_cast<EDataType*
>(kargs.e_ptr) + batch_offset_E;
 
  483         std::array<const void*, NumDTensor> ds_batch_ptr;
 
  484         static_for<0, NumDTensor, 1>{}([&](
auto i) {
 
  485             using DDataType           = 
typename std::tuple_element<i.value, DsDataType>::type;
 
  486             const auto batch_offset_D = i_batch_flat * kargs.batch_stride_Ds[i];
 
  487             ds_batch_ptr[i] = 
static_cast<const DDataType*
>(kargs.ds_ptr[i]) + batch_offset_D;
 
  503         const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(gemm_kargs,
 
  506         const ADataType* a_ptr_final = a_ptr + splitk_batch_offset.as_k_split_offset[0];
 
  507         const BDataType* b_ptr_final = b_ptr + splitk_batch_offset.bs_k_split_offset[0];
 
  508         __shared__ 
char smem_ptr[GetSmemSize()];
 
#define CK_TILE_DEVICE
Definition: config.hpp:41
 
#define CK_TILE_HOST
Definition: config.hpp:40
 
Definition: cluster_descriptor.hpp:13
 
int32_t index_t
Definition: integer.hpp:9
 
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
 
GemmPipeline
Definition: gemm_pipelines.hpp:9
 
The Universal GEMM kernel template.
Definition: universal_gemm_kernel.hpp:154
 
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_0, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: universal_gemm_kernel.hpp:955
 
UniversalGemmKernelArgs< AsLayout::size(), BsLayout::size(), DsLayout::size()> KernelArgs
Definition: universal_gemm_kernel.hpp:258
 
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: universal_gemm_kernel.hpp:373
 
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: universal_gemm_kernel.hpp:319
 
static constexpr index_t kBlockSize
Definition: universal_gemm_kernel.hpp:202