/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp Source File
gridwise_normalization_selector.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
8 
9 namespace ck {
10 template <typename GridwiseReduction,
11  typename XDataType,
12  typename GammaDataType,
13  typename BetaDataType,
14  typename YDataType,
15  typename SaveMeanInvStdDataType,
16  typename ComputeDataType,
17  typename YElementwiseOperation,
18  typename GridDesc_M_K,
19  typename GridDesc_M>
20 __global__ void
21 kernel_normalization(const GridDesc_M_K x_grid_desc_m_k,
22  const GridDesc_M_K gamma_grid_desc_m_k,
23  const GridDesc_M_K beta_grid_desc_m_k,
24  const GridDesc_M_K y_grid_desc_m_k,
25  const GridDesc_M save_mean_grid_desc_m,
26  const GridDesc_M save_inv_std_grid_desc_m,
27  index_t num_k_block_tile_iteration,
28  ComputeDataType epsilon,
29  const XDataType* const __restrict__ p_x_global,
30  const GammaDataType* const __restrict__ p_gamma_global,
31  const BetaDataType* const __restrict__ p_beta_global,
32  YDataType* const __restrict__ p_y_global,
33  SaveMeanInvStdDataType* const __restrict__ p_save_mean_global,
34  SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global,
35  const YElementwiseOperation y_elementwise_op)
36 {
37  GridwiseReduction::Run(x_grid_desc_m_k,
38  gamma_grid_desc_m_k,
39  beta_grid_desc_m_k,
40  y_grid_desc_m_k,
41  save_mean_grid_desc_m,
42  save_inv_std_grid_desc_m,
43  num_k_block_tile_iteration,
44  epsilon,
45  p_x_global,
46  p_gamma_global,
47  p_beta_global,
48  p_y_global,
49  p_save_mean_global,
50  p_save_inv_std_global,
51  y_elementwise_op);
52 };
53 
54 template <typename XDataType,
55  typename GammaDataType,
56  typename BetaDataType,
57  typename YDataType,
58  typename SaveMeanInvStdDataType,
59  typename ComputeDataType,
60  typename YElementwiseOperation,
61  typename GridDesc_M_K,
62  typename GridDesc_M,
63  index_t BlockSize,
64  index_t MThreadClusterSize,
65  index_t KThreadClusterSize,
66  index_t MThreadSliceSize,
67  index_t KThreadSliceSize,
68  index_t XSrcVectorDim,
69  index_t XSrcVectorSize,
70  index_t GammaSrcVectorDim,
71  index_t GammaSrcVectorSize,
72  index_t BetaSrcVectorDim,
73  index_t BetaSrcVectorSize,
74  index_t YDstVectorDim,
75  index_t YDstVectorSize,
76  index_t SaveMeanInvStdDstVectorSize,
77  bool UseWelford>
78 auto NormalizationKernelSelector(bool isSweepOnce)
79 {
80  using GridwiseNormalizationGenericNaive =
82  GammaDataType,
83  BetaDataType,
84  YDataType,
85  SaveMeanInvStdDataType,
86  ComputeDataType,
87  YElementwiseOperation,
88  GridDesc_M_K,
89  GridDesc_M,
90  BlockSize,
91  MThreadClusterSize,
92  KThreadClusterSize,
93  MThreadSliceSize,
94  KThreadSliceSize,
95  XSrcVectorDim,
96  XSrcVectorSize,
97  GammaSrcVectorDim,
98  GammaSrcVectorSize,
99  BetaSrcVectorDim,
100  BetaSrcVectorSize,
101  YDstVectorDim,
102  YDstVectorSize,
103  SaveMeanInvStdDstVectorSize,
104  false>;
105  using GridwiseNormalizationSweepOnceNaive =
107  GammaDataType,
108  BetaDataType,
109  YDataType,
110  SaveMeanInvStdDataType,
111  ComputeDataType,
112  YElementwiseOperation,
113  GridDesc_M_K,
114  GridDesc_M,
115  BlockSize,
116  MThreadClusterSize,
117  KThreadClusterSize,
118  MThreadSliceSize,
119  KThreadSliceSize,
120  XSrcVectorDim,
121  XSrcVectorSize,
122  GammaSrcVectorDim,
123  GammaSrcVectorSize,
124  BetaSrcVectorDim,
125  BetaSrcVectorSize,
126  YDstVectorDim,
127  YDstVectorSize,
128  SaveMeanInvStdDstVectorSize,
129  true>;
130  using GridwiseNormalizationGenericWelford =
132  GammaDataType,
133  BetaDataType,
134  YDataType,
135  SaveMeanInvStdDataType,
136  ComputeDataType,
137  YElementwiseOperation,
138  GridDesc_M_K,
139  GridDesc_M,
140  BlockSize,
141  MThreadClusterSize,
142  KThreadClusterSize,
143  MThreadSliceSize,
144  KThreadSliceSize,
145  XSrcVectorDim,
146  XSrcVectorSize,
147  GammaSrcVectorDim,
148  GammaSrcVectorSize,
149  BetaSrcVectorDim,
150  BetaSrcVectorSize,
151  YDstVectorDim,
152  YDstVectorSize,
153  SaveMeanInvStdDstVectorSize,
154  false>;
155  using GridwiseNormalizationSweepOnceWelford =
157  GammaDataType,
158  BetaDataType,
159  YDataType,
160  SaveMeanInvStdDataType,
161  ComputeDataType,
162  YElementwiseOperation,
163  GridDesc_M_K,
164  GridDesc_M,
165  BlockSize,
166  MThreadClusterSize,
167  KThreadClusterSize,
168  MThreadSliceSize,
169  KThreadSliceSize,
170  XSrcVectorDim,
171  XSrcVectorSize,
172  GammaSrcVectorDim,
173  GammaSrcVectorSize,
174  BetaSrcVectorDim,
175  BetaSrcVectorSize,
176  YDstVectorDim,
177  YDstVectorSize,
178  SaveMeanInvStdDstVectorSize,
179  true>;
180 
181  if constexpr(UseWelford)
182  {
183  return isSweepOnce ? kernel_normalization<GridwiseNormalizationSweepOnceWelford,
184  XDataType,
185  GammaDataType,
186  BetaDataType,
187  YDataType,
188  SaveMeanInvStdDataType,
189  ComputeDataType,
190  YElementwiseOperation,
191  GridDesc_M_K,
192  GridDesc_M>
193  : kernel_normalization<GridwiseNormalizationGenericWelford,
194  XDataType,
195  GammaDataType,
196  BetaDataType,
197  YDataType,
198  SaveMeanInvStdDataType,
199  ComputeDataType,
200  YElementwiseOperation,
201  GridDesc_M_K,
202  GridDesc_M>;
203  }
204  else
205  {
206  return isSweepOnce ? kernel_normalization<GridwiseNormalizationSweepOnceNaive,
207  XDataType,
208  GammaDataType,
209  BetaDataType,
210  YDataType,
211  SaveMeanInvStdDataType,
212  ComputeDataType,
213  YElementwiseOperation,
214  GridDesc_M_K,
215  GridDesc_M>
216  : kernel_normalization<GridwiseNormalizationGenericNaive,
217  XDataType,
218  GammaDataType,
219  BetaDataType,
220  YDataType,
221  SaveMeanInvStdDataType,
222  ComputeDataType,
223  YElementwiseOperation,
224  GridDesc_M_K,
225  GridDesc_M>;
226  }
227 }
228 
229 } // namespace ck
Definition: ck.hpp:267
__global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k, const GridDesc_M_K gamma_grid_desc_m_k, const GridDesc_M_K beta_grid_desc_m_k, const GridDesc_M_K y_grid_desc_m_k, const GridDesc_M save_mean_grid_desc_m, const GridDesc_M save_inv_std_grid_desc_m, index_t num_k_block_tile_iteration, ComputeDataType epsilon, const XDataType *const __restrict__ p_x_global, const GammaDataType *const __restrict__ p_gamma_global, const BetaDataType *const __restrict__ p_beta_global, YDataType *const __restrict__ p_y_global, SaveMeanInvStdDataType *const __restrict__ p_save_mean_global, SaveMeanInvStdDataType *const __restrict__ p_save_inv_std_global, const YElementwiseOperation y_elementwise_op)
Definition: gridwise_normalization_selector.hpp:21
int32_t index_t
Definition: ck.hpp:298
auto NormalizationKernelSelector(bool isSweepOnce)
Definition: gridwise_normalization_selector.hpp:78
Definition: gridwise_normalization_naive_variance.hpp:42
Definition: gridwise_normalization_welford_variance.hpp:40