/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp Source File
default_2d_epilogue.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
12 // this epilogue just store out a M*N matrix, row major
13 
14 template <typename AccDataType_,
15  typename ODataType_,
16  bool kPadM_,
17  bool kPadN_,
18  bool UseRawStore_ = true>
20 {
23  static constexpr bool kPadM = kPadM_;
24  static constexpr bool kPadN = kPadN_;
25  static constexpr bool UseRawStore = UseRawStore_;
26  static constexpr index_t NumDTensor = 0;
27 };
28 
29 template <typename AsDataType_,
30  typename BsDataType_,
31  typename DsDataType_,
32  typename AccDataType_,
33  typename ODataType_,
34  typename DsLayout_,
35  typename CLayout_,
36  typename CDElementwise_,
37  index_t kM_,
38  index_t kN_,
39  bool kPadM_,
40  bool kPadN_,
41  index_t kMPerXdl_,
42  index_t kNPerXdl_,
43  index_t kKPerXdl_,
44  bool isCTransposed_,
45  bool UseRawStore_ = true>
47  : public Default2DEpilogueProblem<AccDataType_, ODataType_, kPadM_, kPadN_, UseRawStore_>
48 {
55  static constexpr index_t kMPerBlock = kM_;
56  static constexpr index_t kNPerBlock = kN_;
57  static constexpr index_t kMPerXdl = kMPerXdl_;
58  static constexpr index_t kNPerXdl = kNPerXdl_;
59  static constexpr index_t kKPerXdl = kKPerXdl_;
60  static constexpr index_t isCTransposed = isCTransposed_;
61 
62  static constexpr index_t NumDTensor = DsDataType::size();
63 
64  static_assert(NumDTensor == DsLayout::size(),
65  "The size of DsDataType and DsLayout should be the same");
66 };
67 
68 template <typename Problem_, typename Policy_ = void>
70 {
74  static constexpr bool kPadM = Problem::kPadM;
75  static constexpr bool kPadN = Problem::kPadN;
76  static constexpr bool UseRawStore = Problem::UseRawStore;
77 
78  CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
79 
80  // TODO: this function assume store out vector size is the same as OAccTile last dimension size
81  // how do we fix this ?
82  template <typename ODramWindowTmp, typename OAccTile, typename DsDramWindows>
83  CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
84  const OAccTile& o_acc_tile,
85  const DsDramWindows& ds_dram_windows,
86  void* = nullptr) const
87  {
88  constexpr bool is_partition_index =
89  std::is_convertible_v<decltype(ds_dram_windows),
90  decltype(get_partition_index(
91  o_acc_tile.get_tile_distribution()))>;
92 
93  const auto storeOrUpdateTile = [&](const auto& o_tile) {
94  // TODO: this is ugly
95  if constexpr(UseRawStore && (kPadM || kPadN))
96  {
97  // FIXME?
98  // if constexpr(decltype(o_dram_window_tmp.get_bottom_tensor_view())::DstInMemOp ==
99  // memory_operation_enum::set)
100  if constexpr(true)
101  {
102  if constexpr(is_partition_index)
103  {
104  store_tile_raw(o_dram_window_tmp,
105  cast_tile<ODataType>(o_tile),
106  /*partition_index=*/ds_dram_windows);
107  }
108  else
109  {
110  store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
111  }
112  }
113  else
114  {
115  update_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
116  }
118  }
119  else
120  {
121  // FIXME?
122  // if constexpr(decltype(o_dram_window_tmp.get_bottom_tensor_view())::DstInMemOp ==
123  // memory_operation_enum::set)
124  if constexpr(true)
125  {
126  if constexpr(is_partition_index)
127  {
128  store_tile(o_dram_window_tmp,
129  cast_tile<ODataType>(o_tile),
130  /*partition_index=*/ds_dram_windows);
131  }
132  else
133  {
134  store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
135  }
136  }
137  else
138  {
139  if constexpr(is_partition_index)
140  {
141  update_tile(o_dram_window_tmp,
142  cast_tile<ODataType>(o_tile),
143  /*partition_index=*/ds_dram_windows);
144  }
145  else
146  {
147  update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
148  }
149  }
150  }
151  };
152 
153  if constexpr(!std::is_same_v<DsDramWindows, std::nullptr_t> && !is_partition_index &&
154  Problem::NumDTensor >= 1)
155  {
156  using elementwise_result_t = decltype(load_tile(
157  make_tile_window(ds_dram_windows[number<0>{}].get_bottom_tensor_view(),
158  make_tuple(Problem::kMPerBlock, Problem::kNPerBlock),
159  ds_dram_windows[number<0>{}].get_window_origin(),
160  o_acc_tile.get_tile_distribution())));
161 
162  elementwise_result_t elementwise_result;
163 
164  const auto d_tensor_tuple = generate_tuple(
165  [&](auto idx) {
166  const auto d_tile_window =
167  make_tile_window(ds_dram_windows[idx], o_acc_tile.get_tile_distribution());
168  return load_tile(d_tile_window);
169  },
171 
172  const auto c_d_tuple = concat_tuple_of_reference(
173  tie(elementwise_result, o_acc_tile),
174  generate_tie([&](auto idx) -> const auto& { return d_tensor_tuple[idx]; },
176 
177  tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_d_tuple);
178 
179  storeOrUpdateTile(elementwise_result);
180  }
181  else
182  {
183  storeOrUpdateTile(o_acc_tile);
184  }
185  }
186 };
187 
188 template <typename Problem_, typename Policy_ = void>
189 struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
190 {
198 
202 
206 
209  // Used for weight-only quantization kernel, B would be dequantized to the same data type as A
210  using BTypeToUse =
211  std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
212 
217  static constexpr index_t kMPerXdl = Problem::kMPerXdl;
218  static constexpr index_t kNPerXdl = Problem::kNPerXdl;
219  static constexpr index_t kKPerXdl = Problem::kKPerXdl;
220  static constexpr index_t isCTransposed = Problem::isCTransposed;
221 
223  BTypeToUse,
224  AccDataType,
225  kMPerXdl,
226  kNPerXdl,
227  kKPerXdl,
228  isCTransposed>;
229 
230  using CWarpDstr = typename WG::CWarpDstr;
231 
232  CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
233  {
234  // N is contiguous dimension
235  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
236  {
237  if constexpr(isCTransposed)
238  {
239  // In this case each thread has multiple consecutive elements in
240  // N dimension, however consecutive threads' elements have stride.
241  constexpr index_t NDimY = CWarpDstr::NDimY;
242  constexpr auto c_warp_y_lengths =
243  CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
244  static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
245  c_warp_y_lengths.get(number<NDimY - 1>{}));
246  return c_warp_y_lengths.get(number<NDimY - 1>{});
247  }
248  else
249  {
250  // In this case each thread has just a single item in Ndim
251  return (WG::WarpGemmAttribute::Impl::kCNLane *
252  WG::WarpGemmAttribute::Impl::kBNBlock) /
253  WG::kN;
254  }
255  }
256  // M is contiguous dimension
257  else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
258  {
259  if constexpr(isCTransposed)
260  {
261  // In this case each thread has just a single item in Mdim
262  return (WG::WarpGemmAttribute::Impl::kCNLane *
263  WG::WarpGemmAttribute::Impl::kAMBlock) /
264  WG::kN;
265  }
266  else
267  {
268  // In this case each thread has multiple consecutive elements in
269  // M dimension, however consecutive threads' elements have stride.
270  constexpr index_t NDimY = CWarpDstr::NDimY;
271  constexpr auto c_warp_y_lengths =
272  CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
273  static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
274  c_warp_y_lengths.get(number<NDimY - 1>{}));
275  return c_warp_y_lengths.get(number<NDimY - 1>{});
276  }
277  }
278  else
279  {
280  static_assert(false, "Unsupported CLayout!");
281  }
282  }
283 
284  template <index_t I>
285  CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeD([[maybe_unused]] number<I> index)
286  {
287  return GetVectorSizeC();
288  }
289 };
290 
291 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
Definition: cluster_descriptor.hpp:13
constexpr tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:376
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto generate_tie(F &&f, number< N >)
Definition: tuple.hpp:435
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename impl::warp_gemm_dispatcher::Dispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition: warp_gemm_dispatcher.hpp:176
CK_TILE_DEVICE auto tile_elementwise_inout_unpack(const InElementFunc &in_element_func, const Tuple &t, std::index_sequence< I... >)
Template function that "unpacks" a tuple and applies an element-wise operation.
Definition: tile_elementwise.hpp:71
CK_TILE_DEVICE void buffer_store_fence(index_t cnt=0)
Definition: amd_buffer_addressing.hpp:1064
CK_TILE_DEVICE void store_tile_raw(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:72
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto concat_tuple_of_reference(const tuple< X &... > &tx, const tuple< Y &... > &ty)
Definition: tuple.hpp:443
CK_TILE_DEVICE void update_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: update_tile.hpp:22
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:24
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:36
CK_TILE_DEVICE void update_tile_raw(tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition: update_tile.hpp:68
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: default_2d_epilogue.hpp:70
remove_cvref_t< typename Problem::ODataType > ODataType
Definition: default_2d_epilogue.hpp:73
CK_TILE_DEVICE auto operator()(ODramWindowTmp &o_dram_window_tmp, const OAccTile &o_acc_tile, const DsDramWindows &ds_dram_windows, void *=nullptr) const
Definition: default_2d_epilogue.hpp:83
static constexpr bool kPadN
Definition: default_2d_epilogue.hpp:75
remove_cvref_t< Problem_ > Problem
Definition: default_2d_epilogue.hpp:71
static constexpr bool kPadM
Definition: default_2d_epilogue.hpp:74
remove_cvref_t< typename Problem::AccDataType > AccDataType
Definition: default_2d_epilogue.hpp:72
static constexpr bool UseRawStore
Definition: default_2d_epilogue.hpp:76
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: default_2d_epilogue.hpp:78
Definition: default_2d_epilogue.hpp:20
remove_cvref_t< ODataType_ > ODataType
Definition: default_2d_epilogue.hpp:22
static constexpr index_t NumDTensor
Definition: default_2d_epilogue.hpp:26
static constexpr bool kPadM
Definition: default_2d_epilogue.hpp:23
remove_cvref_t< AccDataType_ > AccDataType
Definition: default_2d_epilogue.hpp:21
static constexpr bool kPadN
Definition: default_2d_epilogue.hpp:24
static constexpr bool UseRawStore
Definition: default_2d_epilogue.hpp:25
Definition: default_2d_epilogue.hpp:190
remove_cvref_t< typename Problem::BsDataType > BsDataType
Definition: default_2d_epilogue.hpp:193
remove_cvref_t< typename Problem::AccDataType > AccDataType
Definition: default_2d_epilogue.hpp:194
static constexpr index_t kMPerXdl
Definition: default_2d_epilogue.hpp:217
static constexpr bool ADataTypeIsTuple
Definition: default_2d_epilogue.hpp:196
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsDataTypeTuple > > BDataType
Definition: default_2d_epilogue.hpp:208
static constexpr index_t kNPerXdl
Definition: default_2d_epilogue.hpp:218
std::conditional_t< BDataTypeIsTuple, remove_cvref_t< BsDataType >, remove_cvref_t< tuple< BsDataType > >> BsDataTypeTuple
Definition: default_2d_epilogue.hpp:205
static constexpr index_t kKPerXdl
Definition: default_2d_epilogue.hpp:219
remove_cvref_t< typename Problem::CDElementwise > CDElementwise
Definition: default_2d_epilogue.hpp:215
static constexpr CK_TILE_HOST_DEVICE auto GetVectorSizeC()
Definition: default_2d_epilogue.hpp:232
static constexpr index_t isCTransposed
Definition: default_2d_epilogue.hpp:220
remove_cvref_t< typename Problem::CLayout > CLayout
Definition: default_2d_epilogue.hpp:216
remove_cvref_t< typename Problem::ODataType > ODataType
Definition: default_2d_epilogue.hpp:195
remove_cvref_t< typename Problem::DsDataType > DsDataType
Definition: default_2d_epilogue.hpp:213
remove_cvref_t< typename Problem::AsDataType > AsDataType
Definition: default_2d_epilogue.hpp:192
std::conditional_t< std::is_same_v< BDataType, pk_int4_t >, ADataType, BDataType > BTypeToUse
Definition: default_2d_epilogue.hpp:211
static constexpr CK_TILE_HOST_DEVICE auto GetVectorSizeD([[maybe_unused]] number< I > index)
Definition: default_2d_epilogue.hpp:285
WarpGemmDispatcher< ADataType, BTypeToUse, AccDataType, kMPerXdl, kNPerXdl, kKPerXdl, isCTransposed > WG
Definition: default_2d_epilogue.hpp:228
std::conditional_t< ADataTypeIsTuple, remove_cvref_t< AsDataType >, remove_cvref_t< tuple< AsDataType > >> AsDataTypeTuple
Definition: default_2d_epilogue.hpp:201
static constexpr bool BDataTypeIsTuple
Definition: default_2d_epilogue.hpp:197
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsDataTypeTuple > > ADataType
Definition: default_2d_epilogue.hpp:207
typename WG::CWarpDstr CWarpDstr
Definition: default_2d_epilogue.hpp:230
remove_cvref_t< typename Problem::DsLayout > DsLayout
Definition: default_2d_epilogue.hpp:214
remove_cvref_t< Problem_ > Problem
Definition: default_2d_epilogue.hpp:191
Definition: default_2d_epilogue.hpp:48
remove_cvref_t< AsDataType_ > AsDataType
Definition: default_2d_epilogue.hpp:49
static constexpr index_t NumDTensor
Definition: default_2d_epilogue.hpp:62
remove_cvref_t< BsDataType_ > BsDataType
Definition: default_2d_epilogue.hpp:50
static constexpr index_t kMPerXdl
Definition: default_2d_epilogue.hpp:57
static constexpr index_t kMPerBlock
Definition: default_2d_epilogue.hpp:55
static constexpr index_t kKPerXdl
Definition: default_2d_epilogue.hpp:59
remove_cvref_t< DsLayout_ > DsLayout
Definition: default_2d_epilogue.hpp:54
remove_cvref_t< CLayout_ > CLayout
Definition: default_2d_epilogue.hpp:51
static constexpr index_t isCTransposed
Definition: default_2d_epilogue.hpp:60
remove_cvref_t< DsDataType_ > DsDataType
Definition: default_2d_epilogue.hpp:52
static constexpr index_t kNPerBlock
Definition: default_2d_epilogue.hpp:56
remove_cvref_t< CDElementwise_ > CDElementwise
Definition: default_2d_epilogue.hpp:53
static constexpr index_t kNPerXdl
Definition: default_2d_epilogue.hpp:58
Definition: integral_constant.hpp:13