/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/host_utility/flush_cache.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/host_utility/flush_cache.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/host_utility/flush_cache.hpp Source File
flush_cache.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <hip/hip_runtime.h>
7 #include <numeric>
8 #include <set>
9 #include <vector>
10 
11 #include "ck/ck.hpp"
12 #include "ck/utility/env.hpp"
13 #include "ck/stream_config.hpp"
16 namespace ck {
17 namespace utility {
18 
19 template <typename Argument, typename AsDataType, typename BsDataType, typename DsDataType>
21 {
22  static constexpr index_t NumAs = AsDataType::Size();
23  static constexpr index_t NumBs = BsDataType::Size();
24  static constexpr index_t NumDs = DsDataType::Size();
25 
26  using AsGridPointer = decltype(Argument::p_as_grid);
27  using BsGridPointer = decltype(Argument::p_bs_grid);
28  using DsGridPointer = decltype(Argument::p_ds_grid);
29 
32  std::size_t rotating_count_hint,
33  std::array<std::size_t, NumAs> size_as_,
34  std::array<std::size_t, NumBs> size_bs_,
35  std::array<std::size_t, NumDs> size_ds_)
36  : arg(arg_),
37  rotating_count(rotating_count_hint),
38  size_as(size_as_),
39  size_bs(size_bs_),
40  size_ds(size_ds_)
41  {
42  p_as_grids.push_back(arg.p_as_grid);
43  p_bs_grids.push_back(arg.p_bs_grid);
44  p_ds_grids.push_back(arg.p_ds_grid);
45 
46  // limit the rotating count to prevent oom
47  const uint64_t footprint = std::accumulate(size_as.begin(), size_as.end(), 0UL) +
48  std::accumulate(size_bs.begin(), size_bs.end(), 0UL) +
49  std::accumulate(size_ds.begin(), size_ds.end(), 0UL);
50  const uint64_t max_rotating_count = (1ULL << 31) / footprint;
51  rotating_count = std::min(rotating_count, max_rotating_count);
52 
53  for(size_t i = 1; i < rotating_count; i++)
54  {
55  {
56  AsGridPointer as_buffer;
57  static_for<0, NumAs, 1>{}([&](auto j) {
58  void* pADeviceBuf;
59  hip_check_error(hipMalloc(static_cast<void**>(&pADeviceBuf), size_as_[j]));
60  hip_check_error(hipMemcpy(static_cast<void*>(pADeviceBuf),
61  static_cast<const void*>(p_as_grids[0][j]),
62  size_as_[j],
63  hipMemcpyDeviceToDevice));
64  using ADataType = remove_cvref_t<tuple_element_t<j.value, AsDataType>>;
65 
66  as_buffer(j) = static_cast<const ADataType*>(pADeviceBuf);
67  });
68  p_as_grids.push_back(as_buffer);
69  }
70 
71  {
72  BsGridPointer bs_buffer;
73  static_for<0, NumBs, 1>{}([&](auto j) {
74  void* pBDeviceBuf;
75  hip_check_error(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_bs_[j]));
76  hip_check_error(hipMemcpy(static_cast<void*>(pBDeviceBuf),
77  static_cast<const void*>(p_bs_grids[0][j]),
78  size_bs_[j],
79  hipMemcpyDeviceToDevice));
80  using BDataType = remove_cvref_t<tuple_element_t<j.value, BsDataType>>;
81 
82  bs_buffer(j) = static_cast<const BDataType*>(pBDeviceBuf);
83  });
84  p_bs_grids.push_back(bs_buffer);
85  }
86 
87  {
88  DsGridPointer ds_buffer;
89  static_for<0, NumDs, 1>{}([&](auto j) {
90  void* pDDeviceBuf;
91  hip_check_error(hipMalloc(static_cast<void**>(&pDDeviceBuf), size_ds_[j]));
92  hip_check_error(hipMemcpy(static_cast<void*>(pDDeviceBuf),
93  static_cast<const void*>(p_ds_grids[0][j]),
94  size_ds_[j],
95  hipMemcpyDeviceToDevice));
96 
97  using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
98 
99  ds_buffer(j) = static_cast<const DDataType*>(pDDeviceBuf);
100  });
101 
102  p_ds_grids.push_back(ds_buffer);
103  }
104  }
105  }
106 
107  void Next()
108  {
109  if(rotating_count > 1)
110  {
111  std::size_t idx = iter++ % rotating_count;
112  arg.p_as_grid = p_as_grids[idx];
113  arg.p_bs_grid = p_bs_grids[idx];
114  arg.p_ds_grid = p_ds_grids[idx];
115  }
116  }
117  void Print()
118  {
119  std::cout << "RotatingMemWrapperMultiD: { size_a: {";
121  [&](auto j) { std::cout << size_as[j] << (j.value < NumAs - 1 ? ", " : ""); });
122  std::cout << "}, size_b: {";
124  [&](auto j) { std::cout << size_bs[j] << (j.value < NumBs - 1 ? ", " : ""); });
125  std::cout << "}, rotating_count: " << rotating_count << "}" << std::endl;
126  }
128  {
129  if(rotating_count > 1)
130  {
131  // restore ptr
132  arg.p_as_grid = p_as_grids[0];
133  arg.p_bs_grid = p_bs_grids[0];
134  arg.p_ds_grid = p_ds_grids[0];
135 
136  // free device mem
137  for(size_t i = 1; i < rotating_count; i++)
138  {
139  static_for<0, NumAs, 1>{}([&](auto j) {
140  using ADataType = remove_cvref_t<tuple_element_t<j.value, AsDataType>>;
142  hipFree(static_cast<void*>(const_cast<ADataType*>(p_as_grids[i][j]))));
143  });
144 
145  static_for<0, NumBs, 1>{}([&](auto j) {
146  using BDataType = remove_cvref_t<tuple_element_t<j.value, BsDataType>>;
148  hipFree(static_cast<void*>(const_cast<BDataType*>(p_bs_grids[i][j]))));
149  });
150 
151  static_for<0, NumDs, 1>{}([&](auto j) {
152  using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
154  hipFree(static_cast<void*>(const_cast<DDataType*>(p_ds_grids[i][j]))));
155  });
156  }
157  }
158  }
159 
160  private:
161  Argument& arg;
162  std::size_t iter = 0;
163  std::size_t rotating_count = 1;
164  std::array<std::size_t, NumAs> size_as = {0};
165  std::array<std::size_t, NumBs> size_bs = {0};
166  std::array<std::size_t, NumDs> size_ds = {0};
167  std::vector<AsGridPointer> p_as_grids;
168  std::vector<BsGridPointer> p_bs_grids;
169  std::vector<DsGridPointer> p_ds_grids;
170 };
171 
172 template <typename Argument, typename DsDataType>
174 {
175  static constexpr index_t NumDs = DsDataType::Size();
176 
177  using ADataType = decltype(Argument::p_a_grid);
178  using BDataType = decltype(Argument::p_b_grid);
179  using DsGridPointer = decltype(Argument::p_ds_grid);
180 
182  RotatingMemWrapperMultiD(Argument& arg_,
183  std::size_t rotating_count_hint,
184  std::size_t size_a_,
185  std::size_t size_b_,
186  std::array<std::size_t, NumDs> size_ds_)
187  : arg(arg_),
188  rotating_count(rotating_count_hint),
189  size_a(size_a_),
190  size_b(size_b_),
191  size_ds(size_ds_)
192  {
193  p_a_grids.push_back(arg.p_a_grid);
194  p_b_grids.push_back(arg.p_b_grid);
195  p_ds_grids.push_back(arg.p_ds_grid);
196 
197  // limit the rotating count to prevent oom
198  const uint64_t footprint =
199  std::accumulate(size_ds.begin(), size_ds.end(), 0UL) + (size_a + size_b);
200  const uint64_t max_rotating_count = (1ULL << 31) / footprint;
201  rotating_count = std::min(rotating_count, max_rotating_count);
202 
203  for(size_t i = 1; i < rotating_count; i++)
204  {
205  {
206  void* pADeviceBuf;
207  hip_check_error(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
208  hip_check_error(hipMemcpy(static_cast<void*>(pADeviceBuf),
209  const_cast<void*>(p_a_grids[0]),
210  size_a_,
211  hipMemcpyDeviceToDevice));
212  p_a_grids.push_back(pADeviceBuf);
213  }
214 
215  {
216  void* pBDeviceBuf;
217  hip_check_error(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
218  hip_check_error(hipMemcpy(static_cast<void*>(pBDeviceBuf),
219  const_cast<void*>(p_b_grids[0]),
220  size_b_,
221  hipMemcpyDeviceToDevice));
222  p_b_grids.push_back(pBDeviceBuf);
223  }
224 
225  {
226 
227  DsGridPointer ds_buffer;
228  static_for<0, NumDs, 1>{}([&](auto j) {
229  void* pDDeviceBuf;
230  hip_check_error(hipMalloc(static_cast<void**>(&pDDeviceBuf), size_ds_[j]));
231  hip_check_error(hipMemcpy(static_cast<void*>(pDDeviceBuf),
232  static_cast<const void*>(p_ds_grids[0][j]),
233  size_ds_[j],
234  hipMemcpyDeviceToDevice));
235 
236  using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
237 
238  ds_buffer(j) = static_cast<const DDataType*>(pDDeviceBuf);
239  });
240 
241  p_ds_grids.push_back(ds_buffer);
242  }
243  }
244  }
245 
246  void Next()
247  {
248  if(rotating_count > 1)
249  {
250  std::size_t idx = iter++ % rotating_count;
251  arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[idx]);
252  arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[idx]);
253  arg.p_ds_grid = p_ds_grids[idx];
254  }
255  }
256  void Print()
257  {
258  std::cout << "RotatingMemWrapperMultiD: { size_a: " << size_a << ", size_b: " << size_b
259  << ", rotating_count: " << rotating_count << "}" << std::endl;
260  }
262  {
263  if(rotating_count > 1)
264  {
265  // restore ptr
266  arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[0]);
267  arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[0]);
268  arg.p_ds_grid = p_ds_grids[0];
269 
270  // free device mem
271  for(size_t i = 1; i < rotating_count; i++)
272  {
273  hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
274  hip_check_error(hipFree(const_cast<void*>(p_b_grids[i])));
275 
276  static_for<0, NumDs, 1>{}([&](auto j) {
277  using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
279  hipFree(static_cast<void*>(const_cast<DDataType*>(p_ds_grids[i][j]))));
280  });
281  }
282  }
283  }
284 
285  private:
286  Argument& arg;
287  std::size_t iter = 0;
288  std::size_t rotating_count = 1;
289  std::size_t size_a = 0;
290  std::size_t size_b = 0;
291  std::array<std::size_t, NumDs> size_ds = {0};
292  std::vector<const void*> p_a_grids;
293  std::vector<const void*> p_b_grids;
294  std::vector<DsGridPointer> p_ds_grids;
295 };
296 
297 template <typename Argument>
299 {
300  using ADataType = decltype(Argument::p_a_grid);
301  using BDataType = decltype(Argument::p_b_grid);
302 
303  RotatingMemWrapper() = delete;
304  RotatingMemWrapper(Argument& arg_,
305  std::size_t rotating_count_hint,
306  std::size_t size_a_,
307  std::size_t size_b_)
308  : arg(arg_), rotating_count(rotating_count_hint), size_a(size_a_), size_b(size_b_)
309  {
310  p_a_grids.push_back(arg.p_a_grid);
311  p_b_grids.push_back(arg.p_b_grid);
312 
313  // limit the rotating count to prevent oom
314  const uint64_t footprint = (size_a + size_b);
315  const uint64_t max_rotating_count = (1ULL << 31) / footprint;
316  rotating_count = std::min(rotating_count, max_rotating_count);
317 
318  for(size_t i = 1; i < rotating_count; i++)
319  {
320  {
321  void* pADeviceBuf;
322  hip_check_error(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
323  hip_check_error(hipMemcpy(static_cast<void*>(pADeviceBuf),
324  const_cast<void*>(p_a_grids[0]),
325  size_a_,
326  hipMemcpyDeviceToDevice));
327  p_a_grids.push_back(pADeviceBuf);
328  }
329 
330  {
331  void* pBDeviceBuf;
332  hip_check_error(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
333  hip_check_error(hipMemcpy(static_cast<void*>(pBDeviceBuf),
334  const_cast<void*>(p_b_grids[0]),
335  size_b_,
336  hipMemcpyDeviceToDevice));
337  p_b_grids.push_back(pBDeviceBuf);
338  }
339  }
340  }
341 
342  void Next()
343  {
344  if(rotating_count > 1)
345  {
346  std::size_t idx = iter++ % rotating_count;
347  arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[idx]);
348  arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[idx]);
349  }
350  }
351  void Print()
352  {
353  std::cout << "RotatingMemWrapper: { size_a: " << size_a << ", size_b: " << size_b
354  << ", rotating_count: " << rotating_count << "}" << std::endl;
355  }
357  {
358  if(rotating_count > 1)
359  {
360  // restore ptr
361  arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[0]);
362  arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[0]);
363 
364  // free device mem
365  for(size_t i = 1; i < rotating_count; i++)
366  {
367  hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
368  hip_check_error(hipFree(const_cast<void*>(p_b_grids[i])));
369  }
370  }
371  }
372 
373  private:
374  Argument& arg;
375  std::size_t iter = 0;
376  std::size_t rotating_count = 1;
377  std::size_t size_a = 0;
378  std::size_t size_b = 0;
379  std::vector<const void*> p_a_grids;
380  std::vector<const void*> p_b_grids;
381 };
382 
383 inline void flush_icache()
384 {
385  hipDeviceProp_t deviceProps;
386  hip_check_error(hipGetDeviceProperties(&deviceProps, 0));
387  int32_t gpu_block3 = deviceProps.multiProcessorCount * 60;
388 
389  ck::flush_icache<<<dim3(gpu_block3), dim3(64), 0, nullptr>>>();
390  hip_check_error(hipGetLastError());
391 }
392 // if TimePrePress == false, return time does not include preprocess's time
393 template <bool TimePreprocess,
394  typename GemmArgs,
395  typename... Args,
396  typename F,
397  typename PreProcessFunc>
399  PreProcessFunc preprocess,
400  F kernel,
401  dim3 grid_dim,
402  dim3 block_dim,
403  std::size_t lds_byte,
404  GemmArgs& gemm_args,
405  Args... args)
406 {
407 #if CK_TIME_KERNEL
408 #define MEDIAN 0
409  if(stream_config.time_kernel_)
410  {
411  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
412  {
413  printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n",
414  __func__,
415  grid_dim.x,
416  grid_dim.y,
417  grid_dim.z,
418  block_dim.x,
419  block_dim.y,
420  block_dim.z);
421 
422  printf("Warm up %d times\n", stream_config.cold_niters_);
423  }
424  // warm up
425  for(int i = 0; i < stream_config.cold_niters_; ++i)
426  {
427  kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
428  hip_check_error(hipGetLastError());
429  }
430 
431  const int nrepeat = stream_config.nrepeat_;
432  if(nrepeat == 0)
433  {
434  return 0.0;
435  }
436  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
437  {
438  printf("Start running %d times...\n", nrepeat);
439  }
440 
441 #if MEDIAN
442  std::set<float> times;
443 #else
444  float total_time = 0;
445 #endif
446  hipEvent_t start, stop;
447 
448  hip_check_error(hipEventCreate(&start));
449  hip_check_error(hipEventCreate(&stop));
450 
451  hip_check_error(hipDeviceSynchronize());
452  hip_check_error(hipEventRecord(start, stream_config.stream_id_));
453 
454  for(int i = 0; i < nrepeat; ++i)
455  {
456  if constexpr(!TimePreprocess)
457  {
458  preprocess();
459  }
460 
461  // hipEvent_t start, stop;
462 
463  // hip_check_error(hipEventCreate(&start));
464  // hip_check_error(hipEventCreate(&stop));
465 
466  // hip_check_error(hipDeviceSynchronize());
467  // hip_check_error(hipEventRecord(start, stream_config.stream_id_));
468  // calculate preprocess time
469  if constexpr(TimePreprocess)
470  {
471  preprocess();
472  }
473  // run real kernel
474  kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
475  hip_check_error(hipGetLastError());
476  // end real kernel
477 
478  // hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
479  // hip_check_error(hipEventSynchronize(stop));
480  // float cur_time = 0;
481  // hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
482  // #if MEDIAN
483  // times.insert(cur_time);
484  // #else
485  // total_time += cur_time;
486  // #endif
487 
488 #if !defined(CK_USE_WMMA)
489  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
490  {
491  // std::cout << "i: " << i << " cur_time: " << cur_time << std::endl;
492 
493  printf("gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p\n",
494  static_cast<const void*>(gemm_args.p_a_grid),
495  static_cast<const void*>(gemm_args.p_b_grid));
496  }
497 #endif
498  }
499  hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
500  hip_check_error(hipEventSynchronize(stop));
501  float cur_time = 0;
502  hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
503 #if MEDIAN
504  times.insert(cur_time);
505 #else
506  total_time += cur_time;
507 #endif
508 
509 #if MEDIAN
510  auto mid = times.begin();
511  std::advance(mid, (nrepeat - 1) / 2);
512  if(nrepeat % 2 == 1)
513  {
514  return *mid;
515  }
516  else
517  {
518  auto mid_next = mid;
519  std::advance(mid_next, 1);
520  return (*mid + *mid_next) / 2;
521  }
522 #else
523  // return total_time / nrepeat;
524  hipDeviceProp_t deviceProps;
525  hip_check_error(hipGetDeviceProperties(&deviceProps, 0));
526  float preprocess_offset = deviceProps.multiProcessorCount == 80 ? 0.005 : 0.01;
527  return (total_time - preprocess_offset * nrepeat) / nrepeat;
528 #endif
529  }
530  else
531  {
532  preprocess();
533  kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
534  hip_check_error(hipGetLastError());
535 
536  return 0;
537  }
538 #else
539  kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
540  hip_check_error(hipGetLastError());
541 
542  return 0;
543 #endif
544 }
545 
546 } // namespace utility
547 } // namespace ck
void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:10
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
void flush_icache()
Definition: flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition: flush_cache.hpp:398
Definition: ck.hpp:268
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:299
signed int int32_t
Definition: stdint.h:123
unsigned __int64 uint64_t
Definition: stdint.h:136
Definition: stream_config.hpp:10
int cold_niters_
Definition: stream_config.hpp:14
bool time_kernel_
Definition: stream_config.hpp:12
int nrepeat_
Definition: stream_config.hpp:15
hipStream_t stream_id_
Definition: stream_config.hpp:11
Definition: functional2.hpp:33
Definition: flush_cache.hpp:299
decltype(Argument::p_a_grid) ADataType
Definition: flush_cache.hpp:300
RotatingMemWrapper(Argument &arg_, std::size_t rotating_count_hint, std::size_t size_a_, std::size_t size_b_)
Definition: flush_cache.hpp:304
decltype(Argument::p_b_grid) BDataType
Definition: flush_cache.hpp:301
~RotatingMemWrapper()
Definition: flush_cache.hpp:356
void Print()
Definition: flush_cache.hpp:351
void Next()
Definition: flush_cache.hpp:342
Definition: flush_cache.hpp:21
static constexpr index_t NumBs
Definition: flush_cache.hpp:23
RotatingMemWrapperMultiABD(Argument &arg_, std::size_t rotating_count_hint, std::array< std::size_t, NumAs > size_as_, std::array< std::size_t, NumBs > size_bs_, std::array< std::size_t, NumDs > size_ds_)
Definition: flush_cache.hpp:31
static constexpr index_t NumDs
Definition: flush_cache.hpp:24
decltype(Argument::p_bs_grid) BsGridPointer
Definition: flush_cache.hpp:27
decltype(Argument::p_ds_grid) DsGridPointer
Definition: flush_cache.hpp:28
void Print()
Definition: flush_cache.hpp:117
void Next()
Definition: flush_cache.hpp:107
static constexpr index_t NumAs
Definition: flush_cache.hpp:22
decltype(Argument::p_as_grid) AsGridPointer
Definition: flush_cache.hpp:26
~RotatingMemWrapperMultiABD()
Definition: flush_cache.hpp:127
Definition: flush_cache.hpp:174
decltype(Argument::p_b_grid) BDataType
Definition: flush_cache.hpp:178
void Print()
Definition: flush_cache.hpp:256
static constexpr index_t NumDs
Definition: flush_cache.hpp:175
decltype(Argument::p_a_grid) ADataType
Definition: flush_cache.hpp:177
~RotatingMemWrapperMultiD()
Definition: flush_cache.hpp:261
decltype(Argument::p_ds_grid) DsGridPointer
Definition: flush_cache.hpp:179
RotatingMemWrapperMultiD(Argument &arg_, std::size_t rotating_count_hint, std::size_t size_a_, std::size_t size_b_, std::array< std::size_t, NumDs > size_ds_)
Definition: flush_cache.hpp:182
void Next()
Definition: flush_cache.hpp:246
#define CK_ENV(name)
Definition: env.hpp:129