6 #include <hip/hip_runtime.h>
18 template <
typename Argument,
typename AsDataType,
typename BsDataType,
typename DsDataType>
31 std::size_t rotating_count_,
32 std::array<std::size_t, NumAs> size_as_,
33 std::array<std::size_t, NumBs> size_bs_,
34 std::array<std::size_t, NumDs> size_ds_)
36 rotating_count(rotating_count_),
41 p_as_grids.push_back(arg.p_as_grid);
42 p_bs_grids.push_back(arg.p_bs_grid);
43 p_ds_grids.push_back(arg.p_ds_grid);
44 for(
size_t i = 1; i < rotating_count; i++)
50 hip_check_error(hipMalloc(
static_cast<void**
>(&pADeviceBuf), size_as_[j]));
52 static_cast<const void*
>(p_as_grids[0][j]),
54 hipMemcpyDeviceToDevice));
57 as_buffer(j) =
static_cast<const ADataType*
>(pADeviceBuf);
59 p_as_grids.push_back(as_buffer);
66 hip_check_error(hipMalloc(
static_cast<void**
>(&pBDeviceBuf), size_bs_[j]));
68 static_cast<const void*
>(p_bs_grids[0][j]),
70 hipMemcpyDeviceToDevice));
73 bs_buffer(j) =
static_cast<const BDataType*
>(pBDeviceBuf);
75 p_bs_grids.push_back(bs_buffer);
82 hip_check_error(hipMalloc(
static_cast<void**
>(&pDDeviceBuf), size_ds_[j]));
84 static_cast<const void*
>(p_ds_grids[0][j]),
86 hipMemcpyDeviceToDevice));
90 ds_buffer(j) =
static_cast<const DDataType*
>(pDDeviceBuf);
93 p_ds_grids.push_back(ds_buffer);
100 if(rotating_count > 1)
102 std::size_t idx = iter++ % rotating_count;
103 arg.p_as_grid = p_as_grids[idx];
104 arg.p_bs_grid = p_bs_grids[idx];
105 arg.p_ds_grid = p_ds_grids[idx];
110 std::cout <<
"RotatingMemWrapperMultiD: { size_a: {";
112 [&](
auto j) { std::cout << size_as[j] << (j.value <
NumAs - 1 ?
", " :
""); });
113 std::cout <<
"}, size_b: {";
115 [&](
auto j) { std::cout << size_bs[j] << (j.value <
NumBs - 1 ?
", " :
""); });
116 std::cout <<
"}, rotating_count: " << rotating_count <<
"}" << std::endl;
120 if(rotating_count > 1)
123 arg.p_as_grid = p_as_grids[0];
124 arg.p_bs_grid = p_bs_grids[0];
125 arg.p_ds_grid = p_ds_grids[0];
128 for(
size_t i = 1; i < rotating_count; i++)
133 hipFree(
static_cast<void*
>(
const_cast<ADataType*
>(p_as_grids[i][j]))));
139 hipFree(
static_cast<void*
>(
const_cast<BDataType*
>(p_bs_grids[i][j]))));
145 hipFree(
static_cast<void*
>(
const_cast<DDataType*
>(p_ds_grids[i][j]))));
153 std::size_t iter = 0;
154 std::size_t rotating_count = 1;
155 std::array<std::size_t, NumAs> size_as = {0};
156 std::array<std::size_t, NumBs> size_bs = {0};
157 std::array<std::size_t, NumDs> size_ds = {0};
158 std::vector<AsGridPointer> p_as_grids;
159 std::vector<BsGridPointer> p_bs_grids;
160 std::vector<DsGridPointer> p_ds_grids;
163 template <
typename Argument,
typename DsDataType>
174 std::size_t rotating_count_,
177 std::array<std::size_t, NumDs> size_ds_)
179 rotating_count(rotating_count_),
184 p_a_grids.push_back(arg.p_a_grid);
185 p_b_grids.push_back(arg.p_b_grid);
186 p_ds_grids.push_back(arg.p_ds_grid);
187 for(
size_t i = 1; i < rotating_count; i++)
191 hip_check_error(hipMalloc(
static_cast<void**
>(&pADeviceBuf), size_a_));
193 const_cast<void*
>(p_a_grids[0]),
195 hipMemcpyDeviceToDevice));
196 p_a_grids.push_back(pADeviceBuf);
201 hip_check_error(hipMalloc(
static_cast<void**
>(&pBDeviceBuf), size_b_));
203 const_cast<void*
>(p_b_grids[0]),
205 hipMemcpyDeviceToDevice));
206 p_b_grids.push_back(pBDeviceBuf);
214 hip_check_error(hipMalloc(
static_cast<void**
>(&pDDeviceBuf), size_ds_[j]));
216 static_cast<const void*
>(p_ds_grids[0][j]),
218 hipMemcpyDeviceToDevice));
222 ds_buffer(j) =
static_cast<const DDataType*
>(pDDeviceBuf);
225 p_ds_grids.push_back(ds_buffer);
232 if(rotating_count > 1)
234 std::size_t idx = iter++ % rotating_count;
235 arg.p_a_grid =
reinterpret_cast<ADataType>(p_a_grids[idx]);
236 arg.p_b_grid =
reinterpret_cast<BDataType>(p_b_grids[idx]);
237 arg.p_ds_grid = p_ds_grids[idx];
242 std::cout <<
"RotatingMemWrapperMultiD: { size_a: " << size_a <<
", size_b: " << size_b
243 <<
", rotating_count: " << rotating_count <<
"}" << std::endl;
247 if(rotating_count > 1)
250 arg.p_a_grid =
reinterpret_cast<ADataType>(p_a_grids[0]);
251 arg.p_b_grid =
reinterpret_cast<BDataType>(p_b_grids[0]);
252 arg.p_ds_grid = p_ds_grids[0];
255 for(
size_t i = 1; i < rotating_count; i++)
263 hipFree(
static_cast<void*
>(
const_cast<DDataType*
>(p_ds_grids[i][j]))));
271 std::size_t iter = 0;
272 std::size_t rotating_count = 1;
273 std::size_t size_a = 0;
274 std::size_t size_b = 0;
275 std::array<std::size_t, NumDs> size_ds = {0};
276 std::vector<const void*> p_a_grids;
277 std::vector<const void*> p_b_grids;
278 std::vector<DsGridPointer> p_ds_grids;
281 template <
typename Argument>
289 std::size_t rotating_count_,
292 : arg(arg_), rotating_count(rotating_count_), size_a(size_a_), size_b(size_b_)
294 p_a_grids.push_back(arg.p_a_grid);
295 p_b_grids.push_back(arg.p_b_grid);
296 for(
size_t i = 1; i < rotating_count; i++)
300 hip_check_error(hipMalloc(
static_cast<void**
>(&pADeviceBuf), size_a_));
302 const_cast<void*
>(p_a_grids[0]),
304 hipMemcpyDeviceToDevice));
305 p_a_grids.push_back(pADeviceBuf);
310 hip_check_error(hipMalloc(
static_cast<void**
>(&pBDeviceBuf), size_b_));
312 const_cast<void*
>(p_b_grids[0]),
314 hipMemcpyDeviceToDevice));
315 p_b_grids.push_back(pBDeviceBuf);
322 if(rotating_count > 1)
324 std::size_t idx = iter++ % rotating_count;
325 arg.p_a_grid =
reinterpret_cast<ADataType>(p_a_grids[idx]);
326 arg.p_b_grid =
reinterpret_cast<BDataType>(p_b_grids[idx]);
331 std::cout <<
"RotatingMemWrapper: { size_a: " << size_a <<
", size_b: " << size_b
332 <<
", rotating_count: " << rotating_count <<
"}" << std::endl;
336 if(rotating_count > 1)
339 arg.p_a_grid =
reinterpret_cast<ADataType>(p_a_grids[0]);
340 arg.p_b_grid =
reinterpret_cast<BDataType>(p_b_grids[0]);
343 for(
size_t i = 1; i < rotating_count; i++)
353 std::size_t iter = 0;
354 std::size_t rotating_count = 1;
355 std::size_t size_a = 0;
356 std::size_t size_b = 0;
357 std::vector<const void*> p_a_grids;
358 std::vector<const void*> p_b_grids;
363 hipDeviceProp_t deviceProps;
365 int32_t gpu_block3 = deviceProps.multiProcessorCount * 60;
367 ck::flush_icache<<<dim3(gpu_block3), dim3(64), 0,
nullptr>>>();
371 template <
bool TimePreprocess,
375 typename PreProcessFunc>
377 PreProcessFunc preprocess,
381 std::size_t lds_byte,
391 printf(
"%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n",
400 printf(
"Warm up %d times\n", stream_config.
cold_niters_);
405 kernel<<<grid_dim, block_dim, lds_byte, stream_config.
stream_id_>>>(gemm_args, args...);
409 const int nrepeat = stream_config.
nrepeat_;
416 printf(
"Start running %d times...\n", nrepeat);
420 std::set<float> times;
422 float total_time = 0;
424 hipEvent_t start, stop;
432 for(
int i = 0; i < nrepeat; ++i)
434 if constexpr(!TimePreprocess)
447 if constexpr(TimePreprocess)
452 kernel<<<grid_dim, block_dim, lds_byte, stream_config.
stream_id_>>>(gemm_args, args...);
466 #if !defined(CK_USE_WMMA)
471 printf(
"gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p\n",
472 static_cast<const void*
>(gemm_args.p_a_grid),
473 static_cast<const void*
>(gemm_args.p_b_grid));
482 times.insert(cur_time);
484 total_time += cur_time;
488 auto mid = times.begin();
489 std::advance(mid, (nrepeat - 1) / 2);
497 std::advance(mid_next, 1);
498 return (*mid + *mid_next) / 2;
502 hipDeviceProp_t deviceProps;
504 float preprocess_offset = deviceProps.multiProcessorCount == 80 ? 0.005 : 0.01;
505 return (total_time - preprocess_offset * nrepeat) / nrepeat;
511 kernel<<<grid_dim, block_dim, lds_byte, stream_config.
stream_id_>>>(gemm_args, args...);
517 kernel<<<grid_dim, block_dim, lds_byte, stream_config.
stream_id_>>>(gemm_args, args...);
void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:10
void flush_icache()
Definition: flush_cache.hpp:361
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition: flush_cache.hpp:376
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:299
signed int int32_t
Definition: stdint.h:123
Definition: stream_config.hpp:10
int cold_niters_
Definition: stream_config.hpp:14
bool time_kernel_
Definition: stream_config.hpp:12
int nrepeat_
Definition: stream_config.hpp:15
hipStream_t stream_id_
Definition: stream_config.hpp:11
Definition: functional2.hpp:33
Definition: flush_cache.hpp:283
decltype(Argument::p_a_grid) ADataType
Definition: flush_cache.hpp:284
RotatingMemWrapper(Argument &arg_, std::size_t rotating_count_, std::size_t size_a_, std::size_t size_b_)
Definition: flush_cache.hpp:288
decltype(Argument::p_b_grid) BDataType
Definition: flush_cache.hpp:285
~RotatingMemWrapper()
Definition: flush_cache.hpp:334
RotatingMemWrapper()=delete
void Print()
Definition: flush_cache.hpp:329
void Next()
Definition: flush_cache.hpp:320
Definition: flush_cache.hpp:20
static constexpr index_t NumBs
Definition: flush_cache.hpp:22
static constexpr index_t NumDs
Definition: flush_cache.hpp:23
decltype(Argument::p_bs_grid) BsGridPointer
Definition: flush_cache.hpp:26
decltype(Argument::p_ds_grid) DsGridPointer
Definition: flush_cache.hpp:27
void Print()
Definition: flush_cache.hpp:108
void Next()
Definition: flush_cache.hpp:98
static constexpr index_t NumAs
Definition: flush_cache.hpp:21
decltype(Argument::p_as_grid) AsGridPointer
Definition: flush_cache.hpp:25
RotatingMemWrapperMultiABD(Argument &arg_, std::size_t rotating_count_, std::array< std::size_t, NumAs > size_as_, std::array< std::size_t, NumBs > size_bs_, std::array< std::size_t, NumDs > size_ds_)
Definition: flush_cache.hpp:30
RotatingMemWrapperMultiABD()=delete
~RotatingMemWrapperMultiABD()
Definition: flush_cache.hpp:118
Definition: flush_cache.hpp:165
decltype(Argument::p_b_grid) BDataType
Definition: flush_cache.hpp:169
void Print()
Definition: flush_cache.hpp:240
static constexpr index_t NumDs
Definition: flush_cache.hpp:166
decltype(Argument::p_a_grid) ADataType
Definition: flush_cache.hpp:168
RotatingMemWrapperMultiD()=delete
~RotatingMemWrapperMultiD()
Definition: flush_cache.hpp:245
decltype(Argument::p_ds_grid) DsGridPointer
Definition: flush_cache.hpp:170
RotatingMemWrapperMultiD(Argument &arg_, std::size_t rotating_count_, std::size_t size_a_, std::size_t size_b_, std::array< std::size_t, NumDs > size_ds_)
Definition: flush_cache.hpp:173
void Next()
Definition: flush_cache.hpp:230
#define CK_ENV(name)
Definition: env.hpp:129