/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d_scale.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d_scale.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d_scale.hpp Source File
gridwise_elementwise_1d_scale.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
10 
11 namespace ck {
12 
13 template <typename GridwiseElementwise1dFunctor,
14  typename InGrid1dDescTuple,
15  typename OutGrid1dDescTuple,
16  typename InDataTypePointerTuple,
17  typename OutDataTypePointerTuple,
18  typename ElementwiseOperation,
19  typename UnaryOperation,
20  typename Scale>
21 __global__ void kernel_elementwise_1d(const InGrid1dDescTuple in_grid_1d_desc_tuple,
22  const OutGrid1dDescTuple out_grid_1d_desc_tuple,
23  const InDataTypePointerTuple p_in_global_tuple,
24  const OutDataTypePointerTuple p_out_global_tuple,
25  const ElementwiseOperation elementwise_op,
26  const UnaryOperation unary_op,
27  const Scale scale_op)
28 {
29  GridwiseElementwise1dFunctor::Run(in_grid_1d_desc_tuple,
30  out_grid_1d_desc_tuple,
31  p_in_global_tuple,
32  p_out_global_tuple,
33  elementwise_op,
34  unary_op,
35  scale_op);
36 }
37 
38 template <typename InGrid1dDescTuple,
39  typename OutGrid1dDescTuple,
40  typename InDataTypePointerTuple,
41  typename OutDataTypePointerTuple,
42  typename ElementwiseOperation,
43  typename UnaryOperation,
44  typename Scale,
45  index_t MPerThread,
46  typename InScalarPerVectorSeq,
47  typename OutScalarPerVectorSeq>
49 {
50  static constexpr index_t NumInput = InDataTypePointerTuple::Size();
51  static constexpr index_t NumOutput = OutDataTypePointerTuple::Size();
52 
53  static_assert(NumInput == InScalarPerVectorSeq::Size() &&
54  NumOutput == OutScalarPerVectorSeq::Size() &&
55  NumInput == InGrid1dDescTuple::Size() &&
56  NumOutput == OutGrid1dDescTuple::Size(),
57  "Tuple size is inconsistent with the number of in/out!");
58 
59  static constexpr auto I0 = Number<0>{};
60 
61  static constexpr auto thread_buffer_desc_m =
63 
65 
66  __device__ static void Run(const InGrid1dDescTuple in_grid_1d_desc_tuple,
67  const OutGrid1dDescTuple out_grid_1d_desc_tuple,
68  const InDataTypePointerTuple p_in_global_tuple,
69  const OutDataTypePointerTuple p_out_global_tuple,
70  const ElementwiseOperation elementwise_op,
71  const UnaryOperation unary_op,
72  const Scale scale_op)
73  {
74  const index_t thread_global_id = get_thread_global_1d_id();
75 
76  auto in_thread_buf_tuple = generate_tuple(
77  [&](auto I) {
78  using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
80 
82  },
84 
85  auto out_thread_buf_tuple = generate_tuple(
86  [&](auto I) {
87  using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
88  using DataType = remove_pointer_t<DataTypePointer>;
89 
91  },
93 
94  auto in_global_buf_tuple = generate_tuple(
95  [&](auto I) {
96  static_assert(in_grid_1d_desc_tuple[I].GetNumOfDimension() == 1);
97 
98  return make_dynamic_buffer<AddressSpaceEnum::Global>(
99  p_in_global_tuple[I], in_grid_1d_desc_tuple[I].GetElementSpaceSize());
100  },
101  Number<NumInput>{});
102 
103  auto out_global_buf_tuple = generate_tuple(
104  [&](auto I) {
105  static_assert(out_grid_1d_desc_tuple[I].GetNumOfDimension() == 1);
106 
107  return make_dynamic_buffer<AddressSpaceEnum::Global>(
108  p_out_global_tuple[I], out_grid_1d_desc_tuple[I].GetElementSpaceSize());
109  },
111 
112  const auto thread_global_offset = make_multi_index(thread_global_id * MPerThread);
113 
114  const index_t blockSize = get_block_size();
115  const index_t blockPerGrid = get_grid_size();
116  const auto M = in_grid_1d_desc_tuple[I0].GetLength(I0);
117  const index_t loop_step = blockPerGrid * blockSize * MPerThread;
118  const auto loop_step_index = make_multi_index(loop_step);
119 
120  auto in_global_load_tuple = generate_tuple(
121  [&](auto I) {
122  using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
124 
125  return ThreadwiseTensorSliceTransfer_v2<DataType,
126  DataType,
127  decltype(in_grid_1d_desc_tuple[I]),
128  decltype(thread_buffer_desc_m),
129  Sequence<MPerThread>, // SliceLengths
130  Sequence<0>, // DimAccessOrder
131  0, // SrcVectorDim
132  InScalarPerVectorSeq::At(
133  I), // ScalarPerVector
134  1, // SrcScalarStrideInVector
135  false>{in_grid_1d_desc_tuple[I],
136  thread_global_offset};
137  },
138  Number<NumInput>{});
139 
140  auto out_global_store_tuple = generate_tuple(
141  [&](auto I) {
142  using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
143  using DataType = remove_pointer_t<DataTypePointer>;
144 
145  return ThreadwiseTensorSliceTransfer_v1r3<DataType,
146  DataType,
147  decltype(thread_buffer_desc_m),
148  decltype(out_grid_1d_desc_tuple[I]),
150  Sequence<MPerThread>, // SliceLengths
151  Sequence<0>, // DimAccessOrder
152  0, // SrcVectorDim
153  OutScalarPerVectorSeq::At(I),
155  1,
156  false>(
157  out_grid_1d_desc_tuple[I], thread_global_offset, PassThroughOp{});
158  },
160 
161  index_t num_iter = M / (loop_step);
162  do
163  {
164  static_for<0, NumInput, 1>{}([&](auto I) {
165  in_global_load_tuple(I).Run(in_grid_1d_desc_tuple[I],
166  in_global_buf_tuple[I],
168  make_tuple(I0),
169  in_thread_buf_tuple(I));
170 
171  in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_1d_desc_tuple[I],
172  loop_step_index);
173  });
174 
175  static_for<0, MPerThread, 1>{}([&](auto iM) {
176  // get reference to in data
177  auto uop_data_refs = generate_tie(
178  // return type should be lvalue
179  [&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); },
180  Number<NumInput>{});
181 
182  // get reference to dst data
183  auto out_data_refs = generate_tie(
184  // return type should be lvalue
185  [&](auto I) -> auto& { return out_thread_buf_tuple(I)(iM); },
187 
188  unpack2(unary_op, uop_data_refs, uop_data_refs);
189 
190  auto sop_in_data_refs = generate_tie(
191  // return type should be lvalue
192  [&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); },
193  Number<NumInput>{});
194 
195  auto sop_out_data_refs = generate_tie(
196  // return type should be lvalue
197  [&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); },
198  Number<NumInput>{});
199 
200  unpack2(scale_op, sop_out_data_refs, sop_in_data_refs);
201 
202  const auto in_data_refs = generate_tie(
203  // return type should be lvalue
204  [&](auto I) -> const auto& { return in_thread_buf_tuple(I)(iM); },
205  Number<NumInput>{});
206 
207  unpack2(elementwise_op, out_data_refs, in_data_refs);
208  });
209 
210  static_for<0, NumOutput, 1>{}([&](auto I) {
211  out_global_store_tuple(I).Run(thread_buffer_desc_m,
212  make_tuple(I0),
213  out_thread_buf_tuple[I],
214  out_grid_1d_desc_tuple[I],
215  out_global_buf_tuple(I));
216 
217  out_global_store_tuple(I).MoveDstSliceWindow(out_grid_1d_desc_tuple[I],
218  loop_step_index);
219  });
220  } while(--num_iter);
221  }
222 };
223 
224 } // namespace ck
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto unpack2(F &&f, X &&x, Y &&y)
Definition: functional4.hpp:55
__device__ index_t get_grid_size()
Definition: get_id.hpp:60
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
typename remove_pointer< T >::type remove_pointer_t
Definition: type.hpp:300
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__device__ index_t get_block_size()
Definition: get_id.hpp:62
__global__ void kernel_elementwise_1d(const InGrid1dDescTuple in_grid_1d_desc_tuple, const OutGrid1dDescTuple out_grid_1d_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const ElementwiseOperation elementwise_op, const UnaryOperation unary_op, const Scale scale_op)
Definition: gridwise_elementwise_1d_scale.hpp:21
__device__ index_t get_thread_global_1d_id()
Definition: get_id.hpp:54
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:298
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:295
Definition: gridwise_elementwise_1d_scale.hpp:49
tensor_operation::element_wise::PassThrough PassThroughOp
Definition: gridwise_elementwise_1d_scale.hpp:64
static constexpr index_t NumOutput
Definition: gridwise_elementwise_1d_scale.hpp:51
static __device__ void Run(const InGrid1dDescTuple in_grid_1d_desc_tuple, const OutGrid1dDescTuple out_grid_1d_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const ElementwiseOperation elementwise_op, const UnaryOperation unary_op, const Scale scale_op)
Definition: gridwise_elementwise_1d_scale.hpp:66
static constexpr auto thread_buffer_desc_m
Definition: gridwise_elementwise_1d_scale.hpp:61
static constexpr auto I0
Definition: gridwise_elementwise_1d_scale.hpp:59
static constexpr index_t NumInput
Definition: gridwise_elementwise_1d_scale.hpp:50
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:334