10 template <
typename DsDataType,
13 typename CShuffleDataType,
20 index_t CShuffleMRepeatPerShuffle,
21 index_t CShuffleNRepeatPerShuffle,
22 typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
23 typename CDEShuffleBlockTransferScalarPerVectors,
24 typename CDEElementwiseOperation,
25 typename ThisThreadBlock,
26 typename BlockwiseGemmPipe>
38 CShuffleMRepeatPerShuffle,
39 CShuffleNRepeatPerShuffle,
40 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
41 CDEShuffleBlockTransferScalarPerVectors,
42 CDEElementwiseOperation,
57 CShuffleMRepeatPerShuffle,
58 CShuffleNRepeatPerShuffle,
59 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
60 CDEShuffleBlockTransferScalarPerVectors,
61 CDEElementwiseOperation,
73 typename DsGridPointer,
74 typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
75 typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
76 __device__
static void Run(CThreadBuf& c_thread_buf,
77 DsGridPointer p_ds_grid,
80 const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
81 ds_grid_desc_mblock_mperblock_nblock_nperblock,
82 const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
83 e_grid_desc_mblock_mperblock_nblock_nperblock,
84 CDEElementwiseOperation& cde_element_op,
90 return make_dynamic_buffer<AddressSpaceEnum::Global>(
92 ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
96 auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
97 p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
100 constexpr
auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
102 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
105 constexpr
auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
108 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
109 static_cast<CShuffleDataType*
>(p_shared),
110 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
111 .GetElementSpaceSize());
124 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
129 tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
131 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
135 auto cde_shuffle_block_copy_lds_and_global =
136 Base::template GetLDSToVmemEpilogueDescriptor<EGlobalMemoryDataOperation, EDataType>(
138 e_grid_desc_mblock_mperblock_nblock_nperblock,
145 tie(c_shuffle_block_buf),
147 {
return ds_grid_buf[i]; },
150 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
152 static_assert(num_access == sfc_cde_global.GetNumOfAccess(),
"wrong!");
160 c_thread_copy_vgpr_to_lds.Run(
161 c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
162 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
164 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
165 c_shuffle_block_buf);
172 cde_shuffle_block_copy_lds_and_global.Run(
175 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
178 if constexpr(access_id < num_access - 1)
180 constexpr
auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
183 cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow(
184 c_ds_desc_refs, i +
I1, cde_global_step);
188 cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow(
189 tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step);
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:277
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
int32_t index_t
Definition: ck.hpp:299
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
Definition: epilogue_cshuffle_v3_wmma_base.hpp:29
static constexpr index_t NumDTensor
Definition: epilogue_cshuffle_v3_wmma_base.hpp:38
static constexpr __device__ auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:63
static constexpr auto I1
Definition: epilogue_cshuffle_v3_wmma_base.hpp:31
static __device__ auto GetVgprToLDSEpilogueDescriptor()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:118
static constexpr __device__ auto GetCShuffleLDSDescriptor()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:78
Definition: epilogue_cshuffle_v3_wmma.hpp:45
static constexpr __device__ auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:63
static constexpr auto I1
Definition: epilogue_cshuffle_v3_wmma_base.hpp:31
static __device__ auto GetVgprToLDSEpilogueDescriptor()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:118
static constexpr __device__ auto GetCShuffleLDSDescriptor()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:78
static __device__ void Run(CThreadBuf &c_thread_buf, DsGridPointer p_ds_grid, EDataType *p_e_grid, void *p_shared, const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, CDEElementwiseOperation &cde_element_op, const index_t &block_m_id, const index_t &block_n_id)
Definition: epilogue_cshuffle_v3_wmma.hpp:76
Definition: tensor_space_filling_curve.hpp:20
Definition: thread_group.hpp:12
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33