29 template <
typename Shape,
typename UnrolledDescriptorType>
33 using is_tuple = decltype(std::declval<T&>().IsTuple());
43 template <
typename... Ts>
44 __host__ __device__ constexpr
static auto
45 GenerateColumnMajorPackedStrides(
const Tuple<Ts...>&
shape)
50 if constexpr(i.value == 0)
56 return TupleReduce<Number<0>{}.value, i.value>([](
auto x,
auto y) {
return x * y; },
60 Number<decltype(unrolled_shape)::Size()>{});
70 template <
typename LayoutShape,
typename LayoutStr
ides>
71 __host__ __device__ constexpr
auto MakeUnrolledDescriptor(
const LayoutShape&
shape,
72 const LayoutStrides& strides)
75 if constexpr(
is_same_v<LayoutStrides, Tuple<>>)
78 const auto unrolled_strides = GenerateColumnMajorPackedStrides(unrolled_shape);
79 static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
80 "Size of strides and shape are not consistent.");
86 static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
87 "Size of strides and shape are not consistent.");
104 template <
typename Shape,
typename Str
ides>
107 using UnrolledDescriptorType = decltype(detail::MakeUnrolledDescriptor(Shape{}, Strides{}));
109 detail::MakeUnrolledDescriptor(
shape, strides));
119 template <
typename Shape>
122 using UnrolledDescriptorType = decltype(detail::MakeUnrolledDescriptor(Shape{}, Tuple<>{}));
124 detail::MakeUnrolledDescriptor(
shape, Tuple<>{}));
135 template <
typename T>
136 __host__ __device__ T constexpr
get(
const T& dim)
148 template <
index_t idx,
typename... Dims>
149 __host__ __device__ constexpr
auto get(
const Tuple<Dims...>& tuple)
151 return tuple.At(Number<idx>{});
161 template <index_t
idx,
typename Shape,
typename UnrolledDesc>
165 const auto new_shape = get<idx>(
shape);
167 "Shape of sub layout must be tuple");
177 if constexpr(i < shape_offset || i >= shape_offset + new_shape_dims)
187 Number<old_shape_dims>{});
189 const auto lower_dims =
190 generate_tuple([&](
auto i) {
return Sequence<i.value>{}; }, Number<old_shape_dims>{});
193 if constexpr(i < shape_offset || i >= shape_offset + new_shape_dims)
198 return Sequence<i.value - shape_offset>{};
201 Number<old_shape_dims>{});
203 const auto& flatten_desc =
layout.GetUnrolledDescriptor();
216 __host__ __device__ constexpr
auto get(
const T& elem)
218 return get<Idxs...>(get<Idx>(elem));
229 template <
typename T>
230 __host__ __device__ T constexpr
size(
const T& dim)
242 template <index_t
idx,
typename Shape,
typename UnrolledDescriptorType>
245 return layout.template GetLength<idx>();
254 template <
typename... ShapeDims>
255 __host__ __device__ constexpr
auto size(
const Tuple<ShapeDims...>&
shape)
258 return TupleReduce<0, unrolled_shape.Size()>([](
auto x,
auto y) {
return x * y; },
268 template <
typename Shape,
typename UnrolledDescriptorType>
271 return layout.GetLengths();
281 template <
index_t idx,
typename... Ts>
282 __host__ __device__ constexpr
auto size(
const Tuple<Ts...>& tuple)
284 return size(tuple.At(Number<idx>{}));
296 __host__ __device__ constexpr
auto size(
const T& elem)
298 return size(get<Idx, Idxs...>(elem));
308 template <
typename Shape,
typename UnrolledDescriptorType>
309 __host__ __device__ constexpr
auto
312 return Shape::Size();
322 template <
typename... Dims>
323 __host__ __device__ constexpr
auto rank([[maybe_unused]]
const Tuple<Dims...>& tuple)
325 return Tuple<Dims...>::Size();
335 template <index_t IDim>
336 __host__ __device__ constexpr
index_t rank([[maybe_unused]]
const Number<IDim>& dim)
348 __host__ __device__ constexpr
index_t rank([[maybe_unused]]
const index_t& dim) {
return 1; }
357 template <
index_t... Idxs,
typename T>
358 __host__ __device__ constexpr
auto rank(
const T& elem)
360 return rank(get<Idxs...>(elem));
370 template <
typename Shape,
typename UnrolledDescriptorType>
383 template <
typename... Dims>
384 __host__ __device__ constexpr
auto depth(
const Tuple<Dims...>& tuple)
396 template <index_t IDim>
397 __host__ __device__ constexpr
index_t depth([[maybe_unused]]
const Number<IDim>& dim)
409 __host__ __device__ constexpr
index_t depth([[maybe_unused]]
const index_t& dim) {
return 0; }
418 template <
index_t... Idxs,
typename T>
419 __host__ __device__ constexpr
auto depth(
const T& elem)
421 return depth(get<Idxs...>(elem));
430 template <
typename LayoutType>
431 __host__ __device__ constexpr
const auto&
shape(
const LayoutType&
layout)
445 template <
typename Shape,
typename UnrolledDesc,
typename TileLengths>
447 const TileLengths& tile_lengths)
449 auto& unrolled_desc =
layout.GetUnrolledDescriptor();
451 constexpr
auto do_pads_seq =
458 [&](
auto i) {
return padded_desc.GetLength(Number<i>{}); },
Number<TileLengths::Size()>{});
473 template <index_t Idx,
typename Shape,
typename UnrolledDesc,
typename NewLengths,
typename NewIdxs>
475 const NewLengths& new_lengths,
476 [[maybe_unused]]
const NewIdxs& new_indexes)
479 auto& unrolled_desc =
layout.GetUnrolledDescriptor();
480 constexpr
auto dims = Shape::Size();
484 if constexpr(i == Idx)
495 constexpr
auto lower_dims =
496 generate_tuple([&](
auto i) {
return Sequence<i.value>{}; }, Number<dims>{});
507 return Sequence<index>{};
512 const auto unmerged_desc =
514 const auto unmerged_shape =
515 generate_tuple([&](
auto i) {
return unmerged_desc.GetLength(Number<i>{}); },
516 Number<decltype(unmerged_desc)::GetNumOfVisibleDimension()>{});
auto pad(ck::index_t mpb, ck::index_t npb, ck::index_t kpb, ck::tensor_operation::device::GemmSpecialization gemm, CDesc_MRaw_NRaw conv)
Definition: helper.hpp:70
__host__ constexpr __device__ auto rank([[maybe_unused]] const Layout< Shape, UnrolledDescriptorType > &layout)
Get layout rank (num elements in shape).
Definition: layout_utils.hpp:310
__host__ constexpr __device__ auto depth(const Layout< Shape, UnrolledDescriptorType > &layout)
Get depth of the layout shape (return 0 if scalar).
Definition: layout_utils.hpp:371
__host__ constexpr __device__ auto get(const Tuple< Dims... > &tuple)
Get element from tuple (Shape/Strides/Idxs).
Definition: layout_utils.hpp:149
__host__ constexpr __device__ auto size(const Layout< Shape, UnrolledDescriptorType > &layout)
Length get (product if tuple).
Definition: layout_utils.hpp:243
__host__ constexpr __device__ auto unmerge(const Layout< Shape, UnrolledDesc > &layout, const NewLengths &new_lengths, [[maybe_unused]] const NewIdxs &new_indexes)
Unmerge selected dim in layout.
Definition: layout_utils.hpp:474
__host__ constexpr __device__ const auto & shape(const LayoutType &layout)
Get Layout shape.
Definition: layout_utils.hpp:431
__host__ constexpr __device__ auto make_layout(const Shape &shape, const Strides &strides)
Make layout function.
Definition: layout_utils.hpp:105
__host__ constexpr __device__ auto PadTensorDescriptor(const TensorDesc &desc, const TileLengths &tile_lengths, DoPads)
Definition: matrix_padder.hpp:19
__host__ constexpr __device__ auto TupleReduce(F &&f, const Tuple< Ts... > &tuple)
Definition: tuple_helper.hpp:161
__host__ constexpr __device__ auto to_sequence(Tuple< Number< Is >... >)
Definition: sequence_helper.hpp:32
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
__host__ constexpr __device__ auto UnrollNestedTuple(const Tuple<> &element)
Definition: tuple_helper.hpp:120
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
__host__ constexpr __device__ auto TupleDepth(const T &)
Definition: tuple_helper.hpp:188
__host__ constexpr __device__ auto generate_sequence_v2(F &&f, Number< N >)
Definition: sequence_helper.hpp:25
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
constexpr bool is_same_v
Definition: type.hpp:283
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: is_detected.hpp:34
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
decltype(ck::declval< T & >().IsTuple()) is_tuple
Definition: tuple_helper.hpp:176
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
integral_constant< index_t, N > Number
Definition: number.hpp:12
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Layout wrapper that performs the tensor descriptor logic.
Definition: layout.hpp:24
__host__ constexpr __device__ const auto & layout(const Tensor< BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType > &tensor)
Get Tensor Layout.
Definition: tensor_utils.hpp:162