/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp Source File
device_multiple_reduce_multiblock.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
11 
17 
19 
20 namespace ck {
21 namespace tensor_operation {
22 namespace device {
23 
24 template <index_t NumReduction,
25  typename InDataType,
26  typename AccDataType,
27  typename OutDataTypeTuple,
28  index_t Rank,
29  index_t NumReduceDim,
30  typename ReduceOperation,
31  typename InElementwiseOperationTuple,
32  typename AccElementwiseOperationTuple,
33  InMemoryDataOperationEnum OutMemoryDataOperation,
34  bool PropagateNan,
35  index_t BlockSize,
36  index_t MThreadClusterSize,
37  index_t KThreadClusterSize,
38  index_t MThreadSliceSize,
39  index_t KThreadSliceSize,
40  index_t InSrcVectorDim,
41  index_t InSrcVectorSize,
42  typename OutDstVectorSizeSeq>
44  NumReduceDim,
45  NumReduction,
46  InElementwiseOperationTuple,
47  AccElementwiseOperationTuple>
48 {
49  static_assert(Rank <= 6, "Bigger Rank size is not supported!");
50  static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
51  "Invalid thread cluster size assignments!");
52 
53  static_assert((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
54  (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0),
55  "Invalid thread slice sizes and/or vector sizes configuration, please check!");
56 
57  static_assert(NumReduction == OutDataTypeTuple::Size() &&
58  NumReduction == InElementwiseOperationTuple::Size() &&
59  NumReduction == AccElementwiseOperationTuple::Size() &&
60  NumReduction == OutDstVectorSizeSeq::Size(),
61  "All tuple should have the same size as the number of Reductions!");
62 
63  static_assert(sequence_all_of(OutDstVectorSizeSeq{},
64  [](auto vectorSize) {
65  return (MThreadSliceSize % vectorSize == 0);
66  }),
67  "The OutDstVectorSize should completely divide the MThreadSliceSize!");
68 
69  static constexpr bool CheckDataTypeTuple()
70  {
71  bool flag = true;
72 
73  static_for<0, NumReduction, 1>{}([&](auto I) {
74  using OutDataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
75  flag =
76  flag && ck::reduce::InMemoryDataOperationSupportedOnDataType<OutMemoryDataOperation,
77  OutDataType>::value;
78  });
79 
80  return flag;
81  };
82 
83  static_assert(CheckDataTypeTuple(),
84  "The OutDataType must support the specified OutMemoryDataOperation!");
85 
86  static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
87 
88  static constexpr index_t NumInputDim = Rank;
89  static constexpr index_t NumOutputDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
90  static constexpr bool reduceAllDim = (NumInvariantDim == 0);
91 
92  // So far, only AtomicAdd is considered, other Atomic Operation like AtomicMax can be added
93  // later
94  static constexpr bool use_multiblock =
95  (OutMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd);
96 
97  static_assert(
98  ReduceOperation::IsCompatibleInMemoryDataOperation(OutMemoryDataOperation),
99  "The reduction accumulation operation must be compatible with the OutMemoryDataOperation!");
100 
101  static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
102  static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
103 
105  {
106  return generate_tuple(
107  [&](auto I) {
108  using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
109 
110  return static_cast<DataType*>(nullptr);
111  },
113  };
114 
116 
117  static auto MakeSrc2dDescriptor(const std::array<index_t, NumInputDim>& inLengths,
118  const std::array<index_t, NumInputDim>& inStrides,
119  int blkGroupSize,
120  int numBlockTileIteration)
121  {
122  const auto tupleSrcLengths =
123  generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInputDim>{});
124  const auto tupleSrcStrides =
125  generate_tuple([&](auto I) { return inStrides[I]; }, Number<NumInputDim>{});
126 
127  const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
128 
129  const auto in_grid_desc_m_k = [&]() {
130  if constexpr(reduceAllDim)
131  {
132  const auto one_dim_inDesc = transform_tensor_descriptor(
133  inDesc,
134  make_tuple(make_merge_transform(tupleSrcLengths)),
137 
138  return transform_tensor_descriptor(one_dim_inDesc,
140  1, one_dim_inDesc.GetLength(Number<0>{})))),
143  }
144  else
145  {
146  using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
148 
149  const auto reduceDimLengths = generate_tuple(
150  [&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
151  const auto invariantDimLengths =
152  generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
153 
155  inDesc,
156  make_tuple(make_merge_transform(invariantDimLengths),
157  make_merge_transform(reduceDimLengths)),
158  make_tuple(InvariantDims{}, ReduceDims{}),
160  }
161  }();
162 
163  const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
164  const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
165 
166  const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration;
167  const auto inPad_M =
168  math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
169  const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
170 
171  auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
172  in_grid_desc_m_k,
173  make_tuple(make_right_pad_transform(invariantLength, inPad_M),
174  make_right_pad_transform(reduceLength, inPad_K)),
177 
178  return (in_grid_desc_m_k_padded);
179  };
180 
181  static auto MakeDst1dDescriptor(const std::array<index_t, NumOutputDim>& outLengths,
182  const std::array<index_t, NumOutputDim>& outStrides)
183  {
184  const auto tupleDstLengths =
185  generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumOutputDim>{});
186  const auto tupleDstStrides =
187  generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumOutputDim>{});
188 
189  auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
190 
191  auto out_grid_desc_m = transform_tensor_descriptor(
192  outDesc,
193  make_tuple(make_merge_transform(tupleDstLengths)),
196 
197  const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
198 
199  const auto outPad =
200  math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
201 
202  auto out_grid_desc_m_padded = transform_tensor_descriptor(
203  out_grid_desc_m,
204  make_tuple(make_right_pad_transform(invariantLength, outPad)),
207  return (out_grid_desc_m_padded);
208  };
209 
211  {
212  return generate_tuple(
213  [&](auto I) {
214  (void)I;
215  return MakeDst1dDescriptor(std::array<index_t, NumOutputDim>{},
216  std::array<index_t, NumOutputDim>{});
217  },
219  };
220 
222  std::array<index_t, NumInputDim>{}, std::array<index_t, NumInputDim>{}, 1, 1));
224 
225  static auto MakeDst1dDescriptorForBufferSet(const std::array<index_t, NumOutputDim>& outLengths,
226  const std::array<index_t, NumOutputDim>& outStrides)
227  {
228  const auto tupleDstLengths =
229  generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumOutputDim>{});
230  const auto tupleDstStrides =
231  generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumOutputDim>{});
232 
233  auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
234 
235  auto out_grid_desc_m = transform_tensor_descriptor(
236  outDesc,
237  make_tuple(make_merge_transform(tupleDstLengths)),
240 
241  const auto length = out_grid_desc_m.GetLength(Number<0>{});
242 
243  const auto pad = math::integer_least_multiple(length, BlockSize) - length;
244 
245  auto out_grid_desc_m_padded =
246  transform_tensor_descriptor(out_grid_desc_m,
250  return (out_grid_desc_m_padded);
251  };
252 
254  {
255  return generate_tuple(
256  [&](auto I) {
257  (void)I;
258  return MakeDst1dDescriptorForBufferSet(std::array<index_t, NumOutputDim>{},
259  std::array<index_t, NumOutputDim>{});
260  },
262  };
263 
265 
266  struct Argument : public BaseArgument
267  {
268  Argument(const std::array<index_t, NumInputDim>& inLengths,
269  const std::array<index_t, NumInputDim>& inStrides,
270  const std::array<index_t, NumOutputDim>& outLengths,
271  const std::array<std::array<index_t, NumOutputDim>, NumReduction>& outStridesArray,
272  const std::array<int, NumReduceDim>& reduceDims,
273  const std::array<double, NumReduction>& alphas,
274  const std::array<double, NumReduction>& betas,
275  const void* in_dev,
276  const std::array<void*, NumReduction>& out_dev_buffers,
277  const InElementwiseOperationTuple in_elementwise_op_tuple,
278  const AccElementwiseOperationTuple acc_elementwise_op_tuple)
279  : outLengths_{outLengths},
280  outStridesArray_{outStridesArray},
281  in_elementwise_op_tuple_{in_elementwise_op_tuple},
282  acc_elementwise_op_tuple_{acc_elementwise_op_tuple}
283  {
284  inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
285  inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
286 
287  for(size_t i = 0; i < NumReduction; i++)
288  {
289  alpha_values_(i) = static_cast<AccDataType>(alphas[i]);
290  beta_values_(i) = static_cast<AccDataType>(betas[i]);
291  };
292 
293  in_dev_ = static_cast<const InDataType*>(in_dev);
294 
296  [&](auto iR) {
297  using OutDataTypePointer =
298  remove_cvref_t<decltype(OutDataTypePointerTuple{}[iR])>;
300  return static_cast<OutDataType*>(out_dev_buffers[iR]);
301  },
303 
305  get_2d_lengths<Rank, NumReduceDim>(inLengths_);
306 
307  if constexpr(use_multiblock)
308  {
309 
310  int iterations = 1;
311  while(true)
312  {
313  int testBlkGroupSize =
314  (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
315  (K_BlockTileSize * iterations);
316 
317  // we want the blkGroupSize be not more than 128
318  if(testBlkGroupSize <= 128)
319  break;
320 
321  iterations++;
322  };
323 
324  blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
325  (K_BlockTileSize * iterations);
326 
327  numBlockTileIteration = iterations;
328  }
329  else
330  {
331  blkGroupSize = 1;
334  };
335 
338 
340  [&](auto I) { return MakeDst1dDescriptor(outLengths, outStridesArray[I]); },
342 
344  [&](auto I) {
345  return MakeDst1dDescriptorForBufferSet(outLengths, outStridesArray[I]);
346  },
348 
351 
352  gridSize_pre =
354  }
355 
356  std::array<index_t, NumInputDim> inLengths_;
357  std::array<index_t, NumInputDim> inStrides_;
358 
359  std::array<index_t, NumOutputDim> outLengths_;
360  std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray_;
361 
364 
365  const InDataType* in_dev_;
367 
371 
372  InElementwiseOperationTuple in_elementwise_op_tuple_;
373  AccElementwiseOperationTuple acc_elementwise_op_tuple_;
374 
377 
380  size_t gridSize;
381 
382  size_t gridSize_pre;
383  };
384 
385  struct Invoker : public BaseInvoker
386  {
387  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
388  {
389  using GridwiseMultipleReduce =
391  InDataType,
393  AccDataType,
396  ReduceOperation,
397  InElementwiseOperationTuple,
398  AccElementwiseOperationTuple,
399  OutMemoryDataOperation,
400  PropagateNan,
401  BlockSize,
402  MThreadClusterSize,
403  KThreadClusterSize,
404  MThreadSliceSize,
405  KThreadSliceSize,
406  InSrcVectorDim,
407  InSrcVectorSize,
408  OutDstVectorSizeSeq>;
409 
410  const auto kernel_main =
411  kernel_multiple_reduce_multiblock<GridwiseMultipleReduce,
412  NumReduction,
413  InDataType,
415  AccDataType,
418  InElementwiseOperationTuple,
419  AccElementwiseOperationTuple>;
420 
421  float avg_time = 0;
422 
423  if constexpr(use_multiblock)
424  {
425  auto identity_values = generate_tuple(
426  [&](auto iR) {
427  using OutDataType = remove_cvref_t<decltype(OutDataTypeTuple{}[iR])>;
428  return ck::reduce::GetIdentityValueForInMemoryDataOperation<OutDataType>(
429  OutMemoryDataOperation);
430  },
432 
434  NumReduction,
435  BlockSize,
437  OutDataTypeTuple>;
438 
439  avg_time += launch_and_time_kernel(stream_config,
440  kernel_pre,
441  dim3(arg.gridSize_pre),
442  dim3(BlockSize),
443  0,
445  arg.out_dev_buffers_,
446  identity_values);
447  };
448 
449  avg_time += launch_and_time_kernel(stream_config,
450  kernel_main,
451  dim3(arg.gridSize),
452  dim3(BlockSize),
453  0,
454  arg.in_grid_desc_m_k,
458  arg.blkGroupSize,
460  arg.alpha_values_,
461  arg.in_dev_,
462  arg.beta_values_,
463  arg.out_dev_buffers_);
464 
465  return (avg_time);
466  };
467 
468  float Run(const BaseArgument* p_arg,
469  const StreamConfig& stream_config = StreamConfig{}) override
470  {
471  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
472  };
473  };
474 
475  bool IsSupportedArgument(const BaseArgument* p_arg) override
476  {
477  const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
478 
479  if constexpr(use_multiblock)
480  {
481  for(size_t i = 0; i < pArg->beta_values_.Size(); i++)
482  if(pArg->beta_values_[i] != 0.0f)
483  return (false);
484  };
485 
486  if constexpr(InSrcVectorDim == 0)
487  {
488  if constexpr(NumInvariantDim == 0)
489  {
490  return (false);
491  }
492  else
493  {
494  if(pArg->inStrides_[NumInvariantDim - 1] != 1 && InSrcVectorSize != 1)
495  return (false);
496 
497  if(pArg->inLengths_[NumInvariantDim - 1] % InSrcVectorSize != 0)
498  return (false);
499  };
500  }
501  else
502  {
503  if(pArg->inStrides_[Rank - 1] != 1 && InSrcVectorSize != 1)
504  return (false);
505 
506  if(pArg->inLengths_[Rank - 1] % InSrcVectorSize != 0)
507  return (false);
508  };
509  // To improve
510  bool valid = true;
511  static_for<0, NumReduction, 1>{}([&](auto I) {
512  if(pArg->outStridesArray_[I.value][NumOutputDim - 1] != 1 &&
513  OutDstVectorSizeSeq::At(I) != 1)
514  valid = false;
515 
516  if(pArg->outLengths_[NumOutputDim - 1] % OutDstVectorSizeSeq::At(I) != 0)
517  valid = false;
518  });
519 
520  if(!valid)
521  return (false);
522 
523  if constexpr(use_multiblock)
524  {
525  // blkGroupSize of 1 should be handled by Blockwise path using
526  // InMemoryDataOperationEnum::Set
527  if(pArg->blkGroupSize == 1)
528  return (false);
529 
530  // This is very strong restriction, but needed to avoid some failure
531  if(pArg->outLengths_[NumOutputDim - 1] % M_BlockTileSize != 0)
532  return (false);
533  }
534  else
535  {
536  // cases with very small reduce_total_length should be handled by ThreadWise kernel
537  if(pArg->reduce_total_length / KThreadSliceSize < 2)
538  return (false);
539  };
540 
541  return (true);
542  };
543 
544  std::unique_ptr<BaseArgument> MakeArgumentPointer(
545  const std::array<index_t, NumInputDim> inLengths,
546  const std::array<index_t, NumInputDim> inStrides,
547  const std::array<index_t, NumOutputDim> outLengths,
548  const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray,
549  const std::array<int, NumReduceDim> reduceDims,
550  const std::array<double, NumReduction> alphas,
551  const std::array<double, NumReduction> betas,
552  const void* in_dev,
553  const std::array<void*, NumReduction> out_dev_buffers,
554  const InElementwiseOperationTuple in_elementwise_op_tuple,
555  const AccElementwiseOperationTuple acc_elementwise_op_tuple) override
556  {
557  return std::make_unique<Argument>(inLengths,
558  inStrides,
559  outLengths,
560  outStridesArray,
561  reduceDims,
562  alphas,
563  betas,
564  in_dev,
565  out_dev_buffers,
566  in_elementwise_op_tuple,
567  acc_elementwise_op_tuple);
568  };
569 
570  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
571  {
572  return std::make_unique<Invoker>();
573  };
574 
575  std::string GetTypeString() const override
576  {
577  auto str = std::stringstream();
578 
579  // clang-format off
580  str << (OutMemoryDataOperation == InMemoryDataOperationEnum::Set? "DeviceMultipleReduceBlockWise<" : "DeviceMultipleReduceMultiBlock<") << BlockSize << ",";
581  str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
582  str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
583  str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << ",";
584  str << "OutDstVectorSize";
585  static_for<0, OutDstVectorSizeSeq::Size(), 1>{}([&](auto I) {str << "_" << OutDstVectorSizeSeq::At(I); });
586  str << ">";
587  // clang-format on
588 
589  return str.str();
590  }
591 };
592 
593 } // namespace device
594 } // namespace tensor_operation
595 } // namespace ck
auto pad(ck::index_t mpb, ck::index_t npb, ck::index_t kpb, ck::tensor_operation::device::GemmSpecialization gemm, CDesc_MRaw_NRaw conv)
Definition: helper.hpp:70
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:276
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__global__ void kernel_multiple_buffer_set_value(const Grid1dBufferDescTuple grid_1d_buffer_desc_tuple, DataTypePointerTuple p_global_tuple, DataTypeTuple value_tuple)
Definition: gridwise_set_multiple_buffer_value.hpp:17
int64_t long_index_t
Definition: ck.hpp:299
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ bool sequence_all_of(Seq, F f)
Definition: sequence.hpp:912
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__global__ void kernel_multiple_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M_Tuple out_grid_desc_m_tuple, const InElementwiseOperationTuple in_elementwise_op_tuple, const AccElementwiseOperationTuple acc_elementwise_op_tuple, index_t block_group_size, index_t num_k_block_tile_iteration, Array< AccDataType, NumReduction > alpha_values, const InDataType *const __restrict__ p_in_value_global, Array< AccDataType, NumReduction > beta_values, OutDataTypePointerTuple p_out_value_global_tuple)
Definition: gridwise_2d_multiple_reduction_multiblock.hpp:26
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Definition: stream_config.hpp:10
__host__ static constexpr __device__ index_t Size()
Definition: array.hpp:20
Definition: gridwise_2d_multiple_reduction_multiblock.hpp:69
Definition: sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:271
Definition: integral_constant.hpp:20
Definition: reduction_operator.hpp:485
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: device_base.hpp:62
Definition: device_multiple_reduce.hpp:25
Definition: device_multiple_reduce_multiblock.hpp:267
OutGridDesc_M_Tuple_2 out_grid_desc_m_tuple_2
Definition: device_multiple_reduce_multiblock.hpp:370
std::array< index_t, NumInputDim > inLengths_
Definition: device_multiple_reduce_multiblock.hpp:356
InGridDesc_M_K in_grid_desc_m_k
Definition: device_multiple_reduce_multiblock.hpp:368
long_index_t invariant_total_length
Definition: device_multiple_reduce_multiblock.hpp:375
Array< AccDataType, NumReduction > beta_values_
Definition: device_multiple_reduce_multiblock.hpp:363
long_index_t reduce_total_length
Definition: device_multiple_reduce_multiblock.hpp:376
size_t gridSize_pre
Definition: device_multiple_reduce_multiblock.hpp:382
int blkGroupSize
Definition: device_multiple_reduce_multiblock.hpp:378
int numBlockTileIteration
Definition: device_multiple_reduce_multiblock.hpp:379
const InDataType * in_dev_
Definition: device_multiple_reduce_multiblock.hpp:365
size_t gridSize
Definition: device_multiple_reduce_multiblock.hpp:380
OutGridDesc_M_Tuple out_grid_desc_m_tuple
Definition: device_multiple_reduce_multiblock.hpp:369
OutDataTypePointerTuple out_dev_buffers_
Definition: device_multiple_reduce_multiblock.hpp:366
Argument(const std::array< index_t, NumInputDim > &inLengths, const std::array< index_t, NumInputDim > &inStrides, const std::array< index_t, NumOutputDim > &outLengths, const std::array< std::array< index_t, NumOutputDim >, NumReduction > &outStridesArray, const std::array< int, NumReduceDim > &reduceDims, const std::array< double, NumReduction > &alphas, const std::array< double, NumReduction > &betas, const void *in_dev, const std::array< void *, NumReduction > &out_dev_buffers, const InElementwiseOperationTuple in_elementwise_op_tuple, const AccElementwiseOperationTuple acc_elementwise_op_tuple)
Definition: device_multiple_reduce_multiblock.hpp:268
std::array< std::array< index_t, NumOutputDim >, NumReduction > outStridesArray_
Definition: device_multiple_reduce_multiblock.hpp:360
std::array< index_t, NumOutputDim > outLengths_
Definition: device_multiple_reduce_multiblock.hpp:359
std::array< index_t, NumInputDim > inStrides_
Definition: device_multiple_reduce_multiblock.hpp:357
InElementwiseOperationTuple in_elementwise_op_tuple_
Definition: device_multiple_reduce_multiblock.hpp:372
Array< AccDataType, NumReduction > alpha_values_
Definition: device_multiple_reduce_multiblock.hpp:362
AccElementwiseOperationTuple acc_elementwise_op_tuple_
Definition: device_multiple_reduce_multiblock.hpp:373
Definition: device_multiple_reduce_multiblock.hpp:386
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_multiple_reduce_multiblock.hpp:387
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_multiple_reduce_multiblock.hpp:468
Definition: device_multiple_reduce_multiblock.hpp:48
static constexpr index_t NumInvariantDim
Definition: device_multiple_reduce_multiblock.hpp:86
decltype(GenerateOutGrid1dDescTuple_2()) OutGridDesc_M_Tuple_2
Definition: device_multiple_reduce_multiblock.hpp:264
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_multiple_reduce_multiblock.hpp:475
decltype(MakeSrc2dDescriptor(std::array< index_t, NumInputDim >{}, std::array< index_t, NumInputDim >{}, 1, 1)) InGridDesc_M_K
Definition: device_multiple_reduce_multiblock.hpp:222
static constexpr bool use_multiblock
Definition: device_multiple_reduce_multiblock.hpp:94
static auto GenerateOutGrid1dDescTuple()
Definition: device_multiple_reduce_multiblock.hpp:210
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::array< index_t, NumInputDim > inLengths, const std::array< index_t, NumInputDim > inStrides, const std::array< index_t, NumOutputDim > outLengths, const std::array< std::array< index_t, NumOutputDim >, NumReduction > outStridesArray, const std::array< int, NumReduceDim > reduceDims, const std::array< double, NumReduction > alphas, const std::array< double, NumReduction > betas, const void *in_dev, const std::array< void *, NumReduction > out_dev_buffers, const InElementwiseOperationTuple in_elementwise_op_tuple, const AccElementwiseOperationTuple acc_elementwise_op_tuple) override
Definition: device_multiple_reduce_multiblock.hpp:544
static auto GenerateOutGrid1dDescTuple_2()
Definition: device_multiple_reduce_multiblock.hpp:253
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_multiple_reduce_multiblock.hpp:570
static constexpr bool CheckDataTypeTuple()
Definition: device_multiple_reduce_multiblock.hpp:69
static auto MakeDst1dDescriptor(const std::array< index_t, NumOutputDim > &outLengths, const std::array< index_t, NumOutputDim > &outStrides)
Definition: device_multiple_reduce_multiblock.hpp:181
std::string GetTypeString() const override
Definition: device_multiple_reduce_multiblock.hpp:575
decltype(GenerateOutGrid1dDescTuple()) OutGridDesc_M_Tuple
Definition: device_multiple_reduce_multiblock.hpp:223
static constexpr index_t K_BlockTileSize
Definition: device_multiple_reduce_multiblock.hpp:102
static constexpr bool reduceAllDim
Definition: device_multiple_reduce_multiblock.hpp:90
static auto MakeSrc2dDescriptor(const std::array< index_t, NumInputDim > &inLengths, const std::array< index_t, NumInputDim > &inStrides, int blkGroupSize, int numBlockTileIteration)
Definition: device_multiple_reduce_multiblock.hpp:117
static auto GenerateOutDataTypePointerTuple()
Definition: device_multiple_reduce_multiblock.hpp:104
static constexpr index_t NumInputDim
Definition: device_multiple_reduce_multiblock.hpp:88
static constexpr index_t M_BlockTileSize
Definition: device_multiple_reduce_multiblock.hpp:101
decltype(GenerateOutDataTypePointerTuple()) OutDataTypePointerTuple
Definition: device_multiple_reduce_multiblock.hpp:115
static constexpr index_t NumOutputDim
Definition: device_multiple_reduce_multiblock.hpp:89
static auto MakeDst1dDescriptorForBufferSet(const std::array< index_t, NumOutputDim > &outLengths, const std::array< index_t, NumOutputDim > &outStrides)
Definition: device_multiple_reduce_multiblock.hpp:225