13 template <
typename GridwiseElementwise1dFunctor,
14 typename InGrid1dDescTuple,
15 typename OutGrid1dDescTuple,
16 typename InDataTypePointerTuple,
17 typename OutDataTypePointerTuple,
18 typename ElementwiseOperation,
19 typename UnaryOperation,
22 const OutGrid1dDescTuple out_grid_1d_desc_tuple,
23 const InDataTypePointerTuple p_in_global_tuple,
24 const OutDataTypePointerTuple p_out_global_tuple,
25 const ElementwiseOperation elementwise_op,
26 const UnaryOperation unary_op,
29 GridwiseElementwise1dFunctor::Run(in_grid_1d_desc_tuple,
30 out_grid_1d_desc_tuple,
38 template <
typename InGrid1dDescTuple,
39 typename OutGrid1dDescTuple,
40 typename InDataTypePointerTuple,
41 typename OutDataTypePointerTuple,
42 typename ElementwiseOperation,
43 typename UnaryOperation,
46 typename InScalarPerVectorSeq,
47 typename OutScalarPerVectorSeq>
53 static_assert(
NumInput == InScalarPerVectorSeq::Size() &&
54 NumOutput == OutScalarPerVectorSeq::Size() &&
55 NumInput == InGrid1dDescTuple::Size() &&
57 "Tuple size is inconsistent with the number of in/out!");
66 __device__
static void Run(
const InGrid1dDescTuple in_grid_1d_desc_tuple,
67 const OutGrid1dDescTuple out_grid_1d_desc_tuple,
68 const InDataTypePointerTuple p_in_global_tuple,
69 const OutDataTypePointerTuple p_out_global_tuple,
70 const ElementwiseOperation elementwise_op,
71 const UnaryOperation unary_op,
78 using DataTypePointer =
remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
87 using DataTypePointer =
remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
96 static_assert(in_grid_1d_desc_tuple[I].GetNumOfDimension() == 1);
98 return make_dynamic_buffer<AddressSpaceEnum::Global>(
99 p_in_global_tuple[I], in_grid_1d_desc_tuple[I].GetElementSpaceSize());
105 static_assert(out_grid_1d_desc_tuple[I].GetNumOfDimension() == 1);
107 return make_dynamic_buffer<AddressSpaceEnum::Global>(
108 p_out_global_tuple[I], out_grid_1d_desc_tuple[I].GetElementSpaceSize());
112 const auto thread_global_offset =
make_multi_index(thread_global_id * MPerThread);
116 const auto M = in_grid_1d_desc_tuple[
I0].GetLength(
I0);
117 const index_t loop_step = blockPerGrid * blockSize * MPerThread;
122 using DataTypePointer =
remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
127 decltype(in_grid_1d_desc_tuple[I]),
132 InScalarPerVectorSeq::At(
135 false>{in_grid_1d_desc_tuple[I],
136 thread_global_offset};
142 using DataTypePointer =
remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
148 decltype(out_grid_1d_desc_tuple[I]),
153 OutScalarPerVectorSeq::At(I),
157 out_grid_1d_desc_tuple[I], thread_global_offset,
PassThroughOp{});
161 index_t num_iter = M / (loop_step);
165 in_global_load_tuple(I).Run(in_grid_1d_desc_tuple[I],
166 in_global_buf_tuple[I],
169 in_thread_buf_tuple(I));
171 in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_1d_desc_tuple[I],
179 [&](
auto I) ->
auto& {
return in_thread_buf_tuple(I)(iM); },
185 [&](
auto I) ->
auto& {
return out_thread_buf_tuple(I)(iM); },
188 unpack2(unary_op, uop_data_refs, uop_data_refs);
192 [&](
auto I) ->
auto& {
return in_thread_buf_tuple(I)(iM); },
197 [&](
auto I) ->
auto& {
return in_thread_buf_tuple(I)(iM); },
200 unpack2(scale_op, sop_out_data_refs, sop_in_data_refs);
204 [&](
auto I) ->
const auto& {
return in_thread_buf_tuple(I)(iM); },
207 unpack2(elementwise_op, out_data_refs, in_data_refs);
213 out_thread_buf_tuple[I],
214 out_grid_1d_desc_tuple[I],
215 out_global_buf_tuple(I));
217 out_global_store_tuple(I).MoveDstSliceWindow(out_grid_1d_desc_tuple[I],
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto unpack2(F &&f, X &&x, Y &&y)
Definition: functional4.hpp:55
__device__ index_t get_grid_size()
Definition: get_id.hpp:60
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
typename remove_pointer< T >::type remove_pointer_t
Definition: type.hpp:300
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__device__ index_t get_block_size()
Definition: get_id.hpp:62
__global__ void kernel_elementwise_1d(const InGrid1dDescTuple in_grid_1d_desc_tuple, const OutGrid1dDescTuple out_grid_1d_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const ElementwiseOperation elementwise_op, const UnaryOperation unary_op, const Scale scale_op)
Definition: gridwise_elementwise_1d_scale.hpp:21
__device__ index_t get_thread_global_1d_id()
Definition: get_id.hpp:54
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:298
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:295
Definition: gridwise_elementwise_1d_scale.hpp:49
tensor_operation::element_wise::PassThrough PassThroughOp
Definition: gridwise_elementwise_1d_scale.hpp:64
static constexpr index_t NumOutput
Definition: gridwise_elementwise_1d_scale.hpp:51
static __device__ void Run(const InGrid1dDescTuple in_grid_1d_desc_tuple, const OutGrid1dDescTuple out_grid_1d_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const ElementwiseOperation elementwise_op, const UnaryOperation unary_op, const Scale scale_op)
Definition: gridwise_elementwise_1d_scale.hpp:66
static constexpr auto thread_buffer_desc_m
Definition: gridwise_elementwise_1d_scale.hpp:61
static constexpr auto I0
Definition: gridwise_elementwise_1d_scale.hpp:59
static constexpr index_t NumInput
Definition: gridwise_elementwise_1d_scale.hpp:50
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:334