/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp Source File
reference_batched_rotary_position_embedding.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 
9 #include <cassert>
10 #include <thread>
11 
12 namespace ck_tile {
13 
14 template <typename DataType, typename ComputeDataType = float>
16  const HostTensor<DataType>& cos_sd,
17  const HostTensor<DataType>& sin_sd,
18  bool interleaved,
19  HostTensor<DataType>& output_bsd,
20  bool use_1_row_sin_cos = false)
21 {
22  assert(cos_sd.get_num_of_dimension() == 2 && sin_sd.get_num_of_dimension() == 2);
23  assert(cos_sd.get_length(0) == sin_sd.get_length(0) &&
24  cos_sd.get_length(1) == sin_sd.get_length(1));
25 
26  const index_t rotary_dim = cos_sd.get_length(1) * 2;
27  assert(static_cast<std::size_t>(rotary_dim) <= input_bsd.get_length(2));
28 
29  output_bsd.ForEach([&](auto& self, auto i) {
30  const index_t i_d = i[2];
31  if(rotary_dim <= i_d)
32  {
33  self(i) = input_bsd(i);
34  return;
35  }
36  assert(i_d < rotary_dim);
37 
38  const index_t i_s = i[1];
39  const index_t i_s_cos_sin = (use_1_row_sin_cos ? 0 : i_s);
40 
41  const ComputeDataType cos = type_convert<ComputeDataType>(
42  interleaved ? cos_sd(i_s_cos_sin, i_d / 2)
43  : cos_sd(i_s_cos_sin, i_d % cos_sd.get_length(1)));
44  const ComputeDataType sin = type_convert<ComputeDataType>(
45  interleaved ? sin_sd(i_s_cos_sin, i_d / 2)
46  : sin_sd(i_s_cos_sin, i_d % sin_sd.get_length(1)));
47 
48  const ComputeDataType half_rotated_input = [&] {
49  const index_t i_b = i[0];
50 
51  if(interleaved)
52  {
53  const bool is_even = (i_d % 2 == 0);
54  const index_t pos = i_d + (is_even ? 1 : -1);
55  const ComputeDataType sign = (is_even ? -1 : 1);
56  return sign * type_convert<ComputeDataType>(input_bsd(i_b, i_s, pos));
57  }
58  else
59  {
60  const index_t half_rdim = (rotary_dim / 2);
61  const index_t pos = (i_d + half_rdim) % rotary_dim;
62  const ComputeDataType sign = (pos < half_rdim ? 1 : -1);
63  return sign * type_convert<ComputeDataType>(input_bsd(i_b, i_s, pos));
64  }
65  }();
66  ComputeDataType result =
67  type_convert<ComputeDataType>(input_bsd(i)) * cos + half_rotated_input * sin;
68 
69  self(i) = type_convert<DataType>(result);
70  });
71 }
72 
73 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST T cos(T x)
Definition: math.hpp:752
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST T sin(T x)
Definition: math.hpp:698
CK_TILE_HOST void reference_batched_rotary_position_embedding(const HostTensor< DataType > &input_bsd, const HostTensor< DataType > &cos_sd, const HostTensor< DataType > &sin_sd, bool interleaved, HostTensor< DataType > &output_bsd, bool use_1_row_sin_cos=false)
Definition: reference_batched_rotary_position_embedding.hpp:15
Definition: host_tensor.hpp:336
void ForEach(F &&f)
Definition: host_tensor.hpp:437
std::size_t get_num_of_dimension() const
Definition: host_tensor.hpp:396
std::size_t get_length(std::size_t dim) const
Definition: host_tensor.hpp:388