/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp Source File
device_normalization_bwd_data.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <vector>
8 
10 
11 namespace ck {
12 namespace tensor_operation {
13 namespace device {
14 template <typename DYDataType,
15  typename XDataType,
16  typename GammaDataType,
17  typename MeanInvStdDataType,
18  typename DXDataType,
19  index_t Rank,
20  index_t NumReduceDim>
22 {
23  virtual std::unique_ptr<BaseArgument>
24  MakeArgumentPointer(const std::vector<index_t> lengths,
25  const std::vector<index_t> dyStrides,
26  const std::vector<index_t> xStrides,
27  const std::vector<index_t> gammaStrides,
28  const std::vector<index_t> meanStrides,
29  const std::vector<index_t> invStdStrides,
30  const std::vector<index_t> dxStrides,
31  const std::vector<index_t> reduceDims,
32  const void* p_dy,
33  const void* p_x,
34  const void* p_gamma,
35  const void* p_mean,
36  const void* p_invStd,
37  void* p_dx) = 0;
38 
39  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
40 };
41 
42 template <typename DYDataType,
43  typename XDataType,
44  typename GammaDataType,
45  typename MeanInvStdDataType,
46  typename DXDataType,
47  index_t Rank,
48  index_t NumReduceDim>
49 using DeviceNormalizationBwdDataPtr = std::unique_ptr<DeviceNormalizationBwdData<DYDataType,
50  XDataType,
51  GammaDataType,
52  MeanInvStdDataType,
53  DXDataType,
54  Rank,
55  NumReduceDim>>;
56 
57 } // namespace device
58 } // namespace tensor_operation
59 } // namespace ck
std::unique_ptr< DeviceNormalizationBwdData< DYDataType, XDataType, GammaDataType, MeanInvStdDataType, DXDataType, Rank, NumReduceDim > > DeviceNormalizationBwdDataPtr
Definition: device_normalization_bwd_data.hpp:55
Definition: ck.hpp:267
int32_t index_t
Definition: ck.hpp:298
Definition: device_base.hpp:77
Definition: device_normalization_bwd_data.hpp:22
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::vector< index_t > lengths, const std::vector< index_t > dyStrides, const std::vector< index_t > xStrides, const std::vector< index_t > gammaStrides, const std::vector< index_t > meanStrides, const std::vector< index_t > invStdStrides, const std::vector< index_t > dxStrides, const std::vector< index_t > reduceDims, const void *p_dy, const void *p_x, const void *p_gamma, const void *p_mean, const void *p_invStd, void *p_dx)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0