/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/blkgemmpipe_scheduler.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/blkgemmpipe_scheduler.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/blkgemmpipe_scheduler.hpp Source File
blkgemmpipe_scheduler.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
8 
9 namespace ck {
10 
12 {
13  // For GEMM
14  v1, // Naive
15  v2, // Mem
16  v3, // Comp
17  v4, // Comp, double lds buffer
18  v5, // Comp, double global prefetch register buffer
19 
20  // For GEMM with preshuffled weight
21  // v1, single lds buffer
22  // v2, double lds buffer
23 };
25 {
26  Intrawave,
27  Interwave,
28 };
29 
30 enum struct TailNumber
31 {
32  // Single / Double buffer pipeline
33  Odd,
34  Even,
35 
36  // Long prefetch pipeline, up to 8
37  One,
38  Two,
39  Three,
40  Four,
41  Five,
42  Six,
43  Seven,
44 
45  // Unroll stages > Prefetch stages, number of loop is multiple of unroll stages
46  Empty,
47  // Unroll stages <= Prefetch stages, number of loop is multiple of unroll stages add
48  // prefetchstages
49  Full,
50 };
51 
53 {
54  SCHED_GROUP_MFMA = 0x008, // Matrix FMA instructions
55  SCHED_GROUP_VMEM = 0x020, // Global memory operations
56  SCHED_GROUP_LDS_READ = 0x100, // LDS read operations
57  SCHED_GROUP_LDS_WRITE = 0x200 // LDS write operations
58 };
59 
60 template <index_t BlockSize,
61  index_t MPerBlock,
62  index_t NPerBlock,
63  index_t KPerBlock,
64  index_t ABufferLoadWidth,
65  index_t BBufferLoadWidth,
66  index_t ALDSWriteWidth,
67  index_t BLDSWriteWidth,
68  index_t ALDSReadWidth,
69  index_t BLDSReadWidth,
70  index_t MRepeat,
71  index_t NRepeat,
72  index_t MPerXDL,
73  index_t NPerXDL,
74  index_t KPerXDL,
75  bool IsF4F6 = false>
76 struct BlockwiseGemmXdlops_pipeline_hotloop_inst
77 {
78  static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerXDL);
79  static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerXDL);
80  static constexpr index_t WaveSize = BlockSize / WaveNumM / WaveNumN;
81 
82  static constexpr index_t A_LDS_Read_Width = ALDSReadWidth;
83  static constexpr index_t B_LDS_Read_Width = BLDSReadWidth;
84 
85  static constexpr index_t A_Buffer_Load_Inst_Num =
86  MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth);
87  static constexpr index_t B_Buffer_Load_Inst_Num =
88  NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth);
89 
90  static constexpr index_t A_LDS_Write_Inst_Num =
91  MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth);
92  static constexpr index_t B_LDS_Write_Inst_Num =
93  NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth);
94 
95  static constexpr index_t A_LDS_Read_Inst_Num =
96  WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth);
97  static constexpr index_t B_LDS_Read_Inst_Num =
98  WaveNumM * NPerBlock * KPerBlock / (BlockSize * BLDSReadWidth);
99 
100  static constexpr index_t C_MFMA_Inst_Num =
101  MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
102 
103  static constexpr index_t C_MFMA_SpeedUp = IsF4F6 ? 2 : 1;
104 
105  static constexpr index_t C_MFMA_Inst_Cycle = []() {
106  if constexpr(NPerXDL == 16)
107  {
108  return KPerXDL == 128 ? 32 / C_MFMA_SpeedUp : 16 / C_MFMA_SpeedUp;
109  }
110  else if constexpr(NPerXDL == 32)
111  {
112  return KPerXDL == 64 ? 64 / C_MFMA_SpeedUp : 32 / C_MFMA_SpeedUp;
113  }
114  }();
115 
116  static constexpr auto Print()
117  {
118  printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerXdl: %d, %d, %d\n",
119  BlockSize,
120  WaveSize,
121  MPerBlock,
122  NPerBlock,
123  KPerBlock,
124  MPerXDL,
125  NPerXDL,
126  KPerXDL);
127 
128  printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: "
129  "%d, %d\n C MFMA inst: %d C MFMA cycle: %d\n"
130  "A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: "
131  "%d/ %d\n",
142  ALDSWriteWidth,
143  BLDSWriteWidth,
144  ABufferLoadWidth,
145  BBufferLoadWidth);
146  }
147 };
148 
149 } // namespace ck
Definition: ck.hpp:267
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
SchedulerGroup
Definition: blkgemmpipe_scheduler.hpp:53
@ SCHED_GROUP_LDS_READ
Definition: blkgemmpipe_scheduler.hpp:56
@ SCHED_GROUP_MFMA
Definition: blkgemmpipe_scheduler.hpp:54
@ SCHED_GROUP_LDS_WRITE
Definition: blkgemmpipe_scheduler.hpp:57
@ SCHED_GROUP_VMEM
Definition: blkgemmpipe_scheduler.hpp:55
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
int32_t index_t
Definition: ck.hpp:298
unsigned int uint32_t
Definition: stdint.h:126
static constexpr index_t B_LDS_Write_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:46
static constexpr index_t A_LDS_Read_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:49
static constexpr index_t A_LDS_Read_Width
Definition: blkgemmpipe_scheduler.hpp:82
static constexpr index_t B_LDS_Read_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:51
static constexpr index_t A_LDS_Write_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:44
static constexpr index_t C_MFMA_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:54
static constexpr index_t C_MFMA_Inst_Cycle
Definition: blkgemmpipe_scheduler.hpp:105
static constexpr index_t A_Buffer_Load_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:39
static constexpr index_t C_MFMA_SpeedUp
Definition: blkgemmpipe_scheduler.hpp:103
static constexpr index_t WaveSize
Definition: blockwise_gemm_pipeline_xdlops.hpp:37
static constexpr index_t B_Buffer_Load_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:41
static constexpr auto Print()
Definition: blkgemmpipe_scheduler.hpp:116
static constexpr index_t WaveNumN
Definition: blockwise_gemm_pipeline_xdlops.hpp:36
static constexpr index_t WaveNumM
Definition: blockwise_gemm_pipeline_xdlops.hpp:35
static constexpr index_t B_LDS_Read_Width
Definition: blkgemmpipe_scheduler.hpp:83