/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise_multi_d.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise_multi_d.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise_multi_d.hpp Source File
device_reduce_threadwise_multi_d.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 #include <array>
9 
14 
16 
17 namespace ck {
18 namespace tensor_operation {
19 namespace device {
20 
21 template <typename InDataType,
22  typename DsDataType,
23  typename AccDataType,
24  typename OutDataType,
25  index_t Rank,
26  index_t NumReduceDim,
27  typename ReduceOperation,
28  typename InElementwiseOperation,
29  typename OutElementwiseOperation,
30  index_t BlockSize,
31  index_t MThreadSliceSize,
32  index_t KThreadSliceSize,
33  index_t InSrcVectorDim,
34  index_t InSrcVectorSize,
35  index_t OutDstVectorSize,
36  typename DsVectorSizeSequence>
38  DsDataType,
39  AccDataType,
40  OutDataType,
41  Rank,
42  NumReduceDim,
43  ReduceOperation,
44  InElementwiseOperation,
45  OutElementwiseOperation>
46 
47 {
48  static_assert(Rank <= 12, "Bigger Rank size is not supported!");
49 
50  static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
51  (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
52  (MThreadSliceSize % OutDstVectorSize == 0),
53  "Invalid thread slice sizes and/or vector sizes configuration, please check!");
54 
56 
57  static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
58 
59  static constexpr index_t NumDTensor = DsDataType::Size();
60 
61  static constexpr index_t NumSrcDim = Rank;
62  static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
63  static constexpr bool reduceAllDim = (NumInvariantDim == 0);
64 
65  static constexpr index_t M_BlockTileSize = BlockSize * MThreadSliceSize;
66  static constexpr index_t K_BlockTileSize = 1 * KThreadSliceSize;
67 
68  static auto MakeSrc2dDescriptor(const std::array<index_t, Rank>& inLengths,
69  const std::array<index_t, Rank>& inStrides)
70  {
71  const auto tupleSrcLengths =
72  generate_tuple([&](auto I) { return inLengths[I]; }, Number<Rank>{});
73  const auto tupleSrcStrides =
74  generate_tuple([&](auto I) { return inStrides[I]; }, Number<Rank>{});
75 
76  const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
77 
78  const auto in_grid_desc_m_k = [&]() {
79  if constexpr(reduceAllDim)
80  {
81  const auto one_dim_inDesc = transform_tensor_descriptor(
82  inDesc,
83  make_tuple(make_merge_transform(tupleSrcLengths)),
86 
87  return transform_tensor_descriptor(one_dim_inDesc,
89  1, one_dim_inDesc.GetLength(Number<0>{})))),
92  }
93  else
94  {
95  using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
97 
98  const auto reduceDimLengths = generate_tuple(
99  [&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
100  const auto invariantDimLengths =
101  generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
102 
104  inDesc,
105  make_tuple(make_merge_transform(invariantDimLengths),
106  make_merge_transform(reduceDimLengths)),
107  make_tuple(InvariantDims{}, ReduceDims{}),
109  }
110  }();
111 
112  const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
113  const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
114 
115  const auto inPad_M =
116  math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
117  const auto inPad_K =
118  math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength;
119 
120  auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
121  in_grid_desc_m_k,
122  make_tuple(make_right_pad_transform(invariantLength, inPad_M),
123  make_right_pad_transform(reduceLength, inPad_K)),
126 
127  return (in_grid_desc_m_k_padded);
128  };
129 
130  static auto MakeDst1dDescriptor(const std::array<index_t, NumDstDim>& outLengths,
131  const std::array<index_t, NumDstDim>& outStrides)
132  {
133  const auto tupleDstLengths =
134  generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumDstDim>{});
135  const auto tupleDstStrides =
136  generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumDstDim>{});
137 
138  auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
139 
140  auto out_grid_desc_m = transform_tensor_descriptor(
141  outDesc,
142  make_tuple(make_merge_transform(tupleDstLengths)),
145 
146  const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
147 
148  const auto outPad =
149  math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
150 
151  auto out_grid_desc_m_padded = transform_tensor_descriptor(
152  out_grid_desc_m,
153  make_tuple(make_right_pad_transform(invariantLength, outPad)),
156  return (out_grid_desc_m_padded);
157  };
158 
159  static auto
160  MakeDsDescriptor(const std::array<std::array<index_t, NumDstDim>, NumDTensor> DsLengths,
161  std::array<std::array<index_t, NumDstDim>, NumDTensor> DsStrides)
162  {
163  return generate_tuple(
164  [&](auto i) {
166  DsStrides[i]);
167  },
169  }
170 
171  using InGridDesc_M_K = decltype(MakeSrc2dDescriptor({}, {}));
172  using OutGridDesc_M = decltype(MakeDst1dDescriptor({}, {}));
173  using DsGridDesc_M = decltype(MakeDsDescriptor({}, {}));
174 
177  DsDataType,
178  OutDataType,
179  AccDataType,
181  DsGridDesc_M,
183  ReduceOperation,
184  InElementwiseOperation,
185  OutElementwiseOperation,
187  BlockSize,
188  MThreadSliceSize,
189  KThreadSliceSize,
190  InSrcVectorDim,
191  InSrcVectorSize,
192  OutDstVectorSize,
193  DsVectorSizeSequence>;
194 
196 
197  struct Argument : public BaseArgument
198  {
199  Argument(const std::array<index_t, Rank> inLengths,
200  const std::array<index_t, Rank> inStrides,
201  const std::array<std::array<index_t, NumDstDim>, NumDTensor> DsLengths,
202  const std::array<std::array<index_t, NumDstDim>, NumDTensor> DsStrides,
203  const std::array<index_t, NumDstDim> outLengths,
204  const std::array<index_t, NumDstDim> outStrides,
205  const std::array<int, NumReduceDim> reduceDims,
206  const InDataType* in_dev,
207  const std::array<const void*, NumDTensor> ds_dev,
208  OutDataType* out_dev,
209  const InElementwiseOperation in_elementwise_op,
210  const OutElementwiseOperation out_elementwise_op)
211  : DsLengths_{DsLengths},
212  DsStrides_{DsStrides},
213  outLengths_{outLengths},
214  outStrides_{outStrides},
215  in_dev_{in_dev},
216  out_dev_{out_dev},
217  in_elementwise_op_{in_elementwise_op},
218  out_elementwise_op_{out_elementwise_op}
219  {
220  inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
221  inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
222 
224  get_2d_lengths<Rank, NumReduceDim>(inLengths_);
225 
226  if constexpr(NumInvariantDim == 0)
228  else
230 
231  reduce_lowest_length = inLengths_[Rank - 1];
232 
234 
237 
238  static_for<0, NumDTensor, 1>{}([&](auto i) {
239  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
240  p_ds_grid_(i) = static_cast<const DDataType*>(ds_dev[i]);
241  });
242 
243  ds_grid_desc_m_ = MakeDsDescriptor(DsLengths, DsStrides);
244  }
245 
246  std::array<index_t, Rank> inLengths_;
247  std::array<index_t, Rank> inStrides_;
248 
249  std::array<std::array<index_t, NumDstDim>, NumDTensor> DsLengths_;
250  std::array<std::array<index_t, NumDstDim>, NumDTensor> DsStrides_;
251 
252  std::array<index_t, NumDstDim> outLengths_;
253  std::array<index_t, NumDstDim> outStrides_;
254 
255  const InDataType* in_dev_;
256  OutDataType* out_dev_;
257 
259 
260  InElementwiseOperation in_elementwise_op_;
261  OutElementwiseOperation out_elementwise_op_;
262 
264 
269 
271  size_t gridSize;
272  };
273 
274  struct Invoker : public BaseInvoker
275  {
276  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
277  {
278  const auto in_grid_desc_m_k =
280  const auto out_grid_desc_m =
282 
283  float avg_time = 0;
284 
286  InDataType,
287  OutDataType,
288  AccDataType,
290  DsGridDesc_M,
292  InElementwiseOperation,
293  OutElementwiseOperation,
294  DsGridPointer>;
295 
296  avg_time = launch_and_time_kernel(stream_config,
297  kernel,
298  dim3(arg.gridSize),
299  dim3(BlockSize),
300  0,
301  in_grid_desc_m_k,
302  arg.ds_grid_desc_m_,
303  out_grid_desc_m,
304  arg.in_elementwise_op_,
306  arg.in_dev_,
307  arg.p_ds_grid_,
308  arg.out_dev_);
309 
310  return (avg_time);
311  };
312 
313  float Run(const BaseArgument* p_arg,
314  const StreamConfig& stream_config = StreamConfig{}) override
315  {
316  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
317  };
318  };
319 
320  bool IsSupportedArgument(const BaseArgument* p_arg) override
321  {
322  const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
323 
324  if constexpr(InSrcVectorDim == 0)
325  {
326  if constexpr(NumInvariantDim == 0)
327  {
328  return (false);
329  }
330  else
331  {
332  if(pArg->inStrides_[NumInvariantDim - 1] != 1)
333  return (false);
334 
335  if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
336  return (false);
337  };
338  }
339  else
340  {
341  if(pArg->inStrides_[Rank - 1] != 1)
342  return (false);
343 
344  if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
345  return (false);
346  };
347 
348  // To improve
349  if(pArg->invariant_lowest_length % OutDstVectorSize != 0)
350  return (false);
351 
352  std::cerr << "reduce_total_length = " << pArg->reduce_total_length
353  << " KThreadSliceSize = " << KThreadSliceSize << std::endl;
354 
355  // cases with big reduce_total_length should be handled by Blockwise kernel
356  if(pArg->reduce_total_length / KThreadSliceSize >= 32)
357  return (false);
358 
359  return (true);
360  };
361 
362  std::unique_ptr<BaseArgument>
363  MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
364  const std::array<index_t, Rank> inStrides,
365  const std::array<std::array<index_t, NumDstDim>, NumDTensor> DsLengths,
366  const std::array<std::array<index_t, NumDstDim>, NumDTensor> DsStrides,
367  const std::array<index_t, NumDstDim> outLengths,
368  const std::array<index_t, NumDstDim> outStrides,
369  const std::array<int, NumReduceDim> reduceDims,
370  const void* in_dev,
371  const std::array<const void*, NumDTensor> ds_dev,
372  void* out_dev,
373  const InElementwiseOperation in_elementwise_op,
374  const OutElementwiseOperation out_elementwise_op) override
375  {
376  return std::make_unique<Argument>(inLengths,
377  inStrides,
378  DsLengths,
379  DsStrides,
380  outLengths,
381  outStrides,
382  reduceDims,
383  static_cast<const InDataType*>(in_dev),
384  ds_dev,
385  static_cast<OutDataType*>(out_dev),
386  in_elementwise_op,
387  out_elementwise_op);
388  };
389 
390  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
391  {
392  return std::make_unique<Invoker>();
393  };
394 
395  std::string GetTypeString() const override
396  {
397  auto str = std::stringstream();
398 
399  // clang-format off
400  str << "DeviceReduceThreadWiseMultiD<" << BlockSize << ",";
401  str << "M_C" << BlockSize << "_S" << MThreadSliceSize << ",";
402  str << "K_C" << 1 << "_S" << KThreadSliceSize << ",";
403  str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
404  // clang-format on
405 
406  return str.str();
407  }
408 };
409 
410 } // namespace device
411 } // namespace tensor_operation
412 } // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
Definition: ck.hpp:267
__global__ void kernel_reduce_threadwise_multi_d(const InGridDesc_M_K in_grid_desc_m_k, const DsGridDesc_M ds_grid_desc_m, const OutGridDesc_M out_grid_desc_m, const InElementwiseOperation in_elementwise_op, const OutElementwiseOperation out_elementwise_op, const InDataType *const __restrict__ p_in_value_global, const DsGridPointer p_ds_value_global, OutDataType *const __restrict__ p_out_value_global)
Definition: gridwise_2d_reduction_threadwise_multi_d.hpp:28
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
int64_t long_index_t
Definition: ck.hpp:299
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
signed int int32_t
Definition: stdint.h:123
Definition: stream_config.hpp:10
Definition: gridwise_2d_reduction_threadwise_multi_d.hpp:66
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_2d_reduction_threadwise_multi_d.hpp:98
Definition: sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:271
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: device_base.hpp:62
Definition: device_reduce_multi_d.hpp:26
Definition: device_reduce_threadwise_multi_d.hpp:198
long_index_t invariant_total_length
Definition: device_reduce_threadwise_multi_d.hpp:267
OutElementwiseOperation out_elementwise_op_
Definition: device_reduce_threadwise_multi_d.hpp:261
size_t gridSize
Definition: device_reduce_threadwise_multi_d.hpp:271
std::array< index_t, NumDstDim > outLengths_
Definition: device_reduce_threadwise_multi_d.hpp:252
const InDataType * in_dev_
Definition: device_reduce_threadwise_multi_d.hpp:255
std::array< std::array< index_t, NumDstDim >, NumDTensor > DsStrides_
Definition: device_reduce_threadwise_multi_d.hpp:250
DsGridPointer p_ds_grid_
Definition: device_reduce_threadwise_multi_d.hpp:258
index_t invariant_lowest_length
Definition: device_reduce_threadwise_multi_d.hpp:265
std::array< index_t, Rank > inStrides_
Definition: device_reduce_threadwise_multi_d.hpp:247
index_t reduce_lowest_length
Definition: device_reduce_threadwise_multi_d.hpp:266
int numBlockTileIteration
Definition: device_reduce_threadwise_multi_d.hpp:270
std::array< index_t, Rank > inLengths_
Definition: device_reduce_threadwise_multi_d.hpp:246
InElementwiseOperation in_elementwise_op_
Definition: device_reduce_threadwise_multi_d.hpp:260
OutDataType * out_dev_
Definition: device_reduce_threadwise_multi_d.hpp:256
std::array< std::array< index_t, NumDstDim >, NumDTensor > DsLengths_
Definition: device_reduce_threadwise_multi_d.hpp:249
long_index_t reduce_total_length
Definition: device_reduce_threadwise_multi_d.hpp:268
std::array< index_t, NumDstDim > outStrides_
Definition: device_reduce_threadwise_multi_d.hpp:253
DsGridDesc_M ds_grid_desc_m_
Definition: device_reduce_threadwise_multi_d.hpp:263
Argument(const std::array< index_t, Rank > inLengths, const std::array< index_t, Rank > inStrides, const std::array< std::array< index_t, NumDstDim >, NumDTensor > DsLengths, const std::array< std::array< index_t, NumDstDim >, NumDTensor > DsStrides, const std::array< index_t, NumDstDim > outLengths, const std::array< index_t, NumDstDim > outStrides, const std::array< int, NumReduceDim > reduceDims, const InDataType *in_dev, const std::array< const void *, NumDTensor > ds_dev, OutDataType *out_dev, const InElementwiseOperation in_elementwise_op, const OutElementwiseOperation out_elementwise_op)
Definition: device_reduce_threadwise_multi_d.hpp:199
Definition: device_reduce_threadwise_multi_d.hpp:275
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_reduce_threadwise_multi_d.hpp:276
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_reduce_threadwise_multi_d.hpp:313
Definition: device_reduce_threadwise_multi_d.hpp:47
static auto MakeSrc2dDescriptor(const std::array< index_t, Rank > &inLengths, const std::array< index_t, Rank > &inStrides)
Definition: device_reduce_threadwise_multi_d.hpp:68
std::string GetTypeString() const override
Definition: device_reduce_threadwise_multi_d.hpp:395
int32_t IndexDataType
Definition: device_reduce_threadwise_multi_d.hpp:55
static constexpr index_t K_BlockTileSize
Definition: device_reduce_threadwise_multi_d.hpp:66
decltype(MakeDsDescriptor({}, {})) DsGridDesc_M
Definition: device_reduce_threadwise_multi_d.hpp:173
static constexpr index_t NumInvariantDim
Definition: device_reduce_threadwise_multi_d.hpp:57
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::array< index_t, Rank > inLengths, const std::array< index_t, Rank > inStrides, const std::array< std::array< index_t, NumDstDim >, NumDTensor > DsLengths, const std::array< std::array< index_t, NumDstDim >, NumDTensor > DsStrides, const std::array< index_t, NumDstDim > outLengths, const std::array< index_t, NumDstDim > outStrides, const std::array< int, NumReduceDim > reduceDims, const void *in_dev, const std::array< const void *, NumDTensor > ds_dev, void *out_dev, const InElementwiseOperation in_elementwise_op, const OutElementwiseOperation out_elementwise_op) override
Definition: device_reduce_threadwise_multi_d.hpp:363
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_reduce_threadwise_multi_d.hpp:320
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_reduce_threadwise_multi_d.hpp:390
static constexpr index_t NumSrcDim
Definition: device_reduce_threadwise_multi_d.hpp:61
static constexpr index_t NumDstDim
Definition: device_reduce_threadwise_multi_d.hpp:62
static auto MakeDst1dDescriptor(const std::array< index_t, NumDstDim > &outLengths, const std::array< index_t, NumDstDim > &outStrides)
Definition: device_reduce_threadwise_multi_d.hpp:130
decltype(MakeDst1dDescriptor({}, {})) OutGridDesc_M
Definition: device_reduce_threadwise_multi_d.hpp:172
decltype(MakeSrc2dDescriptor({}, {})) InGridDesc_M_K
Definition: device_reduce_threadwise_multi_d.hpp:171
GridwiseReduction_mk_to_m_threadwise_multi_d< InDataType, DsDataType, OutDataType, AccDataType, InGridDesc_M_K, DsGridDesc_M, OutGridDesc_M, ReduceOperation, InElementwiseOperation, OutElementwiseOperation, InMemoryDataOperationEnum::Set, BlockSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize, DsVectorSizeSequence > GridwiseReduce
Definition: device_reduce_threadwise_multi_d.hpp:193
static auto MakeDsDescriptor(const std::array< std::array< index_t, NumDstDim >, NumDTensor > DsLengths, std::array< std::array< index_t, NumDstDim >, NumDTensor > DsStrides)
Definition: device_reduce_threadwise_multi_d.hpp:160
static constexpr bool reduceAllDim
Definition: device_reduce_threadwise_multi_d.hpp:63
static constexpr index_t NumDTensor
Definition: device_reduce_threadwise_multi_d.hpp:59
static constexpr index_t M_BlockTileSize
Definition: device_reduce_threadwise_multi_d.hpp:65
typename GridwiseReduce::DsGridPointer DsGridPointer
Definition: device_reduce_threadwise_multi_d.hpp:195