26 __host__ __device__ constexpr
bool HasSlice(T&&)
30 template <
typename... Ts>
31 __host__ __device__ constexpr
bool HasSlice(Tuple<Ts...>&&)
33 return (HasSlice(Ts{}) || ...);
43 template <
typename... Ts,
typename SlicedShape>
44 __host__ __device__ constexpr
auto GetSlicedShape(
const Tuple<Ts...>& idxs,
45 const SlicedShape&
shape)
50 constexpr
auto num_i = Number<i>{};
53 if constexpr(!detail::HasSlice(
tuple_element_t<i.value, Tuple<Ts...>>{}))
67 const auto& dim = size(
shape.At(num_i));
68 const auto val = idxs.At(num_i).range(dim);
77 Number<Tuple<Ts...>::Size()>{});
79 return UnrollNestedTuple<0, 1>(new_shape);
89 template <
typename T,
typename Shape>
90 __host__ __device__ constexpr
auto GenerateMultipleFreeze(T idx,
const Shape&
shape)
96 const auto dim = unrolled_shape.At(Number<i>{});
97 const auto dim_idx = idx % dim;
101 Number<decltype(unrolled_shape)::Size()>{});
111 template <
typename... Ts,
typename Shape>
112 __host__ __device__ constexpr
auto GenerateSliceTransforms(
const Tuple<Ts...>& idx,
118 constexpr
auto num_i = Number<i>{};
121 return GenerateSliceTransforms(idx.At(num_i),
shape.At(num_i));
126 const auto from = idx.At(num_i).from_;
127 const auto dim = size<num_i>(
shape);
128 const auto range = idx.At(num_i).range(dim);
134 return GenerateMultipleFreeze(idx.At(num_i),
shape.At(num_i));
137 Number<Tuple<Ts...>::Size()>{});
142 template <index_t i,
typename LowerIndex>
149 template <index_t i,
typename LowLength,
typename SliceBegin,
typename SliceEnd>
152 return Sequence<i>{};
156 __host__ __device__ constexpr
auto GenerateUpperDims(
const Tuple<>&)
161 template <
index_t i,
typename... Transforms>
162 __host__ __device__ constexpr
auto GenerateUpperDims(
const Tuple<Transforms...>& transforms)
164 constexpr
auto num_transforms = Tuple<Transforms...>::Size();
166 const auto current_elem = GetSequenceVal<i>(transforms.At(Number<0>{}));
167 if constexpr(
is_same_v<decltype(current_elem),
const Sequence<>>)
169 const auto next_tuple = GenerateUpperDims<i>(TupleSlice<1, num_transforms>(transforms));
175 const auto next_tuple = GenerateUpperDims<i + 1>(TupleSlice<1, num_transforms>(transforms));
180 template <
typename... Ts,
typename Shape,
typename UnrolledDescriptor>
181 __host__ __device__ constexpr
auto GenerateSlicedDescriptor(
const Tuple<Ts...>& idx,
183 const UnrolledDescriptor& flatten_desc)
187 const auto transforms = GenerateSliceTransforms(idx,
shape);
188 using TransformsTupleType = decltype(transforms);
190 const auto lower_dims =
191 generate_tuple([&](
auto i) {
return Sequence<i.value>{}; }, Number<old_shape_dims>{});
192 const auto upper_dims = decltype(GenerateUpperDims<0>(TransformsTupleType{})){};
210 typename ElementType,
212 typename UnrolledDescriptorType>
221 typename scalar_type<std::remove_const_t<ElementType>>::type>;
224 static constexpr
bool IsDynamicBuffer = !(BufferAddressSpace == MemoryTypeEnum ::Sgpr ||
225 BufferAddressSpace == MemoryTypeEnum ::Vgpr);
235 static_assert(IsDynamicBuffer,
"Wrong BufferAddressSpace for register.");
243 static_assert(!IsDynamicBuffer,
"Wrong BufferAddressSpace for register.");
257 template <
typename... Ts,
enable_if_t<detail::HasSlice(Tuple<Ts...>{}),
bool> =
false>
258 __host__ __device__
auto operator[](
const Tuple<Ts...>& idx)
260 static_assert(IsDynamicBuffer,
"Register slice is not supported");
261 const auto&
shape = layout_.GetShape();
262 auto new_shape = detail::GetSlicedShape(idx,
shape);
264 const auto& flatten_desc = layout_.GetUnrolledDescriptor();
265 auto new_desc = detail::GenerateSlicedDescriptor(idx,
shape, flatten_desc);
266 const auto new_layout =
269 base_offset_ -= new_layout(
make_tuple(Number<0>{}));
270 return make_tensor<BufferAddressSpace>(buffer_.p_data_, new_layout);
273 template <
typename... Ts,
enable_if_t<detail::HasSlice(Tuple<Ts...>{}),
bool> =
false>
274 __host__ __device__
auto operator()(
const Tuple<Ts...>& idx)
276 return this->operator[](idx);
279 template <
typename... Idxs,
enable_if_t<detail::HasSlice(Tuple<Idxs...>{}),
bool> =
false>
280 __host__ __device__
auto operator()(Idxs... idxs)
291 template <
typename... Ts,
enable_if_t<!detail::HasSlice(Tuple<Ts...>{}),
bool> =
false>
294 if constexpr(IsDynamicBuffer)
296 const index_t offset = layout_(idx) + base_offset_;
297 return buffer_[offset];
303 UnrolledDescriptorType{}}.template operator()<Tuple<Ts...>>();
307 UnrolledDescriptorType{}}.template operator()<
MultiIndex<Shape::Size()>>();
308 return buffer_[Number<index_offset + base_offset>{}];
312 template <
typename... Ts,
enable_if_t<!detail::HasSlice(Tuple<Ts...>{}),
bool> =
false>
315 return this->operator[](idx);
318 template <
typename... Idxs,
enable_if_t<!detail::HasSlice(Tuple<Idxs...>{}),
bool> =
false>
330 template <
typename... Ts,
enable_if_t<!detail::HasSlice(Tuple<Ts...>{}),
bool> =
false>
333 if constexpr(IsDynamicBuffer)
335 const index_t offset = layout_(idx) + base_offset_;
336 return buffer_(offset);
342 UnrolledDescriptorType{}}.template operator()<Tuple<Ts...>>();
346 UnrolledDescriptorType{}}.template operator()<
MultiIndex<Shape::Size()>>();
347 return buffer_(Number<index_offset + base_offset>{});
351 template <
typename... Ts,
enable_if_t<!detail::HasSlice(Tuple<Ts...>{}),
bool> =
false>
354 return this->operator[](idx);
357 template <
typename... Idxs,
enable_if_t<!detail::HasSlice(Tuple<Idxs...>{}),
bool> =
false>
370 return layout_.GetMergedNestingDescriptor();
380 __host__ __device__ constexpr
auto&
GetBuffer() {
return buffer_; }
381 __host__ __device__ constexpr
auto&
GetBuffer()
const {
return buffer_; }
395 template <
typename MultiIdxOffsets>
398 multi_idx_offset_ = multi_idx_offset;
399 base_offset_ += layout_(multi_idx_offset);
405 using DynamicBufferType = DynamicBuffer<BufferAddressSpace,
411 StaticBuffer<BufferAddressSpace,
415 StaticBufferTupleOfVector<BufferAddressSpace,
418 scalar_type<std::remove_const_t<ElementType>>::vector_size,
419 scalar_type<std::remove_const_t<ElementType>>::vector_size,
422 using Buffer = std::conditional_t<IsDynamicBuffer, DynamicBufferType, StaticBufferType>;
__host__ constexpr __device__ const auto & shape(const LayoutType &layout)
Get Layout shape.
Definition: layout_utils.hpp:431
__host__ constexpr __device__ auto concat_tuple(const Tuple< X... > &tx, const Tuple< Y... > &ty)
Definition: tuple_helper.hpp:52
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
__host__ constexpr __device__ auto UnrollNestedTuple(const Tuple<> &element)
Definition: tuple_helper.hpp:120
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__host__ constexpr __device__ auto make_slice_transform(const LowLength &low_length, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition: multi_index_transform_helper.hpp:110
constexpr bool is_same_v
Definition: type.hpp:283
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: is_detected.hpp:34
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
__host__ constexpr __device__ auto make_zero_multi_index()
Definition: array_multi_index.hpp:21
decltype(ck::declval< T & >().IsTuple()) is_tuple
Definition: tuple_helper.hpp:176
Array< index_t, N > MultiIndex
Definition: array_multi_index.hpp:12
__host__ constexpr __device__ auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition: dynamic_buffer.hpp:461
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
integral_constant< index_t, N > Number
Definition: number.hpp:12
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
const GenericPointer< typename T::ValueType > & pointer
Definition: pointer.h:1249
Layout wrapper that performs the tensor descriptor logic.
Definition: layout.hpp:24
__host__ constexpr __device__ auto GetElementSpaceSize() const
Definition: layout.hpp:297
Tensor wrapper that performs static and dynamic buffer logic. The tensor is based on a descriptor sto...
Definition: host_tensor.hpp:277
__host__ constexpr __device__ auto & GetBuffer() const
Definition: tensor.hpp:381
__host__ __device__ TensorElementType * GetPointer() const
Get pointer to the data.
Definition: tensor.hpp:378
__host__ constexpr __device__ const Layout< Shape, UnrolledDescriptorType > & GetLayout() const
Definition: tensor.hpp:246
__host__ constexpr __device__ void SetMultiIdxOffset(const MultiIdxOffsets multi_idx_offset)
Apply multi index offset on the tensor.
Definition: tensor.hpp:396
__host__ constexpr __device__ auto & GetMultiIdxOffsets() const
Get multi index offset to the data.
Definition: tensor.hpp:388
std::conditional_t< is_scalar_type< ElementType >::value, ElementType, typename scalar_type< std::remove_const_t< ElementType > >::type > TensorElementType
Definition: tensor.hpp:221
__host__ constexpr __device__ auto GetMergedNestingDescriptor()
Get descriptor with all nested dimensions merged.
Definition: tensor.hpp:368
__host__ constexpr __device__ Tensor(const Layout< Shape, UnrolledDescriptorType > &layout)
Definition: tensor.hpp:238
__host__ __device__ Tensor()=delete
__host__ constexpr __device__ Tensor(ElementType *pointer, const Layout< Shape, UnrolledDescriptorType > &layout)
Definition: tensor.hpp:228
decltype(Layout< Shape, UnrolledDescriptorType >{ Shape{}, UnrolledDescriptorType{}}.GetElementSpaceSize()) ElementSpaceSize
Definition: tensor.hpp:217
__host__ constexpr __device__ auto & GetBuffer()
Definition: tensor.hpp:380
Definition: multi_index_transform.hpp:1558
Definition: multi_index_transform.hpp:1776
static constexpr bool value
Definition: data_type.hpp:218
__host__ constexpr __device__ const auto & layout(const Tensor< BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType > &tensor)
Get Tensor Layout.
Definition: tensor_utils.hpp:162
AddressSpaceEnum MemoryTypeEnum
Memory type, allowed members:
Definition: tensor_utils.hpp:30