11 template <
typename ReduceAccDataType,
12 typename ReducePtrsGlobal,
13 typename D0ElementwiseOperation,
14 typename ReduceOperations,
15 typename ReduceInElementwiseOperations,
16 typename ReduceAccElementwiseOperations,
17 typename ReduceGlobalMemoryDataOperation,
18 typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
19 index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
20 index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock>
31 CReduceThreadClusterLengths_MPerBlock_NPerBlock;
33 CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock;
35 CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock;
38 template <
typename DsDataType,
41 typename CShuffleDataType,
48 index_t CShuffleMRepeatPerShuffle,
49 index_t CShuffleNRepeatPerShuffle,
50 typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
51 typename CDEShuffleBlockTransferScalarPerVectors,
52 typename CDEElementwiseOperation,
54 typename BlockwiseGemmPipe,
69 CShuffleMRepeatPerShuffle,
70 CShuffleNRepeatPerShuffle,
71 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
72 CDEShuffleBlockTransferScalarPerVectors,
73 CDEElementwiseOperation,
88 CShuffleMRepeatPerShuffle,
89 CShuffleNRepeatPerShuffle,
90 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
91 CDEShuffleBlockTransferScalarPerVectors,
92 CDEElementwiseOperation,
112 const auto MPad = M -
MRaw;
114 if constexpr(GemmSpec == GemmSpecialization::MPadding ||
115 GemmSpec == GemmSpecialization::MNPadding ||
116 GemmSpec == GemmSpecialization::MKPadding ||
117 GemmSpec == GemmSpecialization::MNKPadding)
128 return d_grid_desc_mraw;
134 __device__
static constexpr
auto
137 const auto M = d_grid_desc_m.GetLength(
I0);
138 const auto MBlock = M / MPerBlock;
146 return reduce_grid_desc_mblock_mperblock;
150 typename ReduceTrait::ReducePtrsGlobal_ p_reduces_grid_,
151 const typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops_,
152 const typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops_,
154 const typename ReduceTrait::D0ElementwiseOperation_ d0_element_op_)
166 typename DsGridPointer,
167 typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
168 typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
169 __device__
void Run(CThreadBuf& c_thread_buf,
170 DsGridPointer p_ds_grid,
173 const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
174 ds_grid_desc_mblock_mperblock_nblock_nperblock,
175 const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
176 e_grid_desc_mblock_mperblock_nblock_nperblock,
177 CDEElementwiseOperation& cde_element_op,
182 const index_t m_block_data_idx_on_grid =
183 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
185 const index_t n_block_data_idx_on_grid =
186 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
188 auto reduce_grid_desc_mblock_mperblock =
193 return make_dynamic_buffer<AddressSpaceEnum::Global>(
195 ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
199 auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
200 p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
203 constexpr
auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
205 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
208 constexpr
auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
211 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
212 static_cast<CShuffleDataType*
>(p_shared),
213 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
214 .GetElementSpaceSize());
227 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
232 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
236 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength(
240 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength(
246 ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(
I0) *
247 ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(
I1) ==
252 (CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma) %
253 ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(
I0) ==
255 (CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma) %
256 ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(
I1) ==
260 constexpr
index_t mreduce_per_thread =
261 (CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma) /
262 ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(
I0);
264 constexpr
index_t nreduce_per_thread =
265 (CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma) /
266 ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(
I1);
268 static constexpr
index_t NumReduce = ReduceTrait::ReducePtrsGlobal_::Size();
270 constexpr
auto c_reduce_thread_lengths_mperblock_nperblock =
274 constexpr
auto c_reduce_thread_desc_mperblock_nperblock =
279 constexpr
auto reduce_thread_desc_mperblock =
283 constexpr
auto reduce_thread_desc_mblock_mperblock =
286 auto c_reduce_thread_buf =
287 make_static_buffer<AddressSpaceEnum::Vgpr, typename ReduceTrait::ReduceAccDataType_>(
288 c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
292 typename ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_{},
295 const auto c_reduce_thread_cluster_idx = c_reduce_thread_cluster_desc.CalculateBottomIndex(
298 const auto c_reduce_thread_data_idx_begin =
299 c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock;
303 typename ReduceTrait::ReduceAccDataType_,
304 decltype(c_reduce_block_desc_mperblock_nperblock),
305 decltype(c_reduce_thread_desc_mperblock_nperblock),
306 decltype(c_reduce_thread_lengths_mperblock_nperblock),
309 ReduceTrait::CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_,
311 true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin};
319 typename ReduceTrait::ReduceAccDataType_,
321 decltype(reduce_thread_desc_mblock_mperblock),
322 decltype(reduce_grid_desc_mblock_mperblock),
323 decltype(reduce_acc_element_op),
327 ReduceTrait::CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock_,
328 ReduceTrait::ReduceGlobalMemoryDataOperation_::At(I),
330 false>{reduce_grid_desc_mblock_mperblock,
332 c_reduce_thread_data_idx_begin[
I0]),
333 reduce_acc_element_op};
338 constexpr
auto d_reduce_thread_desc_mblock_mperblock_nblock_nperblock =
342 constexpr
auto ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock =
generate_tuple(
343 [&](
auto) {
return d_reduce_thread_desc_mblock_mperblock_nblock_nperblock; },
346 constexpr
auto ds_thread_buf_size =
347 d_reduce_thread_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
349 auto c01_thread_buf =
350 make_static_buffer<AddressSpaceEnum::Vgpr, typename ReduceTrait::ReduceAccDataType_>(
357 typename ReduceTrait::ReduceAccDataType_,
358 decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I]),
360 decltype(ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock[I])>,
364 ReduceTrait::CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_,
366 true>(ds_grid_desc_mblock_mperblock_nblock_nperblock[I],
369 m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[
I0],
371 n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[
I1]));
375 constexpr
auto c_reduce_thread_desc_mblock_mperblock_nblock_nperblock =
381 typename ReduceTrait::ReduceAccDataType_,
383 decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock),
384 decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
389 ReduceTrait::CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_,
390 EGlobalMemoryDataOperation,
392 true>{e_grid_desc_mblock_mperblock_nblock_nperblock,
394 m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[
I0],
396 n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[
I1]),
399 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
401 static_assert(num_access == sfc_cde_global.GetNumOfAccess(),
"wrong!");
409 c_thread_copy_vgpr_to_lds.Run(
410 c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
411 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
413 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
414 c_shuffle_block_buf);
419 c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock,
421 c_reduce_thread_desc_mperblock_nperblock,
423 c_reduce_thread_buf);
429 auto& d0_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(
I0);
431 d0_thread_copy_global_to_vgpr.Run(
432 ds_grid_desc_mblock_mperblock_nblock_nperblock[
I0],
434 ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock[
I0],
439 static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
441 typename ReduceTrait::ReduceAccDataType_ out;
442 cde_element_op(out, c_reduce_thread_buf(i) + c01_thread_buf(i));
443 c_reduce_thread_buf(i) = out;
446 auto& d1_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(
I1);
448 d1_thread_copy_global_to_vgpr.Run(
449 ds_grid_desc_mblock_mperblock_nblock_nperblock[
I1],
451 ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock[
I1],
456 static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
459 c_reduce_thread_buf(i) += c01_thread_buf(i);
464 c_reduce_thread_copy_vgpr_to_global.Run(
465 c_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
468 e_grid_desc_mblock_mperblock_nblock_nperblock,
475 auto reduce_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
476 p_reduce_grid, reduce_grid_desc_mblock_mperblock.GetElementSpaceSize());
478 auto reduce_thread_buf =
480 typename ReduceTrait::ReduceAccDataType_>(
481 reduce_thread_desc_mperblock.GetElementSpaceSize());
485 auto& reduce_thread_copy_vgpr_to_global =
486 reduce_tuple_thread_copy_vgpr_to_global(In);
488 using ReduceOperation =
489 remove_cvref_t<decltype(
typename ReduceTrait::ReduceOperations_{}[In])>;
490 using ThreadwiseReduce =
492 decltype(c_reduce_thread_desc_mperblock_nperblock),
493 decltype(reduce_thread_desc_mperblock),
498 const auto reduce_identityVal = ReduceOperation::template GetIdentityValue<
499 typename ReduceTrait::ReduceAccDataType_>();
502 [&](
auto I) { reduce_thread_buf(I) = reduce_identityVal; });
507 constexpr
auto offset =
508 Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
511 reduce_in_element_op(c_reduce_thread_buf(offset),
512 c_reduce_thread_buf(offset));
516 ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf);
519 reduce_thread_copy_vgpr_to_global.Run(reduce_thread_desc_mblock_mperblock,
522 reduce_grid_desc_mblock_mperblock,
525 if constexpr(access_id < num_access - 1)
527 constexpr
auto c_global_step = sfc_cde_global.GetForwardStep(access_id);
528 reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
529 reduce_grid_desc_mblock_mperblock,
535 if constexpr(access_id < num_access - 1)
537 constexpr
auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
540 auto& d_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(I);
541 d_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
542 ds_grid_desc_mblock_mperblock_nblock_nperblock[I], cde_global_step);
546 c_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
547 e_grid_desc_mblock_mperblock_nblock_nperblock, cde_global_step);
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
GemmSpecialization
Definition: gemm_specialization.hpp:11
ck_tile::element_wise::PassThrough PassThrough
Definition: grouped_convolution_utils.hpp:54
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:279
typename remove_pointer< T >::type remove_pointer_t
Definition: type.hpp:300
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_static_buffer(Number< N >)
Definition: static_buffer.hpp:186
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:301
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
Definition: epilogue_cshuffle_v3_wmma_base.hpp:30
static constexpr index_t NumDTensor
Definition: epilogue_cshuffle_v3_wmma_base.hpp:39
static constexpr auto I0
Definition: epilogue_cshuffle_v3_wmma_base.hpp:31
static constexpr auto I3
Definition: epilogue_cshuffle_v3_wmma_base.hpp:34
static constexpr __device__ auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:64
static constexpr auto I1
Definition: epilogue_cshuffle_v3_wmma_base.hpp:32
static __device__ auto GetVgprToLDSEpilogueDescriptor()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:119
static constexpr __device__ auto GetCShuffleLDSDescriptor()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:79
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:76
ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:554
index_t MRaw
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:555
static constexpr __device__ auto MakeReduceGridDescriptor_MBlock_MPerBlock(const ReduceGridDesc_M &d_grid_desc_m)
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:135
ReduceTrait::ReducePtrsGlobal_ p_reduces_grid
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:552
ReduceGridDesc_M reduce_grid_desc_m
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:557
__device__ EpilogueReduceCShuffle(typename ReduceTrait::ReducePtrsGlobal_ p_reduces_grid_, const typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops_, const typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops_, const index_t MRaw_, const typename ReduceTrait::D0ElementwiseOperation_ d0_element_op_)
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:149
static constexpr index_t NumDTensor
Definition: epilogue_cshuffle_v3_wmma_base.hpp:39
__device__ void Run(CThreadBuf &c_thread_buf, DsGridPointer p_ds_grid, EDataType *p_e_grid, void *p_shared, const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, CDEElementwiseOperation &cde_element_op, const index_t &block_m_id, const index_t &block_n_id)
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:169
static constexpr auto I0
Definition: epilogue_cshuffle_v3_wmma_base.hpp:31
decltype(MakeReduceGridDescriptor_M(1)) ReduceGridDesc_M
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:132
static constexpr auto I3
Definition: epilogue_cshuffle_v3_wmma_base.hpp:34
static constexpr __device__ auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:64
ReduceTrait::D0ElementwiseOperation_ d0_element_op
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:556
ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:553
static __device__ auto MakeReduceGridDescriptor_M(index_t MRaw)
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:105
static constexpr auto I1
Definition: epilogue_cshuffle_v3_wmma_base.hpp:32
static __device__ auto GetVgprToLDSEpilogueDescriptor()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:119
static constexpr __device__ auto GetCShuffleLDSDescriptor()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:79
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:22
static constexpr index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:34
ReduceOperations ReduceOperations_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:26
ReduceInElementwiseOperations ReduceInElementwiseOperations_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:27
static constexpr index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:32
ReducePtrsGlobal ReducePtrsGlobal_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:24
ReduceGlobalMemoryDataOperation ReduceGlobalMemoryDataOperation_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:29
D0ElementwiseOperation D0ElementwiseOperation_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:25
ReduceAccDataType ReduceAccDataType_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:23
ReduceAccElementwiseOperations ReduceAccElementwiseOperations_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:28
CReduceThreadClusterLengths_MPerBlock_NPerBlock CReduceThreadClusterLengths_MPerBlock_NPerBlock_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:31
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: thread_group.hpp:12
Definition: reduction_functions_threadwise.hpp:23
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:340