18 template <
typename GridwiseElementwiseFunctor,
19 typename InGridDescTuple,
20 typename OutGridDescTuple,
21 typename InDataTypePointerTuple,
22 typename OutDataTypePointerTuple,
23 typename Block2TileMap,
24 typename ElementwiseOperation>
26 #if CK_USE_LAUNCH_BOUNDS
30 const OutGridDescTuple out_grid_desc_tuple,
31 const InDataTypePointerTuple p_in_global_tuple,
32 const OutDataTypePointerTuple p_out_global_tuple,
33 const Block2TileMap block_2_tile_map,
34 const ElementwiseOperation elementwise_op)
36 GridwiseElementwiseFunctor::Run(in_grid_desc_tuple,
44 template <
typename GridwiseElementwiseFunctorA,
45 typename GridwiseElementwiseFunctorB,
46 typename InAGridDescTuple,
47 typename InBGridDescTuple,
48 typename OutAGridDescTuple,
49 typename OutBGridDescTuple,
50 typename InADataTypePointerTuple,
51 typename InBDataTypePointerTuple,
52 typename OutADataTypePointerTuple,
53 typename OutBDataTypePointerTuple,
54 typename Block2TileMapA,
55 typename Block2TileMapB,
56 typename ElementwiseOperation>
58 #if CK_USE_LAUNCH_BOUNDS
62 const InBGridDescTuple in_grid_desc_tuple_b,
63 const OutAGridDescTuple out_grid_desc_tuple_a,
64 const OutBGridDescTuple out_grid_desc_tuple_b,
65 const InADataTypePointerTuple p_in_global_tuple_a,
66 const InBDataTypePointerTuple p_in_global_tuple_b,
67 const OutADataTypePointerTuple p_out_global_tuple_a,
68 const OutBDataTypePointerTuple p_out_global_tuple_b,
69 const Block2TileMapA block_2_tile_map_a,
70 const Block2TileMapB block_2_tile_map_b,
71 const ElementwiseOperation elementwise_op,
76 GridwiseElementwiseFunctorA::Run(in_grid_desc_tuple_a,
77 out_grid_desc_tuple_a,
86 GridwiseElementwiseFunctorB::Run(in_grid_desc_tuple_b,
87 out_grid_desc_tuple_b,
96 template <
typename GridwiseElementwiseFunctorA,
97 typename GridwiseElementwiseFunctorB,
98 typename InAGridDescTuple,
99 typename InBGridDescTuple,
100 typename OutAGridDescTuple,
101 typename OutBGridDescTuple,
102 typename InADataTypePointerTuple,
103 typename InBDataTypePointerTuple,
104 typename OutADataTypePointerTuple,
105 typename OutBDataTypePointerTuple,
106 typename Block2TileMapA,
107 typename Block2TileMapB,
108 typename ElementwiseOperation,
114 #if CK_USE_LAUNCH_BOUNDS
118 const InBGridDescTuple in_grid_desc_tuple_b,
119 const OutAGridDescTuple out_grid_desc_tuple_a,
120 const OutBGridDescTuple out_grid_desc_tuple_b,
121 const InADataTypePointerTuple p_in_global_tuple_a,
122 const InBDataTypePointerTuple p_in_global_tuple_b,
123 const OutADataTypePointerTuple p_out_global_tuple_a,
124 const OutBDataTypePointerTuple p_out_global_tuple_b,
125 const Block2TileMapA block_2_tile_map_a,
126 const Block2TileMapB block_2_tile_map_b,
127 const ElementwiseOperation elementwise_op,
131 const std::array<index_t, NumInputsA> input_batch_strides_a,
132 const std::array<index_t, NumInputsB> input_batch_strides_b,
133 const std::array<index_t, NumOutputsA> output_batch_strides_a,
134 const std::array<index_t, NumOutputsB> output_batch_strides_b)
136 static_assert(InAGridDescTuple::Size() == NumInputsA &&
137 InADataTypePointerTuple::Size() == NumInputsA);
138 static_assert(OutAGridDescTuple::Size() == NumOutputsA &&
139 OutADataTypePointerTuple::Size() == NumOutputsA);
140 static_assert(InBGridDescTuple::Size() == NumInputsB &&
141 InBDataTypePointerTuple::Size() == NumInputsB);
142 static_assert(OutBGridDescTuple::Size() == NumOutputsB &&
143 OutBDataTypePointerTuple::Size() == NumOutputsB);
147 if(block_id < a_grid_size)
149 const index_t num_blocks_per_batch =
150 __builtin_amdgcn_readfirstlane(a_grid_size / batch_count_a);
151 const index_t g_idx = __builtin_amdgcn_readfirstlane(block_id / num_blocks_per_batch);
153 InADataTypePointerTuple p_in_global_with_offset_tuple;
154 OutADataTypePointerTuple p_out_global_with_offset_tuple;
156 static_for<0, InADataTypePointerTuple::Size(), 1>{}([&](
auto i) {
157 p_in_global_with_offset_tuple(i) =
158 p_in_global_tuple_a.At(i) +
159 type_convert<long_index_t>(input_batch_strides_a[i]) * g_idx;
162 static_for<0, OutADataTypePointerTuple::Size(), 1>{}([&](
auto i) {
163 p_out_global_with_offset_tuple(i) =
164 p_out_global_tuple_a.At(i) +
165 type_convert<long_index_t>(output_batch_strides_a[i]) * g_idx;
168 GridwiseElementwiseFunctorA::Run(in_grid_desc_tuple_a,
169 out_grid_desc_tuple_a,
170 p_in_global_with_offset_tuple,
171 p_out_global_with_offset_tuple,
178 const index_t num_blocks_per_batch =
179 __builtin_amdgcn_readfirstlane((
get_grid_size() - a_grid_size) / batch_count_b);
181 __builtin_amdgcn_readfirstlane((block_id - a_grid_size) / num_blocks_per_batch);
183 InBDataTypePointerTuple p_in_global_with_offset_tuple;
184 OutBDataTypePointerTuple p_out_global_with_offset_tuple;
186 static_for<0, InBDataTypePointerTuple::Size(), 1>{}([&](
auto i) {
187 p_in_global_with_offset_tuple(i) =
188 p_in_global_tuple_b.At(i) +
189 type_convert<long_index_t>(input_batch_strides_b[i]) * g_idx;
192 static_for<0, OutBDataTypePointerTuple::Size(), 1>{}([&](
auto i) {
193 p_out_global_with_offset_tuple(i) =
194 p_out_global_tuple_b.At(i) +
195 type_convert<long_index_t>(output_batch_strides_b[i]) * g_idx;
198 GridwiseElementwiseFunctorB::Run(in_grid_desc_tuple_b,
199 out_grid_desc_tuple_b,
200 p_in_global_with_offset_tuple,
201 p_out_global_with_offset_tuple,
204 block_id - a_grid_size);
208 template <
typename GridwiseElementwiseFunctor,
209 typename InGridDescTuple,
210 typename OutGridDescTuple,
211 typename InDataTypePointerTuple,
212 typename OutDataTypePointerTuple,
213 typename Block2TileMap,
214 typename ElementwiseOperation,
218 #if CK_USE_LAUNCH_BOUNDS
222 const OutGridDescTuple out_grid_desc_tuple,
223 const InDataTypePointerTuple p_in_global_tuple,
224 const OutDataTypePointerTuple p_out_global_tuple,
225 const Block2TileMap block_2_tile_map,
226 const ElementwiseOperation elementwise_op,
228 const std::array<index_t, NumInputs> input_batch_strides,
229 const std::array<index_t, NumOutputs> output_batch_strides)
231 static_assert(InGridDescTuple::Size() == NumInputs &&
232 InDataTypePointerTuple::Size() == NumInputs);
233 static_assert(OutGridDescTuple::Size() == NumOutputs &&
234 OutDataTypePointerTuple::Size() == NumOutputs);
236 const index_t num_blocks_per_batch =
237 __builtin_amdgcn_readfirstlane(
get_grid_size() / batch_count);
240 InDataTypePointerTuple p_in_global_with_offset_tuple;
241 OutDataTypePointerTuple p_out_global_with_offset_tuple;
243 static_for<0, InDataTypePointerTuple::Size(), 1>{}([&](
auto i) {
244 p_in_global_with_offset_tuple(i) =
245 p_in_global_tuple.At(i) + type_convert<long_index_t>(input_batch_strides[i]) * g_idx;
248 static_for<0, OutDataTypePointerTuple::Size(), 1>{}([&](
auto i) {
249 p_out_global_with_offset_tuple(i) =
250 p_out_global_tuple.At(i) + type_convert<long_index_t>(output_batch_strides[i]) * g_idx;
253 GridwiseElementwiseFunctor::Run(in_grid_desc_tuple,
255 p_in_global_with_offset_tuple,
256 p_out_global_with_offset_tuple,
261 template <
typename InGridDescTuple,
262 typename OutGridDescTuple,
263 typename InDataTypePointerTuple,
264 typename OutDataTypePointerTuple,
265 typename Block2TileMap,
266 typename ElementwiseOperation,
272 typename ThreadClusterArrangeOrder,
273 typename InScalarPerVectorSeq,
274 typename OutScalarPerVectorSeq,
282 static_assert(
NumInput == InScalarPerVectorSeq::Size() &&
283 NumOutput == OutScalarPerVectorSeq::Size() &&
285 "Tuple size is inconsistent with the number of in/out!");
290 static_assert((SrcVectorDim ==
I0 || SrcVectorDim ==
I1) &&
291 (DstVectorDim ==
I0 || DstVectorDim ==
I1),
292 "Vector dim must be equal to 0 or 1.");
296 __device__
static void Run(
const InGridDescTuple& in_grid_desc_tuple,
297 const OutGridDescTuple& out_grid_desc_tuple,
298 const InDataTypePointerTuple& p_in_global_tuple,
299 const OutDataTypePointerTuple& p_out_global_tuple,
300 const Block2TileMap& block_2_tile_map,
301 const ElementwiseOperation& elementwise_op,
307 using DataTypePointer =
remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
316 using DataTypePointer =
remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
325 return make_dynamic_buffer<AddressSpaceEnum::Global>(
326 p_in_global_tuple[I], in_grid_desc_tuple[I].GetElementSpaceSize());
332 return make_dynamic_buffer<AddressSpaceEnum::Global>(
333 p_out_global_tuple[I], out_grid_desc_tuple[I].GetElementSpaceSize());
337 const auto block_work_idx =
340 const index_t m0_block_data_idx_on_grid =
341 __builtin_amdgcn_readfirstlane(block_work_idx[
I0] * M0PerBlock);
342 const index_t m1_block_data_idx_on_grid =
343 __builtin_amdgcn_readfirstlane(block_work_idx[
I1] * M1PerBlock);
346 return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid);
351 return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid);
361 using SrcDimAccessOrder =
362 std::conditional_t<SrcVectorDim == I1, Sequence<0, 1>,
Sequence<1, 0>>;
363 using DstDimAccessOrder =
364 std::conditional_t<DstVectorDim == I1, Sequence<0, 1>,
Sequence<1, 0>>;
366 using ThreadClusterLengths =
371 ElementwiseOperation,
374 ThreadClusterLengths,
375 ThreadClusterArrangeOrder,
384 InScalarPerVectorSeq,
385 OutScalarPerVectorSeq,
390 input_thread_grid_offset,
392 output_thread_grid_offset,
394 global_to_global_transfer.Run(
395 in_grid_desc_tuple, in_global_buf_tuple, out_grid_desc_tuple, out_global_buf_tuple,
I0);
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:30
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
__global__ void kernel_batched_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op, const index_t batch_count, const std::array< index_t, NumInputs > input_batch_strides, const std::array< index_t, NumOutputs > output_batch_strides)
Definition: gridwise_elementwise_2d.hpp:221
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_grid_size()
Definition: get_id.hpp:60
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
typename remove_pointer< T >::type remove_pointer_t
Definition: type.hpp:300
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:298
__global__ void kernel_elementwise_dual(const InAGridDescTuple in_grid_desc_tuple_a, const InBGridDescTuple in_grid_desc_tuple_b, const OutAGridDescTuple out_grid_desc_tuple_a, const OutBGridDescTuple out_grid_desc_tuple_b, const InADataTypePointerTuple p_in_global_tuple_a, const InBDataTypePointerTuple p_in_global_tuple_b, const OutADataTypePointerTuple p_out_global_tuple_a, const OutBDataTypePointerTuple p_out_global_tuple_b, const Block2TileMapA block_2_tile_map_a, const Block2TileMapB block_2_tile_map_b, const ElementwiseOperation elementwise_op, const index_t a_grid_size)
Definition: gridwise_elementwise_2d.hpp:61
__global__ void kernel_elementwise_batched_dual(const InAGridDescTuple in_grid_desc_tuple_a, const InBGridDescTuple in_grid_desc_tuple_b, const OutAGridDescTuple out_grid_desc_tuple_a, const OutBGridDescTuple out_grid_desc_tuple_b, const InADataTypePointerTuple p_in_global_tuple_a, const InBDataTypePointerTuple p_in_global_tuple_b, const OutADataTypePointerTuple p_out_global_tuple_a, const OutBDataTypePointerTuple p_out_global_tuple_b, const Block2TileMapA block_2_tile_map_a, const Block2TileMapB block_2_tile_map_b, const ElementwiseOperation elementwise_op, const index_t a_grid_size, const index_t batch_count_a, const index_t batch_count_b, const std::array< index_t, NumInputsA > input_batch_strides_a, const std::array< index_t, NumInputsB > input_batch_strides_b, const std::array< index_t, NumOutputsA > output_batch_strides_a, const std::array< index_t, NumOutputsB > output_batch_strides_b)
Definition: gridwise_elementwise_2d.hpp:117
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:295
__global__ void kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op)
Definition: gridwise_elementwise_2d.hpp:29
Definition: gridwise_elementwise_2d.hpp:278
static constexpr index_t NumInput
Definition: gridwise_elementwise_2d.hpp:279
static constexpr auto I1
Definition: gridwise_elementwise_2d.hpp:288
static __device__ void Run(const InGridDescTuple &in_grid_desc_tuple, const OutGridDescTuple &out_grid_desc_tuple, const InDataTypePointerTuple &p_in_global_tuple, const OutDataTypePointerTuple &p_out_global_tuple, const Block2TileMap &block_2_tile_map, const ElementwiseOperation &elementwise_op, const index_t block_id=get_block_1d_id())
Definition: gridwise_elementwise_2d.hpp:296
static constexpr auto I0
Definition: gridwise_elementwise_2d.hpp:287
static constexpr index_t NumOutput
Definition: gridwise_elementwise_2d.hpp:280
Definition: sequence.hpp:43
Definition: thread_group.hpp:12
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r2.hpp:45
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:334