/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp Source File
epilogue_cshuffle_v3_wmma_base.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
8 
9 namespace ck {
10 
11 template <typename DsDataType,
12  typename EDataType,
13  typename AccDataType,
14  typename CShuffleDataType,
15  index_t MPerBlock,
16  index_t NPerBlock,
17  index_t MPerWmma,
18  index_t NPerWmma,
19  index_t MRepeat,
20  index_t NRepeat,
21  index_t CShuffleMRepeatPerShuffle,
22  index_t CShuffleNRepeatPerShuffle,
23  typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
24  typename CDEShuffleBlockTransferScalarPerVectors,
25  typename CDEElementwiseOperation,
26  typename ThisThreadBlock,
27  typename BlockwiseGemmPipe>
29 {
30  static constexpr auto I0 = Number<0>{};
31  static constexpr auto I1 = Number<1>{};
32  static constexpr auto I2 = Number<2>{};
33  static constexpr auto I3 = Number<3>{};
34  static constexpr auto I4 = Number<4>{};
35  static constexpr auto I5 = Number<5>{};
36  static constexpr auto I6 = Number<6>{};
37 
38  static constexpr index_t NumDTensor = DsDataType::Size();
39  static constexpr auto EShuffleBlockTransferScalarPerVector =
40  CDEShuffleBlockTransferScalarPerVectors{}[I0];
41 
45  Sequence<CShuffleMRepeatPerShuffle,
46  1,
47  1,
48  CShuffleNRepeatPerShuffle,
49  1,
50  1,
51  BlockwiseGemmPipe::MAccVgprs>>;
52 
56  Sequence<1,
57  CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma,
58  1,
59  CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma>>;
60 
61  // *Caution Here repeat is shuffle repeat
62  __device__ static constexpr auto
64  {
65  constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
66  constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
67 
68  constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
70  make_tuple(I1,
72  I1,
74 
75  return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
76  }
77 
78  __device__ static constexpr auto GetCShuffleLDSDescriptor()
79  {
80  // C mapping in single block
81  constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
82  BlockwiseGemmPipe::
83  GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
84 
85  constexpr auto MWave =
86  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
87  .GetLength(I1);
88  constexpr auto MSubGroup =
89  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
90  .GetLength(I2);
91  constexpr auto NWave =
92  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
93  .GetLength(I4);
94  constexpr auto NThreadPerSubGroup =
95  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
96  .GetLength(I5);
97  constexpr auto MAccVgprs =
98  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
99  .GetLength(I6);
100 
105  Number<CShuffleMRepeatPerShuffle>{}, // MRepeat per shuffle repeat
106  MWave, // MWave
107  MSubGroup, // MSubGroup * MAccVgprs = MPerWmma
108  MAccVgprs)),
111  Number<CShuffleNRepeatPerShuffle>{}, // NRepeat per shuffle repeat
112  NWave, // NWave
113  NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma
116  }
117 
118  __device__ static auto GetVgprToLDSEpilogueDescriptor()
119  {
120  // C mapping in single block
121  constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
122  BlockwiseGemmPipe::
123  GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
124 
125  constexpr auto MWave =
126  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
127  .GetLength(I1);
128  constexpr auto MSubGroup =
129  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
130  .GetLength(I2);
131  constexpr auto NWave =
132  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
133  .GetLength(I4);
134  constexpr auto NThreadPerSubGroup =
135  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
136  .GetLength(I5);
137  constexpr auto MAccVgprs =
138  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
139  .GetLength(I6);
140 
141  // calculate origin of thread output tensor on global memory
142  // blockwise GEMM c matrix starting index
143  const auto c_thread_mtx_on_block =
144  BlockwiseGemmPipe::CalculateCThreadOriginDataIndex(I0, I0);
145 
146  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
147  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
148 
149  const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor =
151  make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))),
154 
155  const auto m_thread_data_on_block_idx =
156  m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor
157  .CalculateBottomIndex(make_multi_index(m_thread_data_on_block));
158 
159  const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
161  make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))),
164 
165  const auto n_thread_data_on_block_idx =
166  n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex(
167  make_multi_index(n_thread_data_on_block));
168 
170  AccDataType,
171  CShuffleDataType,
172  decltype(BlockwiseGemmPipe::
173  GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()),
174  decltype(GetCShuffleLDSDescriptor()),
176  Sequence<CShuffleMRepeatPerShuffle,
177  I1,
178  I1,
179  CShuffleNRepeatPerShuffle,
180  I1,
181  I1,
182  MAccVgprs>,
184  6,
185  1,
187  1,
188  true>{GetCShuffleLDSDescriptor(),
190  m_thread_data_on_block_idx[I1],
191  m_thread_data_on_block_idx[I2],
192  0,
193  n_thread_data_on_block_idx[I1],
194  n_thread_data_on_block_idx[I2],
195  m_thread_data_on_block_idx[I3]),
197  }
198 
199  template <InMemoryDataOperationEnum EGlobalMemoryDataOperation,
200  typename InterDataType,
201  typename CDsDescRefs,
202  typename EGridDesc>
203  __device__ static auto
204  GetLDSToVmemEpilogueDescriptor(CDsDescRefs& c_ds_desc_refs,
205  EGridDesc& e_grid_desc_mblock_mperblock_nblock_nperblock,
206  CDEElementwiseOperation& cde_element_op,
207  const index_t& block_m_id,
208  const index_t& block_n_id)
209  {
210  // tuple of starting index of C/Ds blockwise copy
211  const auto idx_c_ds_block_begin = container_concat(
212  make_tuple(make_multi_index(0, 0, 0, 0)),
213  generate_tuple([&](auto) { return make_multi_index(block_m_id, 0, block_n_id, 0); },
214  Number<NumDTensor>{}));
215 
216  // blockwise copy which loads C from LDS, D from global, applies elementwise
217  // operation and stores result E to global
219  ThisThreadBlock, // ThreadGroup
220  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
222  CDsDescRefs,
223  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
224  CDEElementwiseOperation, // ElementwiseOperation,
225  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // DstInMemOps,
226  Sequence<1,
227  CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma,
228  1,
229  CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves *
230  NPerWmma>, // BlockSliceLengths,
231  CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
232  Sequence<0, 1, 2, 3>, // ThreadClusterArrangeOrder,
233  Sequence<0, 1, 2, 3>, // SrcDimAccessOrder,
234  Sequence<0, 1, 2, 3>, // DstDimAccessOrder,
235  3, // SrcVectorDim,
236  3, // DstVectorDim,
237  CDEShuffleBlockTransferScalarPerVectors, // SrcScalarPerVectors
238  EShuffleBlockTransferScalarPerVector, // DstScalarPerVector
242  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
243  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
244  1,
245  Tuple<InterDataType>>{c_ds_desc_refs,
246  idx_c_ds_block_begin,
247  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
248  make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
249  cde_element_op};
250  }
251 };
252 
253 } // namespace ck
ck_tile::element_wise::PassThrough PassThrough
Definition: grouped_convolution_utils.hpp:47
Definition: ck.hpp:268
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:277
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:299
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
Definition: epilogue_cshuffle_v3_wmma_base.hpp:29
static constexpr auto I6
Definition: epilogue_cshuffle_v3_wmma_base.hpp:36
static constexpr auto I2
Definition: epilogue_cshuffle_v3_wmma_base.hpp:32
static constexpr auto I4
Definition: epilogue_cshuffle_v3_wmma_base.hpp:34
static constexpr auto I5
Definition: epilogue_cshuffle_v3_wmma_base.hpp:35
static constexpr index_t NumDTensor
Definition: epilogue_cshuffle_v3_wmma_base.hpp:38
static constexpr auto I0
Definition: epilogue_cshuffle_v3_wmma_base.hpp:30
static constexpr auto I3
Definition: epilogue_cshuffle_v3_wmma_base.hpp:33
static __device__ auto GetLDSToVmemEpilogueDescriptor(CDsDescRefs &c_ds_desc_refs, EGridDesc &e_grid_desc_mblock_mperblock_nblock_nperblock, CDEElementwiseOperation &cde_element_op, const index_t &block_m_id, const index_t &block_n_id)
Definition: epilogue_cshuffle_v3_wmma_base.hpp:204
static constexpr __device__ auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:63
static constexpr auto I1
Definition: epilogue_cshuffle_v3_wmma_base.hpp:31
static __device__ auto GetVgprToLDSEpilogueDescriptor()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:118
static constexpr __device__ auto GetCShuffleLDSDescriptor()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:78
static constexpr auto EShuffleBlockTransferScalarPerVector
Definition: epilogue_cshuffle_v3_wmma_base.hpp:39
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: thread_group.hpp:12
Definition: thread_group_tensor_slice_transfer_v7r3.hpp:48
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: tuple.hpp:117
Definition: integral_constant.hpp:20
Definition: unary_element_wise_operation.hpp:340