/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp Source File
gridwise_gemm_waveletmodel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 template <typename TileLoadThreadGroup, index_t NumGemmKPrefetchStage>
12 
13 // 1-stage prefetch
14 template <typename TileLoadThreadGroup>
15 struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
16 {
17  __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */)
18  {
19  // TODO: improve applicability
20  return true;
21  }
22 
23  __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
24  {
25  return num_loop > 1;
26  }
27 
28  template <bool HasMainLoop,
29  typename AGridDesc,
30  typename ABlockDesc,
31  typename ABlockTransfer,
32  typename AGridBuffer,
33  typename ABlockBuffer,
34  typename ABlockTransferStep,
35  typename BGridDesc,
36  typename BBlockDesc,
37  typename BBlockTransfer,
38  typename BGridBuffer,
39  typename BBlockBuffer,
40  typename BBlockTransferStep>
41  static __device__ void RunLoadWavePipeline(const AGridDesc& a_grid_desc,
42  const ABlockDesc& a_block_desc,
43  ABlockTransfer& a_blockwise_copy,
44  const AGridBuffer& a_grid_buf,
45  ABlockBuffer& a_block_buf,
46  const ABlockTransferStep& a_block_copy_step,
47  const BGridDesc& b_grid_desc,
48  const BBlockDesc& b_block_desc,
49  BBlockTransfer& b_blockwise_copy,
50  const BGridBuffer& b_grid_buf,
51  BBlockBuffer& b_block_buf,
52  const BBlockTransferStep& b_block_copy_step,
53  index_t num_loop)
54  {
55  // global read 0
56  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
57  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
58 
59  // move to 1
60  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
61  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
62 
63  // LDS write 0
64  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
65  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
66 
67  if constexpr(HasMainLoop)
68  {
69  index_t i = 0;
70 
71  do
72  {
73  // sync for Load threads()
75  // global read i + 1
76  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
77  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
78 
79  // move to i + 2
80  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
81  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
82 
83  // sync with math threads()
85 
86  // LDS write i+1
87  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
88  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
89 
90  ++i;
91  } while(i < (num_loop - 1));
92  }
93 
94  // tail
95  {
97  // GEMM num_loop - 1
98  }
99  }
100 };
101 
102 template <typename TileMathThreadGroup, index_t NumGemmKPrefetchStage>
104 // 1- stage prefetch
105 template <typename TileMathThreadGroup>
106 struct GridwiseGemmMathWave<TileMathThreadGroup, 1>
107 {
108 
109  __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
110 
111  __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
112  {
113  return num_loop > 1;
114  }
115 
116  template <bool HasMainLoop,
117  typename ABlockBuffer,
118  typename BBlockBuffer,
119  typename BlockwiseGemm,
120  typename CThreadBuffer>
121  static __device__ void RunMathWavePipeline(ABlockBuffer& a_block_buf,
122  BBlockBuffer& b_block_buf,
123  const BlockwiseGemm& block_gemm,
124  CThreadBuffer& c_thread_buf,
125  index_t num_loop)
126  {
127  // Initialize C
128  c_thread_buf.Clear();
129 
130  // main body
131  if constexpr(HasMainLoop)
132  {
133  index_t i = 0;
134 
135  do
136  {
137  block_sync_lds();
138 
139  // GEMM i
140  block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
141 
142  block_sync_lds();
143  ++i;
144  } while(i < (num_loop - 1));
145  }
146 
147  // tail
148  {
149  block_sync_lds();
150 
151  // GEMM num_loop - 1
152  block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
153  }
154  }
155 };
156 
157 } // namespace ck
Definition: ck.hpp:267
int32_t index_t
Definition: ck.hpp:298
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
static __device__ void RunLoadWavePipeline(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, index_t num_loop)
Definition: gridwise_gemm_waveletmodel.hpp:41
__host__ static constexpr __device__ bool IsSupported(index_t)
Definition: gridwise_gemm_waveletmodel.hpp:17
__host__ static constexpr __device__ bool CalculateHasMainLoop(index_t num_loop)
Definition: gridwise_gemm_waveletmodel.hpp:23
Definition: gridwise_gemm_waveletmodel.hpp:11
static __device__ void RunMathWavePipeline(ABlockBuffer &a_block_buf, BBlockBuffer &b_block_buf, const BlockwiseGemm &block_gemm, CThreadBuffer &c_thread_buf, index_t num_loop)
Definition: gridwise_gemm_waveletmodel.hpp:121
__host__ static constexpr __device__ bool IsSupported(index_t)
Definition: gridwise_gemm_waveletmodel.hpp:109
__host__ static constexpr __device__ bool CalculateHasMainLoop(index_t num_loop)
Definition: gridwise_gemm_waveletmodel.hpp:111
Definition: gridwise_gemm_waveletmodel.hpp:103