include/ck/utility/blkgemmpipe_scheduler.hpp Source File

include/ck/utility/blkgemmpipe_scheduler.hpp Source File#

Composable Kernel: include/ck/utility/blkgemmpipe_scheduler.hpp Source File
blkgemmpipe_scheduler.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
8 
9 namespace ck {
10 
12 {
13  Intrawave,
14  Interwave,
15 };
16 
17 enum struct TailNumber
18 {
19  // Single / Double buffer pipeline
20  Odd,
21  Even,
22 
23  // Long prefetch pipeline, up to 8
24  One,
25  Two,
26  Three,
27  Four,
28  Five,
29  Six,
30  Seven,
31 
32  // Unroll stages > Prefetch stages, number of loop is multiple of unroll stages
33  Empty,
34  // Unroll stages <= Prefetch stages, number of loop is multiple of unroll stages add
35  // prefetchstages
36  Full,
37 };
38 template <index_t BlockSize,
39  index_t MPerBlock,
40  index_t NPerBlock,
41  index_t KPerBlock,
42  index_t ABufferLoadWidth,
43  index_t BBufferLoadWidth,
44  index_t ALDSWriteWidth,
45  index_t BLDSWriteWidth,
46  index_t ALDSReadWidth,
47  index_t BLDSReadWidth,
48  index_t MRepeat,
49  index_t NRepeat,
50  index_t MPerXDL,
51  index_t NPerXDL,
52  index_t KPerXDL>
53 struct BlockwiseGemmXdlops_pipeline_hotloop_inst
54 {
55  static constexpr index_t WaveSize = 64;
56  static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerXDL);
57  static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerXDL);
58 
59  static constexpr index_t A_LDS_Read_Width = ALDSReadWidth;
60  static constexpr index_t B_LDS_Read_Width = BLDSReadWidth;
61 
62  static constexpr index_t A_Buffer_Load_Inst_Num =
63  MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth);
64  static constexpr index_t B_Buffer_Load_Inst_Num =
65  NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth);
66 
67  static constexpr index_t A_LDS_Write_Inst_Num =
68  MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth);
69  static constexpr index_t B_LDS_Write_Inst_Num =
70  NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth);
71 
72  static constexpr index_t A_LDS_Read_Inst_Num =
73  WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth);
74  static constexpr index_t B_LDS_Read_Inst_Num =
75  WaveNumM * MPerBlock * KPerBlock / (BlockSize * BLDSReadWidth);
76 
77  static constexpr index_t C_MFMA_Inst_Num =
78  MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
79 
80  static constexpr auto Print()
81  {
82  printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerXdl: %d, %d, %d\n",
83  BlockSize,
84  WaveSize,
85  MPerBlock,
86  NPerBlock,
87  KPerBlock,
88  MPerXDL,
89  NPerXDL,
90  KPerXDL);
91 
92  printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: "
93  "%d, %d\n C MFMA inst: %d\n"
94  "A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: "
95  "%d/ %d\n",
105  ALDSWriteWidth,
106  BLDSWriteWidth,
107  ABufferLoadWidth,
108  BBufferLoadWidth);
109  }
110 };
111 
112 } // namespace ck
Definition: ck.hpp:264
TailNumber
Definition: blkgemmpipe_scheduler.hpp:18
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:12
int32_t index_t
Definition: ck.hpp:289
static constexpr index_t B_LDS_Write_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:46
static constexpr index_t A_LDS_Read_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:49
static constexpr index_t A_LDS_Read_Width
Definition: blkgemmpipe_scheduler.hpp:59
static constexpr index_t B_LDS_Read_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:51
static constexpr index_t A_LDS_Write_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:44
static constexpr index_t C_MFMA_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:54
static constexpr index_t A_Buffer_Load_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:39
static constexpr index_t WaveSize
Definition: blockwise_gemm_pipeline_xdlops.hpp:35
static constexpr index_t B_Buffer_Load_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:41
static constexpr auto Print()
Definition: blkgemmpipe_scheduler.hpp:80
static constexpr index_t WaveNumN
Definition: blockwise_gemm_pipeline_xdlops.hpp:37
static constexpr index_t WaveNumM
Definition: blockwise_gemm_pipeline_xdlops.hpp:36
static constexpr index_t B_LDS_Read_Width
Definition: blkgemmpipe_scheduler.hpp:60