/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/host_utility/flush_cache.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/host_utility/flush_cache.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/host_utility/flush_cache.hpp Source File
flush_cache.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <hip/hip_runtime.h>
7 #include <set>
8 #include <vector>
9 
10 #include "ck/ck.hpp"
11 #include "ck/utility/env.hpp"
12 #include "ck/stream_config.hpp"
15 namespace ck {
16 namespace utility {
17 
18 template <typename Argument, typename AsDataType, typename BsDataType, typename DsDataType>
20 {
21  static constexpr index_t NumAs = AsDataType::Size();
22  static constexpr index_t NumBs = BsDataType::Size();
23  static constexpr index_t NumDs = DsDataType::Size();
24 
25  using AsGridPointer = decltype(Argument::p_as_grid);
26  using BsGridPointer = decltype(Argument::p_bs_grid);
27  using DsGridPointer = decltype(Argument::p_ds_grid);
28 
31  std::size_t rotating_count_,
32  std::array<std::size_t, NumAs> size_as_,
33  std::array<std::size_t, NumBs> size_bs_,
34  std::array<std::size_t, NumDs> size_ds_)
35  : arg(arg_),
36  rotating_count(rotating_count_),
37  size_as(size_as_),
38  size_bs(size_bs_),
39  size_ds(size_ds_)
40  {
41  p_as_grids.push_back(arg.p_as_grid);
42  p_bs_grids.push_back(arg.p_bs_grid);
43  p_ds_grids.push_back(arg.p_ds_grid);
44  for(size_t i = 1; i < rotating_count; i++)
45  {
46  {
47  AsGridPointer as_buffer;
48  static_for<0, NumAs, 1>{}([&](auto j) {
49  void* pADeviceBuf;
50  hip_check_error(hipMalloc(static_cast<void**>(&pADeviceBuf), size_as_[j]));
51  hip_check_error(hipMemcpy(static_cast<void*>(pADeviceBuf),
52  static_cast<const void*>(p_as_grids[0][j]),
53  size_as_[j],
54  hipMemcpyDeviceToDevice));
55  using ADataType = remove_cvref_t<tuple_element_t<j.value, AsDataType>>;
56 
57  as_buffer(j) = static_cast<const ADataType*>(pADeviceBuf);
58  });
59  p_as_grids.push_back(as_buffer);
60  }
61 
62  {
63  BsGridPointer bs_buffer;
64  static_for<0, NumBs, 1>{}([&](auto j) {
65  void* pBDeviceBuf;
66  hip_check_error(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_bs_[j]));
67  hip_check_error(hipMemcpy(static_cast<void*>(pBDeviceBuf),
68  static_cast<const void*>(p_bs_grids[0][j]),
69  size_bs_[j],
70  hipMemcpyDeviceToDevice));
71  using BDataType = remove_cvref_t<tuple_element_t<j.value, BsDataType>>;
72 
73  bs_buffer(j) = static_cast<const BDataType*>(pBDeviceBuf);
74  });
75  p_bs_grids.push_back(bs_buffer);
76  }
77 
78  {
79  DsGridPointer ds_buffer;
80  static_for<0, NumDs, 1>{}([&](auto j) {
81  void* pDDeviceBuf;
82  hip_check_error(hipMalloc(static_cast<void**>(&pDDeviceBuf), size_ds_[j]));
83  hip_check_error(hipMemcpy(static_cast<void*>(pDDeviceBuf),
84  static_cast<const void*>(p_ds_grids[0][j]),
85  size_ds_[j],
86  hipMemcpyDeviceToDevice));
87 
88  using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
89 
90  ds_buffer(j) = static_cast<const DDataType*>(pDDeviceBuf);
91  });
92 
93  p_ds_grids.push_back(ds_buffer);
94  }
95  }
96  }
97 
98  void Next()
99  {
100  if(rotating_count > 1)
101  {
102  std::size_t idx = iter++ % rotating_count;
103  arg.p_as_grid = p_as_grids[idx];
104  arg.p_bs_grid = p_bs_grids[idx];
105  arg.p_ds_grid = p_ds_grids[idx];
106  }
107  }
108  void Print()
109  {
110  std::cout << "RotatingMemWrapperMultiD: { size_a: {";
112  [&](auto j) { std::cout << size_as[j] << (j.value < NumAs - 1 ? ", " : ""); });
113  std::cout << "}, size_b: {";
115  [&](auto j) { std::cout << size_bs[j] << (j.value < NumBs - 1 ? ", " : ""); });
116  std::cout << "}, rotating_count: " << rotating_count << "}" << std::endl;
117  }
119  {
120  if(rotating_count > 1)
121  {
122  // restore ptr
123  arg.p_as_grid = p_as_grids[0];
124  arg.p_bs_grid = p_bs_grids[0];
125  arg.p_ds_grid = p_ds_grids[0];
126 
127  // free device mem
128  for(size_t i = 1; i < rotating_count; i++)
129  {
130  static_for<0, NumAs, 1>{}([&](auto j) {
131  using ADataType = remove_cvref_t<tuple_element_t<j.value, AsDataType>>;
133  hipFree(static_cast<void*>(const_cast<ADataType*>(p_as_grids[i][j]))));
134  });
135 
136  static_for<0, NumBs, 1>{}([&](auto j) {
137  using BDataType = remove_cvref_t<tuple_element_t<j.value, BsDataType>>;
139  hipFree(static_cast<void*>(const_cast<BDataType*>(p_bs_grids[i][j]))));
140  });
141 
142  static_for<0, NumDs, 1>{}([&](auto j) {
143  using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
145  hipFree(static_cast<void*>(const_cast<DDataType*>(p_ds_grids[i][j]))));
146  });
147  }
148  }
149  }
150 
151  private:
152  Argument& arg;
153  std::size_t iter = 0;
154  std::size_t rotating_count = 1;
155  std::array<std::size_t, NumAs> size_as = {0};
156  std::array<std::size_t, NumBs> size_bs = {0};
157  std::array<std::size_t, NumDs> size_ds = {0};
158  std::vector<AsGridPointer> p_as_grids;
159  std::vector<BsGridPointer> p_bs_grids;
160  std::vector<DsGridPointer> p_ds_grids;
161 };
162 
163 template <typename Argument, typename DsDataType>
165 {
166  static constexpr index_t NumDs = DsDataType::Size();
167 
168  using ADataType = decltype(Argument::p_a_grid);
169  using BDataType = decltype(Argument::p_b_grid);
170  using DsGridPointer = decltype(Argument::p_ds_grid);
171 
173  RotatingMemWrapperMultiD(Argument& arg_,
174  std::size_t rotating_count_,
175  std::size_t size_a_,
176  std::size_t size_b_,
177  std::array<std::size_t, NumDs> size_ds_)
178  : arg(arg_),
179  rotating_count(rotating_count_),
180  size_a(size_a_),
181  size_b(size_b_),
182  size_ds(size_ds_)
183  {
184  p_a_grids.push_back(arg.p_a_grid);
185  p_b_grids.push_back(arg.p_b_grid);
186  p_ds_grids.push_back(arg.p_ds_grid);
187  for(size_t i = 1; i < rotating_count; i++)
188  {
189  {
190  void* pADeviceBuf;
191  hip_check_error(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
192  hip_check_error(hipMemcpy(static_cast<void*>(pADeviceBuf),
193  const_cast<void*>(p_a_grids[0]),
194  size_a_,
195  hipMemcpyDeviceToDevice));
196  p_a_grids.push_back(pADeviceBuf);
197  }
198 
199  {
200  void* pBDeviceBuf;
201  hip_check_error(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
202  hip_check_error(hipMemcpy(static_cast<void*>(pBDeviceBuf),
203  const_cast<void*>(p_b_grids[0]),
204  size_b_,
205  hipMemcpyDeviceToDevice));
206  p_b_grids.push_back(pBDeviceBuf);
207  }
208 
209  {
210 
211  DsGridPointer ds_buffer;
212  static_for<0, NumDs, 1>{}([&](auto j) {
213  void* pDDeviceBuf;
214  hip_check_error(hipMalloc(static_cast<void**>(&pDDeviceBuf), size_ds_[j]));
215  hip_check_error(hipMemcpy(static_cast<void*>(pDDeviceBuf),
216  static_cast<const void*>(p_ds_grids[0][j]),
217  size_ds_[j],
218  hipMemcpyDeviceToDevice));
219 
220  using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
221 
222  ds_buffer(j) = static_cast<const DDataType*>(pDDeviceBuf);
223  });
224 
225  p_ds_grids.push_back(ds_buffer);
226  }
227  }
228  }
229 
230  void Next()
231  {
232  if(rotating_count > 1)
233  {
234  std::size_t idx = iter++ % rotating_count;
235  arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[idx]);
236  arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[idx]);
237  arg.p_ds_grid = p_ds_grids[idx];
238  }
239  }
240  void Print()
241  {
242  std::cout << "RotatingMemWrapperMultiD: { size_a: " << size_a << ", size_b: " << size_b
243  << ", rotating_count: " << rotating_count << "}" << std::endl;
244  }
246  {
247  if(rotating_count > 1)
248  {
249  // restore ptr
250  arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[0]);
251  arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[0]);
252  arg.p_ds_grid = p_ds_grids[0];
253 
254  // free device mem
255  for(size_t i = 1; i < rotating_count; i++)
256  {
257  hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
258  hip_check_error(hipFree(const_cast<void*>(p_b_grids[i])));
259 
260  static_for<0, NumDs, 1>{}([&](auto j) {
261  using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
263  hipFree(static_cast<void*>(const_cast<DDataType*>(p_ds_grids[i][j]))));
264  });
265  }
266  }
267  }
268 
269  private:
270  Argument& arg;
271  std::size_t iter = 0;
272  std::size_t rotating_count = 1;
273  std::size_t size_a = 0;
274  std::size_t size_b = 0;
275  std::array<std::size_t, NumDs> size_ds = {0};
276  std::vector<const void*> p_a_grids;
277  std::vector<const void*> p_b_grids;
278  std::vector<DsGridPointer> p_ds_grids;
279 };
280 
281 template <typename Argument>
283 {
284  using ADataType = decltype(Argument::p_a_grid);
285  using BDataType = decltype(Argument::p_b_grid);
286 
287  RotatingMemWrapper() = delete;
288  RotatingMemWrapper(Argument& arg_,
289  std::size_t rotating_count_,
290  std::size_t size_a_,
291  std::size_t size_b_)
292  : arg(arg_), rotating_count(rotating_count_), size_a(size_a_), size_b(size_b_)
293  {
294  p_a_grids.push_back(arg.p_a_grid);
295  p_b_grids.push_back(arg.p_b_grid);
296  for(size_t i = 1; i < rotating_count; i++)
297  {
298  {
299  void* pADeviceBuf;
300  hip_check_error(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
301  hip_check_error(hipMemcpy(static_cast<void*>(pADeviceBuf),
302  const_cast<void*>(p_a_grids[0]),
303  size_a_,
304  hipMemcpyDeviceToDevice));
305  p_a_grids.push_back(pADeviceBuf);
306  }
307 
308  {
309  void* pBDeviceBuf;
310  hip_check_error(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
311  hip_check_error(hipMemcpy(static_cast<void*>(pBDeviceBuf),
312  const_cast<void*>(p_b_grids[0]),
313  size_b_,
314  hipMemcpyDeviceToDevice));
315  p_b_grids.push_back(pBDeviceBuf);
316  }
317  }
318  }
319 
320  void Next()
321  {
322  if(rotating_count > 1)
323  {
324  std::size_t idx = iter++ % rotating_count;
325  arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[idx]);
326  arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[idx]);
327  }
328  }
329  void Print()
330  {
331  std::cout << "RotatingMemWrapper: { size_a: " << size_a << ", size_b: " << size_b
332  << ", rotating_count: " << rotating_count << "}" << std::endl;
333  }
335  {
336  if(rotating_count > 1)
337  {
338  // restore ptr
339  arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[0]);
340  arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[0]);
341 
342  // free device mem
343  for(size_t i = 1; i < rotating_count; i++)
344  {
345  hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
346  hip_check_error(hipFree(const_cast<void*>(p_b_grids[i])));
347  }
348  }
349  }
350 
351  private:
352  Argument& arg;
353  std::size_t iter = 0;
354  std::size_t rotating_count = 1;
355  std::size_t size_a = 0;
356  std::size_t size_b = 0;
357  std::vector<const void*> p_a_grids;
358  std::vector<const void*> p_b_grids;
359 };
360 
361 inline void flush_icache()
362 {
363  hipDeviceProp_t deviceProps;
364  hip_check_error(hipGetDeviceProperties(&deviceProps, 0));
365  int32_t gpu_block3 = deviceProps.multiProcessorCount * 60;
366 
367  ck::flush_icache<<<dim3(gpu_block3), dim3(64), 0, nullptr>>>();
368  hip_check_error(hipGetLastError());
369 }
370 // if TimePrePress == false, return time does not include preprocess's time
371 template <bool TimePreprocess,
372  typename GemmArgs,
373  typename... Args,
374  typename F,
375  typename PreProcessFunc>
377  PreProcessFunc preprocess,
378  F kernel,
379  dim3 grid_dim,
380  dim3 block_dim,
381  std::size_t lds_byte,
382  GemmArgs& gemm_args,
383  Args... args)
384 {
385 #if CK_TIME_KERNEL
386 #define MEDIAN 0
387  if(stream_config.time_kernel_)
388  {
389  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
390  {
391  printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n",
392  __func__,
393  grid_dim.x,
394  grid_dim.y,
395  grid_dim.z,
396  block_dim.x,
397  block_dim.y,
398  block_dim.z);
399 
400  printf("Warm up %d times\n", stream_config.cold_niters_);
401  }
402  // warm up
403  for(int i = 0; i < stream_config.cold_niters_; ++i)
404  {
405  kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
406  hip_check_error(hipGetLastError());
407  }
408 
409  const int nrepeat = stream_config.nrepeat_;
410  if(nrepeat == 0)
411  {
412  return 0.0;
413  }
414  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
415  {
416  printf("Start running %d times...\n", nrepeat);
417  }
418 
419 #if MEDIAN
420  std::set<float> times;
421 #else
422  float total_time = 0;
423 #endif
424  hipEvent_t start, stop;
425 
426  hip_check_error(hipEventCreate(&start));
427  hip_check_error(hipEventCreate(&stop));
428 
429  hip_check_error(hipDeviceSynchronize());
430  hip_check_error(hipEventRecord(start, stream_config.stream_id_));
431 
432  for(int i = 0; i < nrepeat; ++i)
433  {
434  if constexpr(!TimePreprocess)
435  {
436  preprocess();
437  }
438 
439  // hipEvent_t start, stop;
440 
441  // hip_check_error(hipEventCreate(&start));
442  // hip_check_error(hipEventCreate(&stop));
443 
444  // hip_check_error(hipDeviceSynchronize());
445  // hip_check_error(hipEventRecord(start, stream_config.stream_id_));
446  // calculate preprocess time
447  if constexpr(TimePreprocess)
448  {
449  preprocess();
450  }
451  // run real kernel
452  kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
453  hip_check_error(hipGetLastError());
454  // end real kernel
455 
456  // hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
457  // hip_check_error(hipEventSynchronize(stop));
458  // float cur_time = 0;
459  // hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
460  // #if MEDIAN
461  // times.insert(cur_time);
462  // #else
463  // total_time += cur_time;
464  // #endif
465 
466 #if !defined(CK_USE_WMMA)
467  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
468  {
469  // std::cout << "i: " << i << " cur_time: " << cur_time << std::endl;
470 
471  printf("gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p\n",
472  static_cast<const void*>(gemm_args.p_a_grid),
473  static_cast<const void*>(gemm_args.p_b_grid));
474  }
475 #endif
476  }
477  hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
478  hip_check_error(hipEventSynchronize(stop));
479  float cur_time = 0;
480  hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
481 #if MEDIAN
482  times.insert(cur_time);
483 #else
484  total_time += cur_time;
485 #endif
486 
487 #if MEDIAN
488  auto mid = times.begin();
489  std::advance(mid, (nrepeat - 1) / 2);
490  if(nrepeat % 2 == 1)
491  {
492  return *mid;
493  }
494  else
495  {
496  auto mid_next = mid;
497  std::advance(mid_next, 1);
498  return (*mid + *mid_next) / 2;
499  }
500 #else
501  // return total_time / nrepeat;
502  hipDeviceProp_t deviceProps;
503  hip_check_error(hipGetDeviceProperties(&deviceProps, 0));
504  float preprocess_offset = deviceProps.multiProcessorCount == 80 ? 0.005 : 0.01;
505  return (total_time - preprocess_offset * nrepeat) / nrepeat;
506 #endif
507  }
508  else
509  {
510  preprocess();
511  kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
512  hip_check_error(hipGetLastError());
513 
514  return 0;
515  }
516 #else
517  kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
518  hip_check_error(hipGetLastError());
519 
520  return 0;
521 #endif
522 }
523 
524 } // namespace utility
525 } // namespace ck
void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:10
void flush_icache()
Definition: flush_cache.hpp:361
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition: flush_cache.hpp:376
Definition: ck.hpp:268
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:299
signed int int32_t
Definition: stdint.h:123
Definition: stream_config.hpp:10
int cold_niters_
Definition: stream_config.hpp:14
bool time_kernel_
Definition: stream_config.hpp:12
int nrepeat_
Definition: stream_config.hpp:15
hipStream_t stream_id_
Definition: stream_config.hpp:11
Definition: functional2.hpp:33
Definition: flush_cache.hpp:283
decltype(Argument::p_a_grid) ADataType
Definition: flush_cache.hpp:284
RotatingMemWrapper(Argument &arg_, std::size_t rotating_count_, std::size_t size_a_, std::size_t size_b_)
Definition: flush_cache.hpp:288
decltype(Argument::p_b_grid) BDataType
Definition: flush_cache.hpp:285
~RotatingMemWrapper()
Definition: flush_cache.hpp:334
void Print()
Definition: flush_cache.hpp:329
void Next()
Definition: flush_cache.hpp:320
Definition: flush_cache.hpp:20
static constexpr index_t NumBs
Definition: flush_cache.hpp:22
static constexpr index_t NumDs
Definition: flush_cache.hpp:23
decltype(Argument::p_bs_grid) BsGridPointer
Definition: flush_cache.hpp:26
decltype(Argument::p_ds_grid) DsGridPointer
Definition: flush_cache.hpp:27
void Print()
Definition: flush_cache.hpp:108
void Next()
Definition: flush_cache.hpp:98
static constexpr index_t NumAs
Definition: flush_cache.hpp:21
decltype(Argument::p_as_grid) AsGridPointer
Definition: flush_cache.hpp:25
RotatingMemWrapperMultiABD(Argument &arg_, std::size_t rotating_count_, std::array< std::size_t, NumAs > size_as_, std::array< std::size_t, NumBs > size_bs_, std::array< std::size_t, NumDs > size_ds_)
Definition: flush_cache.hpp:30
~RotatingMemWrapperMultiABD()
Definition: flush_cache.hpp:118
Definition: flush_cache.hpp:165
decltype(Argument::p_b_grid) BDataType
Definition: flush_cache.hpp:169
void Print()
Definition: flush_cache.hpp:240
static constexpr index_t NumDs
Definition: flush_cache.hpp:166
decltype(Argument::p_a_grid) ADataType
Definition: flush_cache.hpp:168
~RotatingMemWrapperMultiD()
Definition: flush_cache.hpp:245
decltype(Argument::p_ds_grid) DsGridPointer
Definition: flush_cache.hpp:170
RotatingMemWrapperMultiD(Argument &arg_, std::size_t rotating_count_, std::size_t size_a_, std::size_t size_b_, std::array< std::size_t, NumDs > size_ds_)
Definition: flush_cache.hpp:173
void Next()
Definition: flush_cache.hpp:230
#define CK_ENV(name)
Definition: env.hpp:129