/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp Source File
gridwise_normalization_splitk_2nd.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/math.hpp"
12 
13 namespace ck {
14 
15 template <typename MeanVarDataType,
16  typename XDataType,
17  typename GammaDataType,
18  typename BetaDataType,
19  typename YDataType,
20  typename SaveMeanInvStdDataType,
21  typename ComputeDataType,
22  typename YElementwiseOperation,
23  typename MeanVarGridDesc_M_KBlock,
24  typename CountGridDesc_M_KBlock,
25  typename XYGammaBetaGridDesc_M_K,
26  typename SaveMeanInvStdGridDesc_M,
27  index_t BlockSize,
28  index_t MThreadClusterSize,
29  index_t KThreadClusterSize,
30  index_t MThreadSliceSize,
31  index_t KThreadSliceSize,
32  index_t XSrcVectorDim,
33  index_t XSrcVectorSize,
34  index_t GammaSrcVectorDim,
35  index_t GammaSrcVectorSize,
36  index_t BetaSrcVectorDim,
37  index_t BetaSrcVectorSize,
38  index_t YDstVectorDim,
39  index_t YDstVectorSize,
40  index_t SaveMeanInvStdDstVectorSize>
42 {
43  static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
44  (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
45  "Invalid thread slice sizes and/or vector sizes configuration, please check!");
46 
47  static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) ||
48  (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
49  "Invalid thread slice sizes and/or vector sizes configuration, please check!");
50 
51  static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
52  "Invalid thread slice sizes and/or save mean and inverse std vector sizes "
53  "configuration, please check!");
54 
55  static_assert(XSrcVectorSize == YDstVectorSize);
56  static_assert(XSrcVectorSize == GammaSrcVectorSize);
57  static_assert(XSrcVectorSize == BetaSrcVectorSize);
58 
59  static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0);
60 
61  static constexpr auto I0 = Number<0>{};
62  static constexpr auto I1 = Number<1>{};
63 
65 
68 
71 
72  static constexpr auto thread_cluster_desc =
74 
78 
80  static constexpr auto thread_buffer_desc_m =
82 
84  static constexpr auto thread_buffer_desc_m_1 =
86 
90 
93 
94  using BlockwiseWelford = BlockwiseWelford<ComputeDataType,
95  BlockSize,
98 
100 
101  static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
102  static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
103  static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize;
104 
105  static constexpr auto ThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
106 
107  __device__ static void Run(const MeanVarGridDesc_M_KBlock& mean_var_grid_desc_m_kblock,
108  const CountGridDesc_M_KBlock& count_grid_desc_m_kblock,
109  const XYGammaBetaGridDesc_M_K& x_grid_desc_m_k,
110  const XYGammaBetaGridDesc_M_K& gamma_grid_desc_m_k,
111  const XYGammaBetaGridDesc_M_K& beta_grid_desc_m_k,
112  const XYGammaBetaGridDesc_M_K& y_grid_desc_m_k,
113  const SaveMeanInvStdGridDesc_M& save_mean_grid_desc_m,
114  const SaveMeanInvStdGridDesc_M& save_inv_std_grid_desc_m,
115  index_t num_k_mean_var_count_iteration,
116  index_t num_k_block_tile_iteration,
117  index_t k_grid_size,
118  ComputeDataType epsilon,
119  const MeanVarDataType* const p_mean_global,
120  const MeanVarDataType* const p_variance_global,
121  const int32_t* const p_welford_count_global,
122  const XDataType* const __restrict__ p_x_global,
123  const GammaDataType* const __restrict__ p_gamma_global,
124  const BetaDataType* const __restrict__ p_beta_global,
125  YDataType* const __restrict__ p_y_global,
126  SaveMeanInvStdDataType* const __restrict__ p_save_mean_global,
127  SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global,
128  const YElementwiseOperation y_elementwise_op)
129  {
130  // Thread/Block id
131  const index_t thread_local_id = get_thread_local_1d_id();
132  const index_t block_global_id = get_block_1d_id();
133  const index_t block_m_cluster_id = block_global_id / k_grid_size;
134  const index_t block_k_cluster_id = block_global_id % k_grid_size;
135  const auto thread_cluster_idx =
136  thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
137 
138  const auto thread_m_cluster_id = thread_cluster_idx[I0];
139  const auto thread_k_cluster_id = thread_cluster_idx[I1];
140 
141  // Global Memory
142  const auto mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
143  p_mean_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize());
144 
145  const auto var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
146  p_variance_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize());
147 
148  const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
149  p_welford_count_global, count_grid_desc_m_kblock.GetElementSpaceSize());
150 
151  const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
152  p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
153 
154  const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
155  p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize());
156 
157  const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
158  p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize());
159 
160  auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
161  p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
162 
163  auto save_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
164  p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize());
165 
166  auto save_inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
167  p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize());
168 
169  // VGPR
171  in_mean_thread_buf;
173  in_var_thread_buf;
175  in_welford_count_thread_buf;
177  mean_thread_buf;
179  var_thread_buf;
181  welford_count_thread_buf;
182  auto& inv_std_thread_buf = var_thread_buf;
183 
184  auto x_thread_buf = generate_tuple(
185  [&](auto) {
187  ComputeDataType,
188  MThreadSliceSize * XSrcVectorSize,
189  true>{};
190  },
192 
193  auto gamma_thread_buf = generate_tuple(
194  [&](auto) {
196  ComputeDataType,
197  MThreadSliceSize * GammaSrcVectorSize,
198  true>{};
199  },
201 
202  auto& beta_thread_buf = gamma_thread_buf;
203  auto& y_thread_buf = x_thread_buf;
204 
205  // IO
206  auto threadwise_mean_var_load_m_kblock =
207  ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
208  ComputeDataType,
209  MeanVarGridDesc_M_KBlock,
210  decltype(thread_buffer_desc_m_1),
213  1,
214  1,
215  1,
216  true>(
217  mean_var_grid_desc_m_kblock,
218  make_multi_index(block_m_cluster_id * M_BlockTileSize +
219  thread_m_cluster_id * MThreadSliceSize,
220  thread_k_cluster_id));
221 
222  auto threadwise_count_load_m_kblock =
224  int32_t,
225  CountGridDesc_M_KBlock,
226  decltype(thread_buffer_desc_m_1),
229  1,
230  1,
231  1,
232  true>(
233  count_grid_desc_m_kblock,
234  make_multi_index(block_m_cluster_id * M_BlockTileSize +
235  thread_m_cluster_id * MThreadSliceSize,
236  thread_k_cluster_id));
237 
238  auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
239  ComputeDataType,
240  XYGammaBetaGridDesc_M_K,
241  decltype(thread_buffer_desc_m_k),
244  XSrcVectorDim,
245  XSrcVectorSize,
246  1,
247  true>(
248  x_grid_desc_m_k,
249  make_multi_index(block_m_cluster_id * M_BlockTileSize +
250  thread_m_cluster_id * MThreadSliceSize,
251  block_k_cluster_id * K_BlockTileSize * num_k_block_tile_iteration +
252  thread_k_cluster_id * XSrcVectorSize));
253 
254  auto threadwise_gamma_load =
255  ThreadwiseTensorSliceTransfer_v2<GammaDataType,
256  ComputeDataType,
257  XYGammaBetaGridDesc_M_K,
258  decltype(thread_buffer_desc_m_k),
261  GammaSrcVectorDim,
262  GammaSrcVectorSize,
263  1,
264  true>(
265  gamma_grid_desc_m_k,
266  make_multi_index(block_m_cluster_id * M_BlockTileSize +
267  thread_m_cluster_id * MThreadSliceSize,
268  block_k_cluster_id * K_BlockTileSize * num_k_block_tile_iteration +
269  thread_k_cluster_id * GammaSrcVectorSize));
270 
271  auto threadwise_beta_load =
273  ComputeDataType,
274  XYGammaBetaGridDesc_M_K,
275  decltype(thread_buffer_desc_m_k),
278  BetaSrcVectorDim,
279  BetaSrcVectorSize,
280  1,
281  true>(
282  beta_grid_desc_m_k,
283  make_multi_index(block_m_cluster_id * M_BlockTileSize +
284  thread_m_cluster_id * MThreadSliceSize,
285  block_k_cluster_id * K_BlockTileSize * num_k_block_tile_iteration +
286  thread_k_cluster_id * BetaSrcVectorSize));
287 
288  auto threadwise_y_store =
289  ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
290  YDataType,
291  decltype(thread_buffer_desc_m_k),
292  XYGammaBetaGridDesc_M_K,
293  YElementwiseOperation,
296  YDstVectorDim,
297  YDstVectorSize,
299  1,
300  true>(
301  y_grid_desc_m_k,
302  make_multi_index(block_m_cluster_id * M_BlockTileSize +
303  thread_m_cluster_id * MThreadSliceSize,
304  block_k_cluster_id * K_BlockTileSize * num_k_block_tile_iteration +
305  thread_k_cluster_id * YDstVectorSize),
306  y_elementwise_op);
307 
308  auto threadwise_mean_store =
309  ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
310  SaveMeanInvStdDataType,
311  decltype(thread_buffer_desc_m),
312  SaveMeanInvStdGridDesc_M,
315  Sequence<0>, // DimAccessOrder
316  0, // SrcVectorDim
317  SaveMeanInvStdDstVectorSize, // ScalarPerVector
319  1,
320  true>(
321  save_mean_grid_desc_m,
322  make_multi_index(block_m_cluster_id * M_BlockTileSize +
323  thread_m_cluster_id * MThreadSliceSize),
324  PassThroughOp{});
325 
326  auto threadwise_inv_std_store =
327  ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
328  SaveMeanInvStdDataType,
329  decltype(thread_buffer_desc_m),
330  SaveMeanInvStdGridDesc_M,
333  Sequence<0>, // DimAccessOrder
334  0, // SrcVectorDim
335  SaveMeanInvStdDstVectorSize, // ScalarPerVector
337  1,
338  true>(
339  save_inv_std_grid_desc_m,
340  make_multi_index(block_m_cluster_id * M_BlockTileSize +
341  thread_m_cluster_id * MThreadSliceSize),
342  PassThroughOp{});
343 
344  // step1: Merge mean and variance
345  constexpr auto mean_var_count_thread_copy_step_I0_k =
346  make_multi_index(I0, KThreadClusterSize);
347 
348  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
349  mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
350  var_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
351  welford_count_thread_buf(I) = 0;
352  });
353 
354  for(index_t k = 0; k < num_k_mean_var_count_iteration; ++k)
355  {
356  threadwise_mean_var_load_m_kblock.Run(mean_var_grid_desc_m_kblock,
357  mean_global_val_buf,
359  make_tuple(I0, I0),
360  in_mean_thread_buf);
361 
362  threadwise_mean_var_load_m_kblock.Run(mean_var_grid_desc_m_kblock,
363  var_global_val_buf,
365  make_tuple(I0, I0),
366  in_var_thread_buf);
367 
368  threadwise_count_load_m_kblock.Run(count_grid_desc_m_kblock,
369  welford_count_global_val_buf,
371  make_tuple(I0, I0),
372  in_welford_count_thread_buf);
373 
374  ThreadwiseWelford::Run(in_mean_thread_buf,
375  in_var_thread_buf,
376  in_welford_count_thread_buf,
377  mean_thread_buf,
378  var_thread_buf,
379  welford_count_thread_buf);
380 
381  threadwise_mean_var_load_m_kblock.MoveSrcSliceWindow(
382  mean_var_grid_desc_m_kblock, mean_var_count_thread_copy_step_I0_k);
383  threadwise_count_load_m_kblock.MoveSrcSliceWindow(count_grid_desc_m_kblock,
384  mean_var_count_thread_copy_step_I0_k);
385  }
386 
387  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
388  if constexpr(I > 0)
389  block_sync_lds();
390 
392  mean_thread_buf(I), var_thread_buf(I), welford_count_thread_buf(I));
393 
394  inv_std_thread_buf(I) =
395  type_convert<ComputeDataType>(1.0f) / ck::math::sqrt(var_thread_buf(I) + epsilon);
396  });
397 
398  // step2: save mean and inverse std for backward (optional)
399  if(block_k_cluster_id == 0 && thread_k_cluster_id == 0)
400  {
401  if(p_save_mean_global != nullptr)
402  {
403  threadwise_mean_store.Run(thread_buffer_desc_m,
404  make_tuple(I0),
405  mean_thread_buf,
406  save_mean_grid_desc_m,
407  save_mean_global_val_buf);
408  }
409  if(p_save_inv_std_global != nullptr)
410  {
411  threadwise_inv_std_store.Run(thread_buffer_desc_m,
412  make_tuple(I0),
413  inv_std_thread_buf,
414  save_inv_std_grid_desc_m,
415  save_inv_std_global_val_buf);
416  }
417  }
418 
419  // step3: normalization
420  constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
421 
422  for(index_t k = 0; k < num_k_block_tile_iteration; ++k)
423  {
425  threadwise_x_load.Run(x_grid_desc_m_k,
426  x_global_val_buf,
428  make_tuple(I0, I0),
429  x_thread_buf(i));
430  threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
431  });
432 
434  threadwise_gamma_load.Run(gamma_grid_desc_m_k,
435  gamma_global_val_buf,
437  make_tuple(I0, I0),
438  gamma_thread_buf(i));
439 
440  threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
441  thread_copy_fwd_step_m_k);
442  });
443 
444  static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
445  static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
446  static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
447  constexpr auto offset_m_k =
448  thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
449 
450  // normalize
451  y_thread_buf(iK0)(Number<offset_m_k>{}) =
452  (x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
453  inv_std_thread_buf(iM);
454 
455  // gamma
456  y_thread_buf(iK0)(Number<offset_m_k>{}) =
457  y_thread_buf(iK0)(Number<offset_m_k>{}) *
458  gamma_thread_buf(iK0)(Number<offset_m_k>{});
459  });
460  });
461  });
462 
464  threadwise_beta_load.Run(beta_grid_desc_m_k,
465  beta_global_val_buf,
467  make_tuple(I0, I0),
468  beta_thread_buf(i));
469  threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
470  thread_copy_fwd_step_m_k);
471  });
472 
473  static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
474  static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
475  static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
476  constexpr auto offset_m_k =
477  thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
478 
479  // beta
480  y_thread_buf(iK0)(Number<offset_m_k>{}) =
481  y_thread_buf(iK0)(Number<offset_m_k>{}) +
482  beta_thread_buf(iK0)(Number<offset_m_k>{});
483  });
484  });
485  });
486 
488  threadwise_y_store.Run(thread_buffer_desc_m_k,
489  make_tuple(I0, I0),
490  y_thread_buf(i),
491  y_grid_desc_m_k,
492  y_global_val_buf);
493  threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_fwd_step_m_k);
494  });
495  } // end for (normalization)
496  }
497 };
498 
499 } // namespace ck
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
signed int int32_t
Definition: stdint.h:123
static __device__ void Run(T &mean_value, T &var_value, CountDataType &count)
Definition: blockwise_welford.hpp:51
Definition: gridwise_normalization_splitk_2nd.hpp:42
static constexpr auto thread_buffer_desc_m
Definition: gridwise_normalization_splitk_2nd.hpp:80
static constexpr auto I0
Definition: gridwise_normalization_splitk_2nd.hpp:61
BlockwiseWelford< ComputeDataType, BlockSize, ThreadClusterLengths_M_K, ThreadClusterArrangeOrder > BlockwiseWelford
Definition: gridwise_normalization_splitk_2nd.hpp:97
static constexpr bool reorder_thread_cluster
Definition: gridwise_normalization_splitk_2nd.hpp:59
static constexpr auto I1
Definition: gridwise_normalization_splitk_2nd.hpp:62
static constexpr index_t K_BlockTileSize
Definition: gridwise_normalization_splitk_2nd.hpp:102
static constexpr auto ThreadBufferNumber
Definition: gridwise_normalization_splitk_2nd.hpp:105
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition: gridwise_normalization_splitk_2nd.hpp:67
Sequence< MThreadSliceSize > ThreadBufferLengths_M
Definition: gridwise_normalization_splitk_2nd.hpp:79
Sequence< MThreadSliceSize, 1 > ThreadBufferLengths_M_1
Definition: gridwise_normalization_splitk_2nd.hpp:83
static constexpr auto thread_cluster_desc
Definition: gridwise_normalization_splitk_2nd.hpp:72
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition: gridwise_normalization_splitk_2nd.hpp:64
static constexpr auto thread_buffer_desc_m_k
Definition: gridwise_normalization_splitk_2nd.hpp:76
static __device__ void Run(const MeanVarGridDesc_M_KBlock &mean_var_grid_desc_m_kblock, const CountGridDesc_M_KBlock &count_grid_desc_m_kblock, const XYGammaBetaGridDesc_M_K &x_grid_desc_m_k, const XYGammaBetaGridDesc_M_K &gamma_grid_desc_m_k, const XYGammaBetaGridDesc_M_K &beta_grid_desc_m_k, const XYGammaBetaGridDesc_M_K &y_grid_desc_m_k, const SaveMeanInvStdGridDesc_M &save_mean_grid_desc_m, const SaveMeanInvStdGridDesc_M &save_inv_std_grid_desc_m, index_t num_k_mean_var_count_iteration, index_t num_k_block_tile_iteration, index_t k_grid_size, ComputeDataType epsilon, const MeanVarDataType *const p_mean_global, const MeanVarDataType *const p_variance_global, const int32_t *const p_welford_count_global, const XDataType *const __restrict__ p_x_global, const GammaDataType *const __restrict__ p_gamma_global, const BetaDataType *const __restrict__ p_beta_global, YDataType *const __restrict__ p_y_global, SaveMeanInvStdDataType *const __restrict__ p_save_mean_global, SaveMeanInvStdDataType *const __restrict__ p_save_inv_std_global, const YElementwiseOperation y_elementwise_op)
Definition: gridwise_normalization_splitk_2nd.hpp:107
Sequence< MThreadSliceSize, XSrcVectorSize > ThreadBufferLengths_M_K
Definition: gridwise_normalization_splitk_2nd.hpp:75
static constexpr index_t M_BlockTileSize
Definition: gridwise_normalization_splitk_2nd.hpp:101
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadWelfordDstDesc_M
Definition: gridwise_normalization_splitk_2nd.hpp:89
static constexpr index_t K_BlockTileStepSize
Definition: gridwise_normalization_splitk_2nd.hpp:103
decltype(thread_buffer_desc_m_1) ThreadWelfordSrcDesc_M_1
Definition: gridwise_normalization_splitk_2nd.hpp:87
tensor_operation::element_wise::PassThrough PassThroughOp
Definition: gridwise_normalization_splitk_2nd.hpp:99
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition: gridwise_normalization_splitk_2nd.hpp:70
static constexpr auto thread_buffer_desc_m_1
Definition: gridwise_normalization_splitk_2nd.hpp:84
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
Definition: threadwise_tensor_slice_transfer.hpp:39
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:66
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:276
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer.hpp:389
Definition: threadwise_welford.hpp:83
static __device__ void Run(const SrcMeanBufferType &src_mean_buf, const SrcVarBufferType &src_var_buf, const SrcCountBufferType &src_count_buf, DstMeanBufferType &dst_mean_buf, DstVarBufferType &dst_var_buf, DstCountBufferType &dst_count_buf)
Definition: threadwise_welford.hpp:110
Definition: functional.hpp:100
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:334