include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp Source File

include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp Source File#

Composable Kernel: include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp Source File
cshuffle_epilogue.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
12 template <typename AccDataType_,
13  typename ODataType_,
14  typename CLayout_,
15  index_t kBlockSize_,
16  index_t kM_,
17  index_t kN_,
18  index_t kMWave_,
19  index_t kNWave_,
20  index_t kMPerXdl_,
21  index_t kNPerXdl_,
22  index_t kKPerXdl_,
23  bool isCTransposed_>
25 {
29  static constexpr index_t kBlockSize = kBlockSize_;
30  static constexpr index_t kMPerBlock = kM_;
31  static constexpr index_t kNPerBlock = kN_;
32  static constexpr index_t kMWave = kMWave_;
33  static constexpr index_t kNWave = kNWave_;
34  static constexpr index_t kMPerXdl = kMPerXdl_;
35  static constexpr index_t kNPerXdl = kNPerXdl_;
36  static constexpr index_t kKPerXdl = kKPerXdl_;
37  static constexpr index_t isCTransposed = isCTransposed_;
38 };
39 
40 template <typename Problem_, typename Policy_ = void>
42 {
47  static constexpr index_t kBlockSize = Problem::kBlockSize;
48  static constexpr index_t kMPerBlock = Problem::kMPerBlock;
49  static constexpr index_t kNPerBlock = Problem::kNPerBlock;
50  static constexpr index_t kMWave = Problem::kMWave;
51  static constexpr index_t kNWave = Problem::kNWave;
52  static constexpr index_t kMPerXdl = Problem::kMPerXdl;
53  static constexpr index_t kNPerXdl = Problem::kNPerXdl;
54  static constexpr index_t kKPerXdl = Problem::kKPerXdl;
55  static constexpr index_t isCTransposed = Problem::isCTransposed;
56  static constexpr index_t kMPerIteration = kMPerXdl * kMWave;
57  static constexpr index_t kNPerIteration = kNPerXdl * kNWave;
58 
60  ODataType,
62  kMPerXdl,
63  kNPerXdl,
64  kKPerXdl,
66 
67  using CWarpDstr = typename WG::CWarpDstr;
68  using CWarpTensor = typename WG::CWarpTensor;
69 
80  template <typename ODataType>
81  CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
82  {
83  constexpr index_t MaxVectorStoreSize = 16;
84  return MaxVectorStoreSize / sizeof(ODataType);
85  }
86 
87  template <typename Problem>
89  {
90  // N is contiguous dimension
91  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
92  {
96  }
97  // M is contiguous dimension
98  else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
99  {
103  }
104  else
105  {
106  static_assert(false, "Unsupported CLayout!");
107  }
108  }
109 
111  {
112  return kMWave * kNWave * kMPerXdl * kNPerXdl * sizeof(ODataType);
113  }
114 
115  template <typename ODramWindow,
116  typename OAccTile,
118  CK_TILE_DEVICE auto
119  operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, void* p_smem)
120  {
121 
122  const index_t iMWarp = get_warp_id() / kNWave;
123  const index_t iNWarp = get_warp_id() - iMWarp * kNWave;
124 
125  constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
126  auto o_lds_block = make_tensor_view<address_space_enum::lds>(
127  static_cast<ODataType*>(p_smem), lds_block_desc);
128  auto in_lds_window =
129  make_tile_window(o_lds_block,
131  {number<kMPerXdl>{} * iMWarp, number<kNPerXdl>{} * iNWarp});
132  auto out_lds_window =
133  make_tile_window(o_lds_block,
135  {0, 0});
136 
140  constexpr index_t num_access = SFC::get_num_of_access();
141 
142  using TileEncodingPattern =
146  GetVectorSizeC<ODataType>(),
148  constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
149 
150  constexpr auto c_warp_y_lengths =
151  to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
152  constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
153 
154  CWarpTensor c_warp_in_tensor;
155  static_for<0, num_access, 1>{}([&](auto iAccess) {
156  constexpr auto idx_y_start = SFC::get_index(iAccess);
157 
158  constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (kMPerXdl * kMWave)>{};
159  constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (kNPerXdl * kNWave)>{};
160 
161  c_warp_in_tensor.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
162  merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
163  merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
164 
165  const auto c_warp_in_tensor_casted = cast_tile<ODataType>(c_warp_in_tensor);
166 
167  block_sync_lds();
168  store_tile(in_lds_window, c_warp_in_tensor_casted);
169  block_sync_lds();
170 
171  const auto c_out_tensor =
172  load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
173 
174  if constexpr(out_memory_data_op == memory_operation_enum::set)
175  {
176  store_tile(out_dram_window, c_out_tensor);
177  }
178  else
179  {
180  update_tile(out_dram_window, c_out_tensor);
181  }
182  if constexpr(iAccess != num_access - 1)
183  {
184  constexpr auto step = SFC::get_forward_step(iAccess);
185  move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
186  }
187  });
188  }
189 };
190 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:255
memory_operation_enum
Definition: arch.hpp:44
CK_TILE_DEVICE void block_sync_lds()
Definition: arch.hpp:80
int32_t index_t
Definition: integer.hpp:9
typename impl::WarpGemmMfmaDispatcher< AType, BType, CType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA >::Type WarpGemmMfmaDispatcher
Definition: warp_gemm_dispatcher.hpp:81
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:27
@ thread_raked
Thread raked pattern.
CK_TILE_DEVICE index_t get_warp_id()
Definition: arch.hpp:71
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1046
constexpr CK_TILE_HOST_DEVICE auto merge_sequences(Seqs...)
Definition: sequence.hpp:817
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:92
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
CK_TILE_DEVICE void update_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: update_tile.hpp:22
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:1017
Definition: cshuffle_epilogue.hpp:42
static constexpr index_t kBlockSize
Definition: cshuffle_epilogue.hpp:47
static constexpr CK_TILE_HOST_DEVICE auto MakeLdsBlockDescriptor()
Definition: cshuffle_epilogue.hpp:88
typename WG::CWarpTensor CWarpTensor
Definition: cshuffle_epilogue.hpp:68
WarpGemmMfmaDispatcher< ODataType, ODataType, AccDataType, kMPerXdl, kNPerXdl, kKPerXdl, isCTransposed > WG
Definition: cshuffle_epilogue.hpp:65
static constexpr index_t kKPerXdl
Definition: cshuffle_epilogue.hpp:54
remove_cvref_t< Problem_ > Problem
Definition: cshuffle_epilogue.hpp:43
remove_cvref_t< typename Problem::ODataType > ODataType
Definition: cshuffle_epilogue.hpp:45
static constexpr index_t kNPerBlock
Definition: cshuffle_epilogue.hpp:49
static constexpr index_t kMPerIteration
Definition: cshuffle_epilogue.hpp:56
static constexpr index_t kMPerXdl
Definition: cshuffle_epilogue.hpp:52
remove_cvref_t< typename Problem::CLayout > CLayout
Definition: cshuffle_epilogue.hpp:46
static constexpr index_t isCTransposed
Definition: cshuffle_epilogue.hpp:55
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: cshuffle_epilogue.hpp:110
static constexpr index_t kNWave
Definition: cshuffle_epilogue.hpp:51
static constexpr CK_TILE_HOST_DEVICE auto GetVectorSizeC()
Get the vector store size for C tensor.
Definition: cshuffle_epilogue.hpp:81
remove_cvref_t< typename Problem::AccDataType > AccDataType
Definition: cshuffle_epilogue.hpp:44
static constexpr index_t kNPerIteration
Definition: cshuffle_epilogue.hpp:57
static constexpr index_t kMPerBlock
Definition: cshuffle_epilogue.hpp:48
typename WG::CWarpDstr CWarpDstr
Definition: cshuffle_epilogue.hpp:67
CK_TILE_DEVICE auto operator()(ODramWindow &out_dram_window, const OAccTile &o_acc_tile, void *p_smem)
Definition: cshuffle_epilogue.hpp:119
static constexpr index_t kMWave
Definition: cshuffle_epilogue.hpp:50
static constexpr index_t kNPerXdl
Definition: cshuffle_epilogue.hpp:53
Definition: cshuffle_epilogue.hpp:25
static constexpr index_t kNWave
Definition: cshuffle_epilogue.hpp:33
static constexpr index_t isCTransposed
Definition: cshuffle_epilogue.hpp:37
static constexpr index_t kMPerBlock
Definition: cshuffle_epilogue.hpp:30
static constexpr index_t kMPerXdl
Definition: cshuffle_epilogue.hpp:34
remove_cvref_t< CLayout_ > CLayout
Definition: cshuffle_epilogue.hpp:28
static constexpr index_t kNPerXdl
Definition: cshuffle_epilogue.hpp:35
remove_cvref_t< AccDataType_ > AccDataType
Definition: cshuffle_epilogue.hpp:26
static constexpr index_t kKPerXdl
Definition: cshuffle_epilogue.hpp:36
static constexpr index_t kBlockSize
Definition: cshuffle_epilogue.hpp:29
remove_cvref_t< ODataType_ > ODataType
Definition: cshuffle_epilogue.hpp:27
static constexpr index_t kMWave
Definition: cshuffle_epilogue.hpp:32
static constexpr index_t kNPerBlock
Definition: cshuffle_epilogue.hpp:31
Class creating 2D static tile distribution with different load/store patterns.
Definition: static_encoding_pattern.hpp:61
Definition: integral_constant.hpp:13
Definition: sequence.hpp:52
Definition: space_filling_curve.hpp:20
Definition: functional.hpp:43