/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/epilogue/chainer/common_epilogue_ops.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/epilogue/chainer/common_epilogue_ops.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/epilogue/chainer/common_epilogue_ops.hpp Source File
common_epilogue_ops.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 
20 namespace ck_tile {
21 
27 {
28  template <typename OutWindow,
29  typename AccTile,
30  typename AuxWindows,
31  typename IAccess,
32  typename Context,
33  typename ScaleA,
34  typename ScaleB>
35  CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window,
36  [[maybe_unused]] const AccTile& acc_tile,
37  [[maybe_unused]] const AuxWindows& aux_windows,
38  [[maybe_unused]] void* p_smem,
39  [[maybe_unused]] IAccess iAccess,
40  Context& context,
41  const ScaleA& scale_a,
42  const ScaleB& scale_b)
43  {
44  tile_elementwise_inout([&](auto& elem) { elem = elem * scale_a * scale_b; },
45  context.working_tile);
46  }
47 };
48 
56 template <typename DataType>
58 {
59  template <typename OutWindow,
60  typename AccTile,
61  typename AuxWindows,
62  typename IAccess,
63  typename Context>
64  CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window,
65  [[maybe_unused]] const AccTile& acc_tile,
66  [[maybe_unused]] const AuxWindows& aux_windows,
67  [[maybe_unused]] void* p_smem,
68  [[maybe_unused]] IAccess iAccess,
69  Context& context)
70  {
71  const auto casted_tile = cast_tile<DataType>(context.working_tile);
72  store_tile(context.lds_write_window, casted_tile);
73  }
74 };
75 
83 template <typename TileEncodingPattern>
85 {
86  template <typename OutWindow,
87  typename AccTile,
88  typename AuxWindows,
89  typename IAccess,
90  typename Context>
91  CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window,
92  [[maybe_unused]] const AccTile& acc_tile,
93  [[maybe_unused]] const AuxWindows& aux_windows,
94  [[maybe_unused]] void* p_smem,
95  [[maybe_unused]] IAccess iAccess,
96  Context& context)
97  {
98  constexpr auto tile_distribution = TileEncodingPattern::make_2d_static_tile_distribution();
100  context.out_tile = load_tile(make_tile_window(context.lds_read_window, tile_distribution));
101  }
102 };
103 
112 template <typename Elementwise, index_t NumAux>
114 {
115  template <typename OutWindow,
116  typename AccTile,
117  typename AuxWindows,
118  typename IAccess,
119  typename Context>
120  CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window,
121  [[maybe_unused]] const AccTile& acc_tile,
122  [[maybe_unused]] const AuxWindows& aux_windows,
123  [[maybe_unused]] void* p_smem,
124  [[maybe_unused]] IAccess iAccess,
125  Context& context)
126  {
127  const auto aux_tiles = generate_tuple(
128  [&](auto idx) { return load_tile(context.aux_windows[idx]); }, number<NumAux>{});
129 
130  const auto tiles = concat_tuple_of_reference(
131  tie(context.out_tile, context.out_tile),
132  generate_tie([&](auto idx) -> const auto& { return aux_tiles[idx]; },
133  number<NumAux>{}));
134 
135  tile_elementwise_inout_unpack(Elementwise{}, tiles);
136  }
137 };
138 
145 template <memory_operation_enum MemOp>
146 struct StoreOp
147 {
148  template <typename OutWindow,
149  typename AccTile,
150  typename AuxWindows,
151  typename IAccess,
152  typename Context>
153  CK_TILE_DEVICE void operator()(OutWindow& out_window,
154  [[maybe_unused]] const AccTile& acc_tile,
155  [[maybe_unused]] const AuxWindows& aux_windows,
156  [[maybe_unused]] void* p_smem,
157  [[maybe_unused]] IAccess iAccess,
158  Context& context)
159  {
160  if constexpr(MemOp == memory_operation_enum::set)
161  {
162  store_tile(out_window, context.out_tile);
163  }
164  else
165  {
166  update_tile(out_window, context.out_tile);
167  }
168  }
169 };
170 
178 template <typename SFC, index_t NumAux>
180 {
181  template <typename OutWindow,
182  typename AccTile,
183  typename AuxWindows,
184  typename IAccess,
185  typename Context>
186  CK_TILE_DEVICE void operator()(OutWindow& out_window,
187  [[maybe_unused]] const AccTile& acc_tile,
188  [[maybe_unused]] const AuxWindows& aux_windows,
189  [[maybe_unused]] void* p_smem,
190  IAccess iAccess,
191  Context& context)
192  {
193  constexpr index_t num_access = SFC::get_num_of_access();
194  if constexpr(iAccess != num_access - 1)
195  {
196  constexpr auto step = SFC::get_forward_step(iAccess);
197 
198  move_tile_window(out_window, {step.at(number<0>{}), step.at(number<1>{})});
199 
200  static_for<0, NumAux, 1>{}([&](auto idx) {
201  move_tile_window(context.aux_windows[idx],
202  {step.at(number<0>{}), step.at(number<1>{})});
203  });
204  }
205  }
206 };
207 
208 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition: tile_elementwise.hpp:23
constexpr tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:376
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto generate_tie(F &&f, number< N >)
Definition: tuple.hpp:435
CK_TILE_DEVICE auto tile_elementwise_inout_unpack(const InElementFunc &in_element_func, const Tuple &t, std::index_sequence< I... >)
Template function that "unpacks" a tuple and applies an element-wise operation.
Definition: tile_elementwise.hpp:71
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto concat_tuple_of_reference(const tuple< X &... > &tx, const tuple< Y &... > &ty)
Definition: tuple.hpp:443
CK_TILE_DEVICE void update_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: update_tile.hpp:22
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:24
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:36
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
Cast working tile and store to LDS.
Definition: common_epilogue_ops.hpp:58
CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow &out_window, [[maybe_unused]] const AccTile &acc_tile, [[maybe_unused]] const AuxWindows &aux_windows, [[maybe_unused]] void *p_smem, [[maybe_unused]] IAccess iAccess, Context &context)
Definition: common_epilogue_ops.hpp:64
Apply elementwise operation with auxiliary tensors.
Definition: common_epilogue_ops.hpp:114
CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow &out_window, [[maybe_unused]] const AccTile &acc_tile, [[maybe_unused]] const AuxWindows &aux_windows, [[maybe_unused]] void *p_smem, [[maybe_unused]] IAccess iAccess, Context &context)
Definition: common_epilogue_ops.hpp:120
Load output tile from LDS with synchronization.
Definition: common_epilogue_ops.hpp:85
CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow &out_window, [[maybe_unused]] const AccTile &acc_tile, [[maybe_unused]] const AuxWindows &aux_windows, [[maybe_unused]] void *p_smem, [[maybe_unused]] IAccess iAccess, Context &context)
Definition: common_epilogue_ops.hpp:91
Move output and auxiliary windows by step from space-filling curve.
Definition: common_epilogue_ops.hpp:180
CK_TILE_DEVICE void operator()(OutWindow &out_window, [[maybe_unused]] const AccTile &acc_tile, [[maybe_unused]] const AuxWindows &aux_windows, [[maybe_unused]] void *p_smem, IAccess iAccess, Context &context)
Definition: common_epilogue_ops.hpp:186
Scale working tile by scalar values.
Definition: common_epilogue_ops.hpp:27
CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow &out_window, [[maybe_unused]] const AccTile &acc_tile, [[maybe_unused]] const AuxWindows &aux_windows, [[maybe_unused]] void *p_smem, [[maybe_unused]] IAccess iAccess, Context &context, const ScaleA &scale_a, const ScaleB &scale_b)
Definition: common_epilogue_ops.hpp:35
Store output tile to global memory.
Definition: common_epilogue_ops.hpp:147
CK_TILE_DEVICE void operator()(OutWindow &out_window, [[maybe_unused]] const AccTile &acc_tile, [[maybe_unused]] const AuxWindows &aux_windows, [[maybe_unused]] void *p_smem, [[maybe_unused]] IAccess iAccess, Context &context)
Definition: common_epilogue_ops.hpp:153
Definition: integral_constant.hpp:13
Definition: functional.hpp:43
Definition: tile_distribution.hpp:70