/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp Source File
epilogue_cshuffle_v3_reduce_wmma.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
8 
9 namespace ck {
10 
11 template <typename ReduceAccDataType,
12  typename ReducePtrsGlobal,
13  typename D0ElementwiseOperation,
14  typename ReduceOperations,
15  typename ReduceInElementwiseOperations,
16  typename ReduceAccElementwiseOperations,
17  typename ReduceGlobalMemoryDataOperation,
18  typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
19  index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
20  index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock>
22 {
23  using ReduceAccDataType_ = ReduceAccDataType;
24  using ReducePtrsGlobal_ = ReducePtrsGlobal;
25  using D0ElementwiseOperation_ = D0ElementwiseOperation;
26  using ReduceOperations_ = ReduceOperations;
27  using ReduceInElementwiseOperations_ = ReduceInElementwiseOperations;
28  using ReduceAccElementwiseOperations_ = ReduceAccElementwiseOperations;
29  using ReduceGlobalMemoryDataOperation_ = ReduceGlobalMemoryDataOperation;
31  CReduceThreadClusterLengths_MPerBlock_NPerBlock;
33  CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock;
35  CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock;
36 };
37 
38 template <typename DsDataType,
39  typename EDataType,
40  typename AccDataType,
41  typename CShuffleDataType,
42  index_t MPerBlock,
43  index_t NPerBlock,
44  index_t MPerWmma,
45  index_t NPerWmma,
46  index_t MRepeat,
47  index_t NRepeat,
48  index_t CShuffleMRepeatPerShuffle,
49  index_t CShuffleNRepeatPerShuffle,
50  typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
51  typename CDEShuffleBlockTransferScalarPerVectors,
52  typename CDEElementwiseOperation,
53  typename ThisThreadBlock,
54  typename BlockwiseGemmPipe,
56  index_t BlockSize,
57  typename ReduceTrait>
59  : EpilogueCShuffleBase<DsDataType,
60  EDataType,
61  AccDataType,
62  CShuffleDataType,
63  MPerBlock,
64  NPerBlock,
65  MPerWmma,
66  NPerWmma,
67  MRepeat,
68  NRepeat,
69  CShuffleMRepeatPerShuffle,
70  CShuffleNRepeatPerShuffle,
71  CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
72  CDEShuffleBlockTransferScalarPerVectors,
73  CDEElementwiseOperation,
74  ThisThreadBlock,
75  BlockwiseGemmPipe>
76 {
78  DsDataType,
79  EDataType,
80  AccDataType,
81  CShuffleDataType,
82  MPerBlock,
83  NPerBlock,
84  MPerWmma,
85  NPerWmma,
86  MRepeat,
87  NRepeat,
88  CShuffleMRepeatPerShuffle,
89  CShuffleNRepeatPerShuffle,
90  CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
91  CDEShuffleBlockTransferScalarPerVectors,
92  CDEElementwiseOperation,
94  BlockwiseGemmPipe>;
95 
99  using Base::I0;
100  using Base::I1;
101  using Base::I3;
102  using Base::NumDTensor;
103 
104  // assume Reduce is packed tensor
105  __device__ static auto MakeReduceGridDescriptor_M(index_t MRaw)
106  {
108 
109  const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
110 
111  const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
112  const auto MPad = M - MRaw;
113 
114  if constexpr(GemmSpec == GemmSpecialization::MPadding ||
115  GemmSpec == GemmSpecialization::MNPadding ||
116  GemmSpec == GemmSpecialization::MKPadding ||
117  GemmSpec == GemmSpecialization::MNKPadding)
118  {
119  // pad M
120  return transform_tensor_descriptor(d_grid_desc_mraw,
124  }
125  else
126  {
127  // not pad M
128  return d_grid_desc_mraw;
129  }
130  }
131 
133 
134  __device__ static constexpr auto
136  {
137  const auto M = d_grid_desc_m.GetLength(I0);
138  const auto MBlock = M / MPerBlock;
139 
140  const auto reduce_grid_desc_mblock_mperblock = transform_tensor_descriptor(
141  d_grid_desc_m,
145 
146  return reduce_grid_desc_mblock_mperblock;
147  }
148 
150  typename ReduceTrait::ReducePtrsGlobal_ p_reduces_grid_,
151  const typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops_,
152  const typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops_,
153  const index_t MRaw_,
154  const typename ReduceTrait::D0ElementwiseOperation_ d0_element_op_)
155  : p_reduces_grid(p_reduces_grid_),
156  reduce_in_element_ops(reduce_in_element_ops_),
157  reduce_out_element_ops(reduce_out_element_ops_),
158  MRaw(MRaw_),
159  d0_element_op{d0_element_op_},
161  {
162  }
163 
164  template <InMemoryDataOperationEnum EGlobalMemoryDataOperation,
165  typename CThreadBuf,
166  typename DsGridPointer,
167  typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
168  typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
169  __device__ void Run(CThreadBuf& c_thread_buf,
170  DsGridPointer p_ds_grid,
171  EDataType* p_e_grid,
172  void* p_shared,
173  const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
174  ds_grid_desc_mblock_mperblock_nblock_nperblock,
175  const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
176  e_grid_desc_mblock_mperblock_nblock_nperblock,
177  CDEElementwiseOperation& cde_element_op,
178  const index_t& block_m_id,
179  const index_t& block_n_id)
180  {
181  // HACK: this force m/n_block_data_idx_on_grid into SGPR
182  const index_t m_block_data_idx_on_grid =
183  __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
184 
185  const index_t n_block_data_idx_on_grid =
186  __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
187 
188  auto reduce_grid_desc_mblock_mperblock =
190 
191  const auto ds_grid_buf = generate_tuple(
192  [&](auto i) {
193  return make_dynamic_buffer<AddressSpaceEnum::Global>(
194  p_ds_grid[i],
195  ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
196  },
198 
199  auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
200  p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
201 
202  // C mapping in single thread.
203  constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
204  BlockwiseGemmPipe::
205  GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
206 
207  // LDS buffer
208  constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
210 
211  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
212  static_cast<CShuffleDataType*>(p_shared),
213  c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
214  .GetElementSpaceSize());
215 
216  // Thread transfer Vgpr to LDS
217  auto c_thread_copy_vgpr_to_lds = GetVgprToLDSEpilogueDescriptor();
218 
219  // Space Filling Curve Vgpr
220  constexpr auto sfc_c_vgpr = typename Base::SpaceFillingCurveVgpr{};
221 
222  // Space Filling Curve Vmem
223  constexpr auto sfc_cde_global = typename Base::SpaceFillingCurveVmem{};
224 
225  // Block descriptor
226  constexpr auto
227  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
229 
230  // LDS c_reduce_block_desc_mperblock_nperblock
231  constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor(
232  c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
233  make_tuple(
236  c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength(
237  I1)),
240  c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength(
241  I3))),
244 
245  static_assert(
246  ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I0) *
247  ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I1) ==
248  BlockSize,
249  "wrong!");
250 
251  static_assert(
252  (CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma) %
253  ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I0) ==
254  0 &&
255  (CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma) %
256  ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I1) ==
257  0,
258  "wrong!");
259 
260  constexpr index_t mreduce_per_thread =
261  (CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma) /
262  ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I0);
263 
264  constexpr index_t nreduce_per_thread =
265  (CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma) /
266  ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I1);
267 
268  static constexpr index_t NumReduce = ReduceTrait::ReducePtrsGlobal_::Size();
269 
270  constexpr auto c_reduce_thread_lengths_mperblock_nperblock =
272 
273  // VGPR c_reduce_thread_desc_mperblock_nperblock
274  constexpr auto c_reduce_thread_desc_mperblock_nperblock =
277 
278  // VGPR reduce_thread_desc_mperblock
279  constexpr auto reduce_thread_desc_mperblock =
281 
282  // VGPR reduce_thread_desc_mblock_mperblock
283  constexpr auto reduce_thread_desc_mblock_mperblock =
285 
286  auto c_reduce_thread_buf =
287  make_static_buffer<AddressSpaceEnum::Vgpr, typename ReduceTrait::ReduceAccDataType_>(
288  c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
289 
290  // reduce: threadwise copy from LDS to VGPR
291  constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor(
292  typename ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_{},
293  Sequence<1, 0>{});
294 
295  const auto c_reduce_thread_cluster_idx = c_reduce_thread_cluster_desc.CalculateBottomIndex(
297 
298  const auto c_reduce_thread_data_idx_begin =
299  c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock;
300 
301  auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
302  CShuffleDataType,
303  typename ReduceTrait::ReduceAccDataType_,
304  decltype(c_reduce_block_desc_mperblock_nperblock),
305  decltype(c_reduce_thread_desc_mperblock_nperblock),
306  decltype(c_reduce_thread_lengths_mperblock_nperblock),
308  1,
309  ReduceTrait::CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_,
310  1,
311  true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin};
312 
313  auto reduce_tuple_thread_copy_vgpr_to_global = generate_tuple(
314  [&](auto I) {
315  auto p_reduce_grid = p_reduces_grid[I];
316  auto reduce_acc_element_op = reduce_out_element_ops[I];
317 
319  typename ReduceTrait::ReduceAccDataType_,
320  remove_pointer_t<decltype(p_reduce_grid)>,
321  decltype(reduce_thread_desc_mblock_mperblock),
322  decltype(reduce_grid_desc_mblock_mperblock),
323  decltype(reduce_acc_element_op),
326  1,
327  ReduceTrait::CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock_,
328  ReduceTrait::ReduceGlobalMemoryDataOperation_::At(I),
329  1,
330  false>{reduce_grid_desc_mblock_mperblock,
331  make_multi_index(block_m_id, // mblock
332  c_reduce_thread_data_idx_begin[I0]), // mperblock
333  reduce_acc_element_op};
334  },
336 
337  // multiple Ds
338  constexpr auto d_reduce_thread_desc_mblock_mperblock_nblock_nperblock =
341 
342  constexpr auto ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock = generate_tuple(
343  [&](auto) { return d_reduce_thread_desc_mblock_mperblock_nblock_nperblock; },
345 
346  constexpr auto ds_thread_buf_size =
347  d_reduce_thread_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
348 
349  auto c01_thread_buf =
350  make_static_buffer<AddressSpaceEnum::Vgpr, typename ReduceTrait::ReduceAccDataType_>(
352 
353  auto ds_thread_copy_global_to_vgpr = generate_tuple(
354  [&](auto I) {
356  remove_cvref_t<tuple_element_t<I.value, DsDataType>>,
357  typename ReduceTrait::ReduceAccDataType_,
358  decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I]),
360  decltype(ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock[I])>,
363  3,
364  ReduceTrait::CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_,
365  1,
366  true>(ds_grid_desc_mblock_mperblock_nblock_nperblock[I],
368  I0,
369  m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0],
370  I0,
371  n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]));
372  },
374 
375  constexpr auto c_reduce_thread_desc_mblock_mperblock_nblock_nperblock =
378 
379  // Write E from Vgpr to Vmem
380  auto c_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
381  typename ReduceTrait::ReduceAccDataType_,
382  EDataType,
383  decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock),
384  decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
387  Sequence<0, 1, 2, 3>, // DimAccessOrder
388  3, // DstVectorDim
389  ReduceTrait::CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_,
390  EGlobalMemoryDataOperation,
391  1,
392  true>{e_grid_desc_mblock_mperblock_nblock_nperblock,
394  m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0],
395  I0,
396  n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]),
398 
399  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
400 
401  static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!");
402 
403  // CShuffle and Store
404  static_for<0, num_access, 1>{}([&](auto access_id) {
405  // make sure it's safe to write to LDS
406  block_sync_lds();
407 
408  // each thread write its data from VGPR to LDS
409  c_thread_copy_vgpr_to_lds.Run(
410  c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
411  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
412  c_thread_buf,
413  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
414  c_shuffle_block_buf);
415 
416  // make sure it's safe to read from LDS
417  block_sync_lds();
418  {
419  c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock,
420  c_shuffle_block_buf,
421  c_reduce_thread_desc_mperblock_nperblock,
422  make_tuple(I0, I0),
423  c_reduce_thread_buf);
424 
425  // Note: currently multiple Ds supports only Bias + Add.
426  // It needs to be generalized for other operations (currently not needed)
427  if constexpr(NumDTensor > 0)
428  {
429  auto& d0_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(I0);
430  // d0 / d1 operations
431  d0_thread_copy_global_to_vgpr.Run(
432  ds_grid_desc_mblock_mperblock_nblock_nperblock[I0],
433  ds_grid_buf[I0],
434  ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock[I0],
435  make_tuple(I0, I0, I0, I0),
436  c01_thread_buf);
437 
438  // c = activation(c + bias)
439  static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
440  [&](auto i) {
441  typename ReduceTrait::ReduceAccDataType_ out;
442  cde_element_op(out, c_reduce_thread_buf(i) + c01_thread_buf(i));
443  c_reduce_thread_buf(i) = out;
444  });
445 
446  auto& d1_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(I1);
447 
448  d1_thread_copy_global_to_vgpr.Run(
449  ds_grid_desc_mblock_mperblock_nblock_nperblock[I1],
450  ds_grid_buf[I1],
451  ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock[I1],
452  make_tuple(I0, I0, I0, I0),
453  c01_thread_buf);
454 
455  // c = c + c1_function(c1)
456  static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
457  [&](auto i) {
458  d0_element_op(c01_thread_buf(i), c01_thread_buf(i));
459  c_reduce_thread_buf(i) += c01_thread_buf(i);
460  });
461  }
462 
463  // Write E
464  c_reduce_thread_copy_vgpr_to_global.Run(
465  c_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
466  make_tuple(I0, I0, I0, I0),
467  c_reduce_thread_buf,
468  e_grid_desc_mblock_mperblock_nblock_nperblock,
469  e_grid_buf);
470 
471  // Reduction
472  static_for<0, NumReduce, 1>{}([&](auto In) {
473  auto& p_reduce_grid = p_reduces_grid[In];
474 
475  auto reduce_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
476  p_reduce_grid, reduce_grid_desc_mblock_mperblock.GetElementSpaceSize());
477 
478  auto reduce_thread_buf =
480  typename ReduceTrait::ReduceAccDataType_>(
481  reduce_thread_desc_mperblock.GetElementSpaceSize());
482 
483  auto& reduce_in_element_op = reduce_in_element_ops[In];
484 
485  auto& reduce_thread_copy_vgpr_to_global =
486  reduce_tuple_thread_copy_vgpr_to_global(In);
487 
488  using ReduceOperation =
489  remove_cvref_t<decltype(typename ReduceTrait::ReduceOperations_{}[In])>;
490  using ThreadwiseReduce =
491  ThreadwiseReduction<typename ReduceTrait::ReduceAccDataType_,
492  decltype(c_reduce_thread_desc_mperblock_nperblock),
493  decltype(reduce_thread_desc_mperblock),
494  ReduceOperation,
495  false>;
496 
497  // Global write Gemm shuffle + reduction
498  const auto reduce_identityVal = ReduceOperation::template GetIdentityValue<
499  typename ReduceTrait::ReduceAccDataType_>();
500 
502  [&](auto I) { reduce_thread_buf(I) = reduce_identityVal; });
503 
504  // reduce in VGPR
505  static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
506  static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
507  constexpr auto offset =
508  Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
509  make_tuple(im, in))>{};
510 
511  reduce_in_element_op(c_reduce_thread_buf(offset),
512  c_reduce_thread_buf(offset));
513  });
514  });
515 
516  ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf);
517 
518  // copy from VGPR to Global
519  reduce_thread_copy_vgpr_to_global.Run(reduce_thread_desc_mblock_mperblock,
520  make_tuple(I0, I0),
521  reduce_thread_buf,
522  reduce_grid_desc_mblock_mperblock,
523  reduce_grid_buf);
524 
525  if constexpr(access_id < num_access - 1)
526  {
527  constexpr auto c_global_step = sfc_cde_global.GetForwardStep(access_id);
528  reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
529  reduce_grid_desc_mblock_mperblock,
530  make_tuple(c_global_step[I0], c_global_step[I1]));
531  }
532  });
533  }
534 
535  if constexpr(access_id < num_access - 1)
536  {
537  constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
538  // move on Ds
539  static_for<0, NumDTensor, 1>{}([&](auto I) {
540  auto& d_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(I);
541  d_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
542  ds_grid_desc_mblock_mperblock_nblock_nperblock[I], cde_global_step);
543  });
544 
545  // move on E
546  c_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
547  e_grid_desc_mblock_mperblock_nblock_nperblock, cde_global_step);
548  }
549  });
550  }
551 
552  typename ReduceTrait::ReducePtrsGlobal_ p_reduces_grid;
553  typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops;
554  typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops;
556  typename ReduceTrait::D0ElementwiseOperation_ d0_element_op;
558 };
559 
560 } // namespace ck
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
GemmSpecialization
Definition: gemm_specialization.hpp:11
ck_tile::element_wise::PassThrough PassThrough
Definition: grouped_convolution_utils.hpp:54
Definition: ck.hpp:270
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:279
typename remove_pointer< T >::type remove_pointer_t
Definition: type.hpp:300
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_static_buffer(Number< N >)
Definition: static_buffer.hpp:186
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:301
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
Definition: epilogue_cshuffle_v3_wmma_base.hpp:30
static constexpr index_t NumDTensor
Definition: epilogue_cshuffle_v3_wmma_base.hpp:39
static constexpr auto I0
Definition: epilogue_cshuffle_v3_wmma_base.hpp:31
static constexpr auto I3
Definition: epilogue_cshuffle_v3_wmma_base.hpp:34
static constexpr __device__ auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:64
static constexpr auto I1
Definition: epilogue_cshuffle_v3_wmma_base.hpp:32
static __device__ auto GetVgprToLDSEpilogueDescriptor()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:119
static constexpr __device__ auto GetCShuffleLDSDescriptor()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:79
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:76
ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:554
index_t MRaw
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:555
static constexpr __device__ auto MakeReduceGridDescriptor_MBlock_MPerBlock(const ReduceGridDesc_M &d_grid_desc_m)
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:135
ReduceTrait::ReducePtrsGlobal_ p_reduces_grid
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:552
ReduceGridDesc_M reduce_grid_desc_m
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:557
__device__ EpilogueReduceCShuffle(typename ReduceTrait::ReducePtrsGlobal_ p_reduces_grid_, const typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops_, const typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops_, const index_t MRaw_, const typename ReduceTrait::D0ElementwiseOperation_ d0_element_op_)
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:149
static constexpr index_t NumDTensor
Definition: epilogue_cshuffle_v3_wmma_base.hpp:39
__device__ void Run(CThreadBuf &c_thread_buf, DsGridPointer p_ds_grid, EDataType *p_e_grid, void *p_shared, const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, CDEElementwiseOperation &cde_element_op, const index_t &block_m_id, const index_t &block_n_id)
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:169
static constexpr auto I0
Definition: epilogue_cshuffle_v3_wmma_base.hpp:31
decltype(MakeReduceGridDescriptor_M(1)) ReduceGridDesc_M
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:132
static constexpr auto I3
Definition: epilogue_cshuffle_v3_wmma_base.hpp:34
static constexpr __device__ auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:64
ReduceTrait::D0ElementwiseOperation_ d0_element_op
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:556
ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:553
static __device__ auto MakeReduceGridDescriptor_M(index_t MRaw)
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:105
static constexpr auto I1
Definition: epilogue_cshuffle_v3_wmma_base.hpp:32
static __device__ auto GetVgprToLDSEpilogueDescriptor()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:119
static constexpr __device__ auto GetCShuffleLDSDescriptor()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:79
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:22
static constexpr index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:34
ReduceOperations ReduceOperations_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:26
ReduceInElementwiseOperations ReduceInElementwiseOperations_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:27
static constexpr index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:32
ReducePtrsGlobal ReducePtrsGlobal_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:24
ReduceGlobalMemoryDataOperation ReduceGlobalMemoryDataOperation_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:29
D0ElementwiseOperation D0ElementwiseOperation_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:25
ReduceAccDataType ReduceAccDataType_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:23
ReduceAccElementwiseOperations ReduceAccElementwiseOperations_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:28
CReduceThreadClusterLengths_MPerBlock_NPerBlock CReduceThreadClusterLengths_MPerBlock_NPerBlock_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:31
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: thread_group.hpp:12
Definition: reduction_functions_threadwise.hpp:23
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:340