22 template <
typename Shape,
typename UnrolledDescriptorType>
28 static constexpr
auto I0 = Number<0>{};
29 static constexpr
auto I1 = Number<1>{};
37 template <
typename... Ts>
38 __host__ __device__ constexpr
static auto
39 GenerateDefaultIdxsTuple([[maybe_unused]]
const Tuple<Ts...>&
shape)
43 if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
54 Number<Tuple<Ts...>::Size()>{});
66 template <
typename Idx,
typename... Ts>
67 __host__ __device__ constexpr
static auto
68 GenerateLowerDim([[maybe_unused]]
const Tuple<Ts...>&
shape)
77 using LowerDimsSequence =
79 return LowerDimsSequence::Reverse();
90 using PreviousSeqT = decltype(GenerateLowerDim<Number<Idx::value - 1>>(Tuple<Ts...>{}));
91 const auto next_seq_val = PreviousSeqT::At(I0) + 1;
96 using LowerDimsSequence =
97 typename arithmetic_sequence_gen<next_seq_val, next_seq_val + merge_nelems, 1>::
99 return LowerDimsSequence::Reverse();
103 return Sequence<next_seq_val>{};
119 template <
typename... ShapeDims,
typename... IdxDims>
120 __host__ __device__ constexpr
static auto AlignShapeToIdx(
const Tuple<ShapeDims...>&
shape,
121 const Tuple<IdxDims...>& idx)
145 Number<Tuple<IdxDims...>::Size()>{});
148 return AlignShapeToIdx(UnrollNestedTuple<0, 1>(aligned_shape),
149 UnrollNestedTuple<0, 1>(idx));
160 template <
typename... ShapeDims,
typename DescriptorToMerge>
161 __host__ __device__ constexpr
static auto MakeMerge1d(
const Tuple<ShapeDims...>&
shape,
162 const DescriptorToMerge& desc)
167 using MergeElemsSequence =
typename arithmetic_sequence_gen<0, merge_elems.Size(), 1>::type;
168 const auto lower_dims =
make_tuple(MergeElemsSequence::Reverse());
169 const auto upper_dims =
make_tuple(Sequence<0>{});
171 if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
201 template <
typename... ShapeDims,
typename... IdxDims,
typename DescriptorToMerge>
202 __host__ __device__ constexpr
static auto
203 CreateMergedDescriptor(
const Tuple<ShapeDims...>&
shape,
204 [[maybe_unused]]
const Tuple<IdxDims...>& idxs,
205 DescriptorToMerge& desc)
217 if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
235 "Wrong Idx for layout()");
240 Number<Tuple<ShapeDims...>::Size()>{});
242 const auto lower_dims =
244 Number<Tuple<ShapeDims...>::Size()>{});
245 const auto upper_dims =
generate_tuple([&](
auto i) {
return Sequence<i.value>{}; },
246 Number<Tuple<ShapeDims...>::Size()>{});
251 using Descriptor1dType =
252 remove_cvref_t<decltype(MakeMerge1d(Shape{}, UnrolledDescriptorType{}))>;
253 using DefaultIdxsTupleType =
remove_cvref_t<decltype(GenerateDefaultIdxsTuple(Shape{}))>;
268 template <
typename... ShapeDims,
typename... IdxDims>
269 __host__ __device__ constexpr
static auto
271 const Tuple<IdxDims...>& idxs,
272 const UnrolledDescriptorType& naive_descriptor)
274 if constexpr(Tuple<IdxDims...>::Size() == I1)
277 return MakeMerge1d(
shape, naive_descriptor);
285 static_assert(Tuple<ShapeDims...>::Size() == Tuple<IdxDims...>::Size(),
286 "Idx rank and Shape rank must be the same (except 1d).");
288 const auto aligned_shape = AlignShapeToIdx(
shape, idxs);
290 return CreateMergedDescriptor(aligned_shape,
UnrollNestedTuple(idxs), naive_descriptor);
295 Shape{}, DefaultIdxsTupleType{}, UnrolledDescriptorType{}))>;
299 return unrolled_descriptor_.GetElementSpaceSize();
311 const UnrolledDescriptorType& unnested_descriptor)
312 : unrolled_descriptor_(unnested_descriptor), shape_(
shape)
315 if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
317 descriptor_1d_ = MakeMerge1d(shape_, unrolled_descriptor_);
318 merged_nests_descriptor_ =
319 TransformDesc(shape_, DefaultIdxsTupleType{}, unrolled_descriptor_);
329 template <
typename Idxs>
332 static_assert(remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime(),
333 "Compiletime operator used on runtime layout.");
334 using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, UnrolledDescriptorType{}));
336 return TransformedDesc{}.CalculateOffset(UnrolledIdx{});
345 template <
typename... Ts>
348 if constexpr(!
IsNestedTuple(Tuple<Ts...>{}) && Tuple<Ts...>::Size() == 1)
351 return descriptor_1d_.CalculateOffset(Idx);
353 else if constexpr(!
IsNestedTuple(Tuple<Ts...>{}) && Tuple<Ts...>::Size() == Shape::Size())
361 const auto transformed_desc = TransformDesc(shape_, Idx, unrolled_descriptor_);
372 template <index_t IDim>
375 const auto elem = shape_.At(Number<IDim>{});
379 return TupleReduce<I0.value, unrolled_element.Size()>(
380 [](
auto x,
auto y) {
return x * y; }, unrolled_element);
396 return TupleReduce<I0.value, unrolled_shape.Size()>([](
auto x,
auto y) {
return x * y; },
405 __host__ __device__ constexpr
const Shape&
GetShape()
const {
return shape_; }
424 return GenerateDefaultIdxsTuple(shape_);
436 __host__ __device__ constexpr
const MergedNestsDescriptorType&
439 return merged_nests_descriptor_;
451 return descriptor_1d_;
463 return unrolled_descriptor_;
470 UnrolledDescriptorType unrolled_descriptor_;
472 Descriptor1dType descriptor_1d_;
474 MergedNestsDescriptorType merged_nests_descriptor_;
__host__ constexpr __device__ const auto & shape(const LayoutType &layout)
Get Layout shape.
Definition: layout_utils.hpp:431
__host__ constexpr __device__ auto TupleReduce(F &&f, const Tuple< Ts... > &tuple)
Definition: tuple_helper.hpp:161
__host__ constexpr __device__ auto IsNestedTuple(const Tuple< Ts... > &)
Definition: tuple_helper.hpp:180
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto make_merge_transform_v1_carry_check(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:66
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
__host__ constexpr __device__ auto TupleReverse(const Tuple< Ts... > &tuple)
Definition: tuple_helper.hpp:149
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto UnrollNestedTuple(const Tuple<> &element)
Definition: tuple_helper.hpp:120
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: is_detected.hpp:34
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:298
decltype(ck::declval< T & >().IsTuple()) is_tuple
Definition: tuple_helper.hpp:176
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
integral_constant< index_t, N > Number
Definition: number.hpp:12
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Layout wrapper that performs the tensor descriptor logic.
Definition: layout.hpp:24
__host__ __device__ Layout()=delete
Shape LayoutShape
Definition: layout.hpp:257
__host__ constexpr __device__ auto GetElementSpaceSize() const
Definition: layout.hpp:297
__host__ constexpr static __device__ auto TransformDesc(const Tuple< ShapeDims... > &shape, const Tuple< IdxDims... > &idxs, const UnrolledDescriptorType &naive_descriptor)
Transform descriptor to align to passed indexes.
Definition: layout.hpp:270
__host__ constexpr __device__ auto GetLength() const
Length getter (product if tuple).
Definition: layout.hpp:373
remove_cvref_t< decltype(TransformDesc(Shape{}, DefaultIdxsTupleType{}, UnrolledDescriptorType{}))> MergedNestsDescriptorType
Definition: layout.hpp:295
__host__ constexpr __device__ const Shape & GetShape() const
Shape getter.
Definition: layout.hpp:405
__host__ constexpr __device__ const Descriptor1dType & Get1DDescriptor() const
Get descriptor with all dimensions are merged (1D). Example, shape: ((2, 2), 2) Descriptor lengths: (...
Definition: layout.hpp:449
__host__ constexpr __device__ auto GetLengths() const
Layout size getter (product of shape).
Definition: layout.hpp:393
UnrolledDescriptorType LayoutUnrolledDescriptorType
Definition: layout.hpp:258
__host__ constexpr __device__ auto GetDefaultStartIdxs() const
Get default start idx (tuple filled with 0s of the same size as Shape).
Definition: layout.hpp:422
__host__ constexpr __device__ const MergedNestsDescriptorType & GetMergedNestingDescriptor() const
Get descriptor with all nested dimensions merged. Example, shape: ((2, 2), 2) Descriptor lengths: (4,...
Definition: layout.hpp:437
__host__ constexpr __device__ const UnrolledDescriptorType & GetUnrolledDescriptor() const
Get unnested descriptor (with unrolled dims) Example, shape: ((2, 2), 2) Descriptor lengths: (2,...
Definition: layout.hpp:461
__host__ constexpr __device__ auto GetDefaultLengthsTuple() const
Get default lengths (tuple filled with Shape length elements).
Definition: layout.hpp:412
__host__ __device__ index_t operator()(const Tuple< Ts... > &Idx) const
Returns real offset to element in compile time.
Definition: layout.hpp:346
__host__ constexpr __device__ Layout(const Shape &shape, const UnrolledDescriptorType &unnested_descriptor)
Layout constructor.
Definition: layout.hpp:310
__host__ constexpr __device__ index_t operator()() const
Returns real offset to element in runtime.
Definition: layout.hpp:330
typename conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:271