18 template <RotaryEmbeddingEnum>
24 static constexpr
const char* name =
"";
29 static constexpr
const char* name =
"inter";
34 static constexpr
const char* name =
"half";
37 template <RotaryEmbeddingEnum RotaryEnum,
typename ComputeDataType =
float>
40 template <
typename DistributedTensor,
41 typename OtherDramBlockWindow,
42 typename RotaryCosDramBlockWindow,
43 typename RotarySinDramBlockWindow>
45 OtherDramBlockWindow other_window,
46 RotaryCosDramBlockWindow rotary_cos_window,
47 RotarySinDramBlockWindow rotary_sin_window,
55 auto rotary_cos_tile =
load_tile(rotary_cos_window);
56 auto rotary_sin_tile =
load_tile(rotary_sin_window);
58 if(thread_end <= rotary_dim)
60 constexpr
index_t thread_buffer_size = decltype(tile.thread_buf_)::size();
62 const auto left = type_convert<ComputeDataType>(tile.thread_buf_[idx]);
63 const auto right = type_convert<ComputeDataType>(tile.thread_buf_[idx + 1]);
66 type_convert<ComputeDataType>(rotary_cos_tile.thread_buf_[idx / 2]);
68 type_convert<ComputeDataType>(rotary_sin_tile.thread_buf_[idx / 2]);
70 tile.thread_buf_[idx] = type_convert<DataType>(left *
cos - right *
sin);
71 tile.thread_buf_[idx + 1] = type_convert<DataType>(right *
cos + left *
sin);
77 if(thread_end <= rotary_dim)
79 const bool is_left = (thread_end <= (rotary_dim / 2));
81 move_tile_window(other_window, {0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)});
82 auto other_tile =
load_tile(other_window);
85 auto rotary_cos_tile =
load_tile(rotary_cos_window);
88 auto rotary_sin_tile =
load_tile(rotary_sin_window);
90 constexpr
index_t thread_buffer_size = decltype(tile.thread_buf_)::size();
92 const auto curr = type_convert<ComputeDataType>(tile.thread_buf_[idx]);
93 const auto other = type_convert<ComputeDataType>(other_tile.thread_buf_[idx]);
96 type_convert<ComputeDataType>(rotary_cos_tile.thread_buf_[idx]);
98 type_convert<ComputeDataType>(rotary_sin_tile.thread_buf_[idx]);
100 tile.thread_buf_[idx] =
101 type_convert<DataType>(curr *
cos + other * (is_left ? -
sin :
sin));
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
RotaryEmbeddingEnum
Definition: block_rotary_embedding.hpp:12
CK_TILE_HOST T cos(T x)
Definition: math.hpp:752
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
CK_TILE_HOST T sin(T x)
Definition: math.hpp:698
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
Definition: block_rotary_embedding.hpp:39
static CK_TILE_HOST_DEVICE void apply(DistributedTensor &tile, OtherDramBlockWindow other_window, RotaryCosDramBlockWindow rotary_cos_window, RotarySinDramBlockWindow rotary_sin_window, index_t rotary_dim, index_t thread_end)
Definition: block_rotary_embedding.hpp:44
Definition: block_rotary_embedding.hpp:19
Definition: functional.hpp:43