/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp Source File
device_batchnorm_backward.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <array>
7 #include <memory>
8 
9 #include "ck/ck.hpp"
11 
12 namespace ck {
13 namespace tensor_operation {
14 namespace device {
15 
16 template <typename XDataType,
17  typename DxDataType,
18  typename DyDataType,
19  typename AccDataType,
20  typename ScaleDataType,
21  typename DscaleDbiasDataType,
22  typename MeanVarDataType,
23  typename DyElementwiseOp,
24  index_t Rank,
25  index_t NumBatchNormReduceDim>
27 {
28  static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
29 
30  virtual std::unique_ptr<BaseArgument>
31  MakeArgumentPointer(const std::array<index_t, Rank> xyLengths,
32  const std::array<index_t, Rank> xStrides,
33  const std::array<index_t, Rank> dyStrides,
34  const std::array<index_t, Rank> dxStrides,
35  const std::array<int, NumBatchNormReduceDim> reduceDims,
36  const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
37  const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
38  const std::array<ck::index_t, NumInvariantDim> bnDscaleDbiasStrides,
39  const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
40  const void* p_x,
41  const void* p_dy,
42  const void* p_scale,
43  const void* p_savedMean,
44  const void* p_savedInvVar,
45  double epsilon,
46  const DyElementwiseOp dy_elementwise_op,
47  void* p_dx,
48  void* p_dscale,
49  void* p_dbias) = 0;
50 
51  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
52 };
53 
54 template <typename XDataType,
55  typename DxDataType,
56  typename DyDataType,
57  typename AccDataType,
58  typename ScaleDataType,
59  typename DscaleDbiasDataType,
60  typename MeanVarDataType,
61  typename DyElementwiseOp,
62  index_t Rank,
63  index_t NumBatchNormReduceDim>
64 using DeviceBatchNormBwdPtr = std::unique_ptr<DeviceBatchNormBwd<XDataType,
65  DxDataType,
66  DyDataType,
67  AccDataType,
68  ScaleDataType,
69  DscaleDbiasDataType,
70  MeanVarDataType,
71  DyElementwiseOp,
72  Rank,
73  NumBatchNormReduceDim>>;
74 
75 } // namespace device
76 } // namespace tensor_operation
77 } // namespace ck
std::unique_ptr< DeviceBatchNormBwd< XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumBatchNormReduceDim > > DeviceBatchNormBwdPtr
Definition: device_batchnorm_backward.hpp:73
Definition: ck.hpp:267
int32_t index_t
Definition: ck.hpp:298
Definition: device_base.hpp:77
Definition: device_batchnorm_backward.hpp:27
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::array< index_t, Rank > xyLengths, const std::array< index_t, Rank > xStrides, const std::array< index_t, Rank > dyStrides, const std::array< index_t, Rank > dxStrides, const std::array< int, NumBatchNormReduceDim > reduceDims, const std::array< ck::index_t, NumInvariantDim > bnScaleBiasMeanVarLengths, const std::array< ck::index_t, NumInvariantDim > bnScaleStrides, const std::array< ck::index_t, NumInvariantDim > bnDscaleDbiasStrides, const std::array< ck::index_t, NumInvariantDim > bnMeanVarStrides, const void *p_x, const void *p_dy, const void *p_scale, const void *p_savedMean, const void *p_savedInvVar, double epsilon, const DyElementwiseOp dy_elementwise_op, void *p_dx, void *p_dscale, void *p_dbias)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
static constexpr index_t NumInvariantDim
Definition: device_batchnorm_backward.hpp:28