16 template <
typename TilePartitioner_, 
typename MXFlatmmPipeline_, 
typename EpiloguePipeline_>
 
   46     static constexpr 
int MXdlPack = FlatmmPipeline::MXdlPack;
 
   47     static constexpr 
int NXdlPack = FlatmmPipeline::NXdlPack;
 
   48     static constexpr 
int KXdlPack = FlatmmPipeline::KXdlPack;
 
   59     static_assert(DsLayout::size() == DsDataType::size(),
 
   60                   "The size of DsLayout and DsDataType should be the same");
 
   66         return concat(
'_', 
"mx_flatmm_gemm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
 
   70     template <
class ScaleM, 
class ScaleN>
 
   80             int dync_smem_size       = 0;
 
   81             int maxActiveBlocksPerCU = 0;
 
   83             if(hipGetDeviceProperties(&prop, deviceId) != hipSuccess)
 
   84                 throw std::runtime_error(std::string(
"hipGetDeviceProperties failed: ") +
 
   85                                          hipGetErrorName(hipGetLastError()));
 
   87             if(hipOccupancyMaxActiveBlocksPerMultiprocessor(
 
   88                    &maxActiveBlocksPerCU,
 
   89                    reinterpret_cast<void*
>(
 
   92                    dync_smem_size) != hipSuccess)
 
   93                 throw std::runtime_error(
 
   94                     std::string(
"hipOccupancyMaxActiveBlocksPerMultiprocessor failed: ") +
 
   95                     hipGetErrorName(hipGetLastError()));
 
   97             const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
 
   98             const int total_work_tile_cnt   = TilePartitioner::GridSize(kargs.M, kargs.N);
 
  104             if(kargs.k_batch != 1)
 
  105                 throw std::runtime_error(
"Wrong! k_batch != 1 not supported in persistent kernel");
 
  106             return dim3(
min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
 
  110             return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
 
  116     template <memory_operation_enum DstInMemOp = memory_operation_enum::set, 
class KernelArgs>
 
  120                         const std::array<const void*, NumDTensor>& ds_ptr,
 
  122                         const KernelArgs& kargs,
 
  125         const auto& a_tensor_view = [&]() {
 
  126             if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
 
  128                 return make_naive_tensor_view<address_space_enum::global>(
 
  130                     make_tuple(kargs.M, splitk_batch_offset.splitted_k),
 
  132                     number<FlatmmPipeline::GetVectorSizeA()>{},
 
  137                 return make_naive_tensor_view<address_space_enum::global>(
 
  139                     make_tuple(splitk_batch_offset.splitted_k, kargs.M),
 
  141                     number<FlatmmPipeline::GetVectorSizeA()>{},
 
  146         index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(
I1);
 
  147         index_t kFlatN = kargs.N * kargs.K / kFlatK;
 
  149         const auto& b_flat_tensor_view = [&]() {
 
  150             return make_naive_tensor_view<address_space_enum::global>(
 
  154                 number<FlatmmPipeline::GetVectorSizeB()>{},
 
  162                 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
 
  164                     return make_naive_tensor_view<address_space_enum::global>(
 
  165                         static_cast<const DDataType_*
>(ds_ptr[i]),
 
  168                         number<EpiloguePipeline::GetVectorSizeD(i)>{},
 
  173                     return make_naive_tensor_view<address_space_enum::global>(
 
  174                         static_cast<const DDataType_*
>(ds_ptr[i]),
 
  177                         number<EpiloguePipeline::GetVectorSizeD(i)>{},
 
  184         const auto& e_tensor_view = [&]() {
 
  185             if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
 
  187                 return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
 
  191                     number<EpiloguePipeline::GetVectorSizeC()>{},
 
  196                 return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
 
  205         auto scale_a = kargs.scale_m_ptr;
 
  206         auto scale_b = kargs.scale_n_ptr;
 
  208         static constexpr 
int BlockScaleSize = 32; 
 
  214         const auto& scale_a_tensor_view = [&]() {
 
  225             return make_tensor_view<address_space_enum::global>(
 
  226                 reinterpret_cast<const int32_t*
>(scale_a.ptr), scale_a_desc);
 
  230         const auto& scale_b_tensor_view = [&]() {
 
  240             return make_tensor_view<address_space_enum::global>(
 
  241                 reinterpret_cast<const int32_t*
>(scale_b.ptr), scale_b_desc);
 
  249                           scale_b_tensor_view);
 
  252     template <
typename TensorView>
 
  255         const auto& a_pad_view = [&]() {
 
  256             const auto& a_tensor_view = views.at(
I0);
 
  257             if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
 
  273         const auto& b_flat_tensor_view = views.at(
I1);
 
  277                 const auto& d_tensor_view = views.at(
I2);
 
  279                 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
 
  297         const auto& e_pad_view = [&]() {
 
  298             const auto& e_tensor_view = views.at(
I3);
 
  299             if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
 
  316             a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view, views.at(
I4), views.at(
I5));
 
  319     template <
typename PadView>
 
  323         const auto& a_pad_view      = views.at(
I0);
 
  324         const auto& b_flat_pad_view = views.at(
I1);
 
  325         const auto& ds_pad_view     = views.at(
I2);
 
  326         const auto& e_pad_view      = views.at(
I3);
 
  328         const auto& a_block_window = [&]() {
 
  329             if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
 
  345         const auto& b_flat_block_window =
 
  349                              {
static_cast<int>(i_n / BlockGemmShape::WarpTile::at(
I1)), 0});
 
  354                 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
 
  376         static constexpr 
int BlockScaleSize = 32;
 
  381                        number<TilePartitioner::KPerBlock / (BlockScaleSize * 
KXdlPack)>{}),
 
  387                        number<TilePartitioner::KPerBlock / (BlockScaleSize * 
KXdlPack)>{}),
 
  394                           scale_a_block_window,
 
  395                           scale_b_block_window);
 
  398     template <
class ScaleM, 
class ScaleN, 
bool UseDefaultScheduler = true>
 
  402               const std::array<const void*, NumDTensor>& ds_ptr,
 
  412         const auto& gemm_tensor_views_tuple =
 
  413             MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
 
  414                 a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
 
  418         const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
 
  421         const auto& a_block_window       = gemm_tile_windows.at(
I0);
 
  422         const auto& b_flat_block_window  = gemm_tile_windows.at(
I1);
 
  423         const auto& d_block_window       = gemm_tile_windows.at(
I2);
 
  424         const auto& scale_a_block_window = gemm_tile_windows.at(
I4);
 
  425         const auto& scale_b_block_window = gemm_tile_windows.at(
I5);
 
  427         static_assert(ScaleM::GranularityK == ScaleN::GranularityK 
 
  428                           || ScaleM::GranularityMN == -1           
 
  429                           || ScaleN::GranularityMN == -1,          
 
  430                       "ScaleM and ScaleN should have the same GranularityK");
 
  431         constexpr 
bool DoEpiScale =
 
  432             (ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) || 
 
  433             (ScaleN::GranularityMN != -1 && ScaleN::GranularityK == 0);   
 
  435         auto a_block_window_with_distr =
 
  437                                       a_block_window.get_window_lengths(),
 
  438                                       a_block_window.get_window_origin(),
 
  439                                       FlatmmPipeline::GetADramTileDistribution());
 
  440         const auto& c_block_tile = 
FlatmmPipeline{}(a_block_window_with_distr,
 
  442                                                     scale_a_block_window,
 
  443                                                     scale_b_block_window,
 
  449         if constexpr(DoEpiScale)
 
  451             auto& c_block_window = gemm_tile_windows.at(
I3);
 
  456                                kargs.scale_m_ptr + block_idx_m,
 
  457                                kargs.scale_n_ptr + block_idx_n);
 
  459         else if(UseDefaultScheduler || (get_warp_id() == 0))
 
  462             auto& c_block_window = gemm_tile_windows.at(
I3);
 
  463             EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
 
  467     template <
class ScaleM, 
class ScaleN>
 
  469                                    int partition_idx = blockIdx.x)
 const 
  471         int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
 
  475             const auto [iM, iN] =
 
  483                                      splitk_batch_offset.a_k_split_offset / 
APackedSize;
 
  485                                           splitk_batch_offset.b_k_split_offset / 
BPackedSize;
 
  493                            EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
 
  496                 constexpr 
auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
 
  497                 RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
 
  511                               "Unimplemented: atomic_add with odd vector size for fp16/bf16");
 
  513             partition_idx += gridDim.x;
 
#define CK_TILE_DEVICE
Definition: config.hpp:41
 
#define CK_TILE_HOST
Definition: config.hpp:40
 
Definition: cluster_descriptor.hpp:13
 
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
 
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
 
__global__ void kentry(Args... args)
Definition: kernel_launch.hpp:22
 
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1615
 
int32_t index_t
Definition: integer.hpp:9
 
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
 
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
 
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
 
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:203
 
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor_packed(const tuple< Lengths... > &lengths, number< GuaranteedLastDimensionVectorLength >=number<-1 >{})
Definition: tensor_descriptor.hpp:371
 
int32_t int32_t
Definition: integer.hpp:10
 
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
 
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
 
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
 
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
 
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
 
__device__ X atomic_add(X *p_dst, const X &x)
 
Definition: flatmm_kernel.hpp:229
 
Definition: flatmm_kernel.hpp:249
 
static constexpr CK_TILE_HOST auto BlockSize()
Definition: flatmm_kernel.hpp:330
 
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPingSize()
Definition: flatmm_kernel.hpp:352
 
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPongSize()
Definition: flatmm_kernel.hpp:356
 
Definition: mx_flatmm_kernel.hpp:18
 
remove_cvref_t< typename FlatmmPipeline::CLayout > ELayout
Definition: mx_flatmm_kernel.hpp:28
 
static CK_TILE_HOST const std::string GetName()
Definition: mx_flatmm_kernel.hpp:63
 
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: mx_flatmm_kernel.hpp:25
 
static constexpr index_t NumDTensor
Definition: mx_flatmm_kernel.hpp:50
 
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition: mx_flatmm_kernel.hpp:29
 
remove_cvref_t< typename FlatmmPipeline::BLayout > BLayout
Definition: mx_flatmm_kernel.hpp:27
 
static constexpr int NThreadPerXdl
Definition: mx_flatmm_kernel.hpp:40
 
static constexpr int NXdlPack
Definition: mx_flatmm_kernel.hpp:47
 
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: mx_flatmm_kernel.hpp:118
 
static constexpr CK_TILE_HOST auto GridSize(const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs)
Definition: mx_flatmm_kernel.hpp:72
 
static constexpr auto I2
Definition: mx_flatmm_kernel.hpp:54
 
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: mx_flatmm_kernel.hpp:37
 
static constexpr bool UsePersistentKernel
Definition: mx_flatmm_kernel.hpp:32
 
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: mx_flatmm_kernel.hpp:321
 
static constexpr auto I4
Definition: mx_flatmm_kernel.hpp:56
 
static constexpr auto I1
Definition: mx_flatmm_kernel.hpp:53
 
static constexpr int MXdlPack
Definition: mx_flatmm_kernel.hpp:46
 
static constexpr auto I0
Definition: mx_flatmm_kernel.hpp:52
 
remove_cvref_t< typename FlatmmPipeline::ALayout > ALayout
Definition: mx_flatmm_kernel.hpp:26
 
remove_cvref_t< typename MXFlatmmPipeline_::BlockGemmShape > BlockGemmShape
Definition: mx_flatmm_kernel.hpp:24
 
static constexpr int MThreadPerXdl
Definition: mx_flatmm_kernel.hpp:39
 
remove_cvref_t< typename FlatmmPipeline::ADataType > ADataType
Definition: mx_flatmm_kernel.hpp:34
 
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: mx_flatmm_kernel.hpp:21
 
static constexpr auto I3
Definition: mx_flatmm_kernel.hpp:55
 
typename Underlying::SplitKBatchOffset SplitKBatchOffset
Definition: mx_flatmm_kernel.hpp:114
 
CK_TILE_DEVICE void operator()(FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> kargs, int partition_idx=blockIdx.x) const
Definition: mx_flatmm_kernel.hpp:468
 
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: mx_flatmm_kernel.hpp:30
 
remove_cvref_t< typename FlatmmPipeline::BDataType > BDataType
Definition: mx_flatmm_kernel.hpp:35
 
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: mx_flatmm_kernel.hpp:253
 
static constexpr index_t KernelBlockSize
Definition: mx_flatmm_kernel.hpp:31
 
remove_cvref_t< MXFlatmmPipeline_ > FlatmmPipeline
Definition: mx_flatmm_kernel.hpp:22
 
static CK_TILE_DEVICE void RunFlatmm(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_ping, void *smem_ptr_pong, const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Definition: mx_flatmm_kernel.hpp:400
 
static constexpr int KXdlPack
Definition: mx_flatmm_kernel.hpp:48
 
static constexpr auto I5
Definition: mx_flatmm_kernel.hpp:57
 
static constexpr int KThreadPerXdl
Definition: mx_flatmm_kernel.hpp:41
 
static constexpr int BPackedSize
Definition: mx_flatmm_kernel.hpp:44
 
static constexpr int APackedSize
Definition: mx_flatmm_kernel.hpp:43
 
Definition: integral_constant.hpp:13
 
Definition: type_traits.hpp:115
 
Definition: numeric.hpp:81
 
Definition: sequence.hpp:49