XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma > Struct Template Reference

XdlopsGemm&lt; base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma &gt; Struct Template Reference#

Composable Kernel: ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma > Struct Template Reference
ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma > Struct Template Reference

#include <xdlops_gemm.hpp>

Public Types

using CIndex = MultiIndex< 2 >
 
using CIndex4D = MultiIndex< 4 >
 

Public Member Functions

__host__ constexpr __device__ XdlopsGemm ()
 
template<class FloatA , class FloatB , class FloatC >
__device__ void Run (const FloatA &p_a_wave, const FloatB &p_b_wave, FloatC &p_c_thread) const
 
template<index_t OpselA, index_t OpselB, class FloatA , class ScaleA , class FloatB , class ScaleB , class FloatC >
__device__ void Run (const FloatA &p_a_wave, const ScaleA &a_scale_thread, const FloatB &p_b_wave, const ScaleB &b_scale_thread, FloatC &p_c_thread) const
 

Static Public Member Functions

static constexpr __device__ index_t GetNumBlks ()
 
static constexpr __device__ index_t GetNumXdlops ()
 
template<typename CDesc_M0_N0_M1_N1_M2_N2 >
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2 (const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
 
template<typename CDesc_M0_N0_M1_N1_M2_N2 >
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 (const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
 
template<typename CDesc_M0_N0_M1_N1_M2_N2 >
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4 (const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
 
template<typename CDesc_G_M0_N0_M1_N1_M2_N2 >
__host__ static constexpr __device__ auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2 (const CDesc_G_M0_N0_M1_N1_M2_N2 &c_desc_g_m0_n0_m1_n1_m2_n2)
 
__device__ static constexpr __host__ index_t GetRegSizePerXdlops ()
 
static constexpr __device__ index_t GetWaveSize ()
 
static __device__ auto GetLaneId ()
 
static __device__ auto GetBlkIdx ()
 
template<bool SwizzleA>
static __device__ auto GetGfx11InputBlkIdx ()
 
__host__ static __device__ auto CalculateAThreadOriginDataIndex ()
 
__host__ static __device__ auto CalculateBThreadOriginDataIndex ()
 
static __device__ CIndex GetBeginOfThreadBlk (index_t xdlops_i, index_t blk_i)
 
static __device__ CIndex4D GetBeginOfThreadBlk4D (index_t, index_t)
 
__host__ static constexpr __device__ auto GetCM0M1M2NThreadBlkLengths ()
 

Static Public Attributes

static constexpr auto I0 = Number<0>{}
 
static constexpr auto I1 = Number<1>{}
 
static constexpr auto I2 = Number<2>{}
 
static constexpr auto I3 = Number<3>{}
 
static constexpr auto I4 = Number<4>{}
 
static constexpr auto I5 = Number<5>{}
 
static constexpr bool is_single_rate_mfma
 
static constexpr auto mfma
 
static constexpr auto mfma_instr = mfma.selected_mfma
 
static constexpr auto KPerXdlops = mfma.GetKPerXdlops()
 
static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops()
 
static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops
 

Member Typedef Documentation

◆ CIndex

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
using ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::CIndex = MultiIndex<2>

◆ CIndex4D

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
using ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::CIndex4D = MultiIndex<4>

Constructor & Destructor Documentation

◆ XdlopsGemm()

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
__host__ constexpr __device__ ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::XdlopsGemm ( )
inlineconstexpr

Member Function Documentation

◆ CalculateAThreadOriginDataIndex()

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
__host__ static __device__ auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::CalculateAThreadOriginDataIndex ( )
inlinestatic

◆ CalculateBThreadOriginDataIndex()

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
__host__ static __device__ auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::CalculateBThreadOriginDataIndex ( )
inlinestatic

◆ GetBeginOfThreadBlk()

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
static __device__ CIndex ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::GetBeginOfThreadBlk ( index_t  xdlops_i,
index_t  blk_i 
)
inlinestatic

◆ GetBeginOfThreadBlk4D()

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
static __device__ CIndex4D ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::GetBeginOfThreadBlk4D ( index_t  ,
index_t   
)
inlinestatic

◆ GetBlkIdx()

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
static __device__ auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::GetBlkIdx ( )
inlinestatic

◆ GetCM0M1M2NThreadBlkLengths()

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
__host__ static constexpr __device__ auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::GetCM0M1M2NThreadBlkLengths ( )
inlinestaticconstexpr

◆ GetGfx11InputBlkIdx()

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
template<bool SwizzleA>
static __device__ auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::GetGfx11InputBlkIdx ( )
inlinestatic

◆ GetLaneId()

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
static __device__ auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::GetLaneId ( )
inlinestatic

◆ GetNumBlks()

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
static constexpr __device__ index_t ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::GetNumBlks ( )
inlinestaticconstexpr

◆ GetNumXdlops()

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
static constexpr __device__ index_t ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::GetNumXdlops ( )
inlinestaticconstexpr

◆ GetRegSizePerXdlops()

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
__device__ static constexpr __host__ index_t ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::GetRegSizePerXdlops ( )
inlinestaticconstexpr

◆ GetWaveSize()

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
static constexpr __device__ index_t ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::GetWaveSize ( )
inlinestaticconstexpr

◆ MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
template<typename CDesc_G_M0_N0_M1_N1_M2_N2 >
__host__ static constexpr __device__ auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2 ( const CDesc_G_M0_N0_M1_N1_M2_N2 &  c_desc_g_m0_n0_m1_n1_m2_n2)
inlinestaticconstexpr

◆ MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
template<typename CDesc_M0_N0_M1_N1_M2_N2 >
__host__ static constexpr __device__ auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2 ( const CDesc_M0_N0_M1_N1_M2_N2 &  c_desc_m0_n0_m1_n1_m2_n2)
inlinestaticconstexpr

◆ MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3()

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
template<typename CDesc_M0_N0_M1_N1_M2_N2 >
__host__ static constexpr __device__ auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 ( const CDesc_M0_N0_M1_N1_M2_N2 &  c_desc_m0_n0_m1_n1_m2_n2)
inlinestaticconstexpr

◆ MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
template<typename CDesc_M0_N0_M1_N1_M2_N2 >
__host__ static constexpr __device__ auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4 ( const CDesc_M0_N0_M1_N1_M2_N2 &  c_desc_m0_n0_m1_n1_m2_n2)
inlinestaticconstexpr

◆ Run() [1/2]

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
template<class FloatA , class FloatB , class FloatC >
__device__ void ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::Run ( const FloatA &  p_a_wave,
const FloatB &  p_b_wave,
FloatC &  p_c_thread 
) const
inline

◆ Run() [2/2]

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
template<index_t OpselA, index_t OpselB, class FloatA , class ScaleA , class FloatB , class ScaleB , class FloatC >
__device__ void ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::Run ( const FloatA &  p_a_wave,
const ScaleA &  a_scale_thread,
const FloatB &  p_b_wave,
const ScaleB &  b_scale_thread,
FloatC &  p_c_thread 
) const
inline

Member Data Documentation

◆ I0

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
constexpr auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::I0 = Number<0>{}
staticconstexpr

◆ I1

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
constexpr auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::I1 = Number<1>{}
staticconstexpr

◆ I2

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
constexpr auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::I2 = Number<2>{}
staticconstexpr

◆ I3

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
constexpr auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::I3 = Number<3>{}
staticconstexpr

◆ I4

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
constexpr auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::I4 = Number<4>{}
staticconstexpr

◆ I5

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
constexpr auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::I5 = Number<5>{}
staticconstexpr

◆ is_single_rate_mfma

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
constexpr bool ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::is_single_rate_mfma
staticconstexpr
Initial value:

◆ K0PerXdlops

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
constexpr auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::K0PerXdlops = KPerXdlops / K1PerXdlops
staticconstexpr

◆ K1PerXdlops

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
constexpr auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::K1PerXdlops = mfma.GetK1PerXdlops()
staticconstexpr

◆ KPerXdlops

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
constexpr auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::KPerXdlops = mfma.GetKPerXdlops()
staticconstexpr

◆ mfma

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
constexpr auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::mfma
staticconstexpr
Initial value:
= MfmaSelector<base_type,
MPerXdlops,
NPerXdlops,
additional_type,
is_scale_mfma>{}
static constexpr bool is_single_rate_mfma
Definition: xdlops_gemm.hpp:2072

◆ mfma_instr

template<typename base_type , index_t MPerXdlops, index_t NPerXdlops, index_t KPack, typename additional_type = base_type, bool TransposeC = false, bool is_scale_mfma = false>
constexpr auto ck::XdlopsGemm< base_type, MPerXdlops, NPerXdlops, KPack, additional_type, TransposeC, is_scale_mfma >::mfma_instr = mfma.selected_mfma
staticconstexpr

The documentation for this struct was generated from the following file:
  • /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp