/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops.hpp Source File
blockwise_gemm_pipeline_wmmaops.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 template <index_t BlockSize,
11  index_t MPerBlock,
12  index_t NPerBlock,
13  index_t KPerBlock,
14  index_t ABufferLoadWidth,
15  index_t BBufferLoadWidth,
16  index_t ALDSWriteWidth,
17  index_t BLDSWriteWidth,
18  index_t ALDSReadWidth,
19  index_t BLDSReadWidth,
20  index_t MRepeat,
21  index_t NRepeat,
22  index_t MPerWmma,
23  index_t NPerWmma,
24  index_t KPerWmma>
26 {
27  static constexpr index_t WaveSize = 32;
28  static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerWmma);
29  static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerWmma);
30 
31  static constexpr index_t A_LDS_Read_Width = ALDSReadWidth;
32  static constexpr index_t B_LDS_Read_Width = BLDSReadWidth;
33 
34  static constexpr index_t A_Buffer_Load_Inst_Num =
35  MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth);
36  static constexpr index_t B_Buffer_Load_Inst_Num =
37  NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth);
38 
39  static constexpr index_t A_LDS_Write_Inst_Num =
40  MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth);
41  static constexpr index_t B_LDS_Write_Inst_Num =
42  NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth);
43 
44  static constexpr index_t A_LDS_Read_Inst_Num =
45  WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth);
46  static constexpr index_t B_LDS_Read_Inst_Num =
47  WaveNumM * NPerBlock * KPerBlock / (BlockSize * BLDSReadWidth);
48 
49  static constexpr index_t C_WMMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
50  (BlockSize / WaveSize) /
51  (MPerWmma * NPerWmma * KPerWmma);
52 
53  static constexpr auto Print()
54  {
55  printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerWmma: %d, %d, %d\n",
56  BlockSize,
57  WaveSize,
58  MPerBlock,
59  NPerBlock,
60  KPerBlock,
61  MPerWmma,
62  NPerWmma,
63  KPerWmma);
64 
65  printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: "
66  "%d, %d\n C WMMA inst: %d\n"
67  "A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: "
68  "%d, %d\n",
78  ALDSWriteWidth,
79  BLDSWriteWidth,
80  ABufferLoadWidth,
81  BBufferLoadWidth);
82  }
83 };
84 
85 } // namespace ck
Definition: ck.hpp:267
int32_t index_t
Definition: ck.hpp:298
Definition: blockwise_gemm_pipeline_wmmaops.hpp:26
static constexpr index_t B_LDS_Read_Inst_Num
Definition: blockwise_gemm_pipeline_wmmaops.hpp:46
static constexpr index_t B_LDS_Read_Width
Definition: blockwise_gemm_pipeline_wmmaops.hpp:32
static constexpr index_t A_LDS_Write_Inst_Num
Definition: blockwise_gemm_pipeline_wmmaops.hpp:39
static constexpr index_t WaveSize
Definition: blockwise_gemm_pipeline_wmmaops.hpp:27
static constexpr index_t WaveNumN
Definition: blockwise_gemm_pipeline_wmmaops.hpp:29
static constexpr index_t B_Buffer_Load_Inst_Num
Definition: blockwise_gemm_pipeline_wmmaops.hpp:36
static constexpr index_t A_LDS_Read_Inst_Num
Definition: blockwise_gemm_pipeline_wmmaops.hpp:44
static constexpr index_t A_Buffer_Load_Inst_Num
Definition: blockwise_gemm_pipeline_wmmaops.hpp:34
static constexpr auto Print()
Definition: blockwise_gemm_pipeline_wmmaops.hpp:53
static constexpr index_t C_WMMA_Inst_Num
Definition: blockwise_gemm_pipeline_wmmaops.hpp:49
static constexpr index_t B_LDS_Write_Inst_Num
Definition: blockwise_gemm_pipeline_wmmaops.hpp:41
static constexpr index_t A_LDS_Read_Width
Definition: blockwise_gemm_pipeline_wmmaops.hpp:31
static constexpr index_t WaveNumM
Definition: blockwise_gemm_pipeline_wmmaops.hpp:28