21 template <
typename,
typename Default,
typename =
void>
27 template <
typename T,
typename Default>
30 using type =
typename T::AQLayout;
33 template <
typename,
typename Default,
typename =
void>
39 template <
typename T,
typename Default>
42 using type =
typename T::BQLayout;
45 template <
typename,
typename Default,
typename =
void>
51 template <
typename T,
typename Default>
54 using type =
typename T::AQDataType;
57 template <
typename,
typename Default,
typename =
void>
63 template <
typename T,
typename Default>
66 using type =
typename T::BQDataType;
69 template <
typename,
typename =
void>
72 static constexpr
bool value =
false;
78 static constexpr
bool value = T::PreshuffleQuant;
81 template <
typename,
typename =
void>
84 static constexpr
bool value =
false;
90 static constexpr
bool value = T::PreshuffleB;
152 M_, N_, K_, QK_A_, QK_B_, stride_A_, stride_B_, stride_C_, stride_AQ_, stride_BQ_),
190 template <
typename TilePartitioner_,
191 typename GemmPipeline_,
192 typename EpiloguePipeline_,
234 return concat(
'_',
"gemm_quant", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
240 return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
271 return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
287 typename BQDataType_>
298 const auto bq_x = N * KPerBlockBQ;
299 const auto bq_y = QK_B / KPerBlockBQ;
308 const auto block_tile_size = NPerBlockBQ * KPerBlockBQ;
322 const auto pad_bq_x = bq_pad0_desc.get_lengths()[
I1];
323 const auto wave_tile_size = ((QN_B <= WarpTileN) ? (WarpTileN / QN_B) : 1) * KPerBlockBQ;
338 bq_unmerge_pad0_desc,
343 make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
344 make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
360 return make_tensor_view<address_space_enum::global>(bq_ptr, bq_merge_pad1_desc);
367 const std::size_t k_id = blockIdx.z)
369 constexpr
auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(
I2);
373 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
377 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
382 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
386 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
412 const auto& a_tensor_view = [&]() {
413 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
415 return make_naive_tensor_view<address_space_enum::global>(
419 number<GemmPipeline::GetVectorSizeA()>{},
424 return make_naive_tensor_view<address_space_enum::global>(
428 number<GemmPipeline::GetVectorSizeA()>{},
434 const auto& a_pad_view = [&]() {
435 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
452 const auto& a_block_window = [&]() {
453 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
469 return a_block_window;
478 const auto& aq_tensor_view = [&]() {
481 static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
482 const auto aq_x = kargs.
M * GemmPipeline::KPerBlockAQ;
483 const auto aq_y = kargs.
QK_A / GemmPipeline::KPerBlockAQ;
487 number<GemmPipeline::GetVectorSizeAQ()>{},
490 const auto block_tile_size = GemmPipeline::MPerBlock * GemmPipeline::KPerBlockAQ;
499 const auto pad_aq_x = aq_pad0_desc.get_lengths()[
I1];
500 const auto wave_tile_size =
501 GemmPipeline::BlockGemmShape::WarpTile::at(
I0) * GemmPipeline::KPerBlockAQ;
502 const auto wave_tile_count_x =
514 aq_unmerge_pad0_desc,
519 wave_tile_size, get_padding_size(wave_tile_size,
get_warp_size()))),
523 const auto pad_wave_size =
532 return make_tensor_view<address_space_enum::global>(aq_ptr, aq_merge_pad1_desc);
538 if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
540 return make_naive_tensor_view<address_space_enum::global>(
544 number<GemmPipeline::GetVectorSizeAQ()>{},
549 return make_naive_tensor_view<address_space_enum::global>(
553 number<GemmPipeline::GetVectorSizeAQ()>{},
559 return make_naive_tensor_view<address_space_enum::global>(
573 const auto& aq_block_window = [&]() {
576 static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
578 constexpr
auto block_m = TilePartitioner::MPerBlock;
579 constexpr
auto warp_m = GemmPipeline::BlockGemmShape::WarpTile::at(
I0);
580 constexpr
auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
581 constexpr
auto tile_window_width =
583 constexpr
auto tile_window_height = block_m / warp_m;
584 auto block_m_idx = i_m / block_m;
588 {block_m_idx * tile_window_height, 0});
593 constexpr
auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
594 constexpr
auto block_m = TilePartitioner::MPerBlock;
595 if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
610 static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
612 constexpr
auto block_m = TilePartitioner::MPerBlock;
613 constexpr
auto block_k = TilePartitioner::KPerBlock;
632 return aq_block_window;
641 const auto& b_tensor_view = [&]() {
642 if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
644 if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
646 constexpr
index_t K1 = GemmPipeline::GetSmemPackB();
647 const index_t K0 = k_size / K1;
648 constexpr
index_t VectorSizeB =
std::min(K1, GemmPipeline::GetVectorSizeB());
649 const auto b_k0_n_k1_desc =
660 return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
664 return make_naive_tensor_view<address_space_enum::global>(
668 number<GemmPipeline::GetVectorSizeB()>{},
674 if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
676 constexpr
index_t K1 = GemmPipeline::GetSmemPackB();
677 const index_t K0 = k_size / K1;
678 constexpr
index_t VectorSizeB =
std::min(K1, GemmPipeline::GetVectorSizeB());
679 const auto b_k0_n_k1_desc =
690 return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
697 GemmPipeline::flatKPerWarp *
698 (k_size / GemmPipeline::BlockGemmShape::WarpTile::at(
number<2>{}));
699 index_t kFlatN = kargs.
N * kargs.
K / kFlatK;
700 return make_naive_tensor_view<address_space_enum::global>(
704 number<GemmPipeline::GetVectorSizeB()>{},
709 if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
710 return make_naive_tensor_view<address_space_enum::global>(
714 number<GemmPipeline::GetVectorSizeB()>{},
717 return make_naive_tensor_view<address_space_enum::global>(
721 number<GemmPipeline::GetVectorSizeB()>{},
729 const auto& b_pad_view = [&]() {
732 return b_tensor_view;
734 else if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
736 if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
739 number<TilePartitioner::KPerBlock / 2>{}),
757 const auto& b_block_window = [&]() {
764 {
static_cast<int>(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(
I1)), 0});
768 if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
770 if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
774 number<TilePartitioner::KPerBlock / 2>{}),
792 return b_block_window;
801 const auto& bq_tensor_view = [&]() {
804 return make_naive_tensor_view<address_space_enum::global>(
815 static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
816 "PreshuffleQuant with BQuantGrouped currently only supports "
817 "ColumnMajor BQ layout");
820 return MakePreshuffledQuantTensorView<
821 GemmPipeline::KPerBlockBQ,
822 GemmPipeline::NPerBlockBQ,
823 GemmPipeline::NPerBlock,
824 TilePartitioner::BlockGemmShape::WarpTile::at(
I1),
825 GemmPipeline::GetVectorSizeBQ()>(
835 if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
837 return make_naive_tensor_view<address_space_enum::global>(
842 number<GemmPipeline::GetVectorSizeBQ()>{},
847 static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
848 return make_naive_tensor_view<address_space_enum::global>(
853 number<GemmPipeline::GetVectorSizeBQ()>{},
860 static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
862 return make_naive_tensor_view<address_space_enum::global>(
866 number<GemmPipeline::GetVectorSizeBQ()>{},
876 const auto& bq_block_window = [&]() {
889 static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
890 constexpr
auto block_n =
891 TilePartitioner::NPerBlock /
893 constexpr
auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(
895 constexpr
auto warp_per_group =
896 (QuantGroupSize::kN <
898 ? (warp_n / QuantGroupSize::kN)
899 : (QuantGroupSize::kN / warp_n);
900 constexpr
auto bqk_per_block =
901 TilePartitioner::KPerBlock /
917 constexpr
auto tile_window_height =
918 (QuantGroupSize::kN < warp_n) ? block_n / warp_per_group : block_n;
920 i_n / TilePartitioner::NPerBlock;
926 {block_n_idx * tile_window_height, 0});
930 if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
935 number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}),
936 {0, i_n / QuantGroupSize::kN});
940 static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
944 number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
945 {i_n / QuantGroupSize::kN, 0});
951 static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
956 number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
957 {i_n / QuantGroupSize::kN, 0});
965 return bq_block_window;
968 template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
975 const auto& c_tensor_view = [&]() {
976 if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
978 return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
982 number<EpiloguePipeline::GetVectorSizeC()>{},
987 return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
997 const auto& c_pad_view = [&]() {
998 if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
1020 return c_block_window;
1034 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
1036 if(kargs.
K % (TilePartitioner::KPerBlock * kargs.
k_batch) != 0 &&
1037 GemmPipeline::kPadK ==
false)
1041 CK_TILE_ERROR(
"Can't support K that is not a multiple of k_batch * KPerBlock "
1042 "without padding!");
1046 if(kargs.
K % GemmPipeline::GetVectorSizeA() != 0)
1050 CK_TILE_ERROR(
"K is not a multiple of vector load size for A tensor!");
1057 if(kargs.
M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM ==
false)
1062 "Can't support M that is not a multiple of MPerBlock without padding!");
1066 if(kargs.
M % GemmPipeline::GetVectorSizeA() != 0)
1070 CK_TILE_ERROR(
"M is not a multiple of vector load size for A tensor!");
1076 if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
1078 if(kargs.
N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN ==
false)
1083 "Can't support N that is not a multiple of NPerBlock without padding!");
1087 if(kargs.
N % GemmPipeline::GetVectorSizeB() != 0)
1091 CK_TILE_ERROR(
"N is not a multiple of vector load size for B tensor!");
1098 if(kargs.
K % (TilePartitioner::KPerBlock * kargs.
k_batch) != 0 &&
1099 GemmPipeline::kPadK ==
false)
1103 CK_TILE_ERROR(
"Can't support K that is not a multiple of k_batch * KPerBlock "
1104 "without padding!");
1108 if(kargs.
K % GemmPipeline::GetVectorSizeB() != 0)
1112 CK_TILE_ERROR(
"K is not a multiple of vector load size for B tensor!");
1118 if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
1120 if(kargs.
N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN ==
false)
1125 "Can't support N that is not a multiple of NPerBlock without padding!");
1129 if(kargs.
N % EpiloguePipeline::GetVectorSizeC() != 0)
1133 CK_TILE_ERROR(
"N is not a multiple of vector load size for C tensor!");
1140 if(kargs.
M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM ==
false)
1145 "Can't support M that is not a multiple of MPerBlock without padding!");
1149 if(kargs.
M % EpiloguePipeline::GetVectorSizeC() != 0)
1153 CK_TILE_ERROR(
"M is not a multiple of vector load size for C tensor!");
1188 const auto& a_block_window =
1190 const auto& b_block_window =
1192 const auto& aq_block_window =
MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n);
1193 const auto& bq_block_window =
MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n);
1199 const auto& c_block_tile = [&]() {
1208 a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr, m);
1218 a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr, n);
1242 a_block_window, b_block_window, num_loop, smem_ptr);
1251 auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
1252 c_ptr, kargs, block_idx_m, block_idx_n);
1258 EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
1271 const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
1272 const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
1274 c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale);
1279 auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
1280 c_ptr, kargs, block_idx_m, block_idx_n);
1286 EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
1299 const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
1300 const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
1302 c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale);
1310 const auto [iM, iN] =
TilePartitioner{kargs.
M, kargs.
N}.GetOutputTileIndex(blockId);
1328 a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad_, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1659
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:274
constexpr CK_TILE_HOST_DEVICE auto integer_least_multiple(X x, Y y)
Definition: math.hpp:151
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:145
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:36
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1690
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1633
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1697
QuantType
Definition: tile_gemm_quant_traits.hpp:12
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:203
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:157
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
unsigned int uint32_t
Definition: stdint.h:126
Definition: gemm_quant_kernel.hpp:133
void * c_ptr
Definition: gemm_quant_kernel.hpp:166
const void * aq_ptr
Definition: gemm_quant_kernel.hpp:164
const void * bq_ptr
Definition: gemm_quant_kernel.hpp:165
const void * b_ptr
Definition: gemm_quant_kernel.hpp:163
CK_TILE_HOST QuantGemmHostArgs()=default
index_t k_batch
Definition: gemm_quant_kernel.hpp:167
const void * a_ptr
Definition: gemm_quant_kernel.hpp:162
CK_TILE_HOST QuantGemmHostArgs(const void *a_ptr_, const void *b_ptr_, void *c_ptr_, const void *aq_ptr_, const void *bq_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t QK_A_, index_t QK_B_, index_t stride_A_, index_t stride_B_, index_t stride_C_, index_t stride_AQ_, index_t stride_BQ_)
Definition: gemm_quant_kernel.hpp:135
Definition: gemm_quant_kernel.hpp:365
__device__ SplitKBatchOffset(const QuantGemmKernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: gemm_quant_kernel.hpp:366
index_t a_k_split_offset
Definition: gemm_quant_kernel.hpp:401
index_t b_k_split_offset
Definition: gemm_quant_kernel.hpp:402
index_t splitted_k
Definition: gemm_quant_kernel.hpp:403
Definition: gemm_quant_kernel.hpp:171
index_t k_batch
Definition: gemm_quant_kernel.hpp:187
index_t stride_BQ
Definition: gemm_quant_kernel.hpp:186
const void * b_ptr
Definition: gemm_quant_kernel.hpp:173
void * c_ptr
Definition: gemm_quant_kernel.hpp:176
const void * aq_ptr
Definition: gemm_quant_kernel.hpp:174
index_t stride_A
Definition: gemm_quant_kernel.hpp:182
index_t M
Definition: gemm_quant_kernel.hpp:177
const void * a_ptr
Definition: gemm_quant_kernel.hpp:172
const void * bq_ptr
Definition: gemm_quant_kernel.hpp:175
index_t QK_B
Definition: gemm_quant_kernel.hpp:181
index_t K
Definition: gemm_quant_kernel.hpp:179
index_t QK_A
Definition: gemm_quant_kernel.hpp:180
index_t stride_AQ
Definition: gemm_quant_kernel.hpp:185
index_t N
Definition: gemm_quant_kernel.hpp:178
index_t stride_C
Definition: gemm_quant_kernel.hpp:184
index_t stride_B
Definition: gemm_quant_kernel.hpp:183
Definition: gemm_quant_kernel.hpp:195
static constexpr auto I4
Definition: gemm_quant_kernel.hpp:227
static constexpr auto I3
Definition: gemm_quant_kernel.hpp:226
static constexpr bool PreshuffleB
Definition: gemm_quant_kernel.hpp:211
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: gemm_quant_kernel.hpp:238
static CK_TILE_DEVICE void RunGemm(const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, const BQDataType *bq_ptr, CDataType *c_ptr, void *smem_ptr, const QuantGemmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: gemm_quant_kernel.hpp:1176
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: gemm_quant_kernel.hpp:197
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: gemm_quant_kernel.hpp:198
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: gemm_quant_kernel.hpp:196
remove_cvref_t< typename EpiloguePipeline::AccDataType > AccDataType
Definition: gemm_quant_kernel.hpp:216
static CK_TILE_DEVICE auto MakeCBlockWindow(CDataType *c_ptr, const QuantGemmKernelArgs &kargs, const index_t i_m, const index_t i_n)
Definition: gemm_quant_kernel.hpp:969
static constexpr auto I0
Definition: gemm_quant_kernel.hpp:223
CK_TILE_DEVICE void operator()(QuantGemmKernelArgs kargs) const
Definition: gemm_quant_kernel.hpp:1307
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition: gemm_quant_kernel.hpp:215
static constexpr index_t kBlockSize
Definition: gemm_quant_kernel.hpp:208
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: gemm_quant_kernel.hpp:200
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition: gemm_quant_kernel.hpp:201
static CK_TILE_DEVICE auto MakeABlockWindow(const ADataType *a_ptr, const QuantGemmKernelArgs &kargs, const index_t k_size, const index_t i_m)
Definition: gemm_quant_kernel.hpp:406
static CK_TILE_DEVICE auto MakeBQBlockWindow(const BQDataType *bq_ptr, const QuantGemmKernelArgs &kargs, const index_t i_m, const index_t i_n)
Definition: gemm_quant_kernel.hpp:795
static constexpr auto I1
Definition: gemm_quant_kernel.hpp:224
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Definition: gemm_quant_kernel.hpp:199
static constexpr bool PreshuffleQuant
Definition: gemm_quant_kernel.hpp:209
static CK_TILE_DEVICE auto MakeBBlockWindow(const BDataType *b_ptr, const QuantGemmKernelArgs &kargs, const index_t k_size, const index_t i_n)
Definition: gemm_quant_kernel.hpp:635
static CK_TILE_HOST bool IsSupportedArgument(const QuantGemmKernelArgs &kargs)
Definition: gemm_quant_kernel.hpp:1023
remove_cvref_t< typename detail::get_aq_data_type_or< GemmPipeline, AccDataType >::type > AQDataType
Definition: gemm_quant_kernel.hpp:219
remove_cvref_t< typename detail::get_bq_data_type_or< GemmPipeline, AccDataType >::type > BQDataType
Definition: gemm_quant_kernel.hpp:221
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: gemm_quant_kernel.hpp:214
static constexpr auto I2
Definition: gemm_quant_kernel.hpp:225
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: gemm_quant_kernel.hpp:269
static constexpr CK_TILE_HOST QuantGemmKernelArgs MakeKernelArgs(const QuantGemmHostArgs &hostArgs)
Definition: gemm_quant_kernel.hpp:249
static CK_TILE_DEVICE auto MakeAQBlockWindow(const AQDataType *aq_ptr, const QuantGemmKernelArgs &kargs, const index_t i_m, const index_t i_n)
Definition: gemm_quant_kernel.hpp:472
static CK_TILE_HOST const std::string GetName()
Definition: gemm_quant_kernel.hpp:231
remove_cvref_t< typename detail::get_bq_layout_or< GemmPipeline, typename GemmPipeline::BLayout >::type > BQLayout
Definition: gemm_quant_kernel.hpp:206
static CK_TILE_HOST auto BlockSize()
Definition: gemm_quant_kernel.hpp:243
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Definition: gemm_quant_kernel.hpp:213
remove_cvref_t< typename detail::get_aq_layout_or< GemmPipeline, typename GemmPipeline::ALayout >::type > AQLayout
Definition: gemm_quant_kernel.hpp:204
static constexpr auto kQuantType
Definition: gemm_quant_kernel.hpp:229
Definition: gemm_quant_kernel.hpp:95
index_t stride_AQ
Definition: gemm_quant_kernel.hpp:128
index_t N
Definition: gemm_quant_kernel.hpp:121
index_t K
Definition: gemm_quant_kernel.hpp:122
index_t stride_BQ
Definition: gemm_quant_kernel.hpp:129
index_t stride_C
Definition: gemm_quant_kernel.hpp:127
index_t stride_B
Definition: gemm_quant_kernel.hpp:126
index_t stride_A
Definition: gemm_quant_kernel.hpp:125
CK_TILE_HOST QuantGemmProblem(index_t M_, index_t N_, index_t K_, index_t QK_A_, index_t QK_B_, index_t stride_A_, index_t stride_B_, index_t stride_C_, index_t stride_AQ_, index_t stride_BQ_)
Definition: gemm_quant_kernel.hpp:97
index_t QK_A
Definition: gemm_quant_kernel.hpp:123
index_t QK_B
Definition: gemm_quant_kernel.hpp:124
CK_TILE_HOST QuantGemmProblem()=default
index_t M
Definition: gemm_quant_kernel.hpp:120
Definition: integral_constant.hpp:13
typename T::AQDataType type
Definition: gemm_quant_kernel.hpp:54
Definition: gemm_quant_kernel.hpp:47
Default type
Definition: gemm_quant_kernel.hpp:48
typename T::AQLayout type
Definition: gemm_quant_kernel.hpp:30
Definition: gemm_quant_kernel.hpp:23
Default type
Definition: gemm_quant_kernel.hpp:24
typename T::BQDataType type
Definition: gemm_quant_kernel.hpp:66
Definition: gemm_quant_kernel.hpp:59
Default type
Definition: gemm_quant_kernel.hpp:60
typename T::BQLayout type
Definition: gemm_quant_kernel.hpp:42
Definition: gemm_quant_kernel.hpp:35
Default type
Definition: gemm_quant_kernel.hpp:36
Definition: gemm_quant_kernel.hpp:83
static constexpr bool value
Definition: gemm_quant_kernel.hpp:84
Definition: gemm_quant_kernel.hpp:71
static constexpr bool value
Definition: gemm_quant_kernel.hpp:72
Definition: sequence.hpp:49
#define CK_TILE_ENV(name)
Definition: env.hpp:145