15 template <index_t NumDTensor = 0>
21 const std::array<const void*, NumDTensor>& ds_ptr_,
29 const std::array<index_t, NumDTensor>& stride_Ds_,
48 const std::array<const void*, NumDTensor>
ds_ptr;
69 template <index_t NumDTensor = 0>
75 const std::array<const void*, NumDTensor>
ds_ptr;
87 template <
typename TilePartitioner_,
typename FlatmmPipeline_,
typename EpiloguePipeline_>
114 static_assert(DsLayout::size() == DsDataType::size(),
115 "The size of DsLayout and DsDataType should be the same");
121 return concat(
'_',
"gemm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
127 return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
151 return max(FlatmmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
158 constexpr
auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(
number<2>{});
160 const index_t KRead = (kargs.
K + K_t - 1) / K_t * K1;
162 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
166 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
171 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
175 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
197 if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
202 std::cerr <<
"Conditions not met for Kbatch >1 !" << std::endl;
207 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
209 if(kargs.
K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK ==
false)
211 std::cerr <<
"Can't support K that is not a multiple of KPerBlock"
216 if(kargs.
K % FlatmmPipeline::GetVectorSizeA() != 0)
218 std::cerr <<
"K is not a multiple of vector load size for A tensor!" << std::endl;
224 if(kargs.
M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM ==
false)
226 std::cerr <<
"Can't support M that is not a multiple of MPerBlock"
231 if(kargs.
M % FlatmmPipeline::GetVectorSizeA() != 0)
233 std::cerr <<
"M is not a multiple of vector load size for A tensor!" << std::endl;
238 if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
240 if(kargs.
N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN ==
false)
242 std::cerr <<
"Can't support N that is not a multiple of NPerBlock"
247 if(kargs.
N % FlatmmPipeline::GetVectorSizeB() != 0)
249 std::cerr <<
"N is not a multiple of vector load size for B tensor!" << std::endl;
255 if(kargs.
K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK ==
false)
257 std::cerr <<
"Can't support K that is not a multiple of KPerBlock"
262 if(kargs.
K % FlatmmPipeline::GetVectorSizeB() != 0)
264 std::cerr <<
"K is not a multiple of vector load size for B tensor!" << std::endl;
269 bool DTesnorIsValid = {
true};
272 if(std::is_same_v<DiLayout, ELayout> ==
false)
274 DTesnorIsValid =
false;
276 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
278 if(kargs.
N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN ==
false)
280 CK_TILE_ERROR(
"Can't support N for tensor D that is not a multiple of "
281 "NPerBlock without padding!");
282 DTesnorIsValid =
false;
284 if(kargs.
N % EpiloguePipeline::GetVectorSizeD(index) != 0)
286 CK_TILE_ERROR(
"N is not a multiple of vector load size for D tensor!");
287 DTesnorIsValid =
false;
292 if(kargs.
M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM ==
false)
294 CK_TILE_ERROR(
"Can't support M for tensor D that is not a multiple of "
295 "MPerBlock without padding!");
297 DTesnorIsValid =
false;
299 if(kargs.
M % EpiloguePipeline::GetVectorSizeD(index) != 0)
301 CK_TILE_ERROR(
"M is not a multiple of vector load size for D tensor!");
302 DTesnorIsValid =
false;
307 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
309 if(kargs.
N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN ==
false)
311 std::cerr <<
"Can't support N that is not a multiple of NPerBlock"
316 if(kargs.
N % EpiloguePipeline::GetVectorSizeC() != 0)
318 std::cerr <<
"N is not a multiple of vector load size for C tensor!" << std::endl;
324 if(kargs.
M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM ==
false)
326 std::cerr <<
"Can't support M that is not a multiple of MPerBlock"
331 if(kargs.
M % EpiloguePipeline::GetVectorSizeC() != 0)
333 std::cerr <<
"M is not a multiple of vector load size for C tensor!" << std::endl;
337 return DTesnorIsValid;
340 template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
344 const std::array<const void*, NumDTensor>& ds_ptr,
349 const auto& a_tensor_view = [&]() {
350 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
352 return make_naive_tensor_view<address_space_enum::global>(
356 number<FlatmmPipeline::GetVectorSizeA()>{},
361 return make_naive_tensor_view<address_space_enum::global>(
365 number<FlatmmPipeline::GetVectorSizeA()>{},
370 index_t kFlatK = FlatmmPipeline::flatKPerWarp * (splitk_batch_offset.
splitted_k /
371 BlockGemmShape::WarpTile::at(
number<2>{}));
372 index_t kFlatN = kargs.
N * kargs.
K / kFlatK;
373 const auto& b_flat_tensor_view = [&]() {
374 return make_naive_tensor_view<address_space_enum::global>(
378 number<FlatmmPipeline::GetVectorSizeB()>{},
386 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
388 return make_naive_tensor_view<address_space_enum::global>(
389 static_cast<const DDataType_*
>(ds_ptr[i]),
392 number<EpiloguePipeline::GetVectorSizeD(i)>{},
397 return make_naive_tensor_view<address_space_enum::global>(
398 static_cast<const DDataType_*
>(ds_ptr[i]),
401 number<EpiloguePipeline::GetVectorSizeD(i)>{},
408 const auto& e_tensor_view = [&]() {
409 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
411 return make_naive_tensor_view<address_space_enum::global>(
415 number<EpiloguePipeline::GetVectorSizeC()>{},
420 return make_naive_tensor_view<address_space_enum::global>(
429 return make_tuple(a_tensor_view, b_flat_tensor_view, ds_tensor_view, e_tensor_view);
432 template <
typename TensorView>
435 const auto& a_pad_view = [&]() {
436 const auto& a_tensor_view = views.at(
I0);
437 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
453 const auto& b_flat_tensor_view = views.at(
I1);
457 const auto& d_tensor_view = views.at(
I2);
459 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
477 const auto& e_pad_view = [&]() {
478 const auto& e_tensor_view = views.at(
I3);
479 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
495 return make_tuple(a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view);
498 template <
typename PadView>
502 const auto& a_pad_view = views.at(
I0);
503 const auto& b_flat_pad_view = views.at(
I1);
504 const auto& ds_pad_view = views.at(
I2);
505 const auto& e_pad_view = views.at(
I3);
507 const auto& a_block_window = [&]() {
508 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
524 const auto& b_flat_block_window =
528 {
static_cast<int>(i_n / BlockGemmShape::WarpTile::at(
I1)), 0});
533 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
555 return make_tuple(a_block_window, b_flat_block_window, ds_block_window, e_block_window);
558 template <
bool UseDefaultScheduler = true>
561 const std::array<const void*, NumDTensor>& ds_ptr,
570 const auto& gemm_tensor_views_tuple =
571 MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
572 a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
576 const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.
splitted_k);
579 const auto& a_block_window = gemm_tile_windows.at(
I0);
580 const auto& b_flat_block_window = gemm_tile_windows.at(
I1);
581 const auto& d_block_window = gemm_tile_windows.at(
I2);
583 a_block_window, b_flat_block_window, num_loop, smem_ptr);
584 if(UseDefaultScheduler || (get_warp_id() == 0))
587 auto& c_block_window = gemm_tile_windows.at(
I3);
590 operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
591 c_block_window, c_block_tile, d_block_window, smem_ptr);
597 const auto [iM, iN] =
TilePartitioner{kargs.
M, kargs.
N}.GetOutputTileIndex(blockIdx.x);
598 const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
599 const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
613 EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
616 constexpr
auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
617 RunFlatmm<scheduler_type>(a_ptr,
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__device__ X atomic_add(X *p_dst, const X &x)
unsigned int uint32_t
Definition: stdint.h:126
Definition: flatmm_kernel.hpp:17
index_t stride_C
Definition: flatmm_kernel.hpp:63
index_t stride_A
Definition: flatmm_kernel.hpp:57
CK_TILE_HOST FlatmmHostArgs(const void *a_ptr_, const void *b_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_E_)
Definition: flatmm_kernel.hpp:19
index_t K
Definition: flatmm_kernel.hpp:56
index_t stride_E
Definition: flatmm_kernel.hpp:62
const void * b_ptr
Definition: flatmm_kernel.hpp:47
void * c_ptr
Definition: flatmm_kernel.hpp:52
CK_TILE_HOST FlatmmHostArgs()=default
void * e_ptr
Definition: flatmm_kernel.hpp:51
const std::array< index_t, NumDTensor > stride_Ds
Definition: flatmm_kernel.hpp:59
const void * a_ptr
Definition: flatmm_kernel.hpp:46
index_t N
Definition: flatmm_kernel.hpp:55
index_t stride_B
Definition: flatmm_kernel.hpp:58
index_t k_batch
Definition: flatmm_kernel.hpp:66
index_t M
Definition: flatmm_kernel.hpp:54
const std::array< const void *, NumDTensor > ds_ptr
Definition: flatmm_kernel.hpp:48
Definition: flatmm_kernel.hpp:155
index_t b_k_split_offset
Definition: flatmm_kernel.hpp:191
index_t a_k_split_offset
Definition: flatmm_kernel.hpp:190
index_t splitted_k
Definition: flatmm_kernel.hpp:192
__device__ SplitKBatchOffset(const KernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: flatmm_kernel.hpp:156
Definition: flatmm_kernel.hpp:71
index_t N
Definition: flatmm_kernel.hpp:78
index_t K
Definition: flatmm_kernel.hpp:79
void * e_ptr
Definition: flatmm_kernel.hpp:76
index_t k_batch
Definition: flatmm_kernel.hpp:84
const std::array< const void *, NumDTensor > ds_ptr
Definition: flatmm_kernel.hpp:75
index_t M
Definition: flatmm_kernel.hpp:77
const void * a_ptr
Definition: flatmm_kernel.hpp:72
index_t stride_A
Definition: flatmm_kernel.hpp:80
index_t stride_E
Definition: flatmm_kernel.hpp:83
index_t stride_B
Definition: flatmm_kernel.hpp:81
const void * b_ptr
Definition: flatmm_kernel.hpp:74
std::array< index_t, NumDTensor > stride_Ds
Definition: flatmm_kernel.hpp:82
Definition: flatmm_kernel.hpp:89
FlatmmKernelArgs< DsLayout::size()> KernelArgs
Definition: flatmm_kernel.hpp:116
static constexpr CK_TILE_HOST auto BlockSize()
Definition: flatmm_kernel.hpp:130
remove_cvref_t< typename FlatmmPipeline::BlockGemmShape > BlockGemmShape
Definition: flatmm_kernel.hpp:93
static constexpr auto I0
Definition: flatmm_kernel.hpp:109
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: flatmm_kernel.hpp:90
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition: flatmm_kernel.hpp:98
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: flatmm_kernel.hpp:99
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: flatmm_kernel.hpp:105
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: flatmm_kernel.hpp:342
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: flatmm_kernel.hpp:94
static constexpr auto I2
Definition: flatmm_kernel.hpp:111
static CK_TILE_DEVICE void RunFlatmm(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Definition: flatmm_kernel.hpp:559
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: flatmm_kernel.hpp:433
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: flatmm_kernel.hpp:195
remove_cvref_t< FlatmmPipeline_ > FlatmmPipeline
Definition: flatmm_kernel.hpp:91
remove_cvref_t< typename FlatmmPipeline::CLayout > ELayout
Definition: flatmm_kernel.hpp:97
remove_cvref_t< typename FlatmmPipeline::ADataType > ADataType
Definition: flatmm_kernel.hpp:102
remove_cvref_t< typename FlatmmPipeline::BLayout > BLayout
Definition: flatmm_kernel.hpp:96
static constexpr index_t NumDTensor
Definition: flatmm_kernel.hpp:107
static CK_TILE_HOST const std::string GetName()
Definition: flatmm_kernel.hpp:118
static constexpr CK_TILE_HOST KernelArgs MakeKernelArgs(const FlatmmHostArgs< NumDTensor > &hostArgs)
Definition: flatmm_kernel.hpp:133
static constexpr index_t kBlockSize
Definition: flatmm_kernel.hpp:100
remove_cvref_t< typename FlatmmPipeline::ALayout > ALayout
Definition: flatmm_kernel.hpp:95
static constexpr auto I3
Definition: flatmm_kernel.hpp:112
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: flatmm_kernel.hpp:149
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: flatmm_kernel.hpp:125
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: flatmm_kernel.hpp:500
static constexpr auto I1
Definition: flatmm_kernel.hpp:110
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
Definition: flatmm_kernel.hpp:595
remove_cvref_t< typename FlatmmPipeline::BDataType > BDataType
Definition: flatmm_kernel.hpp:103
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:115
Definition: sequence.hpp:49
Definition: functional.hpp:43