/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp Source File
gridwise_normalization_naive_variance.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
13 
14 namespace ck {
15 
16 // Y = Normalization(X, Beta, Gamma)
17 template <typename XDataType,
18  typename GammaDataType,
19  typename BetaDataType,
20  typename YDataType,
21  typename SaveMeanInvStdDataType,
22  typename ComputeDataType,
23  typename YElementwiseOperation,
24  typename GridDesc_M_K,
25  typename GridDesc_M,
26  index_t BlockSize,
27  index_t MThreadClusterSize,
28  index_t KThreadClusterSize,
29  index_t MThreadSliceSize,
30  index_t KThreadSliceSize,
31  index_t XSrcVectorDim,
32  index_t XSrcVectorSize,
33  index_t GammaSrcVectorDim,
34  index_t GammaSrcVectorSize,
35  index_t BetaSrcVectorDim,
36  index_t BetaSrcVectorSize,
37  index_t YDstVectorDim,
38  index_t YDstVectorSize,
39  index_t SaveMeanInvStdDstVectorSize,
40  bool SweepOnce>
42 {
43  static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
44  (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
45  "Invalid thread slice sizes and/or vector sizes configuration, please check!");
46 
47  static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) ||
48  (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
49  "Invalid thread slice sizes and/or vector sizes configuration, please check!");
50 
51  static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
52  "Invalid thread slice sizes and/or save mean and inverse std vector sizes "
53  "configuration, please check!");
54 
55  static_assert(XSrcVectorSize == YDstVectorSize);
56  static_assert(XSrcVectorSize == GammaSrcVectorSize);
57  static_assert(XSrcVectorSize == BetaSrcVectorSize);
58 
59  static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0);
60 
62 
65 
68 
69  static constexpr auto thread_cluster_desc =
71 
75 
77  static constexpr auto thread_buffer_desc_m =
79 
84 
86  BlockSize,
90  true>;
91 
92  using ThreadwiseSumReduce = ThreadwiseReduction<ComputeDataType,
96  true>;
97 
99 
100  static constexpr auto I0 = Number<0>{};
101  static constexpr auto I1 = Number<1>{};
102  static constexpr auto I2 = Number<2>{};
103 
104  static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
105  static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
106  static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize;
107 
108  static constexpr auto ThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
109 
110  __device__ static void Run(const GridDesc_M_K& x_grid_desc_m_k,
111  const GridDesc_M_K& gamma_grid_desc_m_k,
112  const GridDesc_M_K& beta_grid_desc_m_k,
113  const GridDesc_M_K& y_grid_desc_m_k,
114  const GridDesc_M& save_mean_grid_desc_m,
115  const GridDesc_M& save_inv_std_grid_desc_m,
116  index_t num_k_block_tile_iteration,
117  ComputeDataType epsilon,
118  const XDataType* const __restrict__ p_x_global,
119  const GammaDataType* const __restrict__ p_gamma_global,
120  const BetaDataType* const __restrict__ p_beta_global,
121  YDataType* const __restrict__ p_y_global,
122  SaveMeanInvStdDataType* const __restrict__ p_save_mean_global,
123  SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global,
124  const YElementwiseOperation y_elementwise_op)
125  {
126  // LDS
127  __shared__ ComputeDataType p_reduce_work_buffer[BlockSize];
128 
129  auto reduce_work_buf =
130  make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
131 
132  auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
133  p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
134 
135  auto save_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
136  p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize());
137 
138  auto save_inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
139  p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize());
140 
141  auto x_thread_buf = generate_tuple(
142  [&](auto) {
144  ComputeDataType,
145  MThreadSliceSize * XSrcVectorSize,
146  true>{};
147  },
149 
150  auto gamma_thread_buf = generate_tuple(
151  [&](auto) {
153  ComputeDataType,
154  MThreadSliceSize * GammaSrcVectorSize,
155  true>{};
156  },
158 
159  auto& beta_thread_buf = gamma_thread_buf;
160 
161  auto y_thread_buf = generate_tuple(
162  [&](auto) {
164  ComputeDataType,
165  MThreadSliceSize * YDstVectorSize,
166  true>{};
167  },
169 
170  auto& x_square_thread_buf = y_thread_buf;
171 
173  mean_thread_buf;
175  mean_square_thread_buf;
177  var_thread_buf = mean_square_thread_buf;
179  inv_std_thread_buf = mean_square_thread_buf;
180 
181  const index_t thread_local_id = get_thread_local_1d_id();
182  const index_t block_global_id = get_block_1d_id();
183 
184  const auto thread_cluster_idx =
185  thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
186 
187  const auto thread_m_cluster_id = thread_cluster_idx[I0];
188  const auto thread_k_cluster_id = thread_cluster_idx[I1];
189 
190  auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
191  ComputeDataType,
192  GridDesc_M_K,
193  decltype(thread_buffer_desc_m_k),
196  XSrcVectorDim,
197  XSrcVectorSize,
198  1,
199  true>(
200  x_grid_desc_m_k,
201  make_multi_index(block_global_id * M_BlockTileSize +
202  thread_m_cluster_id * MThreadSliceSize,
203  thread_k_cluster_id * XSrcVectorSize));
204 
205  auto threadwise_gamma_load =
206  ThreadwiseTensorSliceTransfer_v2<GammaDataType,
207  ComputeDataType,
208  GridDesc_M_K,
209  decltype(thread_buffer_desc_m_k),
212  GammaSrcVectorDim,
213  GammaSrcVectorSize,
214  1,
215  true>(
216  gamma_grid_desc_m_k,
217  make_multi_index(block_global_id * M_BlockTileSize +
218  thread_m_cluster_id * MThreadSliceSize,
219  thread_k_cluster_id * GammaSrcVectorSize));
220 
221  auto threadwise_beta_load =
223  ComputeDataType,
224  GridDesc_M_K,
225  decltype(thread_buffer_desc_m_k),
228  BetaSrcVectorDim,
229  BetaSrcVectorSize,
230  1,
231  true>(
232  beta_grid_desc_m_k,
233  make_multi_index(block_global_id * M_BlockTileSize +
234  thread_m_cluster_id * MThreadSliceSize,
235  thread_k_cluster_id * BetaSrcVectorSize));
236 
237  auto threadwise_y_store =
238  ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
239  YDataType,
240  decltype(thread_buffer_desc_m_k),
241  GridDesc_M_K,
242  YElementwiseOperation,
245  YDstVectorDim,
246  YDstVectorSize,
248  1,
249  true>(
250  y_grid_desc_m_k,
251  make_multi_index(block_global_id * M_BlockTileSize +
252  thread_m_cluster_id * MThreadSliceSize,
253  thread_k_cluster_id * YDstVectorSize),
254  y_elementwise_op);
255 
256  auto threadwise_mean_store =
257  ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
258  SaveMeanInvStdDataType,
259  decltype(thread_buffer_desc_m),
260  GridDesc_M,
263  Sequence<0>, // DimAccessOrder
264  0, // SrcVectorDim
265  SaveMeanInvStdDstVectorSize, // ScalarPerVector
267  1,
268  true>(
269  save_mean_grid_desc_m,
270  make_multi_index(block_global_id * M_BlockTileSize +
271  thread_m_cluster_id * MThreadSliceSize),
272  PassThroughOp{});
273 
274  auto threadwise_inv_std_store =
275  ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
276  SaveMeanInvStdDataType,
277  decltype(thread_buffer_desc_m),
278  GridDesc_M,
281  Sequence<0>, // DimAccessOrder
282  0, // SrcVectorDim
283  SaveMeanInvStdDstVectorSize, // ScalarPerVector
285  1,
286  true>(
287  save_inv_std_grid_desc_m,
288  make_multi_index(block_global_id * M_BlockTileSize +
289  thread_m_cluster_id * MThreadSliceSize),
290  PassThroughOp{});
291 
292  constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
293  constexpr auto thread_copy_bwd_step_m_k =
294  make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
295 
296  const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
297  p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
298 
299  const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
300  p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize());
301 
302  const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
303  p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize());
304 
305  // E(x), E[x^2], var(x)
306  // FIXME: Should not hack the transform from deviceOP
307  ComputeDataType reduce_length = type_convert<ComputeDataType>(
308  x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]);
309 
310  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
311  mean_thread_buf(I) = reduce::Add::template GetIdentityValue<ComputeDataType>();
312  mean_square_thread_buf(I) = reduce::Add::template GetIdentityValue<ComputeDataType>();
313  });
314 
315  // Separate sweep once and sweep twice pipeline
316  if constexpr(SweepOnce)
317  {
319  threadwise_x_load.Run(x_grid_desc_m_k,
320  x_global_val_buf,
322  make_tuple(I0, I0),
323  x_thread_buf(i));
324 
325  threadwise_gamma_load.Run(gamma_grid_desc_m_k,
326  gamma_global_val_buf,
328  make_tuple(I0, I0),
329  gamma_thread_buf(i));
330 
331  static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
332  static_for<0, XSrcVectorSize, 1>{}([&](auto iK) {
333  constexpr auto offset_m_k =
334  thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
335  x_square_thread_buf(i)(Number<offset_m_k>{}) =
336  x_thread_buf(i)(Number<offset_m_k>{}) *
337  x_thread_buf(i)(Number<offset_m_k>{});
338  });
339  });
340 
341  ThreadwiseSumReduce::Reduce(x_thread_buf[i], mean_thread_buf);
342  ThreadwiseSumReduce::Reduce(x_square_thread_buf[i], mean_square_thread_buf);
343 
344  if constexpr(i != ThreadBufferNumber - 1)
345  {
346  threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
347  threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
348  thread_copy_fwd_step_m_k);
349  }
350  });
351 
352  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
353  if constexpr(I > 0)
354  block_sync_lds();
355 
356  BlockwiseSumReduce::Reduce(reduce_work_buf, mean_thread_buf(I));
357  mean_thread_buf(I) = mean_thread_buf(I) / reduce_length;
358 
359  block_sync_lds();
360 
361  BlockwiseSumReduce::Reduce(reduce_work_buf, mean_square_thread_buf(I));
362  mean_square_thread_buf(I) = mean_square_thread_buf(I) / reduce_length;
363 
364  // var(x) = E[x^2] - E[x]^2
365  var_thread_buf(I) =
366  mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I));
367 
368  inv_std_thread_buf(I) = type_convert<ComputeDataType>(1.0f) /
369  ck::math::sqrt(var_thread_buf(I) + epsilon);
370  });
371 
372  // save mean and inverse std for backward (optional)
373  if(thread_k_cluster_id == 0)
374  {
375  if(p_save_mean_global != nullptr)
376  {
377  threadwise_mean_store.Run(thread_buffer_desc_m,
378  make_tuple(I0),
379  mean_thread_buf,
380  save_mean_grid_desc_m,
381  save_mean_global_val_buf);
382  }
383  if(p_save_inv_std_global != nullptr)
384  {
385  threadwise_inv_std_store.Run(thread_buffer_desc_m,
386  make_tuple(I0),
387  inv_std_thread_buf,
388  save_inv_std_grid_desc_m,
389  save_inv_std_global_val_buf);
390  }
391  }
392 
393  // normalization
394  static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
395  static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
396  static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
397  constexpr auto offset_m_k =
398  thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
399 
400  // normalize
401  y_thread_buf(iK0)(Number<offset_m_k>{}) =
402  (x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
403  inv_std_thread_buf(iM);
404 
405  // gamma & beta
406  y_thread_buf(iK0)(Number<offset_m_k>{}) =
407  y_thread_buf(iK0)(Number<offset_m_k>{}) *
408  gamma_thread_buf(iK0)(Number<offset_m_k>{});
409  });
410  });
411  });
412 
414  threadwise_beta_load.Run(beta_grid_desc_m_k,
415  beta_global_val_buf,
417  make_tuple(I0, I0),
418  beta_thread_buf(i));
419 
420  if constexpr(i != ThreadBufferNumber - 1)
421  threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
422  thread_copy_fwd_step_m_k);
423  });
424 
425  static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
426  static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
427  static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
428  constexpr auto offset_m_k =
429  thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
430 
431  // beta
432  y_thread_buf(iK0)(Number<offset_m_k>{}) =
433  y_thread_buf(iK0)(Number<offset_m_k>{}) +
434  beta_thread_buf(iK0)(Number<offset_m_k>{});
435  });
436  });
437  });
438 
440  threadwise_y_store.Run(thread_buffer_desc_m_k,
441  make_tuple(I0, I0),
442  y_thread_buf(i),
443  y_grid_desc_m_k,
444  y_global_val_buf);
445 
446  if constexpr(i != ThreadBufferNumber - 1)
447  threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k,
448  thread_copy_fwd_step_m_k);
449  });
450  } // end of sweep once
451  else
452  {
453  for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
454  {
456  threadwise_x_load.Run(x_grid_desc_m_k,
457  x_global_val_buf,
459  make_tuple(I0, I0),
460  x_thread_buf(i));
461  threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
462 
463  static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
464  static_for<0, XSrcVectorSize, 1>{}([&](auto iK) {
465  constexpr auto offset_m_k =
466  thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
467  x_square_thread_buf(i)(Number<offset_m_k>{}) =
468  x_thread_buf(i)(Number<offset_m_k>{}) *
469  x_thread_buf(i)(Number<offset_m_k>{});
470  });
471  });
472 
473  ThreadwiseSumReduce::Reduce(x_thread_buf[i], mean_thread_buf);
474  ThreadwiseSumReduce::Reduce(x_square_thread_buf[i], mean_square_thread_buf);
475  });
476  }
477 
478  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
479  if constexpr(I > 0)
480  block_sync_lds();
481 
482  BlockwiseSumReduce::Reduce(reduce_work_buf, mean_thread_buf(I));
483  mean_thread_buf(I) = mean_thread_buf(I) / reduce_length;
484 
485  block_sync_lds();
486 
487  BlockwiseSumReduce::Reduce(reduce_work_buf, mean_square_thread_buf(I));
488  mean_square_thread_buf(I) = mean_square_thread_buf(I) / reduce_length;
489 
490  // var(x) = E[x^2] - E[x]^2
491  var_thread_buf(I) =
492  mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I));
493 
494  inv_std_thread_buf(I) = 1 / ck::math::sqrt(var_thread_buf(I) + epsilon);
495  });
496 
497  if(thread_k_cluster_id == 0)
498  {
499  if(p_save_mean_global != nullptr)
500  {
501  threadwise_mean_store.Run(thread_buffer_desc_m,
502  make_tuple(I0),
503  mean_thread_buf,
504  save_mean_grid_desc_m,
505  save_mean_global_val_buf);
506  }
507  if(p_save_inv_std_global != nullptr)
508  {
509  threadwise_inv_std_store.Run(thread_buffer_desc_m,
510  make_tuple(I0),
511  inv_std_thread_buf,
512  save_inv_std_grid_desc_m,
513  save_inv_std_global_val_buf);
514  }
515  }
516 
517  auto thread_copy_tail_m_k =
518  (num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k;
519 
520  threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
521  threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k);
522  threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k);
523  threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k);
524 
525  for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
526  {
528  threadwise_x_load.Run(x_grid_desc_m_k,
529  x_global_val_buf,
531  make_tuple(I0, I0),
532  x_thread_buf(i));
533  threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
534  });
535 
537  threadwise_gamma_load.Run(gamma_grid_desc_m_k,
538  gamma_global_val_buf,
540  make_tuple(I0, I0),
541  gamma_thread_buf(i));
542 
543  threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
544  thread_copy_fwd_step_m_k);
545  });
546 
547  static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
548  static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
549  static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
550  constexpr auto offset_m_k =
551  thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
552 
553  // normalize
554  y_thread_buf(iK0)(Number<offset_m_k>{}) =
555  (x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
556  inv_std_thread_buf(iM);
557 
558  // gamma
559  y_thread_buf(iK0)(Number<offset_m_k>{}) =
560  y_thread_buf(iK0)(Number<offset_m_k>{}) *
561  gamma_thread_buf(iK0)(Number<offset_m_k>{});
562  });
563  });
564  });
565 
567  threadwise_beta_load.Run(beta_grid_desc_m_k,
568  beta_global_val_buf,
570  make_tuple(I0, I0),
571  beta_thread_buf(i));
572  threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
573  thread_copy_fwd_step_m_k);
574  });
575 
576  static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
577  static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
578  static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
579  constexpr auto offset_m_k =
580  thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
581 
582  // beta
583  y_thread_buf(iK0)(Number<offset_m_k>{}) =
584  y_thread_buf(iK0)(Number<offset_m_k>{}) +
585  beta_thread_buf(iK0)(Number<offset_m_k>{});
586  });
587  });
588  });
589 
591  threadwise_y_store.Run(thread_buffer_desc_m_k,
592  make_tuple(I0, I0),
593  y_thread_buf(i),
594  y_grid_desc_m_k,
595  y_global_val_buf);
596  threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k,
597  thread_copy_fwd_step_m_k);
598  });
599 
600  threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k);
601  threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
602  2 * thread_copy_bwd_step_m_k);
603  threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
604  2 * thread_copy_bwd_step_m_k);
605  threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k,
606  2 * thread_copy_bwd_step_m_k);
607  }
608  } // end of sweep twice
609  }
610 };
611 
612 } // namespace ck
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
Definition: gridwise_normalization_naive_variance.hpp:42
static constexpr auto ThreadBufferNumber
Definition: gridwise_normalization_naive_variance.hpp:108
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}, Number< XSrcVectorSize >{}))) ThreadReduceSrcDesc_M_K
Definition: gridwise_normalization_naive_variance.hpp:81
static constexpr auto I0
Definition: gridwise_normalization_naive_variance.hpp:100
static constexpr auto thread_buffer_desc_m_k
Definition: gridwise_normalization_naive_variance.hpp:73
static __device__ void Run(const GridDesc_M_K &x_grid_desc_m_k, const GridDesc_M_K &gamma_grid_desc_m_k, const GridDesc_M_K &beta_grid_desc_m_k, const GridDesc_M_K &y_grid_desc_m_k, const GridDesc_M &save_mean_grid_desc_m, const GridDesc_M &save_inv_std_grid_desc_m, index_t num_k_block_tile_iteration, ComputeDataType epsilon, const XDataType *const __restrict__ p_x_global, const GammaDataType *const __restrict__ p_gamma_global, const BetaDataType *const __restrict__ p_beta_global, YDataType *const __restrict__ p_y_global, SaveMeanInvStdDataType *const __restrict__ p_save_mean_global, SaveMeanInvStdDataType *const __restrict__ p_save_inv_std_global, const YElementwiseOperation y_elementwise_op)
Definition: gridwise_normalization_naive_variance.hpp:110
static constexpr auto I1
Definition: gridwise_normalization_naive_variance.hpp:101
static constexpr index_t K_BlockTileSize
Definition: gridwise_normalization_naive_variance.hpp:105
Sequence< MThreadSliceSize, XSrcVectorSize > ThreadBufferLengths_M_K
Definition: gridwise_normalization_naive_variance.hpp:72
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition: gridwise_normalization_naive_variance.hpp:61
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition: gridwise_normalization_naive_variance.hpp:64
Sequence< MThreadSliceSize > ThreadBufferLengths_M
Definition: gridwise_normalization_naive_variance.hpp:76
static constexpr bool reorder_thread_cluster
Definition: gridwise_normalization_naive_variance.hpp:59
static constexpr auto thread_cluster_desc
Definition: gridwise_normalization_naive_variance.hpp:69
static constexpr index_t M_BlockTileSize
Definition: gridwise_normalization_naive_variance.hpp:104
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition: gridwise_normalization_naive_variance.hpp:83
tensor_operation::element_wise::PassThrough PassThroughOp
Definition: gridwise_normalization_naive_variance.hpp:98
static constexpr auto thread_buffer_desc_m
Definition: gridwise_normalization_naive_variance.hpp:77
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition: gridwise_normalization_naive_variance.hpp:67
static constexpr index_t K_BlockTileStepSize
Definition: gridwise_normalization_naive_variance.hpp:106
static constexpr auto I2
Definition: gridwise_normalization_naive_variance.hpp:102
Definition: reduction_functions_blockwise.hpp:28
static __device__ void Reduce(BufferType &work_buffer, AccDataType &in_out_value)
Definition: reduction_functions_blockwise.hpp:44
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
Definition: reduction_functions_threadwise.hpp:23
static __device__ void Reduce(const SrcBufferType &src_buf, DstBufferType &dst_buf)
Definition: reduction_functions_threadwise.hpp:36
Definition: threadwise_tensor_slice_transfer.hpp:39
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:66
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:276
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer.hpp:389
Definition: functional.hpp:100
Definition: integral_constant.hpp:20
Definition: reduction_operator.hpp:37
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:334