/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp Source File
elementwise_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
10 namespace ck_tile {
11 
12 template <typename Problem_, typename Policy_>
14 {
17 
22 
23  static constexpr index_t kBlockSize = Problem::BlockShape::kBlockSize;
24  CK_TILE_HOST static constexpr auto BlockSize()
25  {
26  return is_wave32() ? kBlockSize / 2 : kBlockSize;
27  }
28 
29  template <typename... XDataType, typename Dims>
30  CK_TILE_DEVICE void operator()(const Dims lens,
31  const Dims input_strides,
32  const Dims output_strides,
33  const tuple<XDataType...>& input_tensors,
34  YDataType* p_y) const
35  {
36  using S = typename Problem::BlockShape;
37 
38  // Setup block-level coordinates and transforms
39  const index_t iM = get_block_id() * S::kBlockM;
40  const auto merge_transform = make_merge_transform(lens);
41 
42  // Load all input tiles into registers.
43  // The lambda structure here is intended to minimize the lifetime
44  // of intermediate objects (views, windows) used for loading.
45  const auto x_tiles = ck_tile::generate_tuple(
46  [&](auto i) {
47  const auto tensor_view = make_naive_tensor_view<address_space_enum::global>(
48  input_tensors.get(i), lens, input_strides, number<S::kVectorM>{}, number<1>{});
49 
50  const auto transformed_tensor = pad_tensor_view(
52  ck_tile::make_tuple(merge_transform),
57 
58  const auto x_window =
59  make_tile_window(transformed_tensor,
61  {iM},
62  Policy::template MakeXBlockTileDistribution<Problem>());
63 
64  return load_tile(x_window);
65  },
66  number<sizeof...(XDataType)>{});
67 
68  // Setup output tile in registers.
69  const auto& x_tile0 = x_tiles.get(number<0>{});
70  auto y_tile = make_static_distributed_tensor<YDataType>(x_tile0.get_tile_distribution());
71 
72  // Perform element-wise computation.
73  const auto spans = x_tile0.get_distributed_spans();
74  sweep_tile_span(spans[number<0>{}], [&](auto idx) {
75  const auto tile_idx = make_tuple(idx);
76  apply(
77  [&](auto&&... tiles) {
78  ElementWiseOperation{}(y_tile(tile_idx),
79  type_convert<ComputeDataType>(tiles[tile_idx])...);
80  },
81  x_tiles);
82  });
83 
84  // Setup output window and store the result tile.
85  const auto y_m_n = make_naive_tensor_view<address_space_enum::global>(
86  p_y, lens, output_strides, number<S::kVectorM>{});
87 
88  const auto transformed_y_m_n = pad_tensor_view(
90  ck_tile::make_tuple(merge_transform),
95 
96  auto y_window = make_tile_window(transformed_y_m_n,
98  {iM},
99  y_tile.get_tile_distribution());
100 
101  store_tile(y_window, cast_tile<YDataType>(y_tile));
102  }
103 
104  template <typename... Ints>
106  {
107  int total_elements = 1;
108  const auto kVectorM = Problem_::BlockShape::kVectorM;
109 
110  apply([&](auto&&... args) { ((total_elements *= args), ...); }, input_sizes);
111 
112  if((total_elements % kVectorM) != 0)
113  {
114  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
115  {
116  CK_TILE_ERROR("Conditions not met: total number of input elements (",
117  total_elements,
118  ") should be multiple of the vectorization size (",
119  kVectorM,
120  ")");
121  }
122  return false;
123  }
124 
125  return true;
126  }
127 };
128 
129 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
constexpr decltype(auto) apply(F &&f, Tuple &&t)
Definition: tuple.hpp:526
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:511
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1615
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
typename __make_integer_seq< impl::__integer_sequence, index_t, N >::seq_type make_index_sequence
Definition: sequence.hpp:231
Definition: elementwise_kernel.hpp:14
ck_tile::remove_cvref_t< Problem_ > Problem
Definition: elementwise_kernel.hpp:15
ck_tile::remove_cvref_t< typename Problem::XDataType > XDataType
Definition: elementwise_kernel.hpp:18
ck_tile::remove_cvref_t< Policy_ > Policy
Definition: elementwise_kernel.hpp:16
ck_tile::remove_cvref_t< typename Problem::ElementWiseOperation > ElementWiseOperation
Definition: elementwise_kernel.hpp:21
static constexpr CK_TILE_HOST auto BlockSize()
Definition: elementwise_kernel.hpp:24
CK_TILE_DEVICE void operator()(const Dims lens, const Dims input_strides, const Dims output_strides, const tuple< XDataType... > &input_tensors, YDataType *p_y) const
Definition: elementwise_kernel.hpp:30
static constexpr index_t kBlockSize
Definition: elementwise_kernel.hpp:23
static CK_TILE_HOST bool IsSupportedArgument(const ck_tile::tuple< Ints... > &input_sizes)
Definition: elementwise_kernel.hpp:105
ck_tile::remove_cvref_t< typename Problem::YDataType > YDataType
Definition: elementwise_kernel.hpp:20
ck_tile::remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition: elementwise_kernel.hpp:19
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49
Definition: tensor_view.hpp:41
Definition: tuple.hpp:192
#define CK_TILE_ENV(name)
Definition: env.hpp:145