25 template <
typename DYDataType,
27 typename GammaDataType,
28 typename MeanInvStdDataType,
29 typename ComputeDataType,
31 typename GridDesc_M_K,
44 index_t MeanInvStdSrcVectorSize,
51 static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize == DYSrcVectorSize) ||
52 (DYSrcVectorDim == 1 && KThreadSliceSize == DYSrcVectorSize)),
53 "Invalid thread slice sizes and/or dy vector sizes configuration, please check!");
55 static_assert(((XSrcVectorDim == 0 && MThreadSliceSize == XSrcVectorSize) ||
56 (XSrcVectorDim == 1 && KThreadSliceSize == XSrcVectorSize)),
57 "Invalid thread slice sizes and/or x vector sizes configuration, please check!");
60 ((GammaSrcVectorDim == 0 && MThreadSliceSize == GammaSrcVectorSize) ||
61 (GammaSrcVectorDim == 1 && KThreadSliceSize == GammaSrcVectorSize)),
62 "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!");
65 ((MeanInvStdSrcVectorDim == 0 && MThreadSliceSize == MeanInvStdSrcVectorSize) ||
66 (MeanInvStdSrcVectorDim == 1 && KThreadSliceSize == MeanInvStdSrcVectorSize)),
67 "Invalid thread slice sizes and/or mean/inv_std vector sizes configuration, please check!");
69 static_assert(((DXDstVectorDim == 0 && MThreadSliceSize == DXDstVectorSize) ||
70 (DXDstVectorDim == 1 && KThreadSliceSize == DXDstVectorSize)),
71 "Invalid thread slice sizes and/or dx vector sizes configuration, please check!");
114 __device__
static void Run(
const GridDesc_M_K& dy_grid_desc_m_k,
115 const GridDesc_M_K& x_grid_desc_m_k,
116 const GridDesc_M_K& gamma_grid_desc_m_k,
117 const GridDesc_M_K& mean_grid_desc_m_k,
118 const GridDesc_M_K& inv_std_grid_desc_m_k,
119 const GridDesc_M_K& dx_grid_desc_m_k,
120 index_t num_k_block_tile_iteration,
121 const DYDataType*
const __restrict__ p_dy_global,
122 const XDataType*
const __restrict__ p_x_global,
123 const GammaDataType*
const __restrict__ p_gamma_global,
124 const MeanInvStdDataType*
const __restrict__ p_mean_global,
125 const MeanInvStdDataType*
const __restrict__ p_inv_std_global,
126 DXDataType*
const __restrict__ p_dx_global)
129 __shared__ ComputeDataType p_reduce_work_buffer[BlockSize];
131 auto reduce_work_buf =
132 make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
135 const auto dy_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
136 p_dy_global, dy_grid_desc_m_k.GetElementSpaceSize());
138 const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
139 p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
141 auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
142 p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize());
144 const auto mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
145 p_mean_global, mean_grid_desc_m_k.GetElementSpaceSize());
147 const auto inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
148 p_inv_std_global, inv_std_grid_desc_m_k.GetElementSpaceSize());
150 auto dx_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
151 p_dx_global, dx_grid_desc_m_k.GetElementSpaceSize());
156 MThreadSliceSize * KThreadSliceSize,
161 MThreadSliceSize * KThreadSliceSize,
166 MThreadSliceSize * KThreadSliceSize,
171 MThreadSliceSize * KThreadSliceSize,
176 MThreadSliceSize * KThreadSliceSize,
181 MThreadSliceSize * KThreadSliceSize,
194 const auto thread_cluster_idx =
197 const auto thread_m_cluster_id = thread_cluster_idx[
I0];
198 const auto thread_k_cluster_id = thread_cluster_idx[
I1];
213 thread_m_cluster_id * MThreadSliceSize,
214 thread_k_cluster_id * KThreadSliceSize));
228 thread_m_cluster_id * MThreadSliceSize,
229 thread_k_cluster_id * KThreadSliceSize));
231 auto threadwise_gamma_load =
244 thread_m_cluster_id * MThreadSliceSize,
245 thread_k_cluster_id * KThreadSliceSize));
247 auto threadwise_mean_load =
254 MeanInvStdSrcVectorDim,
255 MeanInvStdSrcVectorSize,
260 thread_m_cluster_id * MThreadSliceSize,
261 thread_k_cluster_id * KThreadSliceSize));
263 auto threadwise_inv_std_load =
270 MeanInvStdSrcVectorDim,
271 MeanInvStdSrcVectorSize,
274 inv_std_grid_desc_m_k,
276 thread_m_cluster_id * MThreadSliceSize,
277 thread_k_cluster_id * KThreadSliceSize));
279 auto threadwise_dx_store =
294 thread_m_cluster_id * MThreadSliceSize,
295 thread_k_cluster_id * KThreadSliceSize),
298 ComputeDataType reduce_size = type_convert<ComputeDataType>(
299 dy_grid_desc_m_k.GetTransforms()[
I2].GetUpperLengths()[
I0]);
302 ds_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
303 db_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
309 if constexpr(SweepOnce)
311 threadwise_dy_load.Run(dy_grid_desc_m_k,
317 threadwise_x_load.
Run(x_grid_desc_m_k,
323 threadwise_gamma_load.
Run(gamma_grid_desc_m_k,
324 gamma_global_val_buf,
329 threadwise_mean_load.
Run(mean_grid_desc_m_k,
335 threadwise_inv_std_load.
Run(inv_std_grid_desc_m_k,
336 inv_std_global_val_buf,
342 constexpr
auto offset_m =
346 constexpr
auto offset_m_k =
349 ds_thread_buf(offset_m) += dy_thread_buf[offset_m_k] *
350 gamma_thread_buf[offset_m_k] *
351 x_thread_buf[offset_m_k];
353 db_thread_buf(offset_m) +=
354 dy_thread_buf[offset_m_k] * gamma_thread_buf[offset_m_k];
368 constexpr
auto offset_m =
372 constexpr
auto offset_m_k =
379 ComputeDataType b = db_thread_buf[offset_m] * mean_thread_buf[offset_m_k] -
380 ds_thread_buf[offset_m];
382 b *= inv_std_thread_buf[offset_m_k] * inv_std_thread_buf[offset_m_k] *
383 inv_std_thread_buf[offset_m_k] / reduce_size;
385 ComputeDataType c = -b * mean_thread_buf(offset_m_k);
387 c -= db_thread_buf[offset_m] * inv_std_thread_buf[offset_m_k] / reduce_size;
389 dx_thread_buf(offset_m_k) = dy_thread_buf[offset_m_k] *
390 gamma_thread_buf[offset_m_k] *
391 inv_std_thread_buf[offset_m_k] +
392 b * x_thread_buf[offset_m_k] + c;
407 for(
index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
409 threadwise_dy_load.Run(dy_grid_desc_m_k,
415 threadwise_x_load.
Run(x_grid_desc_m_k,
421 threadwise_gamma_load.
Run(gamma_grid_desc_m_k,
422 gamma_global_val_buf,
427 threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_fwd_step_m_k);
430 thread_copy_fwd_step_m_k);
433 constexpr
auto offset_m =
437 constexpr
auto offset_m_k =
440 ds_thread_buf(offset_m) += dy_thread_buf[offset_m_k] *
441 gamma_thread_buf[offset_m_k] *
442 x_thread_buf[offset_m_k];
444 db_thread_buf(offset_m) +=
445 dy_thread_buf[offset_m_k] * gamma_thread_buf[offset_m_k];
461 auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k;
464 threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k);
466 threadwise_gamma_load.
MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_bwd_step_m_k);
470 threadwise_inv_std_load.
MoveSrcSliceWindow(inv_std_grid_desc_m_k, thread_copy_tail_m_k);
471 threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_tail_m_k);
473 for(
index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
475 threadwise_dy_load.
Run(dy_grid_desc_m_k,
481 threadwise_x_load.
Run(x_grid_desc_m_k,
487 threadwise_gamma_load.
Run(gamma_grid_desc_m_k,
488 gamma_global_val_buf,
493 threadwise_mean_load.
Run(mean_grid_desc_m_k,
499 threadwise_inv_std_load.
Run(inv_std_grid_desc_m_k,
500 inv_std_global_val_buf,
506 constexpr
auto offset_m =
510 constexpr
auto offset_m_k =
517 ComputeDataType b = db_thread_buf[offset_m] * mean_thread_buf[offset_m_k] -
518 ds_thread_buf[offset_m];
520 b *= inv_std_thread_buf[offset_m_k] * inv_std_thread_buf[offset_m_k] *
521 inv_std_thread_buf[offset_m_k] / reduce_size;
523 ComputeDataType c = -b * mean_thread_buf(offset_m_k);
525 c -= db_thread_buf[offset_m] * inv_std_thread_buf[offset_m_k] / reduce_size;
527 dx_thread_buf(offset_m_k) = dy_thread_buf[offset_m_k] *
528 gamma_thread_buf[offset_m_k] *
529 inv_std_thread_buf[offset_m_k] +
530 b * x_thread_buf[offset_m_k] + c;
540 threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k);
543 thread_copy_bwd_step_m_k);
545 thread_copy_bwd_step_m_k);
547 thread_copy_bwd_step_m_k);
548 threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_bwd_step_m_k);
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
Definition: gridwise_normalization_bwd_data.hpp:49
static constexpr auto thread_buffer_desc_m_k
Definition: gridwise_normalization_bwd_data.hpp:92
static constexpr index_t M_BlockTileSize
Definition: gridwise_normalization_bwd_data.hpp:111
static constexpr auto I1
Definition: gridwise_normalization_bwd_data.hpp:108
static constexpr auto I0
Definition: gridwise_normalization_bwd_data.hpp:107
Sequence< MThreadSliceSize, KThreadSliceSize > ThreadBufferLengths_M_K
Definition: gridwise_normalization_bwd_data.hpp:90
DYThreadBufferDimAccessOrder ThreadClusterArrangeOrder
Definition: gridwise_normalization_bwd_data.hpp:86
typename conditional< DXDstVectorDim==0, Sequence< 1, 0 >, Sequence< 0, 1 > >::type DXThreadBufferDimAccessOrder
Definition: gridwise_normalization_bwd_data.hpp:84
static constexpr auto thread_buffer_desc_m
Definition: gridwise_normalization_bwd_data.hpp:95
static constexpr auto thread_cluster_desc
Definition: gridwise_normalization_bwd_data.hpp:87
static constexpr index_t K_BlockTileSize
Definition: gridwise_normalization_bwd_data.hpp:112
typename conditional< DYSrcVectorDim==0, Sequence< 1, 0 >, Sequence< 0, 1 > >::type DYThreadBufferDimAccessOrder
Definition: gridwise_normalization_bwd_data.hpp:76
static __device__ void Run(const GridDesc_M_K &dy_grid_desc_m_k, const GridDesc_M_K &x_grid_desc_m_k, const GridDesc_M_K &gamma_grid_desc_m_k, const GridDesc_M_K &mean_grid_desc_m_k, const GridDesc_M_K &inv_std_grid_desc_m_k, const GridDesc_M_K &dx_grid_desc_m_k, index_t num_k_block_tile_iteration, const DYDataType *const __restrict__ p_dy_global, const XDataType *const __restrict__ p_x_global, const GammaDataType *const __restrict__ p_gamma_global, const MeanInvStdDataType *const __restrict__ p_mean_global, const MeanInvStdDataType *const __restrict__ p_inv_std_global, DXDataType *const __restrict__ p_dx_global)
Definition: gridwise_normalization_bwd_data.hpp:114
typename conditional< GammaSrcVectorDim==0, Sequence< 1, 0 >, Sequence< 0, 1 > >::type GammaThreadBufferDimAccessOrder
Definition: gridwise_normalization_bwd_data.hpp:80
typename conditional< XSrcVectorDim==0, Sequence< 1, 0 >, Sequence< 0, 1 > >::type XThreadBufferDimAccessOrder
Definition: gridwise_normalization_bwd_data.hpp:78
typename conditional< MeanInvStdSrcVectorDim==0, Sequence< 1, 0 >, Sequence< 0, 1 > >::type MeanInvStdThreadBufferDimAccessOrder
Definition: gridwise_normalization_bwd_data.hpp:82
tensor_operation::element_wise::PassThrough PassThroughOp
Definition: gridwise_normalization_bwd_data.hpp:98
static constexpr auto I2
Definition: gridwise_normalization_bwd_data.hpp:109
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition: gridwise_normalization_bwd_data.hpp:73
Definition: reduction_functions_blockwise.hpp:28
static __device__ void Reduce(BufferType &work_buffer, AccDataType &in_out_value)
Definition: reduction_functions_blockwise.hpp:44
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:276
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer.hpp:389
Definition: functional.hpp:100
Definition: integral_constant.hpp:20
Definition: reduction_operator.hpp:37
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:334