6 #include <hip/hip_runtime.h>
18 template <
typename Argument,
typename DsDataType>
29 std::size_t rotating_count_,
32 std::array<std::size_t, NumDs> size_ds_)
34 rotating_count(rotating_count_),
39 p_a_grids.push_back(arg.p_a_grid);
40 p_b_grids.push_back(arg.p_b_grid);
41 p_ds_grids.push_back(arg.p_ds_grid);
42 for(
size_t i = 1; i < rotating_count; i++)
48 const_cast<void*
>(p_a_grids[0]),
50 hipMemcpyDeviceToDevice));
51 p_a_grids.push_back(pADeviceBuf);
58 const_cast<void*
>(p_b_grids[0]),
60 hipMemcpyDeviceToDevice));
61 p_b_grids.push_back(pBDeviceBuf);
69 hip_check_error(hipMalloc(
static_cast<void**
>(&pDDeviceBuf), size_ds_[j]));
71 static_cast<const void*
>(p_ds_grids[0][j]),
73 hipMemcpyDeviceToDevice));
77 ds_buffer(j) =
static_cast<const DDataType*
>(pDDeviceBuf);
80 p_ds_grids.push_back(ds_buffer);
87 if(rotating_count > 1)
89 std::size_t idx = iter++ % rotating_count;
90 arg.p_a_grid =
reinterpret_cast<ADataType>(p_a_grids[idx]);
91 arg.p_b_grid =
reinterpret_cast<BDataType>(p_b_grids[idx]);
92 arg.p_ds_grid = p_ds_grids[idx];
97 std::cout <<
"RotatingMemWrapperMultiD: { size_a: " << size_a <<
", size_b: " << size_b
98 <<
", rotating_count: " << rotating_count <<
"}" << std::endl;
102 if(rotating_count > 1)
105 arg.p_a_grid =
reinterpret_cast<ADataType>(p_a_grids[0]);
106 arg.p_b_grid =
reinterpret_cast<BDataType>(p_b_grids[0]);
107 arg.p_ds_grid = p_ds_grids[0];
110 for(
size_t i = 1; i < rotating_count; i++)
118 hipFree(
static_cast<void*
>(
const_cast<DDataType*
>(p_ds_grids[i][j]))));
126 std::size_t iter = 0;
127 std::size_t rotating_count = 1;
128 std::size_t size_a = 0;
129 std::size_t size_b = 0;
130 std::array<std::size_t, NumDs> size_ds = {0};
131 std::vector<const void*> p_a_grids;
132 std::vector<const void*> p_b_grids;
133 std::vector<DsGridPointer> p_ds_grids;
136 template <
typename Argument>
144 std::size_t rotating_count_,
147 : arg(arg_), rotating_count(rotating_count_), size_a(size_a_), size_b(size_b_)
149 p_a_grids.push_back(arg.p_a_grid);
150 p_b_grids.push_back(arg.p_b_grid);
151 for(
size_t i = 1; i < rotating_count; i++)
155 hip_check_error(hipMalloc(
static_cast<void**
>(&pADeviceBuf), size_a_));
157 const_cast<void*
>(p_a_grids[0]),
159 hipMemcpyDeviceToDevice));
160 p_a_grids.push_back(pADeviceBuf);
165 hip_check_error(hipMalloc(
static_cast<void**
>(&pBDeviceBuf), size_b_));
167 const_cast<void*
>(p_b_grids[0]),
169 hipMemcpyDeviceToDevice));
170 p_b_grids.push_back(pBDeviceBuf);
177 if(rotating_count > 1)
179 std::size_t idx = iter++ % rotating_count;
180 arg.p_a_grid =
reinterpret_cast<ADataType>(p_a_grids[idx]);
181 arg.p_b_grid =
reinterpret_cast<BDataType>(p_b_grids[idx]);
186 std::cout <<
"RotatingMemWrapper: { size_a: " << size_a <<
", size_b: " << size_b
187 <<
", rotating_count: " << rotating_count <<
"}" << std::endl;
191 if(rotating_count > 1)
194 arg.p_a_grid =
reinterpret_cast<ADataType>(p_a_grids[0]);
195 arg.p_b_grid =
reinterpret_cast<BDataType>(p_b_grids[0]);
198 for(
size_t i = 1; i < rotating_count; i++)
208 std::size_t iter = 0;
209 std::size_t rotating_count = 1;
210 std::size_t size_a = 0;
211 std::size_t size_b = 0;
212 std::vector<const void*> p_a_grids;
213 std::vector<const void*> p_b_grids;
218 hipDeviceProp_t deviceProps;
220 int32_t gpu_block3 = deviceProps.multiProcessorCount * 60;
222 ck::flush_icache<<<dim3(gpu_block3), dim3(64), 0,
nullptr>>>();
226 template <
bool TimePreprocess,
230 typename PreProcessFunc>
232 PreProcessFunc preprocess,
236 std::size_t lds_byte,
246 printf(
"%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n",
255 printf(
"Warm up %d times\n", stream_config.
cold_niters_);
260 kernel<<<grid_dim, block_dim, lds_byte, stream_config.
stream_id_>>>(gemm_args, args...);
264 const int nrepeat = stream_config.
nrepeat_;
271 printf(
"Start running %d times...\n", nrepeat);
275 std::set<float> times;
277 float total_time = 0;
279 hipEvent_t start, stop;
287 for(
int i = 0; i < nrepeat; ++i)
289 if constexpr(!TimePreprocess)
302 if constexpr(TimePreprocess)
307 kernel<<<grid_dim, block_dim, lds_byte, stream_config.
stream_id_>>>(gemm_args, args...);
325 printf(
"gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p\n",
326 static_cast<const void*
>(gemm_args.p_a_grid),
327 static_cast<const void*
>(gemm_args.p_b_grid));
335 times.insert(cur_time);
337 total_time += cur_time;
341 auto mid = times.begin();
342 std::advance(mid, (nrepeat - 1) / 2);
350 std::advance(mid_next, 1);
351 return (*mid + *mid_next) / 2;
355 hipDeviceProp_t deviceProps;
357 float preprocess_offset = deviceProps.multiProcessorCount == 80 ? 0.005 : 0.01;
358 return (total_time - preprocess_offset * nrepeat) / nrepeat;
364 kernel<<<grid_dim, block_dim, lds_byte, stream_config.
stream_id_>>>(gemm_args, args...);
370 kernel<<<grid_dim, block_dim, lds_byte, stream_config.
stream_id_>>>(gemm_args, args...);
void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:10
void flush_icache()
Definition: flush_cache.hpp:216
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition: flush_cache.hpp:231
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:298
signed int int32_t
Definition: stdint.h:123
Definition: stream_config.hpp:10
int cold_niters_
Definition: stream_config.hpp:14
bool time_kernel_
Definition: stream_config.hpp:12
int nrepeat_
Definition: stream_config.hpp:15
hipStream_t stream_id_
Definition: stream_config.hpp:11
Definition: functional2.hpp:33
Definition: flush_cache.hpp:138
decltype(Argument::p_a_grid) ADataType
Definition: flush_cache.hpp:139
RotatingMemWrapper(Argument &arg_, std::size_t rotating_count_, std::size_t size_a_, std::size_t size_b_)
Definition: flush_cache.hpp:143
decltype(Argument::p_b_grid) BDataType
Definition: flush_cache.hpp:140
~RotatingMemWrapper()
Definition: flush_cache.hpp:189
RotatingMemWrapper()=delete
void Print()
Definition: flush_cache.hpp:184
void Next()
Definition: flush_cache.hpp:175
Definition: flush_cache.hpp:20
decltype(Argument::p_b_grid) BDataType
Definition: flush_cache.hpp:24
void Print()
Definition: flush_cache.hpp:95
static constexpr index_t NumDs
Definition: flush_cache.hpp:21
decltype(Argument::p_a_grid) ADataType
Definition: flush_cache.hpp:23
RotatingMemWrapperMultiD()=delete
~RotatingMemWrapperMultiD()
Definition: flush_cache.hpp:100
decltype(Argument::p_ds_grid) DsGridPointer
Definition: flush_cache.hpp:25
RotatingMemWrapperMultiD(Argument &arg_, std::size_t rotating_count_, std::size_t size_a_, std::size_t size_b_, std::array< std::size_t, NumDs > size_ds_)
Definition: flush_cache.hpp:28
void Next()
Definition: flush_cache.hpp:85
#define CK_ENV(name)
Definition: env.hpp:129