6 #include <hip/hip_runtime.h> 
   19 template <
typename Argument, 
typename AsDataType, 
typename BsDataType, 
typename DsDataType>
 
   32                                std::size_t rotating_count_hint,
 
   33                                std::array<std::size_t, NumAs> size_as_,
 
   34                                std::array<std::size_t, NumBs> size_bs_,
 
   35                                std::array<std::size_t, NumDs> size_ds_)
 
   37           rotating_count(rotating_count_hint),
 
   42         p_as_grids.push_back(arg.p_as_grid);
 
   43         p_bs_grids.push_back(arg.p_bs_grid);
 
   44         p_ds_grids.push_back(arg.p_ds_grid);
 
   47         const uint64_t footprint = std::accumulate(size_as.begin(), size_as.end(), 0UL) +
 
   48                                    std::accumulate(size_bs.begin(), size_bs.end(), 0UL) +
 
   49                                    std::accumulate(size_ds.begin(), size_ds.end(), 0UL);
 
   50         const uint64_t max_rotating_count = (1ULL << 31) / footprint;
 
   51         rotating_count                    = 
std::min(rotating_count, max_rotating_count);
 
   53         for(
size_t i = 1; i < rotating_count; i++)
 
   59                     hip_check_error(hipMalloc(
static_cast<void**
>(&pADeviceBuf), size_as_[j]));
 
   61                                               static_cast<const void*
>(p_as_grids[0][j]),
 
   63                                               hipMemcpyDeviceToDevice));
 
   66                     as_buffer(j) = 
static_cast<const ADataType*
>(pADeviceBuf);
 
   68                 p_as_grids.push_back(as_buffer);
 
   75                     hip_check_error(hipMalloc(
static_cast<void**
>(&pBDeviceBuf), size_bs_[j]));
 
   77                                               static_cast<const void*
>(p_bs_grids[0][j]),
 
   79                                               hipMemcpyDeviceToDevice));
 
   82                     bs_buffer(j) = 
static_cast<const BDataType*
>(pBDeviceBuf);
 
   84                 p_bs_grids.push_back(bs_buffer);
 
   91                     hip_check_error(hipMalloc(
static_cast<void**
>(&pDDeviceBuf), size_ds_[j]));
 
   93                                               static_cast<const void*
>(p_ds_grids[0][j]),
 
   95                                               hipMemcpyDeviceToDevice));
 
   99                     ds_buffer(j) = 
static_cast<const DDataType*
>(pDDeviceBuf);
 
  102                 p_ds_grids.push_back(ds_buffer);
 
  109         if(rotating_count > 1)
 
  111             std::size_t idx = iter++ % rotating_count;
 
  112             arg.p_as_grid   = p_as_grids[idx];
 
  113             arg.p_bs_grid   = p_bs_grids[idx];
 
  114             arg.p_ds_grid   = p_ds_grids[idx];
 
  119         std::cout << 
"RotatingMemWrapperMultiD: { size_a: {";
 
  121             [&](
auto j) { std::cout << size_as[j] << (j.value < 
NumAs - 1 ? 
", " : 
""); });
 
  122         std::cout << 
"}, size_b: {";
 
  124             [&](
auto j) { std::cout << size_bs[j] << (j.value < 
NumBs - 1 ? 
", " : 
""); });
 
  125         std::cout << 
"}, rotating_count: " << rotating_count << 
"}" << std::endl;
 
  129         if(rotating_count > 1)
 
  132             arg.p_as_grid = p_as_grids[0];
 
  133             arg.p_bs_grid = p_bs_grids[0];
 
  134             arg.p_ds_grid = p_ds_grids[0];
 
  137             for(
size_t i = 1; i < rotating_count; i++)
 
  142                         hipFree(
static_cast<void*
>(
const_cast<ADataType*
>(p_as_grids[i][j]))));
 
  148                         hipFree(
static_cast<void*
>(
const_cast<BDataType*
>(p_bs_grids[i][j]))));
 
  154                         hipFree(
static_cast<void*
>(
const_cast<DDataType*
>(p_ds_grids[i][j]))));
 
  162     std::size_t iter                       = 0;
 
  163     std::size_t rotating_count             = 1;
 
  164     std::array<std::size_t, NumAs> size_as = {0};
 
  165     std::array<std::size_t, NumBs> size_bs = {0};
 
  166     std::array<std::size_t, NumDs> size_ds = {0};
 
  167     std::vector<AsGridPointer> p_as_grids;
 
  168     std::vector<BsGridPointer> p_bs_grids;
 
  169     std::vector<DsGridPointer> p_ds_grids;
 
  172 template <
typename Argument, 
typename DsDataType>
 
  183                              std::size_t rotating_count_hint,
 
  186                              std::array<std::size_t, NumDs> size_ds_)
 
  188           rotating_count(rotating_count_hint),
 
  193         p_a_grids.push_back(arg.p_a_grid);
 
  194         p_b_grids.push_back(arg.p_b_grid);
 
  195         p_ds_grids.push_back(arg.p_ds_grid);
 
  199             std::accumulate(size_ds.begin(), size_ds.end(), 0UL) + (size_a + size_b);
 
  200         const uint64_t max_rotating_count = (1ULL << 31) / footprint;
 
  201         rotating_count                    = 
std::min(rotating_count, max_rotating_count);
 
  203         for(
size_t i = 1; i < rotating_count; i++)
 
  207                 hip_check_error(hipMalloc(
static_cast<void**
>(&pADeviceBuf), size_a_));
 
  209                                           const_cast<void*
>(p_a_grids[0]),
 
  211                                           hipMemcpyDeviceToDevice));
 
  212                 p_a_grids.push_back(pADeviceBuf);
 
  217                 hip_check_error(hipMalloc(
static_cast<void**
>(&pBDeviceBuf), size_b_));
 
  219                                           const_cast<void*
>(p_b_grids[0]),
 
  221                                           hipMemcpyDeviceToDevice));
 
  222                 p_b_grids.push_back(pBDeviceBuf);
 
  230                     hip_check_error(hipMalloc(
static_cast<void**
>(&pDDeviceBuf), size_ds_[j]));
 
  232                                               static_cast<const void*
>(p_ds_grids[0][j]),
 
  234                                               hipMemcpyDeviceToDevice));
 
  238                     ds_buffer(j) = 
static_cast<const DDataType*
>(pDDeviceBuf);
 
  241                 p_ds_grids.push_back(ds_buffer);
 
  248         if(rotating_count > 1)
 
  250             std::size_t idx = iter++ % rotating_count;
 
  251             arg.p_a_grid    = 
reinterpret_cast<ADataType>(p_a_grids[idx]);
 
  252             arg.p_b_grid    = 
reinterpret_cast<BDataType>(p_b_grids[idx]);
 
  253             arg.p_ds_grid   = p_ds_grids[idx];
 
  258         std::cout << 
"RotatingMemWrapperMultiD: { size_a: " << size_a << 
", size_b: " << size_b
 
  259                   << 
", rotating_count: " << rotating_count << 
"}" << std::endl;
 
  263         if(rotating_count > 1)
 
  266             arg.p_a_grid  = 
reinterpret_cast<ADataType>(p_a_grids[0]);
 
  267             arg.p_b_grid  = 
reinterpret_cast<BDataType>(p_b_grids[0]);
 
  268             arg.p_ds_grid = p_ds_grids[0];
 
  271             for(
size_t i = 1; i < rotating_count; i++)
 
  279                         hipFree(
static_cast<void*
>(
const_cast<DDataType*
>(p_ds_grids[i][j]))));
 
  287     std::size_t iter                       = 0;
 
  288     std::size_t rotating_count             = 1;
 
  289     std::size_t size_a                     = 0;
 
  290     std::size_t size_b                     = 0;
 
  291     std::array<std::size_t, NumDs> size_ds = {0};
 
  292     std::vector<const void*> p_a_grids;
 
  293     std::vector<const void*> p_b_grids;
 
  294     std::vector<DsGridPointer> p_ds_grids;
 
  297 template <
typename Argument>
 
  305                        std::size_t rotating_count_hint,
 
  308         : arg(arg_), rotating_count(rotating_count_hint), size_a(size_a_), size_b(size_b_)
 
  310         p_a_grids.push_back(arg.p_a_grid);
 
  311         p_b_grids.push_back(arg.p_b_grid);
 
  314         const uint64_t footprint          = (size_a + size_b);
 
  315         const uint64_t max_rotating_count = (1ULL << 31) / footprint;
 
  316         rotating_count                    = 
std::min(rotating_count, max_rotating_count);
 
  318         for(
size_t i = 1; i < rotating_count; i++)
 
  322                 hip_check_error(hipMalloc(
static_cast<void**
>(&pADeviceBuf), size_a_));
 
  324                                           const_cast<void*
>(p_a_grids[0]),
 
  326                                           hipMemcpyDeviceToDevice));
 
  327                 p_a_grids.push_back(pADeviceBuf);
 
  332                 hip_check_error(hipMalloc(
static_cast<void**
>(&pBDeviceBuf), size_b_));
 
  334                                           const_cast<void*
>(p_b_grids[0]),
 
  336                                           hipMemcpyDeviceToDevice));
 
  337                 p_b_grids.push_back(pBDeviceBuf);
 
  344         if(rotating_count > 1)
 
  346             std::size_t idx = iter++ % rotating_count;
 
  347             arg.p_a_grid    = 
reinterpret_cast<ADataType>(p_a_grids[idx]);
 
  348             arg.p_b_grid    = 
reinterpret_cast<BDataType>(p_b_grids[idx]);
 
  353         std::cout << 
"RotatingMemWrapper: { size_a: " << size_a << 
", size_b: " << size_b
 
  354                   << 
", rotating_count: " << rotating_count << 
"}" << std::endl;
 
  358         if(rotating_count > 1)
 
  361             arg.p_a_grid = 
reinterpret_cast<ADataType>(p_a_grids[0]);
 
  362             arg.p_b_grid = 
reinterpret_cast<BDataType>(p_b_grids[0]);
 
  365             for(
size_t i = 1; i < rotating_count; i++)
 
  375     std::size_t iter           = 0;
 
  376     std::size_t rotating_count = 1;
 
  377     std::size_t size_a         = 0;
 
  378     std::size_t size_b         = 0;
 
  379     std::vector<const void*> p_a_grids;
 
  380     std::vector<const void*> p_b_grids;
 
  385     hipDeviceProp_t deviceProps;
 
  387     int32_t gpu_block3 = deviceProps.multiProcessorCount * 60;
 
  389     ck::flush_icache<<<dim3(gpu_block3), dim3(64), 0, 
nullptr>>>();
 
  393 template <
bool TimePreprocess,
 
  397           typename PreProcessFunc>
 
  399                                              PreProcessFunc preprocess,
 
  403                                              std::size_t lds_byte,
 
  413             printf(
"%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n",
 
  422             printf(
"Warm up %d times\n", stream_config.
cold_niters_);
 
  427             kernel<<<grid_dim, block_dim, lds_byte, stream_config.
stream_id_>>>(gemm_args, args...);
 
  431         const int nrepeat = stream_config.
nrepeat_;
 
  438             printf(
"Start running %d times...\n", nrepeat);
 
  442         std::set<float> times;
 
  444         float total_time = 0;
 
  446         hipEvent_t start, stop;
 
  454         for(
int i = 0; i < nrepeat; ++i)
 
  456             if constexpr(!TimePreprocess)
 
  469             if constexpr(TimePreprocess)
 
  474             kernel<<<grid_dim, block_dim, lds_byte, stream_config.
stream_id_>>>(gemm_args, args...);
 
  488 #if !defined(CK_USE_WMMA) 
  493                 printf(
"gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p\n",
 
  494                        static_cast<const void*
>(gemm_args.p_a_grid),
 
  495                        static_cast<const void*
>(gemm_args.p_b_grid));
 
  504         times.insert(cur_time);
 
  506         total_time += cur_time;
 
  510         auto mid = times.begin();
 
  511         std::advance(mid, (nrepeat - 1) / 2);
 
  519             std::advance(mid_next, 1);
 
  520             return (*mid + *mid_next) / 2;
 
  524         hipDeviceProp_t deviceProps;
 
  526         float preprocess_offset = deviceProps.multiProcessorCount == 80 ? 0.005 : 0.01;
 
  527         return (total_time - preprocess_offset * nrepeat) / nrepeat;
 
  533         kernel<<<grid_dim, block_dim, lds_byte, stream_config.
stream_id_>>>(gemm_args, args...);
 
  539     kernel<<<grid_dim, block_dim, lds_byte, stream_config.
stream_id_>>>(gemm_args, args...);
 
void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:10
 
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
 
void flush_icache()
Definition: flush_cache.hpp:383
 
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition: flush_cache.hpp:398
 
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
 
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
 
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
 
int32_t index_t
Definition: ck.hpp:299
 
signed int int32_t
Definition: stdint.h:123
 
unsigned __int64 uint64_t
Definition: stdint.h:136
 
Definition: stream_config.hpp:10
 
int cold_niters_
Definition: stream_config.hpp:14
 
bool time_kernel_
Definition: stream_config.hpp:12
 
int nrepeat_
Definition: stream_config.hpp:15
 
hipStream_t stream_id_
Definition: stream_config.hpp:11
 
Definition: functional2.hpp:33
 
Definition: flush_cache.hpp:299
 
decltype(Argument::p_a_grid) ADataType
Definition: flush_cache.hpp:300
 
RotatingMemWrapper(Argument &arg_, std::size_t rotating_count_hint, std::size_t size_a_, std::size_t size_b_)
Definition: flush_cache.hpp:304
 
decltype(Argument::p_b_grid) BDataType
Definition: flush_cache.hpp:301
 
~RotatingMemWrapper()
Definition: flush_cache.hpp:356
 
RotatingMemWrapper()=delete
 
void Print()
Definition: flush_cache.hpp:351
 
void Next()
Definition: flush_cache.hpp:342
 
Definition: flush_cache.hpp:21
 
static constexpr index_t NumBs
Definition: flush_cache.hpp:23
 
RotatingMemWrapperMultiABD(Argument &arg_, std::size_t rotating_count_hint, std::array< std::size_t, NumAs > size_as_, std::array< std::size_t, NumBs > size_bs_, std::array< std::size_t, NumDs > size_ds_)
Definition: flush_cache.hpp:31
 
static constexpr index_t NumDs
Definition: flush_cache.hpp:24
 
decltype(Argument::p_bs_grid) BsGridPointer
Definition: flush_cache.hpp:27
 
decltype(Argument::p_ds_grid) DsGridPointer
Definition: flush_cache.hpp:28
 
void Print()
Definition: flush_cache.hpp:117
 
void Next()
Definition: flush_cache.hpp:107
 
static constexpr index_t NumAs
Definition: flush_cache.hpp:22
 
decltype(Argument::p_as_grid) AsGridPointer
Definition: flush_cache.hpp:26
 
RotatingMemWrapperMultiABD()=delete
 
~RotatingMemWrapperMultiABD()
Definition: flush_cache.hpp:127
 
Definition: flush_cache.hpp:174
 
decltype(Argument::p_b_grid) BDataType
Definition: flush_cache.hpp:178
 
void Print()
Definition: flush_cache.hpp:256
 
static constexpr index_t NumDs
Definition: flush_cache.hpp:175
 
decltype(Argument::p_a_grid) ADataType
Definition: flush_cache.hpp:177
 
RotatingMemWrapperMultiD()=delete
 
~RotatingMemWrapperMultiD()
Definition: flush_cache.hpp:261
 
decltype(Argument::p_ds_grid) DsGridPointer
Definition: flush_cache.hpp:179
 
RotatingMemWrapperMultiD(Argument &arg_, std::size_t rotating_count_hint, std::size_t size_a_, std::size_t size_b_, std::array< std::size_t, NumDs > size_ds_)
Definition: flush_cache.hpp:182
 
void Next()
Definition: flush_cache.hpp:246
 
#define CK_ENV(name)
Definition: env.hpp:129