30 template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
34 const std::array<const void*, NumBTensor>& bs_ptr_,
35 const std::array<const void*, NumDTensor>& ds_ptr_,
41 const std::array<index_t, NumATensor>& stride_As_,
42 const std::array<index_t, NumBTensor>& stride_Bs_,
43 const std::array<index_t, NumDTensor>& stride_Ds_,
60 const std::array<const void*, NumATensor>
as_ptr;
61 const std::array<const void*, NumBTensor>
bs_ptr;
62 const std::array<const void*, NumDTensor>
ds_ptr;
84 template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
88 const std::array<const void*, NumATensor>
as_ptr;
90 const std::array<const void*, NumBTensor>
bs_ptr;
92 const std::array<const void*, NumDTensor>
ds_ptr;
152 template <
typename TilePartitioner_,
typename GemmPipeline_,
typename EpiloguePipeline_>
204 template <
typename T>
207 static constexpr
bool value = []() {
209 return GemmPipeline::UsePersistentKernel;
219 template <
typename T,
typename KernelArgs>
221 decltype(T::GetOutputOffset(std::declval<KernelArgs>(), std::declval<index_t>()));
223 static constexpr
bool value = []() {
245 static_assert(AsLayout::size() == AsDataType::size(),
246 "The size of AsLayout and AsDataType should be the same");
248 static_assert(BsLayout::size() == BsDataType::size(),
249 "The size of BsLayout and BsDataType should be the same");
251 static_assert(DsLayout::size() == DsDataType::size(),
252 "The size of DsLayout and DsDataType should be the same");
260 return concat(
'_',
"gemm", gemm_prec_str<ADataType, BDataType>(), GemmPipeline::GetName());
266 return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
278 const auto kernel = kentry<1, Kernel, KernelArgs>;
281 hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel,
BlockSize().x, 0));
283 const int grid_size = get_available_compute_units(s) * occupancy;
284 return dim3(grid_size, 1, 1);
289 if(ck_tile::is_wave32())
318 return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
325 constexpr
auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(
number<2>{});
326 const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.
k_batch * K1);
327 const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.
K + K_t - 1) / K_t * K1);
331 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, AiLayout>)
335 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, AiLayout>)
338 __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.
stride_As[index]);
344 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BiLayout>)
347 __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.
stride_Bs[index]);
349 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BiLayout>)
357 splitted_k = __builtin_amdgcn_readfirstlane(KRead);
372 if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
385 const auto vectorSizeA = is_wave32() ? GemmPipeline::template GetVectorSizeA<true>()
386 : GemmPipeline::template GetVectorSizeA<false>();
387 bool AsTesnorIsValid = {
true};
390 if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
392 if(kargs.
K % (TilePartitioner::KPerBlock * kargs.
k_batch) != 0 &&
393 GemmPipeline::kPadK ==
false)
398 "Can't support K that is not a multiple of k_batch * KPerBlock "
401 AsTesnorIsValid =
false;
403 if(kargs.
K % vectorSizeA != 0)
407 CK_TILE_ERROR(
"K is not a multiple of vector load size for A tensor!");
409 AsTesnorIsValid =
false;
414 if(kargs.
M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM ==
false)
419 "Can't support M that is not a multiple of MPerBlock without padding!");
421 AsTesnorIsValid =
false;
423 if(kargs.
M % vectorSizeA != 0)
427 CK_TILE_ERROR(
"M is not a multiple of vector load size for A tensor!");
429 AsTesnorIsValid =
false;
434 bool BsTesnorIsValid = {
true};
435 const auto vectorSizeB = is_wave32() ? GemmPipeline::template GetVectorSizeB<true>()
436 : GemmPipeline::template GetVectorSizeB<false>();
439 if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::RowMajor>)
441 if(kargs.
N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN ==
false)
446 "Can't support N that is not a multiple of NPerBlock without padding!");
448 BsTesnorIsValid =
false;
450 if(kargs.
N % vectorSizeB != 0)
454 CK_TILE_ERROR(
"N is not a multiple of vector load size for B tensor!");
456 BsTesnorIsValid =
false;
461 if(kargs.
K % (TilePartitioner::KPerBlock * kargs.
k_batch) != 0 &&
462 GemmPipeline::kPadK ==
false)
467 "Can't support K that is not a multiple of k_batch * KPerBlock "
470 BsTesnorIsValid =
false;
472 if(kargs.
K % vectorSizeB != 0)
476 CK_TILE_ERROR(
"K is not a multiple of vector load size for B tensor!");
478 BsTesnorIsValid =
false;
483 bool DTesnorIsValid = {
true};
486 if(std::is_same_v<DiLayout, ELayout> ==
false)
488 DTesnorIsValid =
false;
490 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
492 if(kargs.
N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN ==
false)
496 CK_TILE_ERROR(
"Can't support N for tensor D that is not a multiple of "
497 "NPerBlock without padding!");
499 DTesnorIsValid =
false;
501 if(kargs.
N % EpiloguePipeline::GetVectorSizeD(index) != 0)
505 CK_TILE_ERROR(
"N is not a multiple of vector load size for D tensor!");
507 DTesnorIsValid =
false;
512 if(kargs.
M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM ==
false)
516 CK_TILE_ERROR(
"Can't support M for tensor D that is not a multiple of "
517 "MPerBlock without padding!");
519 DTesnorIsValid =
false;
521 if(kargs.
M % EpiloguePipeline::GetVectorSizeD(index) != 0)
525 CK_TILE_ERROR(
"M is not a multiple of vector load size for D tensor!");
527 DTesnorIsValid =
false;
532 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
534 if(kargs.
N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN ==
false)
539 "Can't support N that is not a multiple of NPerBlock without padding!");
543 if(kargs.
N % EpiloguePipeline::GetVectorSizeC() != 0)
547 CK_TILE_ERROR(
"N is not a multiple of vector load size for C tensor!");
554 if(kargs.
M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM ==
false)
559 "Can't support M that is not a multiple of MPerBlock without padding!");
563 if(kargs.
M % EpiloguePipeline::GetVectorSizeC() != 0)
567 CK_TILE_ERROR(
"M is not a multiple of vector load size for C tensor!");
572 return AsTesnorIsValid && BsTesnorIsValid && DTesnorIsValid;
575 template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
578 const std::array<const BDataType*, NumBTensor>& bs_ptr,
579 const std::array<const void*, NumDTensor>& ds_ptr,
584 static_assert(!TilePartitioner::BlockGemmShape::PermuteA,
"Not implemented!");
590 if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
592 return make_naive_tensor_view<address_space_enum::global>(
593 static_cast<const AiDataType*
>(as_ptr[i]),
596 number<GemmPipeline::GetVectorSizeA()>{},
601 return make_naive_tensor_view<address_space_enum::global>(
602 static_cast<const AiDataType*
>(as_ptr[i]),
605 number<GemmPipeline::GetVectorSizeA()>{},
615 if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::RowMajor>)
617 if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
619 constexpr
index_t K1 = GemmPipeline::GetSmemPackB();
621 constexpr
index_t VectorSizeB =
622 std::min(K1, GemmPipeline::GetVectorSizeB());
623 const auto b_k0_n_k1_desc =
634 return make_tensor_view<address_space_enum::global>(
635 static_cast<const BiDataType*
>(bs_ptr[i]), b_n_k_desc);
639 return make_naive_tensor_view<address_space_enum::global>(
643 number<GemmPipeline::GetVectorSizeB()>{},
649 if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
651 constexpr
index_t K1 = GemmPipeline::GetSmemPackB();
653 constexpr
index_t VectorSizeB =
654 std::min(K1, GemmPipeline::GetVectorSizeB());
655 const auto b_k0_n_k1_desc =
666 return make_tensor_view<address_space_enum::global>(
667 static_cast<const BiDataType*
>(bs_ptr[i]), b_n_k_desc);
671 if constexpr(GemmPipeline::Preshuffle)
674 GemmPipeline::BlockGemmShape::flatKPerWarp *
676 TilePartitioner::BlockGemmShape::WarpTile::at(
number<2>{}));
677 index_t kFlatN = kargs.
N * kargs.
K / kFlatK;
679 return make_naive_tensor_view<address_space_enum::global>(
683 number<GemmPipeline::GetVectorSizeB()>{},
688 return make_naive_tensor_view<address_space_enum::global>(
692 number<GemmPipeline::GetVectorSizeB()>{},
704 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
706 return make_naive_tensor_view<address_space_enum::global>(
707 static_cast<const DDataType_*
>(ds_ptr[i]),
710 number<EpiloguePipeline::GetVectorSizeD(i)>{},
715 return make_naive_tensor_view<address_space_enum::global>(
716 static_cast<const DDataType_*
>(ds_ptr[i]),
719 number<EpiloguePipeline::GetVectorSizeD(i)>{},
726 const auto& e_tensor_view = [&]() {
727 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
729 return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
733 number<EpiloguePipeline::GetVectorSizeC()>{},
738 return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
747 return make_tuple(as_tensor_view, bs_tensor_view, ds_tensor_view, e_tensor_view);
750 template <
typename TensorView>
755 const auto& a_tensor_view = views.at(
I0);
757 if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
774 const auto& b_flat_pad_view = views.at(
I1);
778 const auto& b_tensor_view = views.at(
I1);
780 if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::ColumnMajor>)
799 const auto& d_tensor_view = views.at(
I2);
801 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
819 const auto& e_pad_view = [&]() {
820 const auto& e_tensor_view = views.at(
I3);
821 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
837 if constexpr(GemmPipeline::Preshuffle)
840 return make_tuple(as_pad_view, b_flat_pad_view, ds_pad_view, e_pad_view);
844 return make_tuple(as_pad_view, bs_pad_view, ds_pad_view, e_pad_view);
848 template <
typename PadView>
852 const auto& as_pad_view = views.at(
I0);
853 const auto& bs_pad_view = views.at(
I1);
854 const auto& ds_pad_view = views.at(
I2);
855 const auto& e_pad_view = views.at(
I3);
860 if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
880 if constexpr(GemmPipeline::Preshuffle)
886 {
static_cast<int>(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(
I1)),
891 if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::ColumnMajor>)
912 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
934 return make_tuple(as_block_window, bs_block_window, ds_block_window, e_block_window);
951 template <
bool UseDefaultScheduler = true>
953 const std::array<const BDataType*, NumBTensor>& bs_ptr,
954 const std::array<const void*, NumDTensor>& ds_ptr,
963 const auto& gemm_tensor_views_tuple =
964 MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
965 as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
970 const index_t num_loop = __builtin_amdgcn_readfirstlane(
971 TilePartitioner::GetLoopNum(splitk_batch_offset.
splitted_k));
974 const auto& as_block_window = gemm_tile_windows.at(
I0);
975 const auto& bs_block_window = gemm_tile_windows.at(
I1);
976 const auto& ds_block_window = gemm_tile_windows.at(
I2);
978 const auto& c_block_tile =
979 GemmPipeline{}(as_block_window[
I0], bs_block_window[
I0], num_loop, smem_ptr_0);
981 if(UseDefaultScheduler || (get_warp_id() == 0))
984 auto& c_block_window = gemm_tile_windows.at(
I3);
986 EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
1008 const std::array<const BDataType*, NumBTensor>& bs_ptr,
1009 const std::array<const void*, NumDTensor>& ds_ptr,
1011 void* __restrict__ smem_ptr_0,
1012 void* __restrict__ smem_ptr_1,
1019 const auto& gemm_tensor_views_tuple =
1020 MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
1021 as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
1026 const index_t num_loop = __builtin_amdgcn_readfirstlane(
1027 TilePartitioner::GetLoopNum(splitk_batch_offset.
splitted_k));
1030 const auto& as_block_window = gemm_tile_windows.at(
I0);
1031 const auto& bs_block_window = gemm_tile_windows.at(
I1);
1032 const auto& ds_block_window = gemm_tile_windows.at(
I2);
1035 as_block_window[
I0], bs_block_window[
I0], num_loop, smem_ptr_0, smem_ptr_1);
1038 auto& c_block_window = gemm_tile_windows.at(
I3);
1040 EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
1044 template <
bool U = !PersistentKernel,
typename = std::enable_if_t<U>>
1047 const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x);
1048 const auto [iM, iN] =
TilePartitioner{kargs.
M, kargs.
N}.GetOutputTileIndex(blockId);
1049 const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
1050 const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
1055 std::array<const ADataType*, NumATensor> as_ptr;
1061 std::array<const BDataType*, NumBTensor> bs_ptr;
1071 const index_t output_offset = TilePartitioner::GetOutputOffset(kargs, blockIdx.z);
1072 e_ptr += output_offset;
1078 if constexpr(GemmPipeline::DoubleSmemBuffer ==
true)
1082 EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
1092 splitk_batch_offset,
1100 EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
1103 constexpr
auto scheduler_type = (GemmPipeline::NumWaveGroups == 1);
1104 RunGemm<scheduler_type>(as_ptr,
1110 splitk_batch_offset,
1118 template <
bool U = PersistentKernel,
typename = std::enable_if_t<U>,
typename =
void>
1121 const auto grid_size = __builtin_amdgcn_readfirstlane(
get_grid_size());
1122 const auto num_tiles =
1123 __builtin_amdgcn_readfirstlane(TilePartitioner::GridSize(kargs.
M, kargs.
N));
1124 const auto num_work = __builtin_amdgcn_readfirstlane(num_tiles * kargs.
k_batch);
1125 auto block_id = __builtin_amdgcn_readfirstlane(get_block_id());
1127 while(block_id < num_work)
1130 const auto tile_idx = __builtin_amdgcn_readfirstlane(block_id % num_tiles);
1131 const auto [iM, iN] =
TilePartitioner{kargs.
M, kargs.
N}.GetOutputTileIndex(tile_idx);
1132 const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
1133 const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
1136 const auto k_batch = __builtin_amdgcn_readfirstlane(block_id / num_tiles);
1139 std::array<const ADataType*, NumATensor> as_ptr;
1145 std::array<const BDataType*, NumBTensor> bs_ptr;
1155 const index_t output_offset = TilePartitioner::GetOutputOffset(kargs, k_batch);
1156 e_ptr += output_offset;
1162 if constexpr(GemmPipeline::DoubleSmemBuffer ==
true)
1165 if constexpr(!(EpiloguePipeline::MemoryOperation ==
1167 EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
1177 splitk_batch_offset,
1184 if constexpr(!(EpiloguePipeline::MemoryOperation ==
1186 EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
1195 splitk_batch_offset,
1201 block_id += grid_size;
1202 if(block_id >= num_work)
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:268
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1615
CK_TILE_HOST void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:13
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:197
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
__device__ index_t get_grid_size()
Definition: get_id.hpp:60
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__device__ X atomic_add(X *p_dst, const X &x)
unsigned int uint32_t
Definition: stdint.h:126
The Universal GEMM kernel host arguments.
Definition: universal_gemm_kernel.hpp:32
void * c_ptr
Definition: universal_gemm_kernel.hpp:66
const std::array< index_t, NumDTensor > stride_Ds
Definition: universal_gemm_kernel.hpp:73
const std::array< index_t, NumBTensor > stride_Bs
Definition: universal_gemm_kernel.hpp:72
CK_TILE_HOST UniversalGemmHostArgs(const std::array< const void *, NumATensor > &as_ptr_, const std::array< const void *, NumBTensor > &bs_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, const std::array< index_t, NumATensor > &stride_As_, const std::array< index_t, NumBTensor > &stride_Bs_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_E_)
Definition: universal_gemm_kernel.hpp:33
index_t K
Definition: universal_gemm_kernel.hpp:70
void * e_ptr
Definition: universal_gemm_kernel.hpp:65
index_t M
Definition: universal_gemm_kernel.hpp:68
const std::array< const void *, NumDTensor > ds_ptr
Definition: universal_gemm_kernel.hpp:62
const std::array< const void *, NumATensor > as_ptr
Definition: universal_gemm_kernel.hpp:60
const std::array< index_t, NumATensor > stride_As
Definition: universal_gemm_kernel.hpp:71
index_t N
Definition: universal_gemm_kernel.hpp:69
index_t stride_E
Definition: universal_gemm_kernel.hpp:76
const std::array< const void *, NumBTensor > bs_ptr
Definition: universal_gemm_kernel.hpp:61
index_t stride_C
Definition: universal_gemm_kernel.hpp:77
index_t k_batch
Definition: universal_gemm_kernel.hpp:80
Definition: universal_gemm_kernel.hpp:322
std::array< index_t, NumATensor > as_k_split_offset
Definition: universal_gemm_kernel.hpp:365
index_t splitted_k
Definition: universal_gemm_kernel.hpp:367
__device__ SplitKBatchOffset(const KernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: universal_gemm_kernel.hpp:323
std::array< index_t, NumBTensor > bs_k_split_offset
Definition: universal_gemm_kernel.hpp:366
Definition: universal_gemm_kernel.hpp:203
static constexpr bool value
Definition: universal_gemm_kernel.hpp:207
decltype(T::UsePersistentKernel) has_persistent_type
Definition: universal_gemm_kernel.hpp:205
Definition: universal_gemm_kernel.hpp:218
decltype(T::GetOutputOffset(std::declval< KernelArgs >(), std::declval< index_t >())) has_get_output_offset_t
Definition: universal_gemm_kernel.hpp:221
static constexpr bool value
Definition: universal_gemm_kernel.hpp:223
The GEMM kernel device arguments.
Definition: universal_gemm_kernel.hpp:86
void * e_ptr
The E output tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:94
std::array< index_t, NumBTensor > stride_Bs
The distance between consecutive elements of non-contiguous dimension (in memory) of Bs tensor.
Definition: universal_gemm_kernel.hpp:106
const std::array< const void *, NumDTensor > ds_ptr
The Ds input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:92
std::array< index_t, NumATensor > stride_As
The distance between consecutive elements of non-contiguous dimension (in memory) of As tensor.
Definition: universal_gemm_kernel.hpp:103
const std::array< const void *, NumATensor > as_ptr
The As input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:88
index_t k_batch
Definition: universal_gemm_kernel.hpp:113
index_t N
GEMM's N dimension size.
Definition: universal_gemm_kernel.hpp:98
index_t stride_E
The distance between consecutive elements of non-contiguous dimension (in memory) of E tensor.
Definition: universal_gemm_kernel.hpp:112
index_t K
GEMM's K dimension size.
Definition: universal_gemm_kernel.hpp:100
const std::array< const void *, NumBTensor > bs_ptr
The Bs input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:90
std::array< index_t, NumDTensor > stride_Ds
The distance between consecutive elements of non-contiguous dimension (in memory) of Ds tensor.
Definition: universal_gemm_kernel.hpp:109
index_t M
GEMM's M dimension size.
Definition: universal_gemm_kernel.hpp:96
The Universal GEMM kernel template.
Definition: universal_gemm_kernel.hpp:154
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
Definition: universal_gemm_kernel.hpp:1045
std::conditional_t< BDataTypeIsTuple, remove_cvref_t< typename GemmPipeline::BDataType >, remove_cvref_t< tuple< typename GemmPipeline::BDataType > >> BsDataType
Definition: universal_gemm_kernel.hpp:189
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: universal_gemm_kernel.hpp:156
static CK_TILE_HOST const std::string GetName()
Definition: universal_gemm_kernel.hpp:257
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: universal_gemm_kernel.hpp:155
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
Definition: universal_gemm_kernel.hpp:1119
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_0, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: universal_gemm_kernel.hpp:952
static constexpr bool BDataTypeIsTuple
Definition: universal_gemm_kernel.hpp:161
static constexpr auto I2
Definition: universal_gemm_kernel.hpp:235
static constexpr bool BLayoutIsTuple
Definition: universal_gemm_kernel.hpp:167
static CK_TILE_DEVICE auto MakeGemmTensorViews(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: universal_gemm_kernel.hpp:577
std::conditional_t< BLayoutIsTuple, remove_cvref_t< typename GemmPipeline::BLayout >, remove_cvref_t< tuple< typename GemmPipeline::BLayout > >> BsLayout
Definition: universal_gemm_kernel.hpp:177
static constexpr index_t NumATensor
Definition: universal_gemm_kernel.hpp:238
static constexpr bool ALayoutIsTuple
Definition: universal_gemm_kernel.hpp:165
remove_cvref_t< std::tuple_element_t< I0, AsDataType > > ADataType
Definition: universal_gemm_kernel.hpp:242
static CK_TILE_DEVICE void RunGemm2LDS(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: universal_gemm_kernel.hpp:1007
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: universal_gemm_kernel.hpp:850
static constexpr auto I3
Definition: universal_gemm_kernel.hpp:236
std::conditional_t< DDataTypeIsTuple, remove_cvref_t< typename EpiloguePipeline::DsDataType >, remove_cvref_t< tuple< typename EpiloguePipeline::DsDataType > >> DsDataType
Definition: universal_gemm_kernel.hpp:194
static constexpr bool ADataTypeIsTuple
Definition: universal_gemm_kernel.hpp:159
static constexpr bool has_tile_partitioner_output_offset
Definition: universal_gemm_kernel.hpp:230
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: universal_gemm_kernel.hpp:751
remove_cvref_t< typename GemmPipeline::CLayout > ELayout
Definition: universal_gemm_kernel.hpp:196
static constexpr index_t NumDTensor
Definition: universal_gemm_kernel.hpp:240
UniversalGemmKernelArgs< AsLayout::size(), BsLayout::size(), DsLayout::size()> KernelArgs
Definition: universal_gemm_kernel.hpp:255
static constexpr bool DDataTypeIsTuple
Definition: universal_gemm_kernel.hpp:163
static constexpr bool PersistentKernel
Definition: universal_gemm_kernel.hpp:214
static constexpr auto I1
Definition: universal_gemm_kernel.hpp:234
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: universal_gemm_kernel.hpp:264
static CK_TILE_HOST auto BlockSize()
Definition: universal_gemm_kernel.hpp:287
std::conditional_t< ADataTypeIsTuple, remove_cvref_t< typename GemmPipeline::ADataType >, remove_cvref_t< tuple< typename GemmPipeline::ADataType > >> AsDataType
Definition: universal_gemm_kernel.hpp:185
remove_cvref_t< std::tuple_element_t< I0, BsDataType > > BDataType
Definition: universal_gemm_kernel.hpp:243
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: universal_gemm_kernel.hpp:275
static constexpr index_t NumBTensor
Definition: universal_gemm_kernel.hpp:239
static constexpr auto I0
Definition: universal_gemm_kernel.hpp:233
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: universal_gemm_kernel.hpp:370
std::conditional_t< ALayoutIsTuple, remove_cvref_t< typename GemmPipeline::ALayout >, remove_cvref_t< tuple< typename GemmPipeline::ALayout > >> AsLayout
Definition: universal_gemm_kernel.hpp:174
std::conditional_t< DLayoutIsTuple, remove_cvref_t< typename EpiloguePipeline::DsLayout >, remove_cvref_t< tuple< typename EpiloguePipeline::DsLayout > >> DsLayout
Definition: universal_gemm_kernel.hpp:181
static constexpr bool DLayoutIsTuple
Definition: universal_gemm_kernel.hpp:169
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: universal_gemm_kernel.hpp:157
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: universal_gemm_kernel.hpp:316
static constexpr CK_TILE_HOST KernelArgs MakeKernelArgs(const UniversalGemmHostArgs< NumATensor, NumBTensor, NumDTensor > &hostArgs)
Definition: universal_gemm_kernel.hpp:300
static constexpr index_t kBlockSize
Definition: universal_gemm_kernel.hpp:199
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: universal_gemm_kernel.hpp:197
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:115
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: stream_config.hpp:30
#define CK_TILE_ENV(name)
Definition: env.hpp:145