/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/kernel_launch.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/kernel_launch.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/kernel_launch.hpp Source File
kernel_launch.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <numeric>
7 #include <functional>
12 #include "ck_tile/host/timer.hpp"
13 #include <cstddef>
14 #include <hip/hip_runtime.h>
15 
16 namespace ck_tile {
17 
18 template <int MinBlockPerCu, typename Kernel, typename... Args>
19 #if CK_TILE_USE_LAUNCH_BOUNDS
20 __launch_bounds__(Kernel::kBlockSize, MinBlockPerCu)
21 #endif
22  __global__ void kentry(Args... args)
23 {
24 #if defined(__HIP_DEVICE_COMPILE__)
25  Kernel{}(args...);
26 #else
27  (..., (ignore = args, 0));
28 #endif
29 }
30 
31 template <typename Arch, int MinBlockPerCu, typename Kernel, typename... Args>
32 #if CK_TILE_USE_LAUNCH_BOUNDS
33 __launch_bounds__(Kernel::kBlockSize, MinBlockPerCu)
34 #endif
35  __global__ void kentry(Args... args)
36 {
37 #if defined(__HIP_DEVICE_COMPILE__)
38  Kernel{}(args...);
39 #else
40  (..., (ignore = args, 0));
41 #endif
42 }
43 
44 //
45 // return a anonymous functor(lambda) to be called later
46 // the KernelImpl should be a class without non-static data member, or let's say
47 // can be instantiate with "KernelImpl{}"
48 //
49 // the "static __device__ operator()(some_arg)" is the entry point of KernelImpl
50 //
51 // Arch can be used to support linking multiple object files that have the same kernel compiled for
52 // different architectures. In this case each object file has to use a different tag (gfx9_t,
53 // gfx12_t etc.), so the kernel will have different symbols for each architecture.
54 //
55 template <int MinBlockPerCu = CK_TILE_MIN_BLOCK_PER_CU,
56  typename Arch = void,
57  typename KernelImpl,
58  typename... Args>
59 CK_TILE_HOST auto
60 make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
61 {
62  const auto kernel = []() {
63  if constexpr(std::is_void_v<Arch>)
64  {
65  return kentry<MinBlockPerCu, KernelImpl, Args...>;
66  }
67  else
68  {
69  return kentry<Arch, MinBlockPerCu, KernelImpl, Args...>;
70  }
71  }();
72  return [=](const stream_config& s) {
73  kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
74  };
75 }
76 
77 template <typename... Callables>
78 CK_TILE_HOST void launch_and_check(const stream_config& sc, Callables&&... callables)
79 {
80  // abort the sequence in case of intermediate error
81  if(!((static_cast<void>(callables(sc)), hipPeekAtLastError() == hipSuccess) && ...))
82  {
83  HIP_CHECK_ERROR(hipGetLastError());
84  }
85 }
86 
87 // Measure the preprocess time during the cold iterations
88 template <typename TimerType, typename PreprocessFunc>
89 CK_TILE_HOST double
90 preprocess_profiling_impl(TimerType timer, const stream_config& s, PreprocessFunc preprocess)
91 {
92  timer.start(s.stream_id_);
93  for(int i = 0; i < s.nrepeat_; i++)
94  {
95  if constexpr(!std::is_same_v<PreprocessFunc, std::nullptr_t>)
96  {
97  preprocess();
98  }
99  }
100  timer.stop(s.stream_id_);
101 
102  return timer.duration() / s.nrepeat_;
103 }
104 
105 template <typename TimerType, typename CallablesFunc, typename PreprocessFunc = std::nullptr_t>
106 CK_TILE_HOST double timing_loop_impl(TimerType timer,
107  const stream_config& s,
108  CallablesFunc&& callables_func,
109  PreprocessFunc preprocess = nullptr)
110 {
111  for(int i = 0; i < s.cold_niters_; i++)
112  {
113  callables_func();
114  }
115  // Only profile preprocess if it's provided
116  auto preprocess_time = 0.0;
117  if constexpr(!std::is_same_v<PreprocessFunc, std::nullptr_t>)
118  {
119  preprocess_time = preprocess_profiling_impl(gpu_timer{}, s, preprocess);
120  }
121 
122  int i = 0;
123  timer.start(s.stream_id_);
124  while(i < s.nrepeat_)
125  {
126  if constexpr(!std::is_same_v<PreprocessFunc, std::nullptr_t>)
127  {
128  preprocess();
129  }
130 
131  callables_func();
132  i++;
133  }
134  timer.stop(s.stream_id_);
135 
136  if(!i)
137  return 0.;
138  return (timer.duration() / s.nrepeat_) - preprocess_time;
139 }
140 
141 // clang-format off
142 /*
143  * launch_kernel()
144  *
145  * this is the function to launch arbitrary number of kernels with optional timer(selected by stream_config)
146  * the callables should have signature as "operator()(const stream_config& s){ ... }" to call
147  *
148  * the simplest way is pass in a lambda function, with "[=](const stream_config& s){ call_your_kernel_here() }"
149  * as signature, for the callable (pay attention to the capture list)
150  *
151  * e.g.
152  * ck_tile::launch_kernel(s,
153  * [=](const stream_config& s){ hipMemset(ptr, 0, size) },
154  * [=](const stream_config& s){ some_kernel<<<grids, blocks>>>(arg); }
155  * );
156  *
157  * if you use ck_tile kernel, or similiar to this style (structure with "static __device__ operator()(...){}")
158  * you can pass your kernel to ck_tile::make_kernel(), which will create a anonymous functor for you,
159  * then pass it to ck_tile::launch_kernel()
160  *
161  * e.g.
162  * ck_tile::launch_kernel(s,
163  * ck_tile::make_kernel<T0, B0>(kernel_0{}, grids0, blocks0, 0, kargs0),
164  * ck_tile::make_kernel<T0, B1>(kernel_1{}, grids1, blocks1, 0, kargs1),
165  * ...);
166  **/
167 // clang-format on
168 template <typename... Callables>
169 CK_TILE_HOST float launch_kernel(const stream_config& s, Callables&&... callables)
170 {
171  static_assert(sizeof...(callables) > 0, "At least one callable is required!");
172 
173  if(!s.time_kernel_)
174  {
175  launch_and_check(s, std::forward<Callables>(callables)...);
176  return 0;
177  }
178 
179  auto callables_func = [&]() { launch_and_check(s, std::forward<Callables>(callables)...); };
180 
181  if(s.is_gpu_timer_)
182  {
183  return timing_loop_impl(gpu_timer{}, s, callables_func);
184  }
185  else
186  {
187  return timing_loop_impl(cpu_timer{}, s, callables_func);
188  }
189 }
190 
191 template <typename PreprocessFunc, typename... Callables>
192 CK_TILE_HOST float
193 launch_kernel_time_mask(const stream_config& s, PreprocessFunc preprocess, Callables&&... callables)
194 {
195  static_assert(sizeof...(callables) > 0, "At least one callable is required!");
196 
197  if(!s.time_kernel_)
198  {
199  preprocess();
200  launch_and_check(s, std::forward<Callables>(callables)...);
201  return 0;
202  }
203 
204  auto callables_func = [&]() { launch_and_check(s, std::forward<Callables>(callables)...); };
205 
206  if(s.is_gpu_timer_)
207  {
208  return timing_loop_impl(gpu_timer{}, s, callables_func, preprocess);
209  }
210  else
211  {
212  return timing_loop_impl(cpu_timer{}, s, callables_func, preprocess);
213  }
214 }
215 } // namespace ck_tile
#define CK_TILE_MIN_BLOCK_PER_CU
Definition: config.hpp:115
#define CK_TILE_HOST
Definition: config.hpp:40
#define HIP_CHECK_ERROR(retval_or_funcall)
Definition: hip_check_error.hpp:21
Definition: cluster_descriptor.hpp:13
__global__ void kentry(Args... args)
Definition: kernel_launch.hpp:22
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
CK_TILE_HOST double timing_loop_impl(TimerType timer, const stream_config &s, CallablesFunc &&callables_func, PreprocessFunc preprocess=nullptr)
Definition: kernel_launch.hpp:106
CK_TILE_HOST auto make_kernel(KernelImpl, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:60
CK_TILE_HOST double preprocess_profiling_impl(TimerType timer, const stream_config &s, PreprocessFunc preprocess)
Definition: kernel_launch.hpp:90
CK_TILE_HOST void launch_and_check(const stream_config &sc, Callables &&... callables)
Definition: kernel_launch.hpp:78
CK_TILE_HOST float launch_kernel_time_mask(const stream_config &s, PreprocessFunc preprocess, Callables &&... callables)
Definition: kernel_launch.hpp:193
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition: kernel_launch.hpp:169
Definition: timer.hpp:52
Definition: timer.hpp:15
Definition: stream_config.hpp:30
hipStream_t stream_id_
Definition: stream_config.hpp:31
int cold_niters_
Definition: stream_config.hpp:34
bool time_kernel_
Definition: stream_config.hpp:32
int nrepeat_
Definition: stream_config.hpp:35
bool is_gpu_timer_
Definition: stream_config.hpp:36