/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp Source File
gridwise_elementwise_2d.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
15 
16 namespace ck {
17 
18 template <typename GridwiseElementwiseFunctor,
19  typename InGridDescTuple,
20  typename OutGridDescTuple,
21  typename InDataTypePointerTuple,
22  typename OutDataTypePointerTuple,
23  typename Block2TileMap,
24  typename ElementwiseOperation>
25 __global__ void
26 #if CK_USE_LAUNCH_BOUNDS
28 #endif
29  kernel_elementwise(const InGridDescTuple in_grid_desc_tuple,
30  const OutGridDescTuple out_grid_desc_tuple,
31  const InDataTypePointerTuple p_in_global_tuple,
32  const OutDataTypePointerTuple p_out_global_tuple,
33  const Block2TileMap block_2_tile_map,
34  const ElementwiseOperation elementwise_op)
35 {
36  GridwiseElementwiseFunctor::Run(in_grid_desc_tuple,
37  out_grid_desc_tuple,
38  p_in_global_tuple,
39  p_out_global_tuple,
40  block_2_tile_map,
41  elementwise_op);
42 }
43 
44 template <typename GridwiseElementwiseFunctorA,
45  typename GridwiseElementwiseFunctorB,
46  typename InAGridDescTuple,
47  typename InBGridDescTuple,
48  typename OutAGridDescTuple,
49  typename OutBGridDescTuple,
50  typename InADataTypePointerTuple,
51  typename InBDataTypePointerTuple,
52  typename OutADataTypePointerTuple,
53  typename OutBDataTypePointerTuple,
54  typename Block2TileMapA,
55  typename Block2TileMapB,
56  typename ElementwiseOperation>
57 __global__ void
58 #if CK_USE_LAUNCH_BOUNDS
60 #endif
61  kernel_elementwise_dual(const InAGridDescTuple in_grid_desc_tuple_a,
62  const InBGridDescTuple in_grid_desc_tuple_b,
63  const OutAGridDescTuple out_grid_desc_tuple_a,
64  const OutBGridDescTuple out_grid_desc_tuple_b,
65  const InADataTypePointerTuple p_in_global_tuple_a,
66  const InBDataTypePointerTuple p_in_global_tuple_b,
67  const OutADataTypePointerTuple p_out_global_tuple_a,
68  const OutBDataTypePointerTuple p_out_global_tuple_b,
69  const Block2TileMapA block_2_tile_map_a,
70  const Block2TileMapB block_2_tile_map_b,
71  const ElementwiseOperation elementwise_op,
72  const index_t a_grid_size)
73 {
74  if(get_block_1d_id() < a_grid_size)
75  {
76  GridwiseElementwiseFunctorA::Run(in_grid_desc_tuple_a,
77  out_grid_desc_tuple_a,
78  p_in_global_tuple_a,
79  p_out_global_tuple_a,
80  block_2_tile_map_a,
81  elementwise_op,
82  get_block_1d_id());
83  }
84  else
85  {
86  GridwiseElementwiseFunctorB::Run(in_grid_desc_tuple_b,
87  out_grid_desc_tuple_b,
88  p_in_global_tuple_b,
89  p_out_global_tuple_b,
90  block_2_tile_map_b,
91  elementwise_op,
92  get_block_1d_id() - a_grid_size);
93  }
94 }
95 
96 template <typename GridwiseElementwiseFunctorA,
97  typename GridwiseElementwiseFunctorB,
98  typename InAGridDescTuple,
99  typename InBGridDescTuple,
100  typename OutAGridDescTuple,
101  typename OutBGridDescTuple,
102  typename InADataTypePointerTuple,
103  typename InBDataTypePointerTuple,
104  typename OutADataTypePointerTuple,
105  typename OutBDataTypePointerTuple,
106  typename Block2TileMapA,
107  typename Block2TileMapB,
108  typename ElementwiseOperation,
109  index_t NumInputsA,
110  index_t NumInputsB,
111  index_t NumOutputsA,
112  index_t NumOutputsB>
113 __global__ void
114 #if CK_USE_LAUNCH_BOUNDS
116 #endif
117  kernel_elementwise_batched_dual(const InAGridDescTuple in_grid_desc_tuple_a,
118  const InBGridDescTuple in_grid_desc_tuple_b,
119  const OutAGridDescTuple out_grid_desc_tuple_a,
120  const OutBGridDescTuple out_grid_desc_tuple_b,
121  const InADataTypePointerTuple p_in_global_tuple_a,
122  const InBDataTypePointerTuple p_in_global_tuple_b,
123  const OutADataTypePointerTuple p_out_global_tuple_a,
124  const OutBDataTypePointerTuple p_out_global_tuple_b,
125  const Block2TileMapA block_2_tile_map_a,
126  const Block2TileMapB block_2_tile_map_b,
127  const ElementwiseOperation elementwise_op,
128  const index_t a_grid_size,
129  const index_t batch_count_a,
130  const index_t batch_count_b,
131  const std::array<index_t, NumInputsA> input_batch_strides_a,
132  const std::array<index_t, NumInputsB> input_batch_strides_b,
133  const std::array<index_t, NumOutputsA> output_batch_strides_a,
134  const std::array<index_t, NumOutputsB> output_batch_strides_b)
135 {
136  static_assert(InAGridDescTuple::Size() == NumInputsA &&
137  InADataTypePointerTuple::Size() == NumInputsA);
138  static_assert(OutAGridDescTuple::Size() == NumOutputsA &&
139  OutADataTypePointerTuple::Size() == NumOutputsA);
140  static_assert(InBGridDescTuple::Size() == NumInputsB &&
141  InBDataTypePointerTuple::Size() == NumInputsB);
142  static_assert(OutBGridDescTuple::Size() == NumOutputsB &&
143  OutBDataTypePointerTuple::Size() == NumOutputsB);
144 
145  const index_t block_id = __builtin_amdgcn_readfirstlane(get_block_1d_id());
146 
147  if(block_id < a_grid_size)
148  {
149  const index_t num_blocks_per_batch =
150  __builtin_amdgcn_readfirstlane(a_grid_size / batch_count_a);
151  const index_t g_idx = __builtin_amdgcn_readfirstlane(block_id / num_blocks_per_batch);
152 
153  InADataTypePointerTuple p_in_global_with_offset_tuple;
154  OutADataTypePointerTuple p_out_global_with_offset_tuple;
155 
156  static_for<0, InADataTypePointerTuple::Size(), 1>{}([&](auto i) {
157  p_in_global_with_offset_tuple(i) =
158  p_in_global_tuple_a.At(i) +
159  type_convert<long_index_t>(input_batch_strides_a[i]) * g_idx;
160  });
161 
162  static_for<0, OutADataTypePointerTuple::Size(), 1>{}([&](auto i) {
163  p_out_global_with_offset_tuple(i) =
164  p_out_global_tuple_a.At(i) +
165  type_convert<long_index_t>(output_batch_strides_a[i]) * g_idx;
166  });
167 
168  GridwiseElementwiseFunctorA::Run(in_grid_desc_tuple_a,
169  out_grid_desc_tuple_a,
170  p_in_global_with_offset_tuple,
171  p_out_global_with_offset_tuple,
172  block_2_tile_map_a,
173  elementwise_op,
174  block_id);
175  }
176  else
177  {
178  const index_t num_blocks_per_batch =
179  __builtin_amdgcn_readfirstlane((get_grid_size() - a_grid_size) / batch_count_b);
180  const index_t g_idx =
181  __builtin_amdgcn_readfirstlane((block_id - a_grid_size) / num_blocks_per_batch);
182 
183  InBDataTypePointerTuple p_in_global_with_offset_tuple;
184  OutBDataTypePointerTuple p_out_global_with_offset_tuple;
185 
186  static_for<0, InBDataTypePointerTuple::Size(), 1>{}([&](auto i) {
187  p_in_global_with_offset_tuple(i) =
188  p_in_global_tuple_b.At(i) +
189  type_convert<long_index_t>(input_batch_strides_b[i]) * g_idx;
190  });
191 
192  static_for<0, OutBDataTypePointerTuple::Size(), 1>{}([&](auto i) {
193  p_out_global_with_offset_tuple(i) =
194  p_out_global_tuple_b.At(i) +
195  type_convert<long_index_t>(output_batch_strides_b[i]) * g_idx;
196  });
197 
198  GridwiseElementwiseFunctorB::Run(in_grid_desc_tuple_b,
199  out_grid_desc_tuple_b,
200  p_in_global_with_offset_tuple,
201  p_out_global_with_offset_tuple,
202  block_2_tile_map_b,
203  elementwise_op,
204  block_id - a_grid_size);
205  }
206 }
207 
208 template <typename GridwiseElementwiseFunctor,
209  typename InGridDescTuple,
210  typename OutGridDescTuple,
211  typename InDataTypePointerTuple,
212  typename OutDataTypePointerTuple,
213  typename Block2TileMap,
214  typename ElementwiseOperation,
215  index_t NumInputs,
216  index_t NumOutputs>
217 __global__ void
218 #if CK_USE_LAUNCH_BOUNDS
220 #endif
221  kernel_batched_elementwise(const InGridDescTuple in_grid_desc_tuple,
222  const OutGridDescTuple out_grid_desc_tuple,
223  const InDataTypePointerTuple p_in_global_tuple,
224  const OutDataTypePointerTuple p_out_global_tuple,
225  const Block2TileMap block_2_tile_map,
226  const ElementwiseOperation elementwise_op,
227  const index_t batch_count,
228  const std::array<index_t, NumInputs> input_batch_strides,
229  const std::array<index_t, NumOutputs> output_batch_strides)
230 {
231  static_assert(InGridDescTuple::Size() == NumInputs &&
232  InDataTypePointerTuple::Size() == NumInputs);
233  static_assert(OutGridDescTuple::Size() == NumOutputs &&
234  OutDataTypePointerTuple::Size() == NumOutputs);
235 
236  const index_t num_blocks_per_batch =
237  __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
238  const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
239 
240  InDataTypePointerTuple p_in_global_with_offset_tuple;
241  OutDataTypePointerTuple p_out_global_with_offset_tuple;
242 
243  static_for<0, InDataTypePointerTuple::Size(), 1>{}([&](auto i) {
244  p_in_global_with_offset_tuple(i) =
245  p_in_global_tuple.At(i) + type_convert<long_index_t>(input_batch_strides[i]) * g_idx;
246  });
247 
248  static_for<0, OutDataTypePointerTuple::Size(), 1>{}([&](auto i) {
249  p_out_global_with_offset_tuple(i) =
250  p_out_global_tuple.At(i) + type_convert<long_index_t>(output_batch_strides[i]) * g_idx;
251  });
252 
253  GridwiseElementwiseFunctor::Run(in_grid_desc_tuple,
254  out_grid_desc_tuple,
255  p_in_global_with_offset_tuple,
256  p_out_global_with_offset_tuple,
257  block_2_tile_map,
258  elementwise_op);
259 }
260 
261 template <typename InGridDescTuple,
262  typename OutGridDescTuple,
263  typename InDataTypePointerTuple,
264  typename OutDataTypePointerTuple,
265  typename Block2TileMap,
266  typename ElementwiseOperation,
267  index_t BlockSize,
268  index_t M0PerBlock,
269  index_t M1PerBlock,
270  index_t M0PerThread,
271  index_t M1PerThread,
272  typename ThreadClusterArrangeOrder,
273  typename InScalarPerVectorSeq,
274  typename OutScalarPerVectorSeq,
275  index_t SrcVectorDim,
276  index_t DstVectorDim>
278 {
279  static constexpr index_t NumInput = InDataTypePointerTuple::Size();
280  static constexpr index_t NumOutput = OutDataTypePointerTuple::Size();
281 
282  static_assert(NumInput == InScalarPerVectorSeq::Size() &&
283  NumOutput == OutScalarPerVectorSeq::Size() &&
284  NumInput == InGridDescTuple::Size() && NumOutput == OutGridDescTuple::Size(),
285  "Tuple size is inconsistent with the number of in/out!");
286 
287  static constexpr auto I0 = Number<0>{};
288  static constexpr auto I1 = Number<1>{};
289 
290  static_assert((SrcVectorDim == I0 || SrcVectorDim == I1) &&
291  (DstVectorDim == I0 || DstVectorDim == I1),
292  "Vector dim must be equal to 0 or 1.");
293 
295 
296  __device__ static void Run(const InGridDescTuple& in_grid_desc_tuple,
297  const OutGridDescTuple& out_grid_desc_tuple,
298  const InDataTypePointerTuple& p_in_global_tuple,
299  const OutDataTypePointerTuple& p_out_global_tuple,
300  const Block2TileMap& block_2_tile_map,
301  const ElementwiseOperation& elementwise_op,
302  const index_t block_id = get_block_1d_id())
303  {
304 
305  constexpr auto src_datas = generate_tuple(
306  [&](auto I) {
307  using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
309 
310  return DataType{};
311  },
312  Number<NumInput>{});
313 
314  constexpr auto dst_datas = generate_tuple(
315  [&](auto I) {
316  using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
317  using DataType = remove_pointer_t<DataTypePointer>;
318 
319  return DataType{};
320  },
322 
323  const auto in_global_buf_tuple = generate_tuple(
324  [&](auto I) {
325  return make_dynamic_buffer<AddressSpaceEnum::Global>(
326  p_in_global_tuple[I], in_grid_desc_tuple[I].GetElementSpaceSize());
327  },
328  Number<NumInput>{});
329 
330  auto out_global_buf_tuple = generate_tuple(
331  [&](auto I) {
332  return make_dynamic_buffer<AddressSpaceEnum::Global>(
333  p_out_global_tuple[I], out_grid_desc_tuple[I].GetElementSpaceSize());
334  },
336 
337  const auto block_work_idx =
338  block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id));
339 
340  const index_t m0_block_data_idx_on_grid =
341  __builtin_amdgcn_readfirstlane(block_work_idx[I0] * M0PerBlock);
342  const index_t m1_block_data_idx_on_grid =
343  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * M1PerBlock);
344  const auto input_thread_grid_offset = generate_tuple(
345  [&](auto) {
346  return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid);
347  },
348  Number<NumInput>{});
349  const auto output_thread_grid_offset = generate_tuple(
350  [&](auto) {
351  return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid);
352  },
354 
356  // If src and dst have same vector dim, then:
357  // M0 dim - for src and dst vector load/store
358  // else:
359  // M0 dim - for dst vector load
360  // M1 dim - for src vector store
361  using SrcDimAccessOrder =
362  std::conditional_t<SrcVectorDim == I1, Sequence<0, 1>, Sequence<1, 0>>;
363  using DstDimAccessOrder =
364  std::conditional_t<DstVectorDim == I1, Sequence<0, 1>, Sequence<1, 0>>;
365 
366  using ThreadClusterLengths =
367  Sequence<Number<M0PerBlock / M0PerThread>{}, Number<M1PerBlock / M1PerThread>{}>;
368 
369  auto global_to_global_transfer = ThreadGroupTensorSliceTransfer_v4r2<
371  ElementwiseOperation,
374  ThreadClusterLengths,
375  ThreadClusterArrangeOrder,
376  decltype(src_datas),
377  decltype(dst_datas),
378  InGridDescTuple,
379  OutGridDescTuple,
380  SrcDimAccessOrder,
381  DstDimAccessOrder,
382  SrcVectorDim,
383  DstVectorDim,
384  InScalarPerVectorSeq,
385  OutScalarPerVectorSeq,
389  uniform_sequence_gen_t<NumOutput, false>>{in_grid_desc_tuple,
390  input_thread_grid_offset,
391  out_grid_desc_tuple,
392  output_thread_grid_offset,
393  elementwise_op};
394  global_to_global_transfer.Run(
395  in_grid_desc_tuple, in_global_buf_tuple, out_grid_desc_tuple, out_global_buf_tuple, I0);
396  }
397 };
398 
399 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:30
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Definition: ck.hpp:267
__global__ void kernel_batched_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op, const index_t batch_count, const std::array< index_t, NumInputs > input_batch_strides, const std::array< index_t, NumOutputs > output_batch_strides)
Definition: gridwise_elementwise_2d.hpp:221
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_grid_size()
Definition: get_id.hpp:60
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
typename remove_pointer< T >::type remove_pointer_t
Definition: type.hpp:300
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:298
__global__ void kernel_elementwise_dual(const InAGridDescTuple in_grid_desc_tuple_a, const InBGridDescTuple in_grid_desc_tuple_b, const OutAGridDescTuple out_grid_desc_tuple_a, const OutBGridDescTuple out_grid_desc_tuple_b, const InADataTypePointerTuple p_in_global_tuple_a, const InBDataTypePointerTuple p_in_global_tuple_b, const OutADataTypePointerTuple p_out_global_tuple_a, const OutBDataTypePointerTuple p_out_global_tuple_b, const Block2TileMapA block_2_tile_map_a, const Block2TileMapB block_2_tile_map_b, const ElementwiseOperation elementwise_op, const index_t a_grid_size)
Definition: gridwise_elementwise_2d.hpp:61
__global__ void kernel_elementwise_batched_dual(const InAGridDescTuple in_grid_desc_tuple_a, const InBGridDescTuple in_grid_desc_tuple_b, const OutAGridDescTuple out_grid_desc_tuple_a, const OutBGridDescTuple out_grid_desc_tuple_b, const InADataTypePointerTuple p_in_global_tuple_a, const InBDataTypePointerTuple p_in_global_tuple_b, const OutADataTypePointerTuple p_out_global_tuple_a, const OutBDataTypePointerTuple p_out_global_tuple_b, const Block2TileMapA block_2_tile_map_a, const Block2TileMapB block_2_tile_map_b, const ElementwiseOperation elementwise_op, const index_t a_grid_size, const index_t batch_count_a, const index_t batch_count_b, const std::array< index_t, NumInputsA > input_batch_strides_a, const std::array< index_t, NumInputsB > input_batch_strides_b, const std::array< index_t, NumOutputsA > output_batch_strides_a, const std::array< index_t, NumOutputsB > output_batch_strides_b)
Definition: gridwise_elementwise_2d.hpp:117
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:295
__global__ void kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op)
Definition: gridwise_elementwise_2d.hpp:29
Definition: gridwise_elementwise_2d.hpp:278
static constexpr index_t NumInput
Definition: gridwise_elementwise_2d.hpp:279
static constexpr auto I1
Definition: gridwise_elementwise_2d.hpp:288
static __device__ void Run(const InGridDescTuple &in_grid_desc_tuple, const OutGridDescTuple &out_grid_desc_tuple, const InDataTypePointerTuple &p_in_global_tuple, const OutDataTypePointerTuple &p_out_global_tuple, const Block2TileMap &block_2_tile_map, const ElementwiseOperation &elementwise_op, const index_t block_id=get_block_1d_id())
Definition: gridwise_elementwise_2d.hpp:296
static constexpr auto I0
Definition: gridwise_elementwise_2d.hpp:287
static constexpr index_t NumOutput
Definition: gridwise_elementwise_2d.hpp:280
Definition: sequence.hpp:43
Definition: thread_group.hpp:12
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r2.hpp:45
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:334