CShuffleEpilogue< Problem_, Policy_ > Struct Template Reference

CShuffleEpilogue&lt; Problem_, Policy_ &gt; Struct Template Reference#

Composable Kernel: ck_tile::CShuffleEpilogue< Problem_, Policy_ > Struct Template Reference
ck_tile::CShuffleEpilogue< Problem_, Policy_ > Struct Template Reference

#include <cshuffle_epilogue.hpp>

Public Types

using Problem = remove_cvref_t< Problem_ >
 
using ADataType = remove_cvref_t< typename Problem::ADataType >
 
using BDataType = remove_cvref_t< typename Problem::BDataType >
 
using AccDataType = remove_cvref_t< typename Problem::AccDataType >
 
using ODataType = remove_cvref_t< typename Problem::ODataType >
 
using DsDataType = remove_cvref_t< typename Problem::DsDataType >
 
using DsLayout = remove_cvref_t< typename Problem::DsLayout >
 
using ATypeToUse = std::conditional_t< std::is_same_v< ADataType, pk_int4_t >, BDataType, ADataType >
 
using BTypeToUse = std::conditional_t< std::is_same_v< BDataType, pk_int4_t >, ADataType, BDataType >
 
using ELayout = remove_cvref_t< typename Problem::ELayout >
 
using CDElementwise = remove_cvref_t< typename Problem::CDElementwise >
 
using WG = WarpGemmDispatcher< ATypeToUse, BTypeToUse, AccDataType, MPerXdl, NPerXdl, KPerXdl, isCTransposed >
 
using CWarpDstr = typename WG::CWarpDstr
 
using CWarpTensor = typename WG::CWarpTensor
 

Public Member Functions

template<typename ODramWindow , typename OAccTile , typename DsDramWindows >
CK_TILE_DEVICE auto operator() (ODramWindow &out_dram_window, const OAccTile &o_acc_tile, const DsDramWindows &ds_dram_windows, void *p_smem)
 

Static Public Member Functions

static constexpr CK_TILE_HOST_DEVICE index_t GetVectorSizeC ()
 Get the vector store size for C tensor. More...
 
template<index_t I>
static constexpr CK_TILE_HOST_DEVICE index_t GetVectorSizeD (number< I > index)
 Get the vector store size for Di tensor. More...
 
template<typename Problem >
static constexpr CK_TILE_HOST_DEVICE auto MakeLdsBlockDescriptor ()
 
static constexpr CK_TILE_DEVICE auto MakeLdsDistributionEncode ()
 
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize ()
 

Static Public Attributes

static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation
 
static constexpr index_t kBlockSize = Problem::kBlockSize
 
static constexpr index_t kMPerBlock = Problem::kMPerBlock
 
static constexpr index_t kNPerBlock = Problem::kNPerBlock
 
static constexpr index_t MWave = Problem::MWave
 
static constexpr index_t NWave = Problem::NWave
 
static constexpr index_t MPerXdl = Problem::MPerXdl
 
static constexpr index_t NPerXdl = Problem::NPerXdl
 
static constexpr index_t KPerXdl = Problem::KPerXdl
 
static constexpr index_t isCTransposed = Problem::isCTransposed
 
static constexpr bool FixedVectorSize = Problem::FixedVectorSize
 
static constexpr index_t VectorSizeC = Problem::VectorSizeC
 
static constexpr index_t MPerIteration = MPerXdl * MWave
 
static constexpr index_t NPerIteration = NPerXdl * NWave
 
static constexpr index_t NumDTensor = Problem::NumDTensor
 
static constexpr auto shuffle_tile_tuple
 Shuffle tile configuration parameters. More...
 
static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple)
 
static constexpr index_t NumNXdlPerWavePerShuffle = std::get<1>(shuffle_tile_tuple)
 
static constexpr auto MNPerIterationShuffle
 
static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle)
 
static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle)
 

Member Typedef Documentation

◆ AccDataType

template<typename Problem_ , typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::AccDataType = remove_cvref_t<typename Problem::AccDataType>

◆ ADataType

template<typename Problem_ , typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::ADataType = remove_cvref_t<typename Problem::ADataType>

◆ ATypeToUse

template<typename Problem_ , typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::ATypeToUse = std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>

◆ BDataType

template<typename Problem_ , typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::BDataType = remove_cvref_t<typename Problem::BDataType>

◆ BTypeToUse

template<typename Problem_ , typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::BTypeToUse = std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>

◆ CDElementwise

template<typename Problem_ , typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::CDElementwise = remove_cvref_t<typename Problem::CDElementwise>

◆ CWarpDstr

template<typename Problem_ , typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::CWarpDstr = typename WG::CWarpDstr

◆ CWarpTensor

template<typename Problem_ , typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::CWarpTensor = typename WG::CWarpTensor

◆ DsDataType

template<typename Problem_ , typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::DsDataType = remove_cvref_t<typename Problem::DsDataType>

◆ DsLayout

template<typename Problem_ , typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::DsLayout = remove_cvref_t<typename Problem::DsLayout>

◆ ELayout

template<typename Problem_ , typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::ELayout = remove_cvref_t<typename Problem::ELayout>

◆ ODataType

template<typename Problem_ , typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::ODataType = remove_cvref_t<typename Problem::ODataType>

◆ Problem

template<typename Problem_ , typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::Problem = remove_cvref_t<Problem_>

◆ WG

template<typename Problem_ , typename Policy_ = void>
using ck_tile::CShuffleEpilogue< Problem_, Policy_ >::WG = WarpGemmDispatcher<ATypeToUse, BTypeToUse, AccDataType, MPerXdl, NPerXdl, KPerXdl, isCTransposed>

Member Function Documentation

◆ GetSmemSize()

template<typename Problem_ , typename Policy_ = void>
static constexpr CK_TILE_HOST_DEVICE index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::GetSmemSize ( )
inlinestaticconstexpr

◆ GetVectorSizeC()

template<typename Problem_ , typename Policy_ = void>
static constexpr CK_TILE_HOST_DEVICE index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::GetVectorSizeC ( )
inlinestaticconstexpr

Get the vector store size for C tensor.

Note
The vector store size for output C tensor would depend on multiple factors like its data layout and warp gemm C transposition. In general it would be the number of consecutive elements in contiguous C dimension hold by single thread.
Returns
The vector store size for C tensor.

◆ GetVectorSizeD()

template<typename Problem_ , typename Policy_ = void>
template<index_t I>
static constexpr CK_TILE_HOST_DEVICE index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::GetVectorSizeD ( number< I >  index)
inlinestaticconstexpr

Get the vector store size for Di tensor.

Returns
The vector store size for Di tensor.

◆ MakeLdsBlockDescriptor()

template<typename Problem_ , typename Policy_ = void>
template<typename Problem >
static constexpr CK_TILE_HOST_DEVICE auto ck_tile::CShuffleEpilogue< Problem_, Policy_ >::MakeLdsBlockDescriptor ( )
inlinestaticconstexpr

◆ MakeLdsDistributionEncode()

template<typename Problem_ , typename Policy_ = void>
static constexpr CK_TILE_DEVICE auto ck_tile::CShuffleEpilogue< Problem_, Policy_ >::MakeLdsDistributionEncode ( )
inlinestaticconstexpr

◆ operator()()

template<typename Problem_ , typename Policy_ = void>
template<typename ODramWindow , typename OAccTile , typename DsDramWindows >
CK_TILE_DEVICE auto ck_tile::CShuffleEpilogue< Problem_, Policy_ >::operator() ( ODramWindow &  out_dram_window,
const OAccTile &  o_acc_tile,
const DsDramWindows &  ds_dram_windows,
void *  p_smem 
)
inline

Member Data Documentation

◆ FixedVectorSize

template<typename Problem_ , typename Policy_ = void>
constexpr bool ck_tile::CShuffleEpilogue< Problem_, Policy_ >::FixedVectorSize = Problem::FixedVectorSize
staticconstexpr

◆ isCTransposed

template<typename Problem_ , typename Policy_ = void>
constexpr index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::isCTransposed = Problem::isCTransposed
staticconstexpr

◆ kBlockSize

template<typename Problem_ , typename Policy_ = void>
constexpr index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::kBlockSize = Problem::kBlockSize
staticconstexpr

◆ kMPerBlock

template<typename Problem_ , typename Policy_ = void>
constexpr index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::kMPerBlock = Problem::kMPerBlock
staticconstexpr

◆ kNPerBlock

template<typename Problem_ , typename Policy_ = void>
constexpr index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::kNPerBlock = Problem::kNPerBlock
staticconstexpr

◆ KPerXdl

template<typename Problem_ , typename Policy_ = void>
constexpr index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::KPerXdl = Problem::KPerXdl
staticconstexpr

◆ MemoryOperation

template<typename Problem_ , typename Policy_ = void>
constexpr memory_operation_enum ck_tile::CShuffleEpilogue< Problem_, Policy_ >::MemoryOperation = Problem::MemoryOperation
staticconstexpr

◆ MNPerIterationShuffle

template<typename Problem_ , typename Policy_ = void>
constexpr auto ck_tile::CShuffleEpilogue< Problem_, Policy_ >::MNPerIterationShuffle
staticconstexpr
Initial value:
= [] {
if constexpr(kMPerBlock % m_val != 0 || kNPerBlock % n_val != 0)
return std::make_tuple(MPerXdl * MWave, NPerXdl * NWave);
else
return std::make_tuple(m_val, n_val);
}()
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
static constexpr index_t MPerXdl
Definition: cshuffle_epilogue.hpp:84
static constexpr index_t kNPerBlock
Definition: cshuffle_epilogue.hpp:81
static constexpr index_t MWave
Definition: cshuffle_epilogue.hpp:82
static constexpr index_t NumMXdlPerWavePerShuffle
Definition: cshuffle_epilogue.hpp:191
static constexpr index_t NumNXdlPerWavePerShuffle
Definition: cshuffle_epilogue.hpp:192
static constexpr index_t NWave
Definition: cshuffle_epilogue.hpp:83
static constexpr index_t kMPerBlock
Definition: cshuffle_epilogue.hpp:80
static constexpr index_t NPerXdl
Definition: cshuffle_epilogue.hpp:85

◆ MPerIteration

template<typename Problem_ , typename Policy_ = void>
constexpr index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::MPerIteration = MPerXdl * MWave
staticconstexpr

◆ MPerIterationShuffle

template<typename Problem_ , typename Policy_ = void>
constexpr index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::MPerIterationShuffle = std::get<0>(MNPerIterationShuffle)
staticconstexpr

◆ MPerXdl

template<typename Problem_ , typename Policy_ = void>
constexpr index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::MPerXdl = Problem::MPerXdl
staticconstexpr

◆ MWave

template<typename Problem_ , typename Policy_ = void>
constexpr index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::MWave = Problem::MWave
staticconstexpr

◆ NPerIteration

template<typename Problem_ , typename Policy_ = void>
constexpr index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::NPerIteration = NPerXdl * NWave
staticconstexpr

◆ NPerIterationShuffle

template<typename Problem_ , typename Policy_ = void>
constexpr index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::NPerIterationShuffle = std::get<1>(MNPerIterationShuffle)
staticconstexpr

◆ NPerXdl

template<typename Problem_ , typename Policy_ = void>
constexpr index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::NPerXdl = Problem::NPerXdl
staticconstexpr

◆ NumDTensor

template<typename Problem_ , typename Policy_ = void>
constexpr index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::NumDTensor = Problem::NumDTensor
staticconstexpr

◆ NumMXdlPerWavePerShuffle

template<typename Problem_ , typename Policy_ = void>
constexpr index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple)
staticconstexpr

◆ NumNXdlPerWavePerShuffle

template<typename Problem_ , typename Policy_ = void>
constexpr index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::NumNXdlPerWavePerShuffle = std::get<1>(shuffle_tile_tuple)
staticconstexpr

◆ NWave

template<typename Problem_ , typename Policy_ = void>
constexpr index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::NWave = Problem::NWave
staticconstexpr

◆ shuffle_tile_tuple

template<typename Problem_ , typename Policy_ = void>
constexpr auto ck_tile::CShuffleEpilogue< Problem_, Policy_ >::shuffle_tile_tuple
staticconstexpr
Initial value:
= [] {
constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size();
if constexpr(elem_per_thread >= GetVectorSizeC())
{
return std::make_tuple(1, 1);
}
else
{
constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread;
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
static_assert((kMPerBlock % (MPerXdl * MWave) == 0) &&
(kMPerBlock % num_xdl_shuffles == 0),
"kMPerBlock must be divisible by MPerXdl*MWave and "
"num_xdl_shuffles for CShuffleEpilogue");
return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (MPerXdl * MWave)), 1);
}
else
{
static_assert((kNPerBlock % (NPerXdl * NWave) == 0) &&
(kNPerBlock % num_xdl_shuffles == 0),
"kNPerBlock must be divisible by NPerXdl*NWave and "
"num_xdl_shuffles for CShuffleEpilogue");
return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (NPerXdl * NWave)));
}
}
}()
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
static constexpr CK_TILE_HOST_DEVICE index_t GetVectorSizeC()
Get the vector store size for C tensor.
Definition: cshuffle_epilogue.hpp:106

Shuffle tile configuration parameters.

These parameters control the number of XDL tiles processed per wave in each shuffle iteration:

  • NumMXdlPerWavePerShuffle: Number of XDL tiles in M dimension processed per wave
  • NumNXdlPerWavePerShuffle: Number of XDL tiles in N dimension processed per wave

◆ VectorSizeC

template<typename Problem_ , typename Policy_ = void>
constexpr index_t ck_tile::CShuffleEpilogue< Problem_, Policy_ >::VectorSizeC = Problem::VectorSizeC
staticconstexpr

The documentation for this struct was generated from the following file:
  • /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp