include/ck_tile/host/kernel_launch.hpp Source File

include/ck_tile/host/kernel_launch.hpp Source File#

Composable Kernel: include/ck_tile/host/kernel_launch.hpp Source File
kernel_launch.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
9 #include "ck_tile/host/timer.hpp"
10 #include <hip/hip_runtime.h>
11 #include <cstddef>
12 
13 namespace ck_tile {
14 template <int MaxThreadPerBlock, int MinBlockPerCu, typename Kernel, typename... Args>
15 #if CK_TILE_USE_LAUNCH_BOUNDS
16 __launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
17 #endif
18  __global__ void kentry(Args... args)
19 {
20  Kernel{}(args...);
21 }
22 
23 //
24 // return a anonymous functor(lambda) to be called later
25 // the KernelImpl should be a class without non-static data member, or let's say
26 // can be instantiate with "KernelImpl{}"
27 //
28 // the "static __device__ operator()(some_arg)" is the entry point of KernelImpl
29 //
30 template <int MaxThreadPerBlock = CK_TILE_MAX_THREAD_PER_BLOCK,
31  int MinBlockPerCu = CK_TILE_MIN_BLOCK_PER_CU,
32  typename KernelImpl,
33  typename... Args>
34 CK_TILE_HOST auto
35 make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
36 {
37  const auto kernel = kentry<MaxThreadPerBlock, MinBlockPerCu, KernelImpl, Args...>;
38 
39  return [=](const stream_config& s) {
40  kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
41  };
42 }
43 
44 // clang-format off
45 /*
46  * launch_kernel()
47  *
48  * this is the function to launch arbitrary number of kernels with optional timer(selected by stream_config)
49  * the callables should have signature as "operator()(const stream_config& s){ ... }" to call
50  *
51  * the simplest way is pass in a lambda function, with "[=](const stream_config& s){ call_your_kernel_here() }"
52  * as signature, for the callable (pay attention to the capture list)
53  *
54  * e.g.
55  * ck_tile::launch_kernel(s,
56  * [=](const stream_config& s){ hipMemset(ptr, 0, size) },
57  * [=](const stream_config& s){ some_kernel<<<grids, blocks>>>(arg); }
58  * );
59  *
60  * if you use ck_tile kernel, or similiar to this style (structure with "static __device__ operator()(...){}")
61  * you can pass your kernel to ck_tile::make_kernel(), which will create a anonymous functor for you,
62  * then pass it to ck_tile::launch_kernel()
63  *
64  * e.g.
65  * ck_tile::launch_kernel(s,
66  * ck_tile::make_kernel<T0, B0>(kernel_0{}, grids0, blocks0, 0, kargs0),
67  * ck_tile::make_kernel<T0, B1>(kernel_1{}, grids1, blocks1, 0, kargs1),
68  * ...);
69  **/
70 // clang-format on
71 template <typename... Callables>
72 CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables)
73 {
74  // clang-format off
75  if(!s.time_kernel_) {
76  (callables(s),...); HIP_CHECK_ERROR(hipGetLastError());
77  return 0;
78  }
79  if(s.is_gpu_timer_) {
80  gpu_timer timer {};
81 
82  // warmup
83  for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError());
84 
85  timer.start(s.stream_id_);
86  for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError());
87  timer.stop(s.stream_id_);
88 
89  return timer.duration() / s.nrepeat_;
90  }
91  else {
92  cpu_timer timer {};
93 
94  // warmup
95  for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError());
96 
97  timer.start(s.stream_id_);
98  for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError());
99  timer.stop(s.stream_id_);
100 
101  return timer.duration() / s.nrepeat_;
102  }
103  // clang-format on
104 }
105 
106 } // namespace ck_tile
#define CK_TILE_MIN_BLOCK_PER_CU
Definition: config.hpp:113
#define CK_TILE_MAX_THREAD_PER_BLOCK
Definition: config.hpp:112
#define CK_TILE_HOST
Definition: config.hpp:39
#define HIP_CHECK_ERROR(retval_or_funcall)
Definition: hip_check_error.hpp:22
Definition: cluster_descriptor.hpp:13
__global__ void kentry(Args... args)
Definition: kernel_launch.hpp:18
CK_TILE_HOST auto make_kernel(KernelImpl, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:35
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables... callables)
Definition: kernel_launch.hpp:72
Definition: timer.hpp:52
Definition: timer.hpp:15
Definition: stream_config.hpp:26
hipStream_t stream_id_
Definition: stream_config.hpp:27
int cold_niters_
Definition: stream_config.hpp:30
bool time_kernel_
Definition: stream_config.hpp:28
int nrepeat_
Definition: stream_config.hpp:31
bool is_gpu_timer_
Definition: stream_config.hpp:32