/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp Source File
gridwise_normalization_bwd_data.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
10 
11 namespace ck {
12 
13 // Tensor Shape
14 // dy, x = [M, K], gamma = [1, K], x_mean, inv_std = [M, 1]
15 
16 // Flow:
17 // def normalization_backward_x(dy, x, gamma, x_mean, inv_std, reduce_axis, reduce_size):
18 // ds = np.sum(dy * gamma * x, axis=reduce_axis, keepdims=True)
19 // db = np.sum(dy * gamma, axis=reduce_axis, keepdims=True)
20 // b = (db * x_mean - ds) * inv_std ** (3) / reduce_size
21 // c = -b * x_mean - db * inv_std / reduce_size
22 // dx = inv_std * dy * gamma + b * x + c
23 // return dx
24 
25 template <typename DYDataType,
26  typename XDataType,
27  typename GammaDataType,
28  typename MeanInvStdDataType,
29  typename ComputeDataType,
30  typename DXDataType,
31  typename GridDesc_M_K,
32  index_t BlockSize,
33  index_t MThreadClusterSize,
34  index_t KThreadClusterSize,
35  index_t MThreadSliceSize,
36  index_t KThreadSliceSize,
37  index_t DYSrcVectorDim,
38  index_t DYSrcVectorSize,
39  index_t XSrcVectorDim,
40  index_t XSrcVectorSize,
41  index_t GammaSrcVectorDim,
42  index_t GammaSrcVectorSize,
43  index_t MeanInvStdSrcVectorDim,
44  index_t MeanInvStdSrcVectorSize,
45  index_t DXDstVectorDim,
46  index_t DXDstVectorSize,
47  bool SweepOnce>
49 {
50  // if we just check ThreadSliceSize % VectorSize == 0, the performance may be poor (coalesce)
51  static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize == DYSrcVectorSize) ||
52  (DYSrcVectorDim == 1 && KThreadSliceSize == DYSrcVectorSize)),
53  "Invalid thread slice sizes and/or dy vector sizes configuration, please check!");
54 
55  static_assert(((XSrcVectorDim == 0 && MThreadSliceSize == XSrcVectorSize) ||
56  (XSrcVectorDim == 1 && KThreadSliceSize == XSrcVectorSize)),
57  "Invalid thread slice sizes and/or x vector sizes configuration, please check!");
58 
59  static_assert(
60  ((GammaSrcVectorDim == 0 && MThreadSliceSize == GammaSrcVectorSize) ||
61  (GammaSrcVectorDim == 1 && KThreadSliceSize == GammaSrcVectorSize)),
62  "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!");
63 
64  static_assert(
65  ((MeanInvStdSrcVectorDim == 0 && MThreadSliceSize == MeanInvStdSrcVectorSize) ||
66  (MeanInvStdSrcVectorDim == 1 && KThreadSliceSize == MeanInvStdSrcVectorSize)),
67  "Invalid thread slice sizes and/or mean/inv_std vector sizes configuration, please check!");
68 
69  static_assert(((DXDstVectorDim == 0 && MThreadSliceSize == DXDstVectorSize) ||
70  (DXDstVectorDim == 1 && KThreadSliceSize == DXDstVectorSize)),
71  "Invalid thread slice sizes and/or dx vector sizes configuration, please check!");
72 
74 
85 
87  static constexpr auto thread_cluster_desc =
89 
91 
94 
95  static constexpr auto thread_buffer_desc_m =
97 
99 
101  BlockSize,
104  reduce::Add,
105  true>;
106 
107  static constexpr auto I0 = Number<0>{};
108  static constexpr auto I1 = Number<1>{};
109  static constexpr auto I2 = Number<2>{};
110 
111  static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
112  static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
113 
114  __device__ static void Run(const GridDesc_M_K& dy_grid_desc_m_k,
115  const GridDesc_M_K& x_grid_desc_m_k,
116  const GridDesc_M_K& gamma_grid_desc_m_k,
117  const GridDesc_M_K& mean_grid_desc_m_k,
118  const GridDesc_M_K& inv_std_grid_desc_m_k,
119  const GridDesc_M_K& dx_grid_desc_m_k,
120  index_t num_k_block_tile_iteration,
121  const DYDataType* const __restrict__ p_dy_global,
122  const XDataType* const __restrict__ p_x_global,
123  const GammaDataType* const __restrict__ p_gamma_global,
124  const MeanInvStdDataType* const __restrict__ p_mean_global,
125  const MeanInvStdDataType* const __restrict__ p_inv_std_global,
126  DXDataType* const __restrict__ p_dx_global)
127  {
128  // LDS
129  __shared__ ComputeDataType p_reduce_work_buffer[BlockSize];
130 
131  auto reduce_work_buf =
132  make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
133 
134  // Global
135  const auto dy_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
136  p_dy_global, dy_grid_desc_m_k.GetElementSpaceSize());
137 
138  const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
139  p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
140 
141  auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
142  p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize());
143 
144  const auto mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
145  p_mean_global, mean_grid_desc_m_k.GetElementSpaceSize());
146 
147  const auto inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
148  p_inv_std_global, inv_std_grid_desc_m_k.GetElementSpaceSize());
149 
150  auto dx_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
151  p_dx_global, dx_grid_desc_m_k.GetElementSpaceSize());
152 
153  // VGPR
154  auto dy_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
155  ComputeDataType,
156  MThreadSliceSize * KThreadSliceSize,
157  true>{};
158 
159  auto x_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
160  ComputeDataType,
161  MThreadSliceSize * KThreadSliceSize,
162  true>{};
163 
164  auto gamma_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
165  ComputeDataType,
166  MThreadSliceSize * KThreadSliceSize,
167  true>{};
168 
169  auto mean_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
170  ComputeDataType,
171  MThreadSliceSize * KThreadSliceSize,
172  true>{};
173 
174  auto inv_std_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
175  ComputeDataType,
176  MThreadSliceSize * KThreadSliceSize,
177  true>{};
178 
179  auto dx_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
180  ComputeDataType,
181  MThreadSliceSize * KThreadSliceSize,
182  true>{};
183 
184  auto ds_thread_buf =
186 
187  auto db_thread_buf =
189 
190  // thread id
191  const index_t thread_local_id = get_thread_local_1d_id();
192  const index_t block_global_id = get_block_1d_id();
193 
194  const auto thread_cluster_idx =
195  thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
196 
197  const auto thread_m_cluster_id = thread_cluster_idx[I0];
198  const auto thread_k_cluster_id = thread_cluster_idx[I1];
199 
200  // IO
201  auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2<DYDataType,
202  ComputeDataType,
203  GridDesc_M_K,
204  decltype(thread_buffer_desc_m_k),
207  DYSrcVectorDim,
208  DYSrcVectorSize,
209  1,
210  false>(
211  dy_grid_desc_m_k,
212  make_multi_index(block_global_id * M_BlockTileSize +
213  thread_m_cluster_id * MThreadSliceSize,
214  thread_k_cluster_id * KThreadSliceSize));
215 
216  auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
217  ComputeDataType,
218  GridDesc_M_K,
219  decltype(thread_buffer_desc_m_k),
222  XSrcVectorDim,
223  XSrcVectorSize,
224  1,
225  false>(
226  x_grid_desc_m_k,
227  make_multi_index(block_global_id * M_BlockTileSize +
228  thread_m_cluster_id * MThreadSliceSize,
229  thread_k_cluster_id * KThreadSliceSize));
230 
231  auto threadwise_gamma_load =
232  ThreadwiseTensorSliceTransfer_v2<GammaDataType,
233  ComputeDataType,
234  GridDesc_M_K,
235  decltype(thread_buffer_desc_m_k),
238  GammaSrcVectorDim,
239  GammaSrcVectorSize,
240  1,
241  false>(
242  gamma_grid_desc_m_k,
243  make_multi_index(block_global_id * M_BlockTileSize +
244  thread_m_cluster_id * MThreadSliceSize,
245  thread_k_cluster_id * KThreadSliceSize));
246 
247  auto threadwise_mean_load =
248  ThreadwiseTensorSliceTransfer_v2<MeanInvStdDataType,
249  ComputeDataType,
250  GridDesc_M_K,
251  decltype(thread_buffer_desc_m_k),
254  MeanInvStdSrcVectorDim,
255  MeanInvStdSrcVectorSize,
256  1,
257  false>(
258  mean_grid_desc_m_k,
259  make_multi_index(block_global_id * M_BlockTileSize +
260  thread_m_cluster_id * MThreadSliceSize,
261  thread_k_cluster_id * KThreadSliceSize));
262 
263  auto threadwise_inv_std_load =
264  ThreadwiseTensorSliceTransfer_v2<MeanInvStdDataType,
265  ComputeDataType,
266  GridDesc_M_K,
267  decltype(thread_buffer_desc_m_k),
270  MeanInvStdSrcVectorDim,
271  MeanInvStdSrcVectorSize,
272  1,
273  false>(
274  inv_std_grid_desc_m_k,
275  make_multi_index(block_global_id * M_BlockTileSize +
276  thread_m_cluster_id * MThreadSliceSize,
277  thread_k_cluster_id * KThreadSliceSize));
278 
279  auto threadwise_dx_store =
280  ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
281  DXDataType,
282  decltype(thread_buffer_desc_m_k),
283  GridDesc_M_K,
287  DXDstVectorDim,
288  DXDstVectorSize,
290  1,
291  false>(
292  dx_grid_desc_m_k,
293  make_multi_index(block_global_id * M_BlockTileSize +
294  thread_m_cluster_id * MThreadSliceSize,
295  thread_k_cluster_id * KThreadSliceSize),
296  PassThroughOp{});
297 
298  ComputeDataType reduce_size = type_convert<ComputeDataType>(
299  dy_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]);
300 
301  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
302  ds_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
303  db_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
304  });
305 
306  // Separate sweep once and sweep twice pipeline
307  // Sweep once: for small k, if KThreadClusterSize * KThreadSliceSize > K
308  // we don't need to use loop to read x, dy, gamma twice
309  if constexpr(SweepOnce)
310  {
311  threadwise_dy_load.Run(dy_grid_desc_m_k,
312  dy_global_val_buf,
314  make_tuple(I0, I0),
315  dy_thread_buf);
316 
317  threadwise_x_load.Run(x_grid_desc_m_k,
318  x_global_val_buf,
320  make_tuple(I0, I0),
321  x_thread_buf);
322 
323  threadwise_gamma_load.Run(gamma_grid_desc_m_k,
324  gamma_global_val_buf,
326  make_tuple(I0, I0),
327  gamma_thread_buf);
328 
329  threadwise_mean_load.Run(mean_grid_desc_m_k,
330  mean_global_val_buf,
332  make_tuple(I0, I0),
333  mean_thread_buf);
334 
335  threadwise_inv_std_load.Run(inv_std_grid_desc_m_k,
336  inv_std_global_val_buf,
338  make_tuple(I0, I0),
339  inv_std_thread_buf);
340 
341  static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
342  constexpr auto offset_m =
343  Number<thread_buffer_desc_m.CalculateOffset(make_tuple(iM))>{};
344 
345  static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
346  constexpr auto offset_m_k =
347  Number<thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK))>{};
348 
349  ds_thread_buf(offset_m) += dy_thread_buf[offset_m_k] *
350  gamma_thread_buf[offset_m_k] *
351  x_thread_buf[offset_m_k];
352 
353  db_thread_buf(offset_m) +=
354  dy_thread_buf[offset_m_k] * gamma_thread_buf[offset_m_k];
355  });
356  });
357 
358  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
359  if constexpr(I > 0)
360  block_sync_lds();
361 
362  BlockwiseSumReduce::Reduce(reduce_work_buf, ds_thread_buf(I));
363  block_sync_lds();
364  BlockwiseSumReduce::Reduce(reduce_work_buf, db_thread_buf(I));
365  });
366 
367  static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
368  constexpr auto offset_m =
369  Number<thread_buffer_desc_m.CalculateOffset(make_tuple(iM))>{};
370 
371  static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
372  constexpr auto offset_m_k =
373  Number<thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK))>{};
374 
375  // b = (db * x_mean - ds) * rstd ** (3) / reduce_size
376  // c = -b * x_mean - db * rstd / reduce_size
377  // dx = rstd * dy * gamma + b * x + c
378 
379  ComputeDataType b = db_thread_buf[offset_m] * mean_thread_buf[offset_m_k] -
380  ds_thread_buf[offset_m];
381 
382  b *= inv_std_thread_buf[offset_m_k] * inv_std_thread_buf[offset_m_k] *
383  inv_std_thread_buf[offset_m_k] / reduce_size;
384 
385  ComputeDataType c = -b * mean_thread_buf(offset_m_k);
386 
387  c -= db_thread_buf[offset_m] * inv_std_thread_buf[offset_m_k] / reduce_size;
388 
389  dx_thread_buf(offset_m_k) = dy_thread_buf[offset_m_k] *
390  gamma_thread_buf[offset_m_k] *
391  inv_std_thread_buf[offset_m_k] +
392  b * x_thread_buf[offset_m_k] + c;
393  });
394  });
395 
396  threadwise_dx_store.Run(thread_buffer_desc_m_k,
397  make_tuple(I0, I0),
398  dx_thread_buf,
399  dx_grid_desc_m_k,
400  dx_global_val_buf);
401 
402  } // end of sweep once
403  else // Sweep Twice pipeline
404  {
405  constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
406 
407  for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
408  {
409  threadwise_dy_load.Run(dy_grid_desc_m_k,
410  dy_global_val_buf,
412  make_tuple(I0, I0),
413  dy_thread_buf);
414 
415  threadwise_x_load.Run(x_grid_desc_m_k,
416  x_global_val_buf,
418  make_tuple(I0, I0),
419  x_thread_buf);
420 
421  threadwise_gamma_load.Run(gamma_grid_desc_m_k,
422  gamma_global_val_buf,
424  make_tuple(I0, I0),
425  gamma_thread_buf);
426 
427  threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_fwd_step_m_k);
428  threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
429  threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
430  thread_copy_fwd_step_m_k);
431 
432  static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
433  constexpr auto offset_m =
434  Number<thread_buffer_desc_m.CalculateOffset(make_tuple(iM))>{};
435 
436  static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
437  constexpr auto offset_m_k =
438  Number<thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK))>{};
439 
440  ds_thread_buf(offset_m) += dy_thread_buf[offset_m_k] *
441  gamma_thread_buf[offset_m_k] *
442  x_thread_buf[offset_m_k];
443 
444  db_thread_buf(offset_m) +=
445  dy_thread_buf[offset_m_k] * gamma_thread_buf[offset_m_k];
446  });
447  });
448  } // end of first sweep
449 
450  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
451  if constexpr(I > 0)
452  block_sync_lds();
453 
454  BlockwiseSumReduce::Reduce(reduce_work_buf, ds_thread_buf(I));
455  block_sync_lds();
456  BlockwiseSumReduce::Reduce(reduce_work_buf, db_thread_buf(I));
457  });
458 
459  // reverse read for using dy, gamma and x in the cache
460  constexpr auto thread_copy_bwd_step_m_k = make_multi_index(0, -K_BlockTileSize);
461  auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k;
462 
463  // move to tail
464  threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k);
465  threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
466  threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_bwd_step_m_k);
467 
468  // move from start to tail
469  threadwise_mean_load.MoveSrcSliceWindow(mean_grid_desc_m_k, thread_copy_tail_m_k);
470  threadwise_inv_std_load.MoveSrcSliceWindow(inv_std_grid_desc_m_k, thread_copy_tail_m_k);
471  threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_tail_m_k);
472 
473  for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
474  {
475  threadwise_dy_load.Run(dy_grid_desc_m_k,
476  dy_global_val_buf,
478  make_tuple(I0, I0),
479  dy_thread_buf);
480 
481  threadwise_x_load.Run(x_grid_desc_m_k,
482  x_global_val_buf,
484  make_tuple(I0, I0),
485  x_thread_buf);
486 
487  threadwise_gamma_load.Run(gamma_grid_desc_m_k,
488  gamma_global_val_buf,
490  make_tuple(I0, I0),
491  gamma_thread_buf);
492 
493  threadwise_mean_load.Run(mean_grid_desc_m_k,
494  mean_global_val_buf,
496  make_tuple(I0, I0),
497  mean_thread_buf);
498 
499  threadwise_inv_std_load.Run(inv_std_grid_desc_m_k,
500  inv_std_global_val_buf,
502  make_tuple(I0, I0),
503  inv_std_thread_buf);
504 
505  static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
506  constexpr auto offset_m =
507  Number<thread_buffer_desc_m.CalculateOffset(make_tuple(iM))>{};
508 
509  static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
510  constexpr auto offset_m_k =
511  Number<thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK))>{};
512 
513  // b = (db * x_mean - ds) * rstd ** (3) / reduce_size
514  // c = -b * x_mean - db * rstd / reduce_size
515  // dx = rstd * dy * gamma + b * x + c
516 
517  ComputeDataType b = db_thread_buf[offset_m] * mean_thread_buf[offset_m_k] -
518  ds_thread_buf[offset_m];
519 
520  b *= inv_std_thread_buf[offset_m_k] * inv_std_thread_buf[offset_m_k] *
521  inv_std_thread_buf[offset_m_k] / reduce_size;
522 
523  ComputeDataType c = -b * mean_thread_buf(offset_m_k);
524 
525  c -= db_thread_buf[offset_m] * inv_std_thread_buf[offset_m_k] / reduce_size;
526 
527  dx_thread_buf(offset_m_k) = dy_thread_buf[offset_m_k] *
528  gamma_thread_buf[offset_m_k] *
529  inv_std_thread_buf[offset_m_k] +
530  b * x_thread_buf[offset_m_k] + c;
531  });
532  });
533 
534  threadwise_dx_store.Run(thread_buffer_desc_m_k,
535  make_tuple(I0, I0),
536  dx_thread_buf,
537  dx_grid_desc_m_k,
538  dx_global_val_buf);
539 
540  threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k);
541  threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
542  threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
543  thread_copy_bwd_step_m_k);
544  threadwise_mean_load.MoveSrcSliceWindow(mean_grid_desc_m_k,
545  thread_copy_bwd_step_m_k);
546  threadwise_inv_std_load.MoveSrcSliceWindow(inv_std_grid_desc_m_k,
547  thread_copy_bwd_step_m_k);
548  threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_bwd_step_m_k);
549  }
550  }
551  }
552 };
553 
554 } // namespace ck
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
Definition: gridwise_normalization_bwd_data.hpp:49
static constexpr auto thread_buffer_desc_m_k
Definition: gridwise_normalization_bwd_data.hpp:92
static constexpr index_t M_BlockTileSize
Definition: gridwise_normalization_bwd_data.hpp:111
static constexpr auto I1
Definition: gridwise_normalization_bwd_data.hpp:108
static constexpr auto I0
Definition: gridwise_normalization_bwd_data.hpp:107
Sequence< MThreadSliceSize, KThreadSliceSize > ThreadBufferLengths_M_K
Definition: gridwise_normalization_bwd_data.hpp:90
DYThreadBufferDimAccessOrder ThreadClusterArrangeOrder
Definition: gridwise_normalization_bwd_data.hpp:86
typename conditional< DXDstVectorDim==0, Sequence< 1, 0 >, Sequence< 0, 1 > >::type DXThreadBufferDimAccessOrder
Definition: gridwise_normalization_bwd_data.hpp:84
static constexpr auto thread_buffer_desc_m
Definition: gridwise_normalization_bwd_data.hpp:95
static constexpr auto thread_cluster_desc
Definition: gridwise_normalization_bwd_data.hpp:87
static constexpr index_t K_BlockTileSize
Definition: gridwise_normalization_bwd_data.hpp:112
typename conditional< DYSrcVectorDim==0, Sequence< 1, 0 >, Sequence< 0, 1 > >::type DYThreadBufferDimAccessOrder
Definition: gridwise_normalization_bwd_data.hpp:76
static __device__ void Run(const GridDesc_M_K &dy_grid_desc_m_k, const GridDesc_M_K &x_grid_desc_m_k, const GridDesc_M_K &gamma_grid_desc_m_k, const GridDesc_M_K &mean_grid_desc_m_k, const GridDesc_M_K &inv_std_grid_desc_m_k, const GridDesc_M_K &dx_grid_desc_m_k, index_t num_k_block_tile_iteration, const DYDataType *const __restrict__ p_dy_global, const XDataType *const __restrict__ p_x_global, const GammaDataType *const __restrict__ p_gamma_global, const MeanInvStdDataType *const __restrict__ p_mean_global, const MeanInvStdDataType *const __restrict__ p_inv_std_global, DXDataType *const __restrict__ p_dx_global)
Definition: gridwise_normalization_bwd_data.hpp:114
typename conditional< GammaSrcVectorDim==0, Sequence< 1, 0 >, Sequence< 0, 1 > >::type GammaThreadBufferDimAccessOrder
Definition: gridwise_normalization_bwd_data.hpp:80
typename conditional< XSrcVectorDim==0, Sequence< 1, 0 >, Sequence< 0, 1 > >::type XThreadBufferDimAccessOrder
Definition: gridwise_normalization_bwd_data.hpp:78
typename conditional< MeanInvStdSrcVectorDim==0, Sequence< 1, 0 >, Sequence< 0, 1 > >::type MeanInvStdThreadBufferDimAccessOrder
Definition: gridwise_normalization_bwd_data.hpp:82
tensor_operation::element_wise::PassThrough PassThroughOp
Definition: gridwise_normalization_bwd_data.hpp:98
static constexpr auto I2
Definition: gridwise_normalization_bwd_data.hpp:109
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition: gridwise_normalization_bwd_data.hpp:73
Definition: reduction_functions_blockwise.hpp:28
static __device__ void Reduce(BufferType &work_buffer, AccDataType &in_out_value)
Definition: reduction_functions_blockwise.hpp:44
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:276
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer.hpp:389
Definition: functional.hpp:100
Definition: integral_constant.hpp:20
Definition: reduction_operator.hpp:37
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:334