include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp Source File

include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp Source File#

Composable Kernel: include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp Source File
warp_gemm_dispatcher.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 
9 namespace ck_tile {
10 
11 namespace impl {
12 template <typename AType,
13  typename BType,
14  typename CType,
15  index_t MPerWave,
16  index_t NPerWave,
17  index_t KPerWave,
18  bool TransposeC,
19  bool SwizzleA = false>
21 
22 // clang-format off
23 // fp16
24 template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; };
26 template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16; };
28 template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF16F16F32M16N16K16; };
30 template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32; };
32 template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 4, 64, 16, false> { using Type = WarpGemmMfmaF16F16F32M4N64K16; };
33 template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 64, 4, 16, false> { using Type = WarpGemmMfmaF16F16F32M64N4K16; };
34 
35 template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
36 template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
37 
38 // bf16
39 template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
41 template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; };
43 template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; };
45 template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; };
47 template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 4, 64, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M4N64K16; };
48 template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 64, 4, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M64N4K16; };
49 
50 template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; };
51 template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
52 
53 // fp8
54 template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
55 template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; };
56 template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; };
57 template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; };
58 template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; };
59 template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; };
60 template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; };
61 template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; };
62 
63 // clang-format on
64 } // namespace impl
65 
66 template <typename AType,
67  typename BType,
68  typename CType,
69  index_t MPerWave,
70  index_t NPerWave,
71  index_t KPerWave,
72  bool TransposeC,
73  bool SwizzleA = false>
75  BType,
76  CType,
77  MPerWave,
78  NPerWave,
79  KPerWave,
80  TransposeC,
81  SwizzleA>::Type;
82 
83 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
WarpGemmImpl< WarpGemmAtrributeMfma< WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8< WGAttrCtlEnum::Default_ > >> WarpGemmMfma_f32_32x32x16_bf8_fp8
Definition: warp_gemm.hpp:132
WarpGemmImpl< WarpGemmAtrributeMfmaIterateK< WarpGemmAttributeMfmaImplF16F16F32M4N64K4< WGAttrCtlEnum::Default_ >, 4 > > WarpGemmMfmaF16F16F32M4N64K16
Definition: warp_gemm.hpp:61
WarpGemmImpl< WarpGemmAtrributeMfmaIterateK< WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16< WGAttrCtlEnum::Default_ >, 2 > > WarpGemmMfmaBf16Bf16F32M16N16K32
Definition: warp_gemm.hpp:81
WarpGemmImpl< WarpGemmAtrributeMfma< WarpGemmAttributeMfmaImplF16F16F32M32N32K8< WGAttrCtlEnum::Default_ > >> WarpGemmMfmaF16F16F32M32N32K8
Definition: warp_gemm.hpp:15
WarpGemmImpl< WarpGemmAtrributeMfmaIterateK< WarpGemmAttributeMfmaImplF16F16F32M64N4K4< WGAttrCtlEnum::Default_ >, 4 > > WarpGemmMfmaF16F16F32M64N4K16
Definition: warp_gemm.hpp:65
_BitInt(8) fp8_t
Definition: float8.hpp:204
WarpGemmImpl< WarpGemmAtrributeMfmaIterateK_SwizzleA< WarpGemmAttributeMfmaImplF16F16F32M32N32K8< WGAttrCtlEnum::Default_ >, 2 > > WarpGemmMfmaF16F16F32M32N32K16SwizzleA
Definition: warp_gemm.hpp:34
WarpGemmImpl< WarpGemmAtrributeMfmaIterateK< WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4< WGAttrCtlEnum::Default_ >, 4 > > WarpGemmMfmaBf16Bf16F32M4N64K16
Definition: warp_gemm.hpp:117
WarpGemmImpl< WarpGemmAtrributeMfma< WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8< WGAttrCtlEnum::Default_ > >> WarpGemmMfmaBf16Bf16F32M32N32K8
Definition: warp_gemm.hpp:70
WarpGemmImpl< WarpGemmAtrributeMfmaTransposedCDistribution< WarpGemmAttributeMfmaImplF16F16F32M16N16K16< WGAttrCtlEnum::Default_ > >> WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
Definition: warp_gemm.hpp:42
WarpGemmImpl< WarpGemmAtrributeMfmaTransposedCDistribution< WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8< WGAttrCtlEnum::Default_ > >> WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed
Definition: warp_gemm.hpp:139
bfloat16_t bf16_t
Definition: bfloat16.hpp:106
WarpGemmImpl< WarpGemmAtrributeMfmaIterateK_SwizzleA< WarpGemmAttributeMfmaImplF16F16F32M32N32K8< WGAttrCtlEnum::Default_ >, 1 > > WarpGemmMfmaF16F16F32M32N32K8SwizzleA
Definition: warp_gemm.hpp:30
int32_t index_t
Definition: integer.hpp:9
typename impl::WarpGemmMfmaDispatcher< AType, BType, CType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA >::Type WarpGemmMfmaDispatcher
Definition: warp_gemm_dispatcher.hpp:81
WarpGemmImpl< WarpGemmAtrributeMfma< WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8< WGAttrCtlEnum::Default_ > >> WarpGemmMfma_f32_32x32x16_fp8_bf8
Definition: warp_gemm.hpp:129
WarpGemmImpl< WarpGemmAtrributeMfmaTransposedCDistribution< WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8< WGAttrCtlEnum::Default_ > >> WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
Definition: warp_gemm.hpp:94
WarpGemmImpl< WarpGemmAtrributeMfmaIterateK< WarpGemmAttributeMfmaImplF16F16F32M16N16K16< WGAttrCtlEnum::Default_ >, 2 > > WarpGemmMfmaF16F16F32M16N16K32
Definition: warp_gemm.hpp:26
WarpGemmImpl< WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution< WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8< WGAttrCtlEnum::Default_ >, 2 > > WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
Definition: warp_gemm.hpp:103
WarpGemmImpl< WarpGemmAtrributeMfma< WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16< WGAttrCtlEnum::Default_ > >> WarpGemmMfmaBf16Bf16F32M16N16K16
Definition: warp_gemm.hpp:73
WarpGemmImpl< WarpGemmAtrributeMfmaTransposedCDistribution< WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8< WGAttrCtlEnum::Default_ > >> WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed
Definition: warp_gemm.hpp:143
WarpGemmImpl< WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution< WarpGemmAttributeMfmaImplF16F16F32M32N32K8< WGAttrCtlEnum::Default_ >, 2 > > WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution
Definition: warp_gemm.hpp:47
WarpGemmImpl< WarpGemmAtrributeMfmaIterateK< WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4< WGAttrCtlEnum::Default_ >, 4 > > WarpGemmMfmaBf16Bf16F32M64N4K16
Definition: warp_gemm.hpp:121
WarpGemmImpl< WarpGemmAtrributeMfmaTransposedCDistribution< WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8< WGAttrCtlEnum::Default_ > >> WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed
Definition: warp_gemm.hpp:147
WarpGemmImpl< WarpGemmAtrributeMfmaIterateK< WarpGemmAttributeMfmaImplF16F16F32M32N32K8< WGAttrCtlEnum::Default_ >, 2 > > WarpGemmMfmaF16F16F32M32N32K16
Definition: warp_gemm.hpp:22
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
WarpGemmImpl< WarpGemmAtrributeMfmaIterateK_SwizzleA< WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8< WGAttrCtlEnum::Default_ >, 2 > > WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
Definition: warp_gemm.hpp:90
WarpGemmImpl< WarpGemmAtrributeMfmaTransposedCDistribution< WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8< WGAttrCtlEnum::Default_ > >> WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed
Definition: warp_gemm.hpp:151
WarpGemmImpl< WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution< WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16< WGAttrCtlEnum::Default_ >, 2 > > WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution
Definition: warp_gemm.hpp:108
WarpGemmImpl< WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution< WarpGemmAttributeMfmaImplF16F16F32M16N16K16< WGAttrCtlEnum::Default_ >, 2 > > WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
Definition: warp_gemm.hpp:52
WarpGemmImpl< WarpGemmAtrributeMfmaIterateK< WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8< WGAttrCtlEnum::Default_ >, 2 > > WarpGemmMfmaBf16Bf16F32M32N32K16
Definition: warp_gemm.hpp:77
WarpGemmImpl< WarpGemmAtrributeMfmaIterateK_SwizzleA< WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8< WGAttrCtlEnum::Default_ >, 1 > > WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
Definition: warp_gemm.hpp:85
WarpGemmImpl< WarpGemmAtrributeMfma< WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8< WGAttrCtlEnum::Default_ > >> WarpGemmMfma_f32_32x32x16_fp8_fp8
Definition: warp_gemm.hpp:126
_Float16 half_t
Definition: half.hpp:111
WarpGemmImpl< WarpGemmAtrributeMfmaTransposedCDistribution< WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16< WGAttrCtlEnum::Default_ > >> WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
Definition: warp_gemm.hpp:98
WarpGemmImpl< WarpGemmAtrributeMfmaTransposedCDistribution< WarpGemmAttributeMfmaImplF16F16F32M32N32K8< WGAttrCtlEnum::Default_ > >> WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
Definition: warp_gemm.hpp:38
WarpGemmImpl< WarpGemmAtrributeMfma< WarpGemmAttributeMfmaImplF16F16F32M16N16K16< WGAttrCtlEnum::Default_ > >> WarpGemmMfmaF16F16F32M16N16K16
Definition: warp_gemm.hpp:18
WarpGemmImpl< WarpGemmAtrributeMfma< WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8< WGAttrCtlEnum::Default_ > >> WarpGemmMfma_f32_32x32x16_bf8_bf8
Definition: warp_gemm.hpp:135
Definition: warp_gemm_impl.hpp:11
Definition: warp_gemm_dispatcher.hpp:20