/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/rotating_buffers.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/rotating_buffers.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/rotating_buffers.hpp Source File
rotating_buffers.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
8 #include <hip/hip_runtime.h>
9 
10 namespace ck_tile {
11 
12 template <typename ADataType, typename BDataType>
14 {
15  RotatingMemWrapper() = delete;
16  RotatingMemWrapper(const void* a_ptr_,
17  const void* b_ptr_,
18  std::size_t rotating_count_,
19  std::size_t size_a_,
20  std::size_t size_b_)
21  : a_ptr(a_ptr_),
22  b_ptr(b_ptr_),
23  rotating_count(rotating_count_),
24  size_a(size_a_),
25  size_b(size_b_)
26  {
27  p_a_grids.push_back(a_ptr);
28  p_b_grids.push_back(b_ptr);
29  for(size_t i = 1; i < rotating_count; i++)
30  {
31  {
32  void* pADeviceBuf;
33  HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
34  HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pADeviceBuf),
35  const_cast<void*>(p_a_grids[0]),
36  size_a_,
37  hipMemcpyDeviceToDevice));
38  p_a_grids.push_back(pADeviceBuf);
39  }
40 
41  {
42  void* pBDeviceBuf;
43  HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
44  HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pBDeviceBuf),
45  const_cast<void*>(p_b_grids[0]),
46  size_b_,
47  hipMemcpyDeviceToDevice));
48  p_b_grids.push_back(pBDeviceBuf);
49  }
50  }
51  }
52  void Next()
53  {
54  if(rotating_count > 1)
55  {
56  std::size_t idx = iter++ % rotating_count;
57  a_ptr = p_a_grids[idx];
58  b_ptr = p_b_grids[idx];
59  }
60  }
61  void Print()
62  {
63  std::cout << "RotatingMemWrapper: { size_a: " << size_a << ", size_b: " << size_b
64  << ", rotating_count: " << rotating_count << "}" << std::endl;
65  }
67  {
68  if(rotating_count > 1)
69  {
70  // restore ptr
71  a_ptr = p_a_grids[0];
72  b_ptr = p_b_grids[0];
73 
74  // free device mem
75  for(size_t i = 1; i < rotating_count; i++)
76  {
77  ck_tile::hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
78  ck_tile::hip_check_error(hipFree(const_cast<void*>(p_b_grids[i])));
79  }
80  }
81  }
82 
83  private:
84  const void* a_ptr;
85  const void* b_ptr;
86  std::size_t iter = 0;
87  std::size_t rotating_count = 1;
88  std::size_t size_a = 0;
89  std::size_t size_b = 0;
90  std::vector<const void*> p_a_grids;
91  std::vector<const void*> p_b_grids;
92 };
93 inline void flush_icache()
94 {
95  hipDeviceProp_t deviceProps;
96  HIP_CHECK_ERROR(hipGetDeviceProperties(&deviceProps, 0));
97  int32_t gpu_block3 = deviceProps.multiProcessorCount * 60;
98 
99  ck_tile::flush_cache<<<dim3(gpu_block3), dim3(64), 0, nullptr>>>();
100  HIP_CHECK_ERROR(hipGetLastError());
101 }
102 } // namespace ck_tile
#define HIP_CHECK_ERROR(retval_or_funcall)
Definition: hip_check_error.hpp:21
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:13
int32_t int32_t
Definition: integer.hpp:10
void flush_icache()
Definition: rotating_buffers.hpp:93
Definition: rotating_buffers.hpp:14
void Print()
Definition: rotating_buffers.hpp:61
void Next()
Definition: rotating_buffers.hpp:52
RotatingMemWrapper(const void *a_ptr_, const void *b_ptr_, std::size_t rotating_count_, std::size_t size_a_, std::size_t size_b_)
Definition: rotating_buffers.hpp:16
~RotatingMemWrapper() noexcept
Definition: rotating_buffers.hpp:66