16 template <
typename TilePartitioner_,
typename MXFlatmmPipeline_,
typename EpiloguePipeline_>
46 static constexpr
int MXdlPack = MXFlatmmPipeline::MXdlPack;
47 static constexpr
int NXdlPack = MXFlatmmPipeline::NXdlPack;
48 static constexpr
int KXdlPack = MXFlatmmPipeline::KXdlPack;
59 static_assert(DsLayout::size() == DsDataType::size(),
60 "The size of DsLayout and DsDataType should be the same");
66 return concat(
'_',
"mx_flatmm_gemm", gemm_prec_str<ADataType, BDataType>, MXFlatmmPipeline::GetName());
70 template <
class ScaleM,
class ScaleN>
80 int dync_smem_size = 0;
81 int maxActiveBlocksPerCU = 0;
83 if(hipGetDeviceProperties(&prop, deviceId) != hipSuccess)
84 throw std::runtime_error(std::string(
"hipGetDeviceProperties failed: ") +
85 hipGetErrorName(hipGetLastError()));
87 if(hipOccupancyMaxActiveBlocksPerMultiprocessor(
88 &maxActiveBlocksPerCU,
89 reinterpret_cast<void*
>(
92 dync_smem_size) != hipSuccess)
93 throw std::runtime_error(
94 std::string(
"hipOccupancyMaxActiveBlocksPerMultiprocessor failed: ") +
95 hipGetErrorName(hipGetLastError()));
97 const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
98 const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
104 if(kargs.k_batch != 1)
105 throw std::runtime_error(
"Wrong! k_batch != 1 not supported in persistent kernel");
106 return dim3(
min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
110 return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
116 template <memory_operation_enum DstInMemOp = memory_operation_enum::set,
class KernelArgs>
120 const std::array<const void*, NumDTensor>& ds_ptr,
122 const KernelArgs& kargs,
125 const auto& a_tensor_view = [&]() {
126 static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>,
127 "A tensor for mx must be RowMajor");
128 return make_naive_tensor_view<address_space_enum::global>(
130 make_tuple(kargs.M, splitk_batch_offset.splitted_k),
132 number<MXFlatmmPipeline::GetVectorSizeA()>{},
136 constexpr
index_t kKPerBlock = MXFlatmmPipeline::kKPerBlock;
137 constexpr
index_t kNWarpTile = BlockGemmShape::WarpTile::at(
I1);
138 constexpr
index_t flatKPerBlock = kKPerBlock * kNWarpTile;
139 const index_t kFlatKBlocks = kargs.K / kKPerBlock;
140 const index_t kFlatN = kargs.N / kNWarpTile;
141 const auto& b_flat_tensor_view = [&]() {
142 static_assert(flatKPerBlock % MXFlatmmPipeline::GetVectorSizeB() == 0,
143 "wrong! vector size for B tensor");
153 return make_tensor_view<address_space_enum::global>(b_flat_ptr, desc);
160 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
162 return make_naive_tensor_view<address_space_enum::global>(
163 static_cast<const DDataType_*
>(ds_ptr[i]),
166 number<EpiloguePipeline::GetVectorSizeD(i)>{},
171 return make_naive_tensor_view<address_space_enum::global>(
172 static_cast<const DDataType_*
>(ds_ptr[i]),
175 number<EpiloguePipeline::GetVectorSizeD(i)>{},
182 const auto& e_tensor_view = [&]() {
183 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
185 return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
189 number<EpiloguePipeline::GetVectorSizeC()>{},
194 return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
203 auto scale_a = kargs.scale_m_ptr;
204 auto scale_b = kargs.scale_n_ptr;
206 static constexpr
int BlockScaleSize = 32;
212 const auto& scale_a_tensor_view = [&]() {
223 return make_tensor_view<address_space_enum::global>(
224 reinterpret_cast<const int32_t*
>(scale_a.ptr), scale_a_desc);
228 const auto& scale_b_tensor_view = [&]() {
238 return make_tensor_view<address_space_enum::global>(
239 reinterpret_cast<const int32_t*
>(scale_b.ptr), scale_b_desc);
247 scale_b_tensor_view);
250 template <
typename TensorView>
253 const auto& a_pad_view = [&]() {
254 const auto& a_tensor_view = views.at(
I0);
255 static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>,
256 "A tensor for mx must be RowMajor");
263 const auto& b_flat_tensor_view = views.at(
I1);
267 const auto& d_tensor_view = views.at(
I2);
269 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
287 const auto& e_pad_view = [&]() {
288 const auto& e_tensor_view = views.at(
I3);
289 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
306 a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view, views.at(
I4), views.at(
I5));
309 template <
typename PadView>
313 const auto& a_pad_view = views.at(
I0);
314 const auto& b_flat_pad_view = views.at(
I1);
315 const auto& ds_pad_view = views.at(
I2);
316 const auto& e_pad_view = views.at(
I3);
318 const auto& a_block_window = [&]() {
319 static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>,
320 "A tensor for mx must be RowMajor");
327 const auto& b_flat_block_window =
331 {
static_cast<int>(i_n / BlockGemmShape::WarpTile::at(
I1)), 0});
336 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
358 static constexpr
int BlockScaleSize = 32;
363 number<TilePartitioner::KPerBlock / (BlockScaleSize *
KXdlPack)>{}),
369 number<TilePartitioner::KPerBlock / (BlockScaleSize *
KXdlPack)>{}),
376 scale_a_block_window,
377 scale_b_block_window);
380 template <
class ScaleM,
class ScaleN,
bool UseDefaultScheduler = true>
384 const std::array<const void*, NumDTensor>& ds_ptr,
394 const auto& gemm_tensor_views_tuple =
395 MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
396 a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
400 const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
403 const auto& a_block_window = gemm_tile_windows.at(
I0);
404 const auto& b_flat_block_window = gemm_tile_windows.at(
I1);
405 const auto& d_block_window = gemm_tile_windows.at(
I2);
406 const auto& scale_a_block_window = gemm_tile_windows.at(
I4);
407 const auto& scale_b_block_window = gemm_tile_windows.at(
I5);
409 static_assert(ScaleM::GranularityK == ScaleN::GranularityK
410 || ScaleM::GranularityMN == -1
411 || ScaleN::GranularityMN == -1,
412 "ScaleM and ScaleN should have the same GranularityK");
413 constexpr
bool DoEpiScale =
414 (ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) ||
415 (ScaleN::GranularityMN != -1 && ScaleN::GranularityK == 0);
417 auto a_block_window_with_distr =
419 a_block_window.get_window_lengths(),
420 a_block_window.get_window_origin(),
421 MXFlatmmPipeline::GetADramTileDistribution());
424 scale_a_block_window,
425 scale_b_block_window,
431 if constexpr(DoEpiScale)
433 auto& c_block_window = gemm_tile_windows.at(
I3);
438 kargs.scale_m_ptr + block_idx_m,
439 kargs.scale_n_ptr + block_idx_n);
441 else if(UseDefaultScheduler || (get_warp_id() == 0))
444 auto& c_block_window = gemm_tile_windows.at(
I3);
445 EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
449 template <
class ScaleM,
class ScaleN>
451 int partition_idx = blockIdx.x)
const
453 int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
457 const auto [iM, iN] =
464 const auto a_ptr =
static_cast<const ADataType*
>(kargs.a_ptr) +
465 splitk_batch_offset.a_k_split_offset /
APackedSize;
466 const auto b_flat_ptr =
static_cast<const BDataType*
>(kargs.b_ptr) +
467 splitk_batch_offset.b_k_split_offset /
BPackedSize;
475 EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
478 constexpr
auto scheduler_type = (MXFlatmmPipeline::NumWaveGroups == 1);
479 RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
493 "Unimplemented: atomic_add with odd vector size for fp16/bf16");
495 partition_idx += gridDim.x;
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
__global__ void kentry(Args... args)
Definition: kernel_launch.hpp:22
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1690
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1633
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1684
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:203
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor_packed(const tuple< Lengths... > &lengths, number< GuaranteedLastDimensionVectorLength >=number<-1 >{})
Definition: tensor_descriptor.hpp:371
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__device__ X atomic_add(X *p_dst, const X &x)
Definition: flatmm_kernel.hpp:229
Definition: flatmm_kernel.hpp:249
static constexpr CK_TILE_HOST auto BlockSize()
Definition: flatmm_kernel.hpp:330
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPingSize()
Definition: flatmm_kernel.hpp:352
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPongSize()
Definition: flatmm_kernel.hpp:356
Definition: mx_flatmm_kernel.hpp:18
static CK_TILE_HOST const std::string GetName()
Definition: mx_flatmm_kernel.hpp:63
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: mx_flatmm_kernel.hpp:25
static constexpr index_t NumDTensor
Definition: mx_flatmm_kernel.hpp:50
remove_cvref_t< typename MXFlatmmPipeline::BLayout > BLayout
Definition: mx_flatmm_kernel.hpp:27
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition: mx_flatmm_kernel.hpp:29
static constexpr int NThreadPerXdl
Definition: mx_flatmm_kernel.hpp:40
static constexpr int NXdlPack
Definition: mx_flatmm_kernel.hpp:47
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: mx_flatmm_kernel.hpp:118
static constexpr CK_TILE_HOST auto GridSize(const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs)
Definition: mx_flatmm_kernel.hpp:72
static constexpr auto I2
Definition: mx_flatmm_kernel.hpp:54
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: mx_flatmm_kernel.hpp:37
static constexpr bool UsePersistentKernel
Definition: mx_flatmm_kernel.hpp:32
remove_cvref_t< typename MXFlatmmPipeline::CLayout > ELayout
Definition: mx_flatmm_kernel.hpp:28
remove_cvref_t< typename MXFlatmmPipeline::ALayout > ALayout
Definition: mx_flatmm_kernel.hpp:26
remove_cvref_t< typename MXFlatmmPipeline::BDataType > BDataType
Definition: mx_flatmm_kernel.hpp:35
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: mx_flatmm_kernel.hpp:311
static constexpr auto I4
Definition: mx_flatmm_kernel.hpp:56
remove_cvref_t< MXFlatmmPipeline_ > MXFlatmmPipeline
Definition: mx_flatmm_kernel.hpp:22
static constexpr auto I1
Definition: mx_flatmm_kernel.hpp:53
static constexpr int MXdlPack
Definition: mx_flatmm_kernel.hpp:46
static constexpr auto I0
Definition: mx_flatmm_kernel.hpp:52
remove_cvref_t< typename MXFlatmmPipeline_::BlockGemmShape > BlockGemmShape
Definition: mx_flatmm_kernel.hpp:24
static constexpr int MThreadPerXdl
Definition: mx_flatmm_kernel.hpp:39
remove_cvref_t< typename MXFlatmmPipeline::ADataType > ADataType
Definition: mx_flatmm_kernel.hpp:34
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: mx_flatmm_kernel.hpp:21
static constexpr auto I3
Definition: mx_flatmm_kernel.hpp:55
typename Underlying::SplitKBatchOffset SplitKBatchOffset
Definition: mx_flatmm_kernel.hpp:114
CK_TILE_DEVICE void operator()(FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> kargs, int partition_idx=blockIdx.x) const
Definition: mx_flatmm_kernel.hpp:450
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: mx_flatmm_kernel.hpp:30
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: mx_flatmm_kernel.hpp:251
static constexpr index_t KernelBlockSize
Definition: mx_flatmm_kernel.hpp:31
static CK_TILE_DEVICE void RunFlatmm(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_ping, void *smem_ptr_pong, const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Definition: mx_flatmm_kernel.hpp:382
static constexpr int KXdlPack
Definition: mx_flatmm_kernel.hpp:48
static constexpr auto I5
Definition: mx_flatmm_kernel.hpp:57
static constexpr int KThreadPerXdl
Definition: mx_flatmm_kernel.hpp:41
static constexpr int BPackedSize
Definition: mx_flatmm_kernel.hpp:44
static constexpr int APackedSize
Definition: mx_flatmm_kernel.hpp:43
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:115
Definition: numeric.hpp:81
Definition: sequence.hpp:49