element_wise Namespace Reference

element_wise Namespace Reference#

Composable Kernel: ck_tile::element_wise Namespace Reference
ck_tile::element_wise Namespace Reference

Classes

struct  Add
 
struct  PassThroughPack8
 
struct  DequantPack8
 
struct  PassThroughPack2
 
struct  PassThrough
 
struct  AddScale
 
struct  MultiDMultiply
 
struct  MultiDAdd
 
struct  UnaryConvert
 
struct  Scale
 
struct  ScaleAndResetNaNToMinusInfinity
 
struct  UnaryDivide
 
struct  UnarySquare
 
struct  UnaryAbs
 
struct  UnarySqrt
 
struct  Relu
 
struct  FastGelu
 
struct  FastGeluAsm
 
struct  Gelu
 
struct  Sigmoid
 
struct  Silu
 
struct  TanH
 
struct  ACos
 
struct  Neg
 
struct  ATan
 
struct  Sin
 
struct  ASinH
 
struct  Cos
 
struct  ACosH
 
struct  Tan
 
struct  ATanH
 
struct  SinH
 
struct  Ceil
 
struct  Exp
 
struct  CosH
 
struct  Floor
 
struct  Log
 
struct  ASin
 
struct  Rcp
 
struct  Swish
 
struct  SoftRelu
 
struct  Power
 
struct  ClippedRelu
 
struct  LeakyRelu
 
struct  Elu
 
struct  Logistic
 
struct  ConvInvscale
 
struct  ConvScale
 
struct  ConvScaleRelu
 
struct  Cast
 

Functions

template<typename T , std::size_t N, typename F , std::size_t... Is>
constexpr std::array< T, N > make_lookup_table_impl (F &&func, std::index_sequence< Is... >)
 
template<typename T , std::size_t N, typename F >
constexpr std::array< T, N > make_lookup_table (F &&func)
 
CK_TILE_DEVICE fp16x4_t i4_to_half4 (int q)
 Fast int4x4 to fp16x8_t data type conversion based on paper "Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production". More...
 
CK_TILE_DEVICE fp16x4_t i4_to_half4_scale (int q, const fp16x2_t &scale)
 This function dequantizes 4 int4 values into 4 fp16 values and applies scaling. More...
 
CK_TILE_DEVICE bf16x4_t i4_to_bhalf4 (int q)
 This function converts 4 4-bit integers into 4 bf16 values. More...
 
CK_TILE_DEVICE fp8x8_t amd_assembly_i4_to_fp8x8 (int a)
 This function converts 8 packed 4-bit integers into 8 fp8 values. More...
 
CK_TILE_DEVICE float amd_assembly_fp8_to_fp32 (uint32_t src)
 
CK_TILE_DEVICE float amd_assembly_bf8_to_fp32 (uint32_t src)
 
CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8 (uint32_t a)
 This function converts 8 packed 4-bit integers into 8 bf8 values. More...
 

Function Documentation

◆ amd_assembly_bf8_to_fp32()

CK_TILE_DEVICE float ck_tile::element_wise::amd_assembly_bf8_to_fp32 ( uint32_t  src)

◆ amd_assembly_fp8_to_fp32()

CK_TILE_DEVICE float ck_tile::element_wise::amd_assembly_fp8_to_fp32 ( uint32_t  src)

◆ amd_assembly_i4_to_bf8x8()

CK_TILE_DEVICE bf8x8_t ck_tile::element_wise::amd_assembly_i4_to_bf8x8 ( uint32_t  a)

This function converts 8 packed 4-bit integers into 8 bf8 values.

Note
int q contains 4 bytes, each byte represents 2 int4.
This function assumes pk_int4_t has a bias of 8, meaning 0b0000 is converted to bf8(-8)
The output ordering differs from input ordering. For example, when input is 0x76543210, the output sequence will be bf8(7, 3, 6, 2, 5, 1, 4, 0). Therefore, the input tensor must be preprocessed with permute_vectors_i4x4_b on the host side before using this function.
See also
permute_vectors_i4x4_b

◆ amd_assembly_i4_to_fp8x8()

CK_TILE_DEVICE fp8x8_t ck_tile::element_wise::amd_assembly_i4_to_fp8x8 ( int  a)

This function converts 8 packed 4-bit integers into 8 fp8 values.

Note
int q contains 4 bytes, each byte represents 2 int4.
This function assumes pk_int4_t has a bias of 8, meaning 0b0000 is converted to fp8(-8)
The output ordering differs from input ordering. For example, when input is 0x76543210, the output sequence will be fp8(7, 3, 6, 2, 5, 1, 4, 0). Therefore, the input tensor must be preprocessed with permute_vectors_i4x4_b on the host side before using this function.
See also
permute_vectors_i4x4_b

◆ i4_to_bhalf4()

CK_TILE_DEVICE bf16x4_t ck_tile::element_wise::i4_to_bhalf4 ( int  q)

This function converts 4 4-bit integers into 4 bf16 values.

Note
int q contains 4 bytes, low 4 bits of each byte represent an int4.
This function assumes pk_int4_t has a bias of 8, meaning 0b0000 is converted to bf16(-8)
The output ordering differs from input ordering. For example, when input is 0x76543210, the output sequence will be bf16(7, 3, 6, 2, 5, 1, 4, 0). Therefore, the input tensor must be preprocessed with permute_vectors_i4x4_b on the host side before using this function.
See also
permute_vectors_i4x4_b

◆ i4_to_half4()

CK_TILE_DEVICE fp16x4_t ck_tile::element_wise::i4_to_half4 ( int  q)

Fast int4x4 to fp16x8_t data type conversion based on paper "Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production".

See also
https://arxiv.org/abs/2211.10017
https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h

This function converts 4 4-bit integers into 4 fp16 values.

Note
int q contains 4 bytes, low 4 bits of each byte represent an int4.
This function assumes pk_int4_t has a bias of 8, meaning 0b0000 is converted to fp16(-8)
The output ordering differs from input ordering. For example, when input is 0x76543210, the output sequence will be fp16(7, 3, 6, 2, 5, 1, 4, 0). Therefore, the input tensor must be preprocessed with permute_vectors_i4x4_b on the host side before using this function.
See also
permute_vectors_i4x4_b

◆ i4_to_half4_scale()

CK_TILE_DEVICE fp16x4_t ck_tile::element_wise::i4_to_half4_scale ( int  q,
const fp16x2_t scale 
)

This function dequantizes 4 int4 values into 4 fp16 values and applies scaling.

Note
int q contains 4 bytes, low 4 bits of each byte represent an int4.
This function assumes pk_int4_t has a bias of 8, meaning 0b0000 is converted to fp16(-8)
The output ordering differs from input ordering. For example, when input is 0x76543210, the output sequence will be fp16(7, 3, 6, 2, 5, 1, 4, 0). Therefore, the input tensor must be preprocessed with permute_vectors_i4x4_b on the host side before using this function.
See also
permute_vectors_i4x4_b

◆ make_lookup_table()

template<typename T , std::size_t N, typename F >
constexpr std::array<T, N> ck_tile::element_wise::make_lookup_table ( F &&  func)
constexpr

◆ make_lookup_table_impl()

template<typename T , std::size_t N, typename F , std::size_t... Is>
constexpr std::array<T, N> ck_tile::element_wise::make_lookup_table_impl ( F &&  func,
std::index_sequence< Is... >   
)
constexpr