/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp Source File
block_rotary_embedding.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <string>
7 
8 namespace ck_tile {
9 
10 // This class is used for codegen pattern matching
12 {
13  NONE = 0,
14  INTERLEAVED = 1, // combine dimensions 0 & 1, 2 & 3, etc
15  HALF_ROTATED = 2, // combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1, etc
16 };
17 
18 template <RotaryEmbeddingEnum>
20 
21 template <>
23 {
24  static constexpr const char* name = "";
25 };
26 template <>
28 {
29  static constexpr const char* name = "inter";
30 };
31 template <>
33 {
34  static constexpr const char* name = "half";
35 };
36 
37 template <RotaryEmbeddingEnum RotaryEnum, typename ComputeDataType = float>
39 {
40  template <typename DistributedTensor,
41  typename OtherDramBlockWindow,
42  typename RotaryCosDramBlockWindow,
43  typename RotarySinDramBlockWindow>
44  CK_TILE_HOST_DEVICE static void apply(DistributedTensor& tile,
45  OtherDramBlockWindow other_window,
46  RotaryCosDramBlockWindow rotary_cos_window,
47  RotarySinDramBlockWindow rotary_sin_window,
48  index_t rotary_dim,
49  index_t thread_end)
50  {
51  using DataType = typename remove_cvref_t<DistributedTensor>::DataType;
52 
53  if constexpr(RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED)
54  {
55  auto rotary_cos_tile = load_tile(rotary_cos_window);
56  auto rotary_sin_tile = load_tile(rotary_sin_window);
57 
58  if(thread_end <= rotary_dim)
59  {
60  constexpr index_t thread_buffer_size = decltype(tile.thread_buf_)::size();
61  static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
62  const auto left = type_convert<ComputeDataType>(tile.thread_buf_[idx]);
63  const auto right = type_convert<ComputeDataType>(tile.thread_buf_[idx + 1]);
64 
65  const auto cos =
66  type_convert<ComputeDataType>(rotary_cos_tile.thread_buf_[idx / 2]);
67  const auto sin =
68  type_convert<ComputeDataType>(rotary_sin_tile.thread_buf_[idx / 2]);
69 
70  tile.thread_buf_[idx] = type_convert<DataType>(left * cos - right * sin);
71  tile.thread_buf_[idx + 1] = type_convert<DataType>(right * cos + left * sin);
72  });
73  }
74  }
75  else if constexpr(RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
76  {
77  if(thread_end <= rotary_dim)
78  {
79  const bool is_left = (thread_end <= (rotary_dim / 2));
80 
81  move_tile_window(other_window, {0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)});
82  auto other_tile = load_tile(other_window);
83 
84  move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)});
85  auto rotary_cos_tile = load_tile(rotary_cos_window);
86 
87  move_tile_window(rotary_sin_window, {0, is_left ? 0 : -(rotary_dim / 2)});
88  auto rotary_sin_tile = load_tile(rotary_sin_window);
89 
90  constexpr index_t thread_buffer_size = decltype(tile.thread_buf_)::size();
91  static_for<0, thread_buffer_size, 1>{}([&](auto idx) {
92  const auto curr = type_convert<ComputeDataType>(tile.thread_buf_[idx]);
93  const auto other = type_convert<ComputeDataType>(other_tile.thread_buf_[idx]);
94 
95  const auto cos =
96  type_convert<ComputeDataType>(rotary_cos_tile.thread_buf_[idx]);
97  const auto sin =
98  type_convert<ComputeDataType>(rotary_sin_tile.thread_buf_[idx]);
99 
100  tile.thread_buf_[idx] =
101  type_convert<DataType>(curr * cos + other * (is_left ? -sin : sin));
102  });
103  }
104  }
105  }
106 };
107 
108 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
RotaryEmbeddingEnum
Definition: block_rotary_embedding.hpp:12
CK_TILE_HOST T cos(T x)
Definition: math.hpp:752
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
CK_TILE_HOST T sin(T x)
Definition: math.hpp:698
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
Definition: block_rotary_embedding.hpp:39
static CK_TILE_HOST_DEVICE void apply(DistributedTensor &tile, OtherDramBlockWindow other_window, RotaryCosDramBlockWindow rotary_cos_window, RotarySinDramBlockWindow rotary_sin_window, index_t rotary_dim, index_t thread_end)
Definition: block_rotary_embedding.hpp:44
Definition: block_rotary_embedding.hpp:19
Definition: functional.hpp:43