6 #include <hip/hip_runtime.h>
17 template <
typename Argument,
typename DsDataType>
28 std::size_t rotating_count_,
31 std::array<std::size_t, NumDs> size_ds_)
33 rotating_count(rotating_count_),
38 p_a_grids.push_back(arg.p_a_grid);
39 p_b_grids.push_back(arg.p_b_grid);
40 p_ds_grids.push_back(arg.p_ds_grid);
41 for(
size_t i = 1; i < rotating_count; i++)
47 const_cast<void*
>(p_a_grids[0]),
49 hipMemcpyDeviceToDevice));
50 p_a_grids.push_back(pADeviceBuf);
57 const_cast<void*
>(p_b_grids[0]),
59 hipMemcpyDeviceToDevice));
60 p_b_grids.push_back(pBDeviceBuf);
68 hip_check_error(hipMalloc(
static_cast<void**
>(&pDDeviceBuf), size_ds_[j]));
70 static_cast<const void*
>(p_ds_grids[0][j]),
72 hipMemcpyDeviceToDevice));
76 ds_buffer(j) =
static_cast<const DDataType*
>(pDDeviceBuf);
79 p_ds_grids.push_back(ds_buffer);
86 if(rotating_count > 1)
88 std::size_t idx = iter++ % rotating_count;
89 arg.p_a_grid =
reinterpret_cast<ADataType>(p_a_grids[idx]);
90 arg.p_b_grid =
reinterpret_cast<BDataType>(p_b_grids[idx]);
91 arg.p_ds_grid = p_ds_grids[idx];
96 std::cout <<
"RotatingMemWrapperMultiD: { size_a: " << size_a <<
", size_b: " << size_b
97 <<
", rotating_count: " << rotating_count <<
"}" << std::endl;
101 if(rotating_count > 1)
104 arg.p_a_grid =
reinterpret_cast<ADataType>(p_a_grids[0]);
105 arg.p_b_grid =
reinterpret_cast<BDataType>(p_b_grids[0]);
106 arg.p_ds_grid = p_ds_grids[0];
109 for(
size_t i = 1; i < rotating_count; i++)
117 hipFree(
static_cast<void*
>(
const_cast<DDataType*
>(p_ds_grids[i][j]))));
125 std::size_t iter = 0;
126 std::size_t rotating_count = 1;
127 std::size_t size_a = 0;
128 std::size_t size_b = 0;
129 std::array<std::size_t, NumDs> size_ds = {0};
130 std::vector<const void*> p_a_grids;
131 std::vector<const void*> p_b_grids;
132 std::vector<DsGridPointer> p_ds_grids;
135 template <
typename Argument>
143 std::size_t rotating_count_,
146 : arg(arg_), rotating_count(rotating_count_), size_a(size_a_), size_b(size_b_)
148 p_a_grids.push_back(arg.p_a_grid);
149 p_b_grids.push_back(arg.p_b_grid);
150 for(
size_t i = 1; i < rotating_count; i++)
154 hip_check_error(hipMalloc(
static_cast<void**
>(&pADeviceBuf), size_a_));
156 const_cast<void*
>(p_a_grids[0]),
158 hipMemcpyDeviceToDevice));
159 p_a_grids.push_back(pADeviceBuf);
164 hip_check_error(hipMalloc(
static_cast<void**
>(&pBDeviceBuf), size_b_));
166 const_cast<void*
>(p_b_grids[0]),
168 hipMemcpyDeviceToDevice));
169 p_b_grids.push_back(pBDeviceBuf);
176 if(rotating_count > 1)
178 std::size_t idx = iter++ % rotating_count;
179 arg.p_a_grid =
reinterpret_cast<ADataType>(p_a_grids[idx]);
180 arg.p_b_grid =
reinterpret_cast<BDataType>(p_b_grids[idx]);
185 std::cout <<
"RotatingMemWrapper: { size_a: " << size_a <<
", size_b: " << size_b
186 <<
", rotating_count: " << rotating_count <<
"}" << std::endl;
190 if(rotating_count > 1)
193 arg.p_a_grid =
reinterpret_cast<ADataType>(p_a_grids[0]);
194 arg.p_b_grid =
reinterpret_cast<BDataType>(p_b_grids[0]);
197 for(
size_t i = 1; i < rotating_count; i++)
207 std::size_t iter = 0;
208 std::size_t rotating_count = 1;
209 std::size_t size_a = 0;
210 std::size_t size_b = 0;
211 std::vector<const void*> p_a_grids;
212 std::vector<const void*> p_b_grids;
217 hipDeviceProp_t deviceProps;
219 int32_t gpu_block3 = deviceProps.multiProcessorCount * 60;
221 ck::flush_icache<<<dim3(gpu_block3), dim3(64), 0,
nullptr>>>();
225 template <
bool TimePreprocess,
229 typename PreProcessFunc>
231 PreProcessFunc preprocess,
235 std::size_t lds_byte,
245 printf(
"%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n",
254 printf(
"Warm up %d times\n", stream_config.
cold_niters_);
259 kernel<<<grid_dim, block_dim, lds_byte, stream_config.
stream_id_>>>(gemm_args, args...);
263 const int nrepeat = stream_config.
nrepeat_;
270 printf(
"Start running %d times...\n", nrepeat);
274 std::set<float> times;
276 float total_time = 0;
278 hipEvent_t start, stop;
286 for(
int i = 0; i < nrepeat; ++i)
288 if constexpr(!TimePreprocess)
301 if constexpr(TimePreprocess)
306 kernel<<<grid_dim, block_dim, lds_byte, stream_config.
stream_id_>>>(gemm_args, args...);
324 printf(
"gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p\n",
325 static_cast<const void*
>(gemm_args.p_a_grid),
326 static_cast<const void*
>(gemm_args.p_b_grid));
334 times.insert(cur_time);
336 total_time += cur_time;
340 auto mid = times.begin();
341 std::advance(mid, (nrepeat - 1) / 2);
349 std::advance(mid_next, 1);
350 return (*mid + *mid_next) / 2;
354 hipDeviceProp_t deviceProps;
356 float preprocess_offset = deviceProps.multiProcessorCount == 80 ? 0.005 : 0.01;
357 return (total_time - preprocess_offset * nrepeat) / nrepeat;
363 kernel<<<grid_dim, block_dim, lds_byte, stream_config.
stream_id_>>>(gemm_args, args...);
369 kernel<<<grid_dim, block_dim, lds_byte, stream_config.
stream_id_>>>(gemm_args, args...);
#define CK_ENV(name)
Definition: env.hpp:128
void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:10
void flush_icache()
Definition: flush_cache.hpp:215
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition: flush_cache.hpp:230
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:139
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
int32_t index_t
Definition: ck.hpp:289
Definition: stream_config.hpp:10
int cold_niters_
Definition: stream_config.hpp:14
bool time_kernel_
Definition: stream_config.hpp:12
int nrepeat_
Definition: stream_config.hpp:15
hipStream_t stream_id_
Definition: stream_config.hpp:11
Definition: functional2.hpp:31
Definition: flush_cache.hpp:137
decltype(Argument::p_a_grid) ADataType
Definition: flush_cache.hpp:138
RotatingMemWrapper(Argument &arg_, std::size_t rotating_count_, std::size_t size_a_, std::size_t size_b_)
Definition: flush_cache.hpp:142
decltype(Argument::p_b_grid) BDataType
Definition: flush_cache.hpp:139
~RotatingMemWrapper()
Definition: flush_cache.hpp:188
RotatingMemWrapper()=delete
void Print()
Definition: flush_cache.hpp:183
void Next()
Definition: flush_cache.hpp:174
Definition: flush_cache.hpp:19
decltype(Argument::p_b_grid) BDataType
Definition: flush_cache.hpp:23
void Print()
Definition: flush_cache.hpp:94
static constexpr index_t NumDs
Definition: flush_cache.hpp:20
decltype(Argument::p_a_grid) ADataType
Definition: flush_cache.hpp:22
RotatingMemWrapperMultiD()=delete
~RotatingMemWrapperMultiD()
Definition: flush_cache.hpp:99
decltype(Argument::p_ds_grid) DsGridPointer
Definition: flush_cache.hpp:24
RotatingMemWrapperMultiD(Argument &arg_, std::size_t rotating_count_, std::size_t size_a_, std::size_t size_b_, std::array< std::size_t, NumDs > size_ds_)
Definition: flush_cache.hpp:27
void Next()
Definition: flush_cache.hpp:84