/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/host_utility/flush_cache.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/host_utility/flush_cache.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/host_utility/flush_cache.hpp Source File
flush_cache.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <hip/hip_runtime.h>
7 #include <set>
8 #include <vector>
9 
10 #include "ck/ck.hpp"
11 #include "ck/utility/env.hpp"
12 #include "ck/stream_config.hpp"
15 namespace ck {
16 namespace utility {
17 
18 template <typename Argument, typename DsDataType>
20 {
21  static constexpr index_t NumDs = DsDataType::Size();
22 
23  using ADataType = decltype(Argument::p_a_grid);
24  using BDataType = decltype(Argument::p_b_grid);
25  using DsGridPointer = decltype(Argument::p_ds_grid);
26 
28  RotatingMemWrapperMultiD(Argument& arg_,
29  std::size_t rotating_count_,
30  std::size_t size_a_,
31  std::size_t size_b_,
32  std::array<std::size_t, NumDs> size_ds_)
33  : arg(arg_),
34  rotating_count(rotating_count_),
35  size_a(size_a_),
36  size_b(size_b_),
37  size_ds(size_ds_)
38  {
39  p_a_grids.push_back(arg.p_a_grid);
40  p_b_grids.push_back(arg.p_b_grid);
41  p_ds_grids.push_back(arg.p_ds_grid);
42  for(size_t i = 1; i < rotating_count; i++)
43  {
44  {
45  void* pADeviceBuf;
46  hip_check_error(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
47  hip_check_error(hipMemcpy(static_cast<void*>(pADeviceBuf),
48  const_cast<void*>(p_a_grids[0]),
49  size_a_,
50  hipMemcpyDeviceToDevice));
51  p_a_grids.push_back(pADeviceBuf);
52  }
53 
54  {
55  void* pBDeviceBuf;
56  hip_check_error(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
57  hip_check_error(hipMemcpy(static_cast<void*>(pBDeviceBuf),
58  const_cast<void*>(p_b_grids[0]),
59  size_b_,
60  hipMemcpyDeviceToDevice));
61  p_b_grids.push_back(pBDeviceBuf);
62  }
63 
64  {
65 
66  DsGridPointer ds_buffer;
67  static_for<0, NumDs, 1>{}([&](auto j) {
68  void* pDDeviceBuf;
69  hip_check_error(hipMalloc(static_cast<void**>(&pDDeviceBuf), size_ds_[j]));
70  hip_check_error(hipMemcpy(static_cast<void*>(pDDeviceBuf),
71  static_cast<const void*>(p_ds_grids[0][j]),
72  size_ds_[j],
73  hipMemcpyDeviceToDevice));
74 
75  using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
76 
77  ds_buffer(j) = static_cast<const DDataType*>(pDDeviceBuf);
78  });
79 
80  p_ds_grids.push_back(ds_buffer);
81  }
82  }
83  }
84 
85  void Next()
86  {
87  if(rotating_count > 1)
88  {
89  std::size_t idx = iter++ % rotating_count;
90  arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[idx]);
91  arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[idx]);
92  arg.p_ds_grid = p_ds_grids[idx];
93  }
94  }
95  void Print()
96  {
97  std::cout << "RotatingMemWrapperMultiD: { size_a: " << size_a << ", size_b: " << size_b
98  << ", rotating_count: " << rotating_count << "}" << std::endl;
99  }
101  {
102  if(rotating_count > 1)
103  {
104  // restore ptr
105  arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[0]);
106  arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[0]);
107  arg.p_ds_grid = p_ds_grids[0];
108 
109  // free device mem
110  for(size_t i = 1; i < rotating_count; i++)
111  {
112  hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
113  hip_check_error(hipFree(const_cast<void*>(p_b_grids[i])));
114 
115  static_for<0, NumDs, 1>{}([&](auto j) {
116  using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
118  hipFree(static_cast<void*>(const_cast<DDataType*>(p_ds_grids[i][j]))));
119  });
120  }
121  }
122  }
123 
124  private:
125  Argument& arg;
126  std::size_t iter = 0;
127  std::size_t rotating_count = 1;
128  std::size_t size_a = 0;
129  std::size_t size_b = 0;
130  std::array<std::size_t, NumDs> size_ds = {0};
131  std::vector<const void*> p_a_grids;
132  std::vector<const void*> p_b_grids;
133  std::vector<DsGridPointer> p_ds_grids;
134 };
135 
136 template <typename Argument>
138 {
139  using ADataType = decltype(Argument::p_a_grid);
140  using BDataType = decltype(Argument::p_b_grid);
141 
142  RotatingMemWrapper() = delete;
143  RotatingMemWrapper(Argument& arg_,
144  std::size_t rotating_count_,
145  std::size_t size_a_,
146  std::size_t size_b_)
147  : arg(arg_), rotating_count(rotating_count_), size_a(size_a_), size_b(size_b_)
148  {
149  p_a_grids.push_back(arg.p_a_grid);
150  p_b_grids.push_back(arg.p_b_grid);
151  for(size_t i = 1; i < rotating_count; i++)
152  {
153  {
154  void* pADeviceBuf;
155  hip_check_error(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
156  hip_check_error(hipMemcpy(static_cast<void*>(pADeviceBuf),
157  const_cast<void*>(p_a_grids[0]),
158  size_a_,
159  hipMemcpyDeviceToDevice));
160  p_a_grids.push_back(pADeviceBuf);
161  }
162 
163  {
164  void* pBDeviceBuf;
165  hip_check_error(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
166  hip_check_error(hipMemcpy(static_cast<void*>(pBDeviceBuf),
167  const_cast<void*>(p_b_grids[0]),
168  size_b_,
169  hipMemcpyDeviceToDevice));
170  p_b_grids.push_back(pBDeviceBuf);
171  }
172  }
173  }
174 
175  void Next()
176  {
177  if(rotating_count > 1)
178  {
179  std::size_t idx = iter++ % rotating_count;
180  arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[idx]);
181  arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[idx]);
182  }
183  }
184  void Print()
185  {
186  std::cout << "RotatingMemWrapper: { size_a: " << size_a << ", size_b: " << size_b
187  << ", rotating_count: " << rotating_count << "}" << std::endl;
188  }
190  {
191  if(rotating_count > 1)
192  {
193  // restore ptr
194  arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[0]);
195  arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[0]);
196 
197  // free device mem
198  for(size_t i = 1; i < rotating_count; i++)
199  {
200  hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
201  hip_check_error(hipFree(const_cast<void*>(p_b_grids[i])));
202  }
203  }
204  }
205 
206  private:
207  Argument& arg;
208  std::size_t iter = 0;
209  std::size_t rotating_count = 1;
210  std::size_t size_a = 0;
211  std::size_t size_b = 0;
212  std::vector<const void*> p_a_grids;
213  std::vector<const void*> p_b_grids;
214 };
215 
216 inline void flush_icache()
217 {
218  hipDeviceProp_t deviceProps;
219  hip_check_error(hipGetDeviceProperties(&deviceProps, 0));
220  int32_t gpu_block3 = deviceProps.multiProcessorCount * 60;
221 
222  ck::flush_icache<<<dim3(gpu_block3), dim3(64), 0, nullptr>>>();
223  hip_check_error(hipGetLastError());
224 }
225 // if TimePrePress == false, return time does not include preprocess's time
226 template <bool TimePreprocess,
227  typename GemmArgs,
228  typename... Args,
229  typename F,
230  typename PreProcessFunc>
232  PreProcessFunc preprocess,
233  F kernel,
234  dim3 grid_dim,
235  dim3 block_dim,
236  std::size_t lds_byte,
237  GemmArgs& gemm_args,
238  Args... args)
239 {
240 #if CK_TIME_KERNEL
241 #define MEDIAN 0
242  if(stream_config.time_kernel_)
243  {
244  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
245  {
246  printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n",
247  __func__,
248  grid_dim.x,
249  grid_dim.y,
250  grid_dim.z,
251  block_dim.x,
252  block_dim.y,
253  block_dim.z);
254 
255  printf("Warm up %d times\n", stream_config.cold_niters_);
256  }
257  // warm up
258  for(int i = 0; i < stream_config.cold_niters_; ++i)
259  {
260  kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
261  hip_check_error(hipGetLastError());
262  }
263 
264  const int nrepeat = stream_config.nrepeat_;
265  if(nrepeat == 0)
266  {
267  return 0.0;
268  }
269  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
270  {
271  printf("Start running %d times...\n", nrepeat);
272  }
273 
274 #if MEDIAN
275  std::set<float> times;
276 #else
277  float total_time = 0;
278 #endif
279  hipEvent_t start, stop;
280 
281  hip_check_error(hipEventCreate(&start));
282  hip_check_error(hipEventCreate(&stop));
283 
284  hip_check_error(hipDeviceSynchronize());
285  hip_check_error(hipEventRecord(start, stream_config.stream_id_));
286 
287  for(int i = 0; i < nrepeat; ++i)
288  {
289  if constexpr(!TimePreprocess)
290  {
291  preprocess();
292  }
293 
294  // hipEvent_t start, stop;
295 
296  // hip_check_error(hipEventCreate(&start));
297  // hip_check_error(hipEventCreate(&stop));
298 
299  // hip_check_error(hipDeviceSynchronize());
300  // hip_check_error(hipEventRecord(start, stream_config.stream_id_));
301  // calculate preprocess time
302  if constexpr(TimePreprocess)
303  {
304  preprocess();
305  }
306  // run real kernel
307  kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
308  hip_check_error(hipGetLastError());
309  // end real kernel
310 
311  // hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
312  // hip_check_error(hipEventSynchronize(stop));
313  // float cur_time = 0;
314  // hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
315  // #if MEDIAN
316  // times.insert(cur_time);
317  // #else
318  // total_time += cur_time;
319  // #endif
320 
321  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
322  {
323  // std::cout << "i: " << i << " cur_time: " << cur_time << std::endl;
324 
325  printf("gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p\n",
326  static_cast<const void*>(gemm_args.p_a_grid),
327  static_cast<const void*>(gemm_args.p_b_grid));
328  }
329  }
330  hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
331  hip_check_error(hipEventSynchronize(stop));
332  float cur_time = 0;
333  hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
334 #if MEDIAN
335  times.insert(cur_time);
336 #else
337  total_time += cur_time;
338 #endif
339 
340 #if MEDIAN
341  auto mid = times.begin();
342  std::advance(mid, (nrepeat - 1) / 2);
343  if(nrepeat % 2 == 1)
344  {
345  return *mid;
346  }
347  else
348  {
349  auto mid_next = mid;
350  std::advance(mid_next, 1);
351  return (*mid + *mid_next) / 2;
352  }
353 #else
354  // return total_time / nrepeat;
355  hipDeviceProp_t deviceProps;
356  hip_check_error(hipGetDeviceProperties(&deviceProps, 0));
357  float preprocess_offset = deviceProps.multiProcessorCount == 80 ? 0.005 : 0.01;
358  return (total_time - preprocess_offset * nrepeat) / nrepeat;
359 #endif
360  }
361  else
362  {
363  preprocess();
364  kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
365  hip_check_error(hipGetLastError());
366 
367  return 0;
368  }
369 #else
370  kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
371  hip_check_error(hipGetLastError());
372 
373  return 0;
374 #endif
375 }
376 
377 } // namespace utility
378 } // namespace ck
void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:10
void flush_icache()
Definition: flush_cache.hpp:216
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition: flush_cache.hpp:231
Definition: ck.hpp:267
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:298
signed int int32_t
Definition: stdint.h:123
Definition: stream_config.hpp:10
int cold_niters_
Definition: stream_config.hpp:14
bool time_kernel_
Definition: stream_config.hpp:12
int nrepeat_
Definition: stream_config.hpp:15
hipStream_t stream_id_
Definition: stream_config.hpp:11
Definition: functional2.hpp:33
Definition: flush_cache.hpp:138
decltype(Argument::p_a_grid) ADataType
Definition: flush_cache.hpp:139
RotatingMemWrapper(Argument &arg_, std::size_t rotating_count_, std::size_t size_a_, std::size_t size_b_)
Definition: flush_cache.hpp:143
decltype(Argument::p_b_grid) BDataType
Definition: flush_cache.hpp:140
~RotatingMemWrapper()
Definition: flush_cache.hpp:189
void Print()
Definition: flush_cache.hpp:184
void Next()
Definition: flush_cache.hpp:175
Definition: flush_cache.hpp:20
decltype(Argument::p_b_grid) BDataType
Definition: flush_cache.hpp:24
void Print()
Definition: flush_cache.hpp:95
static constexpr index_t NumDs
Definition: flush_cache.hpp:21
decltype(Argument::p_a_grid) ADataType
Definition: flush_cache.hpp:23
~RotatingMemWrapperMultiD()
Definition: flush_cache.hpp:100
decltype(Argument::p_ds_grid) DsGridPointer
Definition: flush_cache.hpp:25
RotatingMemWrapperMultiD(Argument &arg_, std::size_t rotating_count_, std::size_t size_a_, std::size_t size_b_, std::array< std::size_t, NumDs > size_ds_)
Definition: flush_cache.hpp:28
void Next()
Definition: flush_cache.hpp:85
#define CK_ENV(name)
Definition: env.hpp:129