include/ck/host_utility/flush_cache.hpp Source File

include/ck/host_utility/flush_cache.hpp Source File#

Composable Kernel: include/ck/host_utility/flush_cache.hpp Source File
flush_cache.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <hip/hip_runtime.h>
7 #include <set>
8 #include <vector>
9 
10 #include "ck/ck.hpp"
11 #include "ck/stream_config.hpp"
14 namespace ck {
15 namespace utility {
16 
17 template <typename Argument, typename DsDataType>
19 {
20  static constexpr index_t NumDs = DsDataType::Size();
21 
22  using ADataType = decltype(Argument::p_a_grid);
23  using BDataType = decltype(Argument::p_b_grid);
24  using DsGridPointer = decltype(Argument::p_ds_grid);
25 
27  RotatingMemWrapperMultiD(Argument& arg_,
28  std::size_t rotating_count_,
29  std::size_t size_a_,
30  std::size_t size_b_,
31  std::array<std::size_t, NumDs> size_ds_)
32  : arg(arg_),
33  rotating_count(rotating_count_),
34  size_a(size_a_),
35  size_b(size_b_),
36  size_ds(size_ds_)
37  {
38  p_a_grids.push_back(arg.p_a_grid);
39  p_b_grids.push_back(arg.p_b_grid);
40  p_ds_grids.push_back(arg.p_ds_grid);
41  for(size_t i = 1; i < rotating_count; i++)
42  {
43  {
44  void* pADeviceBuf;
45  hip_check_error(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
46  hip_check_error(hipMemcpy(static_cast<void*>(pADeviceBuf),
47  const_cast<void*>(p_a_grids[0]),
48  size_a_,
49  hipMemcpyDeviceToDevice));
50  p_a_grids.push_back(pADeviceBuf);
51  }
52 
53  {
54  void* pBDeviceBuf;
55  hip_check_error(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
56  hip_check_error(hipMemcpy(static_cast<void*>(pBDeviceBuf),
57  const_cast<void*>(p_b_grids[0]),
58  size_b_,
59  hipMemcpyDeviceToDevice));
60  p_b_grids.push_back(pBDeviceBuf);
61  }
62 
63  {
64 
65  DsGridPointer ds_buffer;
66  static_for<0, NumDs, 1>{}([&](auto j) {
67  void* pDDeviceBuf;
68  hip_check_error(hipMalloc(static_cast<void**>(&pDDeviceBuf), size_ds_[j]));
69  hip_check_error(hipMemcpy(static_cast<void*>(pDDeviceBuf),
70  static_cast<const void*>(p_ds_grids[0][j]),
71  size_ds_[j],
72  hipMemcpyDeviceToDevice));
73 
74  using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
75 
76  ds_buffer(j) = static_cast<const DDataType*>(pDDeviceBuf);
77  });
78 
79  p_ds_grids.push_back(ds_buffer);
80  }
81  }
82  }
83 
84  void Next()
85  {
86  if(rotating_count > 1)
87  {
88  std::size_t idx = iter++ % rotating_count;
89  arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[idx]);
90  arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[idx]);
91  arg.p_ds_grid = p_ds_grids[idx];
92  }
93  }
94  void Print()
95  {
96  std::cout << "RotatingMemWrapperMultiD: { size_a: " << size_a << ", size_b: " << size_b
97  << ", rotating_count: " << rotating_count << "}" << std::endl;
98  }
100  {
101  if(rotating_count > 1)
102  {
103  // restore ptr
104  arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[0]);
105  arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[0]);
106  arg.p_ds_grid = p_ds_grids[0];
107 
108  // free device mem
109  for(size_t i = 1; i < rotating_count; i++)
110  {
111  hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
112  hip_check_error(hipFree(const_cast<void*>(p_b_grids[i])));
113 
114  static_for<0, NumDs, 1>{}([&](auto j) {
115  using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
117  hipFree(static_cast<void*>(const_cast<DDataType*>(p_ds_grids[i][j]))));
118  });
119  }
120  }
121  }
122 
123  private:
124  Argument& arg;
125  std::size_t iter = 0;
126  std::size_t rotating_count = 1;
127  std::size_t size_a = 0;
128  std::size_t size_b = 0;
129  std::array<std::size_t, NumDs> size_ds = {0};
130  std::vector<const void*> p_a_grids;
131  std::vector<const void*> p_b_grids;
132  std::vector<DsGridPointer> p_ds_grids;
133 };
134 
135 template <typename Argument>
137 {
138  using ADataType = decltype(Argument::p_a_grid);
139  using BDataType = decltype(Argument::p_b_grid);
140 
141  RotatingMemWrapper() = delete;
142  RotatingMemWrapper(Argument& arg_,
143  std::size_t rotating_count_,
144  std::size_t size_a_,
145  std::size_t size_b_)
146  : arg(arg_), rotating_count(rotating_count_), size_a(size_a_), size_b(size_b_)
147  {
148  p_a_grids.push_back(arg.p_a_grid);
149  p_b_grids.push_back(arg.p_b_grid);
150  for(size_t i = 1; i < rotating_count; i++)
151  {
152  {
153  void* pADeviceBuf;
154  hip_check_error(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
155  hip_check_error(hipMemcpy(static_cast<void*>(pADeviceBuf),
156  const_cast<void*>(p_a_grids[0]),
157  size_a_,
158  hipMemcpyDeviceToDevice));
159  p_a_grids.push_back(pADeviceBuf);
160  }
161 
162  {
163  void* pBDeviceBuf;
164  hip_check_error(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
165  hip_check_error(hipMemcpy(static_cast<void*>(pBDeviceBuf),
166  const_cast<void*>(p_b_grids[0]),
167  size_b_,
168  hipMemcpyDeviceToDevice));
169  p_b_grids.push_back(pBDeviceBuf);
170  }
171  }
172  }
173 
174  void Next()
175  {
176  if(rotating_count > 1)
177  {
178  std::size_t idx = iter++ % rotating_count;
179  arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[idx]);
180  arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[idx]);
181  }
182  }
183  void Print()
184  {
185  std::cout << "RotatingMemWrapper: { size_a: " << size_a << ", size_b: " << size_b
186  << ", rotating_count: " << rotating_count << "}" << std::endl;
187  }
189  {
190  if(rotating_count > 1)
191  {
192  // restore ptr
193  arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[0]);
194  arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[0]);
195 
196  // free device mem
197  for(size_t i = 1; i < rotating_count; i++)
198  {
199  hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
200  hip_check_error(hipFree(const_cast<void*>(p_b_grids[i])));
201  }
202  }
203  }
204 
205  private:
206  Argument& arg;
207  std::size_t iter = 0;
208  std::size_t rotating_count = 1;
209  std::size_t size_a = 0;
210  std::size_t size_b = 0;
211  std::vector<const void*> p_a_grids;
212  std::vector<const void*> p_b_grids;
213 };
214 
215 inline void flush_icache()
216 {
217  hipDeviceProp_t deviceProps;
218  hip_check_error(hipGetDeviceProperties(&deviceProps, 0));
219  int32_t gpu_block3 = deviceProps.multiProcessorCount * 60;
220 
221  ck::flush_icache<<<dim3(gpu_block3), dim3(64), 0, nullptr>>>();
222  hip_check_error(hipGetLastError());
223 }
224 // if TimePrePress == false, return time does not include preprocess's time
225 template <bool TimePreprocess,
226  typename GemmArgs,
227  typename... Args,
228  typename F,
229  typename PreProcessFunc>
231  PreProcessFunc preprocess,
232  F kernel,
233  dim3 grid_dim,
234  dim3 block_dim,
235  std::size_t lds_byte,
236  GemmArgs& gemm_args,
237  Args... args)
238 {
239 #if CK_TIME_KERNEL
240 #define MEDIAN 0
241  if(stream_config.time_kernel_)
242  {
243  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
244  {
245  printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n",
246  __func__,
247  grid_dim.x,
248  grid_dim.y,
249  grid_dim.z,
250  block_dim.x,
251  block_dim.y,
252  block_dim.z);
253 
254  printf("Warm up %d times\n", stream_config.cold_niters_);
255  }
256  // warm up
257  for(int i = 0; i < stream_config.cold_niters_; ++i)
258  {
259  kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
260  hip_check_error(hipGetLastError());
261  }
262 
263  const int nrepeat = stream_config.nrepeat_;
264  if(nrepeat == 0)
265  {
266  return 0.0;
267  }
268  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
269  {
270  printf("Start running %d times...\n", nrepeat);
271  }
272 
273 #if MEDIAN
274  std::set<float> times;
275 #else
276  float total_time = 0;
277 #endif
278  hipEvent_t start, stop;
279 
280  hip_check_error(hipEventCreate(&start));
281  hip_check_error(hipEventCreate(&stop));
282 
283  hip_check_error(hipDeviceSynchronize());
284  hip_check_error(hipEventRecord(start, stream_config.stream_id_));
285 
286  for(int i = 0; i < nrepeat; ++i)
287  {
288  if constexpr(!TimePreprocess)
289  {
290  preprocess();
291  }
292 
293  // hipEvent_t start, stop;
294 
295  // hip_check_error(hipEventCreate(&start));
296  // hip_check_error(hipEventCreate(&stop));
297 
298  // hip_check_error(hipDeviceSynchronize());
299  // hip_check_error(hipEventRecord(start, stream_config.stream_id_));
300  // calculate preprocess time
301  if constexpr(TimePreprocess)
302  {
303  preprocess();
304  }
305  // run real kernel
306  kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
307  hip_check_error(hipGetLastError());
308  // end real kernel
309 
310  // hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
311  // hip_check_error(hipEventSynchronize(stop));
312  // float cur_time = 0;
313  // hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
314  // #if MEDIAN
315  // times.insert(cur_time);
316  // #else
317  // total_time += cur_time;
318  // #endif
319 
320  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
321  {
322  // std::cout << "i: " << i << " cur_time: " << cur_time << std::endl;
323 
324  printf("gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p\n",
325  static_cast<const void*>(gemm_args.p_a_grid),
326  static_cast<const void*>(gemm_args.p_b_grid));
327  }
328  }
329  hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
330  hip_check_error(hipEventSynchronize(stop));
331  float cur_time = 0;
332  hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
333 #if MEDIAN
334  times.insert(cur_time);
335 #else
336  total_time += cur_time;
337 #endif
338 
339 #if MEDIAN
340  auto mid = times.begin();
341  std::advance(mid, (nrepeat - 1) / 2);
342  if(nrepeat % 2 == 1)
343  {
344  return *mid;
345  }
346  else
347  {
348  auto mid_next = mid;
349  std::advance(mid_next, 1);
350  return (*mid + *mid_next) / 2;
351  }
352 #else
353  // return total_time / nrepeat;
354  hipDeviceProp_t deviceProps;
355  hip_check_error(hipGetDeviceProperties(&deviceProps, 0));
356  float preprocess_offset = deviceProps.multiProcessorCount == 80 ? 0.005 : 0.01;
357  return (total_time - preprocess_offset * nrepeat) / nrepeat;
358 #endif
359  }
360  else
361  {
362  preprocess();
363  kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
364  hip_check_error(hipGetLastError());
365 
366  return 0;
367  }
368 #else
369  kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
370  hip_check_error(hipGetLastError());
371 
372  return 0;
373 #endif
374 }
375 
376 } // namespace utility
377 } // namespace ck
#define CK_ENV(name)
Definition: env.hpp:128
void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:10
void flush_icache()
Definition: flush_cache.hpp:215
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition: flush_cache.hpp:230
Definition: ck.hpp:264
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:139
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
int32_t index_t
Definition: ck.hpp:289
Definition: stream_config.hpp:10
int cold_niters_
Definition: stream_config.hpp:14
bool time_kernel_
Definition: stream_config.hpp:12
int nrepeat_
Definition: stream_config.hpp:15
hipStream_t stream_id_
Definition: stream_config.hpp:11
Definition: functional2.hpp:31
Definition: flush_cache.hpp:137
decltype(Argument::p_a_grid) ADataType
Definition: flush_cache.hpp:138
RotatingMemWrapper(Argument &arg_, std::size_t rotating_count_, std::size_t size_a_, std::size_t size_b_)
Definition: flush_cache.hpp:142
decltype(Argument::p_b_grid) BDataType
Definition: flush_cache.hpp:139
~RotatingMemWrapper()
Definition: flush_cache.hpp:188
void Print()
Definition: flush_cache.hpp:183
void Next()
Definition: flush_cache.hpp:174
Definition: flush_cache.hpp:19
decltype(Argument::p_b_grid) BDataType
Definition: flush_cache.hpp:23
void Print()
Definition: flush_cache.hpp:94
static constexpr index_t NumDs
Definition: flush_cache.hpp:20
decltype(Argument::p_a_grid) ADataType
Definition: flush_cache.hpp:22
~RotatingMemWrapperMultiD()
Definition: flush_cache.hpp:99
decltype(Argument::p_ds_grid) DsGridPointer
Definition: flush_cache.hpp:24
RotatingMemWrapperMultiD(Argument &arg_, std::size_t rotating_count_, std::size_t size_a_, std::size_t size_b_, std::array< std::size_t, NumDs > size_ds_)
Definition: flush_cache.hpp:27
void Next()
Definition: flush_cache.hpp:84