/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/blkgemmpipe_scheduler.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/blkgemmpipe_scheduler.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/blkgemmpipe_scheduler.hpp Source File
blkgemmpipe_scheduler.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
9 
10 namespace ck {
11 
13 {
14  SCHED_GROUP_MFMA = 0x008, // Matrix FMA instructions
15  SCHED_GROUP_VMEM = 0x020, // Global memory operations
16  SCHED_GROUP_LDS_READ = 0x100, // LDS read operations
17  SCHED_GROUP_LDS_WRITE = 0x200 // LDS write operations
18 };
19 
20 template <index_t BlockSize,
21  index_t MPerBlock,
22  index_t NPerBlock,
23  index_t KPerBlock,
24  index_t ABufferLoadWidth,
25  index_t BBufferLoadWidth,
26  index_t ALDSWriteWidth,
27  index_t BLDSWriteWidth,
28  index_t ALDSReadWidth,
29  index_t BLDSReadWidth,
30  index_t MRepeat,
31  index_t NRepeat,
32  index_t MPerXDL,
33  index_t NPerXDL,
34  index_t KPerXDL,
35  bool IsF4F6 = false>
36 struct BlockwiseGemmXdlops_pipeline_hotloop_inst
37 {
38  static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerXDL);
39  static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerXDL);
40  static constexpr index_t WaveSize = BlockSize / WaveNumM / WaveNumN;
41 
42  static constexpr index_t A_LDS_Read_Width = ALDSReadWidth;
43  static constexpr index_t B_LDS_Read_Width = BLDSReadWidth;
44 
45  static constexpr index_t A_Buffer_Load_Inst_Num =
46  MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth);
47  static constexpr index_t B_Buffer_Load_Inst_Num =
48  NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth);
49 
50  static constexpr index_t A_LDS_Write_Inst_Num =
51  MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth);
52  static constexpr index_t B_LDS_Write_Inst_Num =
53  NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth);
54 
55  static constexpr index_t A_LDS_Read_Inst_Num =
56  WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth);
57  static constexpr index_t B_LDS_Read_Inst_Num =
58  WaveNumM * NPerBlock * KPerBlock / (BlockSize * BLDSReadWidth);
59 
60  static constexpr index_t C_MFMA_Inst_Num =
61  MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
62 
63  static constexpr index_t C_MFMA_SpeedUp = IsF4F6 ? 2 : 1;
64 
65  static constexpr index_t C_MFMA_Inst_Cycle = []() {
66  if constexpr(NPerXDL == 16)
67  {
68  return KPerXDL == 128 ? 32 / C_MFMA_SpeedUp : 16 / C_MFMA_SpeedUp;
69  }
70  else if constexpr(NPerXDL == 32)
71  {
72  return KPerXDL == 64 ? 64 / C_MFMA_SpeedUp : 32 / C_MFMA_SpeedUp;
73  }
74  }();
75 
76  static constexpr auto Print()
77  {
78  printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerXdl: %d, %d, %d\n",
79  BlockSize,
80  WaveSize,
81  MPerBlock,
82  NPerBlock,
83  KPerBlock,
84  MPerXDL,
85  NPerXDL,
86  KPerXDL);
87 
88  printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: "
89  "%d, %d\n C MFMA inst: %d C MFMA cycle: %d\n"
90  "A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: "
91  "%d/ %d\n",
102  ALDSWriteWidth,
103  BLDSWriteWidth,
104  ABufferLoadWidth,
105  BBufferLoadWidth);
106  }
107 };
108 
109 } // namespace ck
Definition: ck.hpp:270
SchedulerGroup
Definition: blkgemmpipe_scheduler.hpp:13
@ SCHED_GROUP_LDS_READ
Definition: blkgemmpipe_scheduler.hpp:16
@ SCHED_GROUP_MFMA
Definition: blkgemmpipe_scheduler.hpp:14
@ SCHED_GROUP_LDS_WRITE
Definition: blkgemmpipe_scheduler.hpp:17
@ SCHED_GROUP_VMEM
Definition: blkgemmpipe_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:301
unsigned int uint32_t
Definition: stdint.h:126
static constexpr index_t B_LDS_Write_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:46
static constexpr index_t A_LDS_Read_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:49
static constexpr index_t A_LDS_Read_Width
Definition: blkgemmpipe_scheduler.hpp:42
static constexpr index_t B_LDS_Read_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:51
static constexpr index_t A_LDS_Write_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:44
static constexpr index_t C_MFMA_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:54
static constexpr index_t C_MFMA_Inst_Cycle
Definition: blkgemmpipe_scheduler.hpp:65
static constexpr index_t A_Buffer_Load_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:39
static constexpr index_t C_MFMA_SpeedUp
Definition: blkgemmpipe_scheduler.hpp:63
static constexpr index_t WaveSize
Definition: blockwise_gemm_pipeline_xdlops.hpp:37
static constexpr index_t B_Buffer_Load_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:41
static constexpr auto Print()
Definition: blkgemmpipe_scheduler.hpp:76
static constexpr index_t WaveNumN
Definition: blockwise_gemm_pipeline_xdlops.hpp:36
static constexpr index_t WaveNumM
Definition: blockwise_gemm_pipeline_xdlops.hpp:35
static constexpr index_t B_LDS_Read_Width
Definition: blkgemmpipe_scheduler.hpp:43