/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp Source File
gridwise_gemm_pipeline_v2.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
11 {
12  __host__ __device__ static constexpr bool IsSupported(const index_t num_loop)
13  {
14  // TODO: improve applicability
15  return num_loop % 2 == 0;
16  }
17 
18  __host__ __device__ static constexpr bool CalculateHasMainLoop(const index_t num_loop)
19  {
20  return (num_loop / 2) > 1;
21  }
22 
23  template <bool HasMainLoop,
24  typename AGridDesc,
25  typename ABlockDesc,
26  typename ABlockTransfer,
27  typename AGridBuffer,
28  typename ABlockBuffer,
29  typename ABlockTransferStep,
30  typename BGridDesc,
31  typename BBlockDesc,
32  typename BBlockTransfer,
33  typename BGridBuffer,
34  typename BBlockBuffer,
35  typename BBlockTransferStep,
36  typename BlockwiseGemm,
37  typename CThreadBuffer>
38  __device__ static void Run(const AGridDesc& a_grid_desc,
39  const ABlockDesc& a_block_desc,
40  ABlockTransfer& a_blockwise_copy,
41  const AGridBuffer& a_grid_buf,
42  ABlockBuffer& a_block_buf,
43  const ABlockTransferStep& a_block_copy_step,
44  const BGridDesc& b_grid_desc,
45  const BBlockDesc& b_block_desc,
46  BBlockTransfer& b_blockwise_copy,
47  const BGridBuffer& b_grid_buf,
48  BBlockBuffer& b_block_buf,
49  const BBlockTransferStep& b_block_copy_step,
50  const BlockwiseGemm& blockwise_gemm,
51  CThreadBuffer& c_thread_buf,
52  index_t num_loop)
53  {
54  // global read 0
55  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
56  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
57 
58  // move to 1
59  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
60  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
61 
62  // Initialize C
63  c_thread_buf.Clear();
64 
65  // LDS write 0
66  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
67  // global Read 1
68  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
69 
70  // LDS write 0
71  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
72  // global Read 1
73  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
74 
75  // main body
76  if constexpr(HasMainLoop)
77  {
78  index_t i = 0;
79 
80  do
81  {
82 #if CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT
83  __builtin_amdgcn_iglp_opt(CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT);
84 #endif
85 
87 
88  // GEMM i
89  blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
90 
92 
93  // move to i + 2
94  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
95  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
96 
97  // LDS write i + 1
98  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
99  // global read i + 2
100  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
101 
102  // LDS write i + 1
103  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
104  // global read i + 2
105  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
106 
107  ++i;
108  } while(i < (num_loop - 2));
109  }
110 
111  // tail
112  {
113  block_sync_lds();
114 
115  // GEMM num_loop - 2
116  blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
117 
118  block_sync_lds();
119 
120  // LDS write num_loop - 1
121  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
122  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
123 
124  block_sync_lds();
125 
126  // GEMM num_loop - 1
127  blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
128  }
129  }
130 };
131 
132 } // namespace ck
#define CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT
Definition: ck.hpp:217
Definition: ck.hpp:267
int32_t index_t
Definition: ck.hpp:298
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: gridwise_gemm_pipeline_v2.hpp:11
__host__ static constexpr __device__ bool CalculateHasMainLoop(const index_t num_loop)
Definition: gridwise_gemm_pipeline_v2.hpp:18
__host__ static constexpr __device__ bool IsSupported(const index_t num_loop)
Definition: gridwise_gemm_pipeline_v2.hpp:12
static __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, const BlockwiseGemm &blockwise_gemm, CThreadBuffer &c_thread_buf, index_t num_loop)
Definition: gridwise_gemm_pipeline_v2.hpp:38