31 template <
int SharedGranularityMN,
int SharedGranularityK = 0,
typename ScaleType_ =
float>
64 template <
int SharedGranularityMN,
typename ScaleType_>
79 :
ptr(ptr_), length(length_)
89 ret.length = length -
offset;
103 return i < length ?
ptr[i] : 0;
110 template <
typename ScaleType_>
133 template <index_t NumDTensor = 0>
139 const std::array<const void*, NumDTensor>& ds_ptr_,
147 const std::array<index_t, NumDTensor>& stride_Ds_,
166 const std::array<const void*, NumDTensor>
ds_ptr;
193 const void* b_shuffle_ptr_,
194 const std::array<const void*, NumDTensor>& ds_ptr_,
202 const std::array<index_t, NumDTensor>& stride_Ds_,
204 ScaleM scale_m_ =
nullptr,
205 ScaleN scale_n_ =
nullptr)
226 template <
int NumberTensor = 0>
230 template <
class ScaleM,
class ScaleN, index_t NumDTensor = 0>
236 const std::array<const void*, NumDTensor>
ds_ptr;
250 template <
typename TilePartitioner_,
typename FlatmmPipeline_,
typename EpiloguePipeline_>
278 static_assert(DsLayout::size() == DsDataType::size(),
279 "The size of DsLayout and DsDataType should be the same");
285 return concat(
'_',
"gemm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
292 return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
295 template <
class ScaleM,
class ScaleN>
301 hipDeviceProp_t prop;
305 int dync_smem_size = 0;
306 int maxActiveBlocksPerCU = 0;
308 [[maybe_unused]]
auto e = hipGetDeviceProperties(&prop, deviceId);
310 e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
311 &maxActiveBlocksPerCU,
312 reinterpret_cast<void*
>(
317 const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
318 const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
324 assert(kargs.k_batch == 1);
325 return dim3(
min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
329 return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
335 template <
class ScaleM,
class ScaleN>
339 return {hostArgs.a_ptr,
357 return max(FlatmmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
361 return FlatmmPipeline::GetSmemSize();
366 template <
class KernelArgs>
369 constexpr
auto N1 = BlockGemmShape::WarpTile::at(
number<1>{});
370 constexpr
auto K1 = BlockGemmShape::WarpTile::at(
number<2>{});
371 const index_t K_t = kargs.k_batch * K1;
372 const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
374 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
378 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
383 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
387 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
392 if(k_id <
static_cast<uint32_t>(kargs.k_batch - 1))
398 splitted_k = kargs.K - KRead * (kargs.k_batch - 1);
407 template <
class KernelArgs>
410 if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
413 if(kargs.k_batch != 1)
415 std::cerr <<
"Conditions not met for Kbatch >1 !" << std::endl;
421 if(kargs.k_batch != 1)
423 std::cerr <<
"Persistent mode doesn't support Kbatch >1 !" << std::endl;
428 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
430 if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK ==
false)
432 std::cerr <<
"Can't support K that is not a multiple of KPerBlock"
437 if(kargs.K % FlatmmPipeline::GetVectorSizeA() != 0)
439 std::cerr <<
"K is not a multiple of vector load size for A tensor!" << std::endl;
445 if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM ==
false)
447 std::cerr <<
"Can't support M that is not a multiple of MPerBlock"
452 if(kargs.M % FlatmmPipeline::GetVectorSizeA() != 0)
454 std::cerr <<
"M is not a multiple of vector load size for A tensor!" << std::endl;
459 if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
461 if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN ==
false)
463 std::cerr <<
"Can't support N that is not a multiple of NPerBlock"
468 if(kargs.N % FlatmmPipeline::GetVectorSizeB() != 0)
470 std::cerr <<
"N is not a multiple of vector load size for B tensor!" << std::endl;
476 if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK ==
false)
478 std::cerr <<
"Can't support K that is not a multiple of KPerBlock"
483 if(kargs.K % FlatmmPipeline::GetVectorSizeB() != 0)
485 std::cerr <<
"K is not a multiple of vector load size for B tensor!" << std::endl;
490 bool DTesnorIsValid = {
true};
493 if(std::is_same_v<DiLayout, ELayout> ==
false)
495 DTesnorIsValid =
false;
497 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
499 if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN ==
false)
501 CK_TILE_ERROR(
"Can't support N for tensor D that is not a multiple of "
502 "NPerBlock without padding!");
503 DTesnorIsValid =
false;
505 if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0)
507 CK_TILE_ERROR(
"N is not a multiple of vector load size for D tensor!");
508 DTesnorIsValid =
false;
513 if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM ==
false)
515 CK_TILE_ERROR(
"Can't support M for tensor D that is not a multiple of "
516 "MPerBlock without padding!");
518 DTesnorIsValid =
false;
520 if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0)
522 CK_TILE_ERROR(
"M is not a multiple of vector load size for D tensor!");
523 DTesnorIsValid =
false;
528 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
530 if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN ==
false)
532 std::cerr <<
"Can't support N that is not a multiple of NPerBlock"
537 if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
539 std::cerr <<
"N is not a multiple of vector load size for C tensor!" << std::endl;
545 if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM ==
false)
547 std::cerr <<
"Can't support M that is not a multiple of MPerBlock"
552 if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
554 std::cerr <<
"M is not a multiple of vector load size for C tensor!" << std::endl;
558 return DTesnorIsValid;
561 template <
typename KernelArgs>
563 const KernelArgs& kargs,
568 const auto& a_tensor_view = [&]() {
569 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
571 return make_naive_tensor_view<address_space_enum::global>(
575 number<FlatmmPipeline::GetVectorSizeA()>{},
580 return make_naive_tensor_view<address_space_enum::global>(
584 number<FlatmmPipeline::GetVectorSizeA()>{},
590 const auto& a_pad_view = [&]() {
591 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
608 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
624 template <
typename KernelArgs>
626 const KernelArgs& kargs,
631 FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(
I2));
632 index_t kFlatN = kargs.N * kargs.K / kFlatK;
634 const auto& b_flat_tensor_view = make_naive_tensor_view<address_space_enum::global>(
638 number<FlatmmPipeline::GetVectorSizeB()>{},
647 {
static_cast<int>(block_idx_n / BlockGemmShape::WarpTile::at(
I1)), 0});
650 template <
typename KernelArgs>
652 const KernelArgs& kargs,
661 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
663 return make_naive_tensor_view<address_space_enum::global>(
664 static_cast<const DDataType_*
>(ds_ptr[i]),
667 number<EpiloguePipeline::GetVectorSizeD(i)>{},
672 return make_naive_tensor_view<address_space_enum::global>(
673 static_cast<const DDataType_*
>(ds_ptr[i]),
676 number<EpiloguePipeline::GetVectorSizeD(i)>{},
686 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
707 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
712 {block_idx_m, block_idx_n});
719 {block_idx_n, block_idx_m});
725 template <memory_operation_enum DstInMemOp = memory_operation_enum::set,
typename KernelArgs>
727 const KernelArgs& kargs,
732 const auto& e_tensor_view = [&]() {
733 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
735 return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
739 number<EpiloguePipeline::GetVectorSizeC()>{},
744 return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
754 const auto& e_pad_view = [&]() {
755 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
775 {block_idx_m, block_idx_n});
778 template <
typename KernelArgs>
783 constexpr
int ScaleGranularityM = decltype(kargs.scale_m_ptr)::GranularityMN;
784 constexpr
int ScaleGranularityKA = decltype(kargs.scale_m_ptr)::GranularityK;
786 auto scale_stride_m = ScaleGranularityM == 0 ? 0
790 const auto scale_m_view = make_naive_tensor_view<address_space_enum::global>(
791 kargs.scale_m_ptr.ptr,
793 ScaleGranularityKA == 0
795 : (splitk_batch_offset.
splitted_k / ScaleGranularityKA)),
797 number < ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1 > {},
803 number < ScaleGranularityKA == 0
804 ? TilePartitioner::NPerBlock
805 : TilePartitioner::KPerBlock > {}),
809 template <
typename KernelArgs>
814 constexpr
int ScaleGranularityN = decltype(kargs.scale_n_ptr)::GranularityMN;
815 constexpr
int ScaleGranularityKB = decltype(kargs.scale_n_ptr)::GranularityK;
817 auto scale_stride_n = ScaleGranularityN == 0 ? 0
821 const auto scale_n_view = make_naive_tensor_view<address_space_enum::global>(
822 kargs.scale_n_ptr.ptr,
824 ScaleGranularityKB == 0 ? 1 : (splitk_batch_offset.
splitted_k / ScaleGranularityKB),
825 kargs.N / ScaleGranularityN),
827 number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
833 ? TilePartitioner::MPerBlock
834 : TilePartitioner::KPerBlock > {},
839 template <
class ScaleM,
class ScaleN,
bool UseDefaultScheduler = true>
843 const std::array<const void*, NumDTensor>& ds_ptr,
853 const auto& a_block_window =
856 const auto& ds_block_window =
MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
857 const auto& scale_m_window =
MakeScaleMWindow(kargs, splitk_batch_offset, block_idx_m);
858 const auto& scale_n_window =
MakeScaleNWindow(kargs, splitk_batch_offset, block_idx_n);
860 const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.
splitted_k);
864 a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong);
867 if constexpr(ScaleM::GranularityMN != -1 || ScaleN::GranularityMN != -1)
869 if(kargs.k_batch == 1)
871 auto e_block_window = MakeEBlockWindow<memory_operation_enum::set>(
872 e_ptr, kargs, block_idx_m, block_idx_n);
874 .template operator()<decltype(e_block_window),
875 decltype(c_block_tile),
876 decltype(ds_block_window)>(e_block_window,
885 auto e_block_window = MakeEBlockWindow<memory_operation_enum::atomic_add>(
886 e_ptr, kargs, block_idx_m, block_idx_n);
888 .template operator()<decltype(e_block_window),
889 decltype(c_block_tile),
890 decltype(ds_block_window)>(e_block_window,
898 else if(UseDefaultScheduler || (get_warp_id() == 0))
900 if(kargs.k_batch == 1)
902 auto e_block_window = MakeEBlockWindow<memory_operation_enum::set>(
903 e_ptr, kargs, block_idx_m, block_idx_n);
905 .template operator()<decltype(e_block_window),
906 decltype(c_block_tile),
907 decltype(ds_block_window)>(
908 e_block_window, c_block_tile, ds_block_window, smem_ptr_ping);
912 auto e_block_window = MakeEBlockWindow<memory_operation_enum::atomic_add>(
913 e_ptr, kargs, block_idx_m, block_idx_n);
915 .template operator()<decltype(e_block_window),
916 decltype(c_block_tile),
917 decltype(ds_block_window)>(
918 e_block_window, c_block_tile, ds_block_window, smem_ptr_ping);
923 template <
class ScaleM,
class ScaleN>
925 int partition_idx = blockIdx.x)
const
927 int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
931 const auto [iM, iN] =
948 if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
951 constexpr
auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
952 RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
963 partition_idx += gridDim.x;
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
Definition: cluster_descriptor.hpp:13
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:36
__global__ void kentry(Args... args)
Definition: kernel_launch.hpp:22
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:206
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:157
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
unsigned int uint32_t
Definition: stdint.h:126
Definition: flatmm_kernel.hpp:135
index_t N
Definition: flatmm_kernel.hpp:173
const void * a_ptr
Definition: flatmm_kernel.hpp:164
index_t stride_B
Definition: flatmm_kernel.hpp:176
const std::array< index_t, NumDTensor > stride_Ds
Definition: flatmm_kernel.hpp:177
index_t stride_C
Definition: flatmm_kernel.hpp:181
CK_TILE_HOST BaseFlatmmHostArgs()=default
index_t K
Definition: flatmm_kernel.hpp:174
const void * b_ptr
Definition: flatmm_kernel.hpp:165
index_t k_batch
Definition: flatmm_kernel.hpp:184
index_t stride_E
Definition: flatmm_kernel.hpp:180
CK_TILE_HOST BaseFlatmmHostArgs(const void *a_ptr_, const void *b_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_E_)
Definition: flatmm_kernel.hpp:137
index_t stride_A
Definition: flatmm_kernel.hpp:175
const std::array< const void *, NumDTensor > ds_ptr
Definition: flatmm_kernel.hpp:166
void * c_ptr
Definition: flatmm_kernel.hpp:170
void * e_ptr
Definition: flatmm_kernel.hpp:169
index_t M
Definition: flatmm_kernel.hpp:172
Definition: flatmm_kernel.hpp:365
index_t b_k_split_offset
Definition: flatmm_kernel.hpp:403
index_t a_k_split_offset
Definition: flatmm_kernel.hpp:402
__device__ SplitKBatchOffset(const KernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: flatmm_kernel.hpp:367
index_t splitted_k
Definition: flatmm_kernel.hpp:404
Definition: flatmm_kernel.hpp:232
ScaleN scale_n_ptr
Definition: flatmm_kernel.hpp:247
void * e_ptr
Definition: flatmm_kernel.hpp:237
std::array< index_t, NumDTensor > stride_Ds
Definition: flatmm_kernel.hpp:243
index_t K
Definition: flatmm_kernel.hpp:240
ScaleM scale_m_ptr
Definition: flatmm_kernel.hpp:246
const void * b_ptr
Definition: flatmm_kernel.hpp:235
index_t k_batch
Definition: flatmm_kernel.hpp:245
index_t N
Definition: flatmm_kernel.hpp:239
index_t stride_B
Definition: flatmm_kernel.hpp:242
const std::array< const void *, NumDTensor > ds_ptr
Definition: flatmm_kernel.hpp:236
const void * a_ptr
Definition: flatmm_kernel.hpp:233
index_t stride_E
Definition: flatmm_kernel.hpp:244
index_t M
Definition: flatmm_kernel.hpp:238
index_t stride_A
Definition: flatmm_kernel.hpp:241
Definition: flatmm_kernel.hpp:252
static CK_TILE_DEVICE auto MakeEBlockWindow(EDataType *e_ptr, const KernelArgs &kargs, const index_t block_idx_m, const index_t block_idx_n)
Definition: flatmm_kernel.hpp:726
static constexpr CK_TILE_HOST auto BlockSize()
Definition: flatmm_kernel.hpp:333
remove_cvref_t< typename FlatmmPipeline::BlockGemmShape > BlockGemmShape
Definition: flatmm_kernel.hpp:256
static constexpr auto I0
Definition: flatmm_kernel.hpp:273
static CK_TILE_DEVICE auto MakeABlockWindow(const ADataType *a_ptr, const KernelArgs &kargs, const index_t k_size, const index_t block_idx_m)
Definition: flatmm_kernel.hpp:562
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: flatmm_kernel.hpp:253
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition: flatmm_kernel.hpp:261
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: flatmm_kernel.hpp:262
static constexpr bool UsePersistentKernel
Definition: flatmm_kernel.hpp:264
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: flatmm_kernel.hpp:269
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: flatmm_kernel.hpp:257
static constexpr auto I2
Definition: flatmm_kernel.hpp:275
static constexpr CK_TILE_HOST auto GridSize(const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs)
Definition: flatmm_kernel.hpp:297
CK_TILE_DEVICE void operator()(FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> kargs, int partition_idx=blockIdx.x) const
Definition: flatmm_kernel.hpp:924
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPingSize()
Definition: flatmm_kernel.hpp:355
remove_cvref_t< FlatmmPipeline_ > FlatmmPipeline
Definition: flatmm_kernel.hpp:254
remove_cvref_t< typename FlatmmPipeline::CLayout > ELayout
Definition: flatmm_kernel.hpp:260
remove_cvref_t< typename FlatmmPipeline::ADataType > ADataType
Definition: flatmm_kernel.hpp:266
remove_cvref_t< typename FlatmmPipeline::BLayout > BLayout
Definition: flatmm_kernel.hpp:259
static constexpr index_t NumDTensor
Definition: flatmm_kernel.hpp:271
static CK_TILE_HOST const std::string GetName()
Definition: flatmm_kernel.hpp:282
static constexpr CK_TILE_HOST FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> MakeKernelArgs(const ScaleFlatmmHostArgs< ScaleM, ScaleN, DsDataType::size()> &hostArgs)
Definition: flatmm_kernel.hpp:337
static constexpr index_t kBlockSize
Definition: flatmm_kernel.hpp:263
static CK_TILE_DEVICE auto MakeBFlatBlockWindow(const BDataType *b_flat_ptr, const KernelArgs &kargs, const index_t block_idx_n)
Definition: flatmm_kernel.hpp:625
static CK_TILE_DEVICE auto MakeScaleMWindow(const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m)
Definition: flatmm_kernel.hpp:779
static CK_TILE_DEVICE auto MakeScaleNWindow(const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_n)
Definition: flatmm_kernel.hpp:810
remove_cvref_t< typename FlatmmPipeline::ALayout > ALayout
Definition: flatmm_kernel.hpp:258
static constexpr auto I3
Definition: flatmm_kernel.hpp:276
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: flatmm_kernel.hpp:289
static constexpr auto I1
Definition: flatmm_kernel.hpp:274
static CK_TILE_DEVICE void RunFlatmm(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_ping, void *smem_ptr_pong, const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Definition: flatmm_kernel.hpp:841
remove_cvref_t< typename FlatmmPipeline::BDataType > BDataType
Definition: flatmm_kernel.hpp:267
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: flatmm_kernel.hpp:408
static CK_TILE_DEVICE auto MakeDBlockWindows(const std::array< const void *, NumDTensor > &ds_ptr, const KernelArgs &kargs, const index_t block_idx_m, const index_t block_idx_n)
Definition: flatmm_kernel.hpp:651
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPongSize()
Definition: flatmm_kernel.hpp:359
Definition: flatmm_kernel.hpp:15
index_t stride_C
Definition: flatmm_kernel.hpp:28
CK_TILE_HOST FlatmmProblem()=default
index_t M
Definition: flatmm_kernel.hpp:23
index_t stride_B
Definition: flatmm_kernel.hpp:27
CK_TILE_HOST FlatmmProblem(index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
Definition: flatmm_kernel.hpp:17
index_t stride_A
Definition: flatmm_kernel.hpp:26
index_t N
Definition: flatmm_kernel.hpp:24
index_t K
Definition: flatmm_kernel.hpp:25
constexpr CK_TILE_HOST_DEVICE FlatmmScalePointer()=default
constexpr CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType *)
Definition: flatmm_kernel.hpp:120
constexpr CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType *, index_t)
Definition: flatmm_kernel.hpp:121
constexpr CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t) const
Definition: flatmm_kernel.hpp:123
constexpr CK_TILE_HOST_DEVICE ScaleType operator[](index_t) const
Definition: flatmm_kernel.hpp:127
ScaleType_ ScaleType
Definition: flatmm_kernel.hpp:113
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType *ptr_, index_t length_)
Definition: flatmm_kernel.hpp:78
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType *ptr_)
Definition: flatmm_kernel.hpp:77
CK_TILE_HOST_DEVICE FlatmmScalePointer()=default
index_t length
Definition: flatmm_kernel.hpp:74
CK_TILE_HOST_DEVICE ScaleType operator[](index_t i) const
Definition: flatmm_kernel.hpp:99
ScaleType_ ScaleType
Definition: flatmm_kernel.hpp:67
const ScaleType * ptr
Definition: flatmm_kernel.hpp:71
CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const
Definition: flatmm_kernel.hpp:83
Definition: flatmm_kernel.hpp:33
CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const
Definition: flatmm_kernel.hpp:47
CK_TILE_HOST_DEVICE ScaleType operator[](index_t i) const =delete
ScaleType_ ScaleType
Definition: flatmm_kernel.hpp:34
const ScaleType * ptr
Definition: flatmm_kernel.hpp:38
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType *ptr_)
Definition: flatmm_kernel.hpp:41
static constexpr int GranularityK
Definition: flatmm_kernel.hpp:36
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType *ptr_, [[maybe_unused]] index_t length_)
Definition: flatmm_kernel.hpp:42
static constexpr int GranularityMN
Definition: flatmm_kernel.hpp:35
CK_TILE_HOST_DEVICE FlatmmScalePointer()=default
Definition: flatmm_kernel.hpp:190
CK_TILE_HOST ScaleFlatmmHostArgs()=default
ScaleM scale_m
Definition: flatmm_kernel.hpp:222
ScaleN scale_n
Definition: flatmm_kernel.hpp:223
CK_TILE_HOST ScaleFlatmmHostArgs(const void *a_ptr_, const void *b_shuffle_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *c_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_C_, ScaleM scale_m_=nullptr, ScaleN scale_n_=nullptr)
Definition: flatmm_kernel.hpp:192
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:115
Definition: coordinate_transform.hpp:1392
Definition: sequence.hpp:49
Definition: functional.hpp:43