31 template <
int SharedGranularityMN,
int SharedGranularityK = 0>
63 template <
int SharedGranularityMN>
77 :
ptr(ptr_), length(length_)
87 ret.length = length -
offset;
101 return i < length ?
ptr[i] : 0;
114 const float*
ptr =
nullptr;
130 template <index_t NumDTensor = 0>
136 const std::array<const void*, NumDTensor>& ds_ptr_,
144 const std::array<index_t, NumDTensor>& stride_Ds_,
163 const std::array<const void*, NumDTensor>
ds_ptr;
190 const void* b_shuffle_ptr_,
191 const std::array<const void*, NumDTensor>& ds_ptr_,
199 const std::array<index_t, NumDTensor>& stride_Ds_,
201 ScaleM scale_m_ =
nullptr,
202 ScaleN scale_n_ =
nullptr)
223 template <
int NumberTensor = 0>
227 template <
class ScaleM,
class ScaleN, index_t NumDTensor = 0>
233 const std::array<const void*, NumDTensor>
ds_ptr;
247 template <
typename TilePartitioner_,
typename FlatmmPipeline_,
typename EpiloguePipeline_>
275 static_assert(DsLayout::size() == DsDataType::size(),
276 "The size of DsLayout and DsDataType should be the same");
282 return concat(
'_',
"gemm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
289 return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
292 template <
class ScaleM,
class ScaleN>
298 hipDeviceProp_t prop;
302 int dync_smem_size = 0;
303 int maxActiveBlocksPerCU = 0;
305 [[maybe_unused]]
auto e = hipGetDeviceProperties(&prop, deviceId);
307 e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
308 &maxActiveBlocksPerCU,
309 reinterpret_cast<void*
>(
314 const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
315 const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
321 assert(kargs.k_batch == 1);
322 return dim3(
min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
326 return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
332 template <
class ScaleM,
class ScaleN>
336 return {hostArgs.a_ptr,
354 return max(FlatmmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
358 return FlatmmPipeline::GetSmemSize();
363 template <
class KernelArgs>
366 constexpr
auto N1 = TilePartitioner::BlockGemmShape::WarpTile::at(
number<1>{});
367 constexpr
auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(
number<2>{});
368 const index_t K_t = kargs.k_batch * K1;
369 const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
371 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
375 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
380 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
384 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
389 if(k_id <
static_cast<uint32_t>(kargs.k_batch - 1))
395 splitted_k = kargs.K - KRead * (kargs.k_batch - 1);
404 template <
class KernelArgs>
407 if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
410 if(kargs.k_batch != 1)
412 std::cerr <<
"Conditions not met for Kbatch >1 !" << std::endl;
418 if(kargs.k_batch != 1)
420 std::cerr <<
"Persistent mode doesn't support Kbatch >1 !" << std::endl;
425 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
427 if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK ==
false)
429 std::cerr <<
"Can't support K that is not a multiple of KPerBlock"
434 if(kargs.K % FlatmmPipeline::GetVectorSizeA() != 0)
436 std::cerr <<
"K is not a multiple of vector load size for A tensor!" << std::endl;
442 if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM ==
false)
444 std::cerr <<
"Can't support M that is not a multiple of MPerBlock"
449 if(kargs.M % FlatmmPipeline::GetVectorSizeA() != 0)
451 std::cerr <<
"M is not a multiple of vector load size for A tensor!" << std::endl;
456 if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
458 if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN ==
false)
460 std::cerr <<
"Can't support N that is not a multiple of NPerBlock"
465 if(kargs.N % FlatmmPipeline::GetVectorSizeB() != 0)
467 std::cerr <<
"N is not a multiple of vector load size for B tensor!" << std::endl;
473 if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK ==
false)
475 std::cerr <<
"Can't support K that is not a multiple of KPerBlock"
480 if(kargs.K % FlatmmPipeline::GetVectorSizeB() != 0)
482 std::cerr <<
"K is not a multiple of vector load size for B tensor!" << std::endl;
487 bool DTesnorIsValid = {
true};
490 if(std::is_same_v<DiLayout, ELayout> ==
false)
492 DTesnorIsValid =
false;
494 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
496 if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN ==
false)
498 CK_TILE_ERROR(
"Can't support N for tensor D that is not a multiple of "
499 "NPerBlock without padding!");
500 DTesnorIsValid =
false;
502 if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0)
504 CK_TILE_ERROR(
"N is not a multiple of vector load size for D tensor!");
505 DTesnorIsValid =
false;
510 if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM ==
false)
512 CK_TILE_ERROR(
"Can't support M for tensor D that is not a multiple of "
513 "MPerBlock without padding!");
515 DTesnorIsValid =
false;
517 if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0)
519 CK_TILE_ERROR(
"M is not a multiple of vector load size for D tensor!");
520 DTesnorIsValid =
false;
525 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
527 if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN ==
false)
529 std::cerr <<
"Can't support N that is not a multiple of NPerBlock"
534 if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
536 std::cerr <<
"N is not a multiple of vector load size for C tensor!" << std::endl;
542 if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM ==
false)
544 std::cerr <<
"Can't support M that is not a multiple of MPerBlock"
549 if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
551 std::cerr <<
"M is not a multiple of vector load size for C tensor!" << std::endl;
555 return DTesnorIsValid;
558 template <memory_operation_enum DstInMemOp = memory_operation_enum::set,
class KernelArgs>
562 const std::array<const void*, NumDTensor>& ds_ptr,
564 const KernelArgs& kargs,
567 const auto& a_tensor_view = [&]() {
568 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
570 return make_naive_tensor_view<address_space_enum::global>(
574 number<FlatmmPipeline::GetVectorSizeA()>{},
579 return make_naive_tensor_view<address_space_enum::global>(
583 number<FlatmmPipeline::GetVectorSizeA()>{},
589 FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(
I2));
590 index_t kFlatN = kargs.N * kargs.K / kFlatK;
591 const auto& b_flat_tensor_view = [&]() {
592 return make_naive_tensor_view<address_space_enum::global>(
596 number<FlatmmPipeline::GetVectorSizeB()>{},
604 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
606 return make_naive_tensor_view<address_space_enum::global>(
607 static_cast<const DDataType_*
>(ds_ptr[i]),
610 number<EpiloguePipeline::GetVectorSizeD(i)>{},
615 return make_naive_tensor_view<address_space_enum::global>(
616 static_cast<const DDataType_*
>(ds_ptr[i]),
619 number<EpiloguePipeline::GetVectorSizeD(i)>{},
626 const auto& e_tensor_view = [&]() {
627 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
629 return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
633 number<EpiloguePipeline::GetVectorSizeC()>{},
638 return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
647 constexpr
int ScaleGranularityM = decltype(kargs.scale_m_ptr)::GranularityMN;
648 constexpr
int ScaleGranularityN = decltype(kargs.scale_n_ptr)::GranularityMN;
650 constexpr
int ScaleGranularityKA = decltype(kargs.scale_m_ptr)::GranularityK;
651 constexpr
int ScaleGranularityKB = decltype(kargs.scale_n_ptr)::GranularityK;
653 auto scale_stride_m = ScaleGranularityM == 0 ? 0
655 auto scale_stride_n = ScaleGranularityN == 0 ? 0
658 static_assert(ScaleGranularityM == 0 || ScaleGranularityM == 1 || ScaleGranularityM == -1,
659 "only support per-tensor or per-row scaling");
660 static_assert(ScaleGranularityN == 0 || ScaleGranularityN == 1 || ScaleGranularityN == -1,
661 "only support per-tensor or per-column scaling");
663 const auto scale_m_view = make_naive_tensor_view<address_space_enum::global>(
664 kargs.scale_m_ptr.ptr,
666 kargs.M / ScaleGranularityM,
667 ScaleGranularityKA == 0 ? 1 : splitk_batch_offset.
splitted_k / ScaleGranularityKA),
669 number < ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1 > {},
671 const auto scale_n_view = make_naive_tensor_view<address_space_enum::global>(
672 kargs.scale_n_ptr.ptr,
674 ScaleGranularityKB == 0 ? 1 : (splitk_batch_offset.
splitted_k / ScaleGranularityKB),
675 kargs.N / ScaleGranularityN),
677 number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
688 template <
typename TensorView>
691 const auto& a_pad_view = [&]() {
692 const auto& a_tensor_view = views.at(
I0);
693 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
709 const auto& b_flat_tensor_view = views.at(
I1);
713 const auto& d_tensor_view = views.at(
I2);
715 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
733 const auto& e_pad_view = [&]() {
734 const auto& e_tensor_view = views.at(
I3);
735 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
759 template <
typename PadView>
763 const auto& a_pad_view = views.at(
I0);
764 const auto& b_flat_pad_view = views.at(
I1);
765 const auto& ds_pad_view = views.at(
I2);
766 const auto& e_pad_view = views.at(
I3);
768 const auto& a_block_window = [&]() {
769 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
785 const auto& b_flat_block_window =
789 {
static_cast<int>(i_n / BlockGemmShape::WarpTile::at(
I1)), 0});
794 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
816 constexpr
int ScaleGranularityKA = 0;
817 constexpr
int ScaleGranularityKB = 0;
821 number < ScaleGranularityKA == 0
822 ? TilePartitioner::NPerBlock
823 : TilePartitioner::KPerBlock > {}),
827 ? TilePartitioner::MPerBlock
828 : TilePartitioner::KPerBlock > {},
840 template <
class ScaleM,
class ScaleN,
bool UseDefaultScheduler = true>
844 const std::array<const void*, NumDTensor>& ds_ptr,
854 const auto& gemm_tensor_views_tuple =
855 MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
856 a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
860 const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.
splitted_k);
863 const auto& a_block_window = gemm_tile_windows.at(
I0);
864 const auto& b_flat_block_window = gemm_tile_windows.at(
I1);
865 const auto& d_block_window = gemm_tile_windows.at(
I2);
867 a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong);
869 auto scale_m_window = gemm_tile_windows.at(
number<4>{});
870 auto scale_n_window = gemm_tile_windows.at(
number<5>{});
873 if constexpr(ScaleM::GranularityMN != -1 || ScaleN::GranularityMN != -1)
875 auto& c_block_window = gemm_tile_windows.at(
I3);
877 operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
885 else if(UseDefaultScheduler || (get_warp_id() == 0))
888 auto& c_block_window = gemm_tile_windows.at(
I3);
890 operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
891 c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
895 template <
class ScaleM,
class ScaleN>
897 int partition_idx = blockIdx.x)
const
899 int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
903 const auto [iM, iN] =
921 EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
924 constexpr
auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
925 RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
936 partition_idx += gridDim.x;
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
__global__ void kentry(Args... args)
Definition: kernel_launch.hpp:22
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__device__ X atomic_add(X *p_dst, const X &x)
unsigned int uint32_t
Definition: stdint.h:126
Definition: flatmm_kernel.hpp:132
index_t N
Definition: flatmm_kernel.hpp:170
const void * a_ptr
Definition: flatmm_kernel.hpp:161
index_t stride_B
Definition: flatmm_kernel.hpp:173
const std::array< index_t, NumDTensor > stride_Ds
Definition: flatmm_kernel.hpp:174
index_t stride_C
Definition: flatmm_kernel.hpp:178
CK_TILE_HOST BaseFlatmmHostArgs()=default
index_t K
Definition: flatmm_kernel.hpp:171
const void * b_ptr
Definition: flatmm_kernel.hpp:162
index_t k_batch
Definition: flatmm_kernel.hpp:181
index_t stride_E
Definition: flatmm_kernel.hpp:177
CK_TILE_HOST BaseFlatmmHostArgs(const void *a_ptr_, const void *b_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_E_)
Definition: flatmm_kernel.hpp:134
index_t stride_A
Definition: flatmm_kernel.hpp:172
const std::array< const void *, NumDTensor > ds_ptr
Definition: flatmm_kernel.hpp:163
void * c_ptr
Definition: flatmm_kernel.hpp:167
void * e_ptr
Definition: flatmm_kernel.hpp:166
index_t M
Definition: flatmm_kernel.hpp:169
Definition: flatmm_kernel.hpp:362
index_t b_k_split_offset
Definition: flatmm_kernel.hpp:400
index_t a_k_split_offset
Definition: flatmm_kernel.hpp:399
__device__ SplitKBatchOffset(const KernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: flatmm_kernel.hpp:364
index_t splitted_k
Definition: flatmm_kernel.hpp:401
Definition: flatmm_kernel.hpp:229
ScaleN scale_n_ptr
Definition: flatmm_kernel.hpp:244
void * e_ptr
Definition: flatmm_kernel.hpp:234
std::array< index_t, NumDTensor > stride_Ds
Definition: flatmm_kernel.hpp:240
index_t K
Definition: flatmm_kernel.hpp:237
ScaleM scale_m_ptr
Definition: flatmm_kernel.hpp:243
const void * b_ptr
Definition: flatmm_kernel.hpp:232
index_t k_batch
Definition: flatmm_kernel.hpp:242
index_t N
Definition: flatmm_kernel.hpp:236
index_t stride_B
Definition: flatmm_kernel.hpp:239
const std::array< const void *, NumDTensor > ds_ptr
Definition: flatmm_kernel.hpp:233
const void * a_ptr
Definition: flatmm_kernel.hpp:230
index_t stride_E
Definition: flatmm_kernel.hpp:241
index_t M
Definition: flatmm_kernel.hpp:235
index_t stride_A
Definition: flatmm_kernel.hpp:238
Definition: flatmm_kernel.hpp:249
static constexpr CK_TILE_HOST auto BlockSize()
Definition: flatmm_kernel.hpp:330
remove_cvref_t< typename FlatmmPipeline::BlockGemmShape > BlockGemmShape
Definition: flatmm_kernel.hpp:253
static constexpr auto I0
Definition: flatmm_kernel.hpp:270
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: flatmm_kernel.hpp:250
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition: flatmm_kernel.hpp:258
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: flatmm_kernel.hpp:560
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: flatmm_kernel.hpp:259
static constexpr bool UsePersistentKernel
Definition: flatmm_kernel.hpp:261
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: flatmm_kernel.hpp:266
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: flatmm_kernel.hpp:254
static constexpr auto I2
Definition: flatmm_kernel.hpp:272
static constexpr CK_TILE_HOST auto GridSize(const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs)
Definition: flatmm_kernel.hpp:294
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: flatmm_kernel.hpp:689
CK_TILE_DEVICE void operator()(FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> kargs, int partition_idx=blockIdx.x) const
Definition: flatmm_kernel.hpp:896
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPingSize()
Definition: flatmm_kernel.hpp:352
remove_cvref_t< FlatmmPipeline_ > FlatmmPipeline
Definition: flatmm_kernel.hpp:251
remove_cvref_t< typename FlatmmPipeline::CLayout > ELayout
Definition: flatmm_kernel.hpp:257
remove_cvref_t< typename FlatmmPipeline::ADataType > ADataType
Definition: flatmm_kernel.hpp:263
remove_cvref_t< typename FlatmmPipeline::BLayout > BLayout
Definition: flatmm_kernel.hpp:256
static constexpr index_t NumDTensor
Definition: flatmm_kernel.hpp:268
static CK_TILE_HOST const std::string GetName()
Definition: flatmm_kernel.hpp:279
static constexpr CK_TILE_HOST FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> MakeKernelArgs(const ScaleFlatmmHostArgs< ScaleM, ScaleN, DsDataType::size()> &hostArgs)
Definition: flatmm_kernel.hpp:334
static constexpr index_t kBlockSize
Definition: flatmm_kernel.hpp:260
remove_cvref_t< typename FlatmmPipeline::ALayout > ALayout
Definition: flatmm_kernel.hpp:255
static constexpr auto I3
Definition: flatmm_kernel.hpp:273
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: flatmm_kernel.hpp:286
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: flatmm_kernel.hpp:761
static constexpr auto I1
Definition: flatmm_kernel.hpp:271
static CK_TILE_DEVICE void RunFlatmm(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_ping, void *smem_ptr_pong, const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Definition: flatmm_kernel.hpp:842
remove_cvref_t< typename FlatmmPipeline::BDataType > BDataType
Definition: flatmm_kernel.hpp:264
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: flatmm_kernel.hpp:405
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPongSize()
Definition: flatmm_kernel.hpp:356
Definition: flatmm_kernel.hpp:15
index_t stride_C
Definition: flatmm_kernel.hpp:28
CK_TILE_HOST FlatmmProblem()=default
index_t M
Definition: flatmm_kernel.hpp:23
index_t stride_B
Definition: flatmm_kernel.hpp:27
CK_TILE_HOST FlatmmProblem(index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
Definition: flatmm_kernel.hpp:17
index_t stride_A
Definition: flatmm_kernel.hpp:26
index_t N
Definition: flatmm_kernel.hpp:24
index_t K
Definition: flatmm_kernel.hpp:25
constexpr CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t) const
Definition: flatmm_kernel.hpp:120
constexpr CK_TILE_HOST_DEVICE FlatmmScalePointer(const float *, index_t)
Definition: flatmm_kernel.hpp:118
constexpr CK_TILE_HOST_DEVICE FlatmmScalePointer()=default
constexpr CK_TILE_HOST_DEVICE FlatmmScalePointer(const float *)
Definition: flatmm_kernel.hpp:117
constexpr CK_TILE_HOST_DEVICE float operator[](index_t) const
Definition: flatmm_kernel.hpp:124
CK_TILE_HOST_DEVICE FlatmmScalePointer()=default
CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const
Definition: flatmm_kernel.hpp:81
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float *ptr_, index_t length_)
Definition: flatmm_kernel.hpp:76
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float *ptr_)
Definition: flatmm_kernel.hpp:75
CK_TILE_HOST_DEVICE float operator[](index_t i) const
Definition: flatmm_kernel.hpp:97
const float * ptr
Definition: flatmm_kernel.hpp:69
index_t length
Definition: flatmm_kernel.hpp:72
Definition: flatmm_kernel.hpp:33
static constexpr int GranularityMN
Definition: flatmm_kernel.hpp:34
CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const
Definition: flatmm_kernel.hpp:46
const float * ptr
Definition: flatmm_kernel.hpp:37
CK_TILE_HOST_DEVICE FlatmmScalePointer()=default
static constexpr int GranularityK
Definition: flatmm_kernel.hpp:35
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float *ptr_)
Definition: flatmm_kernel.hpp:40
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float *ptr_, [[maybe_unused]] index_t length_)
Definition: flatmm_kernel.hpp:41
CK_TILE_HOST_DEVICE float operator[](index_t i) const =delete
Definition: flatmm_kernel.hpp:187
CK_TILE_HOST ScaleFlatmmHostArgs()=default
ScaleM scale_m
Definition: flatmm_kernel.hpp:219
ScaleN scale_n
Definition: flatmm_kernel.hpp:220
CK_TILE_HOST ScaleFlatmmHostArgs(const void *a_ptr_, const void *b_shuffle_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *c_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_C_, ScaleM scale_m_=nullptr, ScaleN scale_n_=nullptr)
Definition: flatmm_kernel.hpp:189
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:115
Definition: coordinate_transform.hpp:1392
Definition: sequence.hpp:49
Definition: functional.hpp:43