/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp Source File
epilogue_cshuffle_v3_wmma.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 template <typename DsDataType,
11  typename EDataType,
12  typename AccDataType,
13  typename CShuffleDataType,
14  index_t MPerBlock,
15  index_t NPerBlock,
16  index_t MPerWmma,
17  index_t NPerWmma,
18  index_t MRepeat,
19  index_t NRepeat,
20  index_t CShuffleMRepeatPerShuffle,
21  index_t CShuffleNRepeatPerShuffle,
22  typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
23  typename CDEShuffleBlockTransferScalarPerVectors,
24  typename CDEElementwiseOperation,
25  typename ThisThreadBlock,
26  typename BlockwiseGemmPipe>
28  : EpilogueCShuffleBase<DsDataType,
29  EDataType,
30  AccDataType,
31  CShuffleDataType,
32  MPerBlock,
33  NPerBlock,
34  MPerWmma,
35  NPerWmma,
36  MRepeat,
37  NRepeat,
38  CShuffleMRepeatPerShuffle,
39  CShuffleNRepeatPerShuffle,
40  CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
41  CDEShuffleBlockTransferScalarPerVectors,
42  CDEElementwiseOperation,
43  ThisThreadBlock,
44  BlockwiseGemmPipe>
45 {
47  DsDataType,
48  EDataType,
49  AccDataType,
50  CShuffleDataType,
51  MPerBlock,
52  NPerBlock,
53  MPerWmma,
54  NPerWmma,
55  MRepeat,
56  NRepeat,
57  CShuffleMRepeatPerShuffle,
58  CShuffleNRepeatPerShuffle,
59  CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
60  CDEShuffleBlockTransferScalarPerVectors,
61  CDEElementwiseOperation,
63  BlockwiseGemmPipe>;
64 
68  using Base::I1;
69  using Base::NumDTensor;
70 
71  template <InMemoryDataOperationEnum EGlobalMemoryDataOperation,
72  typename CThreadBuf,
73  typename DsGridPointer,
74  typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
75  typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
76  __device__ static void Run(CThreadBuf& c_thread_buf,
77  DsGridPointer p_ds_grid,
78  EDataType* p_e_grid,
79  void* p_shared,
80  const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
81  ds_grid_desc_mblock_mperblock_nblock_nperblock,
82  const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
83  e_grid_desc_mblock_mperblock_nblock_nperblock,
84  CDEElementwiseOperation& cde_element_op,
85  const index_t& block_m_id,
86  const index_t& block_n_id)
87  {
88  const auto ds_grid_buf = generate_tuple(
89  [&](auto i) {
90  return make_dynamic_buffer<AddressSpaceEnum::Global>(
91  p_ds_grid[i],
92  ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
93  },
95 
96  auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
97  p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
98 
99  // C mapping in single thread.
100  constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
101  BlockwiseGemmPipe::
102  GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
103 
104  // LDS buffer
105  constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
107 
108  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
109  static_cast<CShuffleDataType*>(p_shared),
110  c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
111  .GetElementSpaceSize());
112 
113  // Thread transfer Vgpr to LDS
114  auto c_thread_copy_vgpr_to_lds = GetVgprToLDSEpilogueDescriptor();
115 
116  // Space Filling Curve Vgpr
117  constexpr auto sfc_c_vgpr = typename Base::SpaceFillingCurveVgpr{};
118 
119  // Space Filling Curve Vmem
120  constexpr auto sfc_cde_global = typename Base::SpaceFillingCurveVmem{};
121 
122  // Block descriptor
123  constexpr auto
124  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
126 
127  // tuple of reference to C/Ds tensor descriptors
128  const auto c_ds_desc_refs = concat_tuple_of_reference(
129  tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
130  generate_tie([&](auto i) -> const auto& // return type should be reference
131  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
132  Number<NumDTensor>{}));
133 
134  // Thread transfer LDS to Vmem
135  auto cde_shuffle_block_copy_lds_and_global =
136  Base::template GetLDSToVmemEpilogueDescriptor<EGlobalMemoryDataOperation, EDataType>(
137  c_ds_desc_refs,
138  e_grid_desc_mblock_mperblock_nblock_nperblock,
139  cde_element_op,
140  block_m_id,
141  block_n_id);
142 
143  // tuple of reference to C/Ds tensor buffers
144  const auto c_ds_buf_refs = concat_tuple_of_reference(
145  tie(c_shuffle_block_buf),
146  generate_tie([&](auto i) -> const auto& // return type should be reference
147  { return ds_grid_buf[i]; },
148  Number<NumDTensor>{}));
149 
150  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
151 
152  static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!");
153 
154  // CShuffle and Store
155  static_for<0, num_access, 1>{}([&](auto access_id) {
156  // make sure it's safe to write to LDS
157  block_sync_lds();
158 
159  // each thread write its data from VGPR to LDS
160  c_thread_copy_vgpr_to_lds.Run(
161  c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
162  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
163  c_thread_buf,
164  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
165  c_shuffle_block_buf);
166 
167  // make sure it's safe to read from LDS
168  block_sync_lds();
169 
170  // each block loads its C data from LDS, D from global, applies elementwise
171  // operation and stores result E to global
172  cde_shuffle_block_copy_lds_and_global.Run(
173  c_ds_desc_refs,
174  c_ds_buf_refs,
175  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
176  tie(e_grid_buf));
177 
178  if constexpr(access_id < num_access - 1)
179  {
180  constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
181  // move on Ds
182  static_for<0, NumDTensor, 1>{}([&](auto i) {
183  cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow(
184  c_ds_desc_refs, i + I1, cde_global_step);
185  });
186 
187  // move on E
188  cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow(
189  tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step);
190  }
191  });
192  }
193 };
194 
195 } // namespace ck
Definition: ck.hpp:268
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:277
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
int32_t index_t
Definition: ck.hpp:299
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
Definition: epilogue_cshuffle_v3_wmma_base.hpp:29
static constexpr index_t NumDTensor
Definition: epilogue_cshuffle_v3_wmma_base.hpp:38
static constexpr __device__ auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:63
static constexpr auto I1
Definition: epilogue_cshuffle_v3_wmma_base.hpp:31
static __device__ auto GetVgprToLDSEpilogueDescriptor()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:118
static constexpr __device__ auto GetCShuffleLDSDescriptor()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:78
Definition: epilogue_cshuffle_v3_wmma.hpp:45
static constexpr __device__ auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:63
static constexpr auto I1
Definition: epilogue_cshuffle_v3_wmma_base.hpp:31
static __device__ auto GetVgprToLDSEpilogueDescriptor()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:118
static constexpr __device__ auto GetCShuffleLDSDescriptor()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:78
static __device__ void Run(CThreadBuf &c_thread_buf, DsGridPointer p_ds_grid, EDataType *p_e_grid, void *p_shared, const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, CDEElementwiseOperation &cde_element_op, const index_t &block_m_id, const index_t &block_n_id)
Definition: epilogue_cshuffle_v3_wmma.hpp:76
Definition: tensor_space_filling_curve.hpp:20
Definition: thread_group.hpp:12
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33