32 template <
typename... Ts,
typename... Ls>
33 __host__ __device__ constexpr
auto CalculateLocalPartitionShape(
const Tuple<Ts...>&
shape,
34 const Tuple<Ls...>& thread_lengths)
36 static_assert(Tuple<Ts...>::Size() == Tuple<Ls...>::Size(),
"Wrong thread_lengths shape.");
39 constexpr
auto num_i = Number<i>{};
40 const auto slice_len =
44 Number<Tuple<Ls...>::Size()>{});
56 template <
typename MultiIndex,
typename ProjectionTuple>
57 __host__ __device__ constexpr
auto
58 ApplyProjection([[maybe_unused]]
const MultiIndex& base_tuple,
59 [[maybe_unused]]
const ProjectionTuple& projection)
61 if constexpr(
is_same_v<ProjectionTuple, Tuple<>>)
69 const auto i_num =
Number<i.value>{};
72 is_same_v<tuple_element_t<i_num, ProjectionTuple>, Number<1>>);
73 if constexpr(
is_detected<is_slice, tuple_element_t<i_num, ProjectionTuple>>::
value)
86 return UnrollNestedTuple<0, 1>(base_tuple_after_projection);
99 template <
typename... Ts,
typename... Ps>
100 __host__ __device__ constexpr
auto CalculateShapeWithProjection(
const Tuple<Ts...>&
shape,
101 const Tuple<Ps...>& projection)
107 return size<i>(projection).to_;
114 detail::ApplyProjection(TupleSlice<0, i>(Tuple<Ts...>{}),
115 TupleSlice<0, i>(Tuple<Ps...>{}))
117 return size<shape_i>(
shape);
120 Number<Tuple<Ps...>::Size()>{});
130 template <
typename... Ts,
typename... Ls,
typename... Ps>
131 __host__ __device__ constexpr
auto CalculateGridSize(
const Tuple<Ts...>&
shape,
132 const Tuple<Ls...>& tile_shape)
136 Number<Tuple<Ls...>::Size()>{});
147 template <
typename ThreadIdxs,
typename PartitionLengthsSeq,
typename OldOffsetIdxs>
148 __host__ __device__ constexpr
auto
149 CalculateOffsetMultiIdxs(
const ThreadIdxs& thread_idxs,
150 const PartitionLengthsSeq& partition_lengths_seq,
151 const OldOffsetIdxs& old_offset_idxs)
153 return thread_idxs * partition_lengths_seq + old_offset_idxs;
162 template <
typename BlockIdxs>
163 __host__ __device__ constexpr
auto GetDimsToPartition([[maybe_unused]]
const BlockIdxs& block_idxs)
176 Number<BlockIdxs::Size()>{});
178 return UnrollNestedTuple<0, 1>(dims_to_partition);
187 template <
typename BlockIdxs>
188 __host__ __device__ constexpr
auto ReplaceSlicesWithZeros(
const BlockIdxs& block_idxs)
194 return block_idxs.At(i);
201 Number<BlockIdxs::Size()>{});
210 template <
typename TileShape>
211 __host__ __device__ constexpr
auto
212 GenerateDefaultProjection([[maybe_unused]]
const TileShape tile_shape)
224 template <
typename ThreadShape,
typename ThreadUnrolledDesc>
225 __host__ __device__ constexpr
auto CalculateThreadMultiIdx(
229 static_assert(ThreadUnrolledDesc::GetNumOfTransform() == 1,
230 "Thread layout should not be transformed.");
231 constexpr
auto embed_transform = ThreadUnrolledDesc{}.GetTransforms().At(Number<0>{});
232 constexpr
auto shape = ThreadShape{};
233 constexpr
auto strides = embed_transform.coefficients_;
237 constexpr
auto num_i = Number<i>{};
238 return (thread_id / strides.At(num_i)) %
shape.At(num_i);
240 Number<ThreadShape::Size()>{});
258 template <
typename TensorType,
259 typename ThreadShape,
260 typename ThreadUnrolledDesc,
261 typename ProjectionTuple>
262 __host__ __device__ constexpr
auto
266 const ProjectionTuple& projection)
270 const auto& tensor_shape =
shape(tensor);
272 constexpr
auto projected_thread_lengths =
273 detail::ApplyProjection(ThreadShape{}, ProjectionTuple{});
274 constexpr
auto partition_shape =
275 detail::CalculateLocalPartitionShape(decltype(tensor_shape){}, projected_thread_lengths);
276 constexpr
auto partition_shape_seq =
278 Number<decltype(partition_shape)::Size()>{});
280 const auto thread_idxs = detail::CalculateThreadMultiIdx(thread_layout, thread_id);
282 const auto projected_thread_idxs = detail::ApplyProjection(thread_idxs, projection);
283 const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
284 projected_thread_idxs, partition_shape_seq, tensor.GetMultiIdxOffsets());
286 auto& unrolled_desc =
layout(tensor).GetUnrolledDescriptor();
291 offset_multi_idxs.At(i),
292 partition_shape.At(i) + offset_multi_idxs.At(i));
295 const auto lower_upper_dims =
301 const auto partition_layout =
303 partition_shape, sliced_desc);
304 auto partition_tensor =
305 make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), partition_layout);
307 return partition_tensor;
319 template <
typename TensorType,
typename ThreadShape,
typename ThreadUnrolledDesc>
320 __host__ __device__ constexpr
auto
325 const auto projection = detail::GenerateDefaultProjection(ThreadShape{});
346 template <
typename TensorType,
347 typename BlockShapeTuple,
349 typename ProjectionTuple>
351 const BlockShapeTuple& tile_shape,
352 const BlockIdxs& block_idxs,
353 const ProjectionTuple& projection)
358 constexpr
auto I0 = Number<0>{};
359 constexpr
auto I1 = Number<1>{};
360 constexpr
auto I2 = Number<2>{};
362 auto& aligned_desc =
layout(tensor).GetMergedNestingDescriptor();
364 constexpr
auto projected_tile_shape =
365 detail::ApplyProjection(BlockShapeTuple{}, ProjectionTuple{});
367 constexpr
auto dims_to_partition = detail::GetDimsToPartition(BlockIdxs{});
368 const auto parsed_block_idxs = detail::ReplaceSlicesWithZeros(block_idxs);
369 if constexpr(decltype(dims_to_partition)::Size() == I2)
371 const auto shape_with_projection_dims =
372 detail::CalculateShapeWithProjection(
shape(tensor), projection);
374 const auto M = shape_with_projection_dims.At(dims_to_partition.At(I0));
375 const auto N = shape_with_projection_dims.At(dims_to_partition.At(I1));
376 constexpr
auto MPerBlock = BlockShapeTuple{}.At(dims_to_partition.At(I0));
377 constexpr
auto NPerBlock = BlockShapeTuple{}.At(dims_to_partition.At(I1));
380 const auto grid_size = detail::CalculateGridSize(shape_with_projection_dims, tile_shape);
382 const index_t block_id_1d = block_lengths_desc.CalculateOffset(parsed_block_idxs);
384 const auto block_2_tile_map =
385 BlockToCTileMap_M00_N0_M01Adapt<MPerBlock,
388 const auto block_work_idx =
390 const index_t m_block_data_idx_on_grid =
391 __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
392 const index_t n_block_data_idx_on_grid =
393 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
397 if constexpr(i == dims_to_partition.At(I0))
399 return m_block_data_idx_on_grid;
401 else if constexpr(i == dims_to_partition.At(I1))
403 return n_block_data_idx_on_grid;
410 Number<BlockShapeTuple::Size()>{});
411 const auto projected_offset_multi_idxs =
412 detail::ApplyProjection(offset_multi_idxs, projection);
414 const auto tile_layout =
416 projected_tile_shape, aligned_desc);
418 make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), tile_layout);
420 tile_tensor.SetMultiIdxOffset(
to_multi_index(projected_offset_multi_idxs));
427 using ProjectedTileShapeTuple = decltype(projected_tile_shape);
428 constexpr
auto projected_tile_shape_seq =
430 Number<ProjectedTileShapeTuple::Size()>{});
432 const auto projected_block_idxs =
433 to_multi_index(detail::ApplyProjection(parsed_block_idxs, projection));
434 const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
435 projected_block_idxs, projected_tile_shape_seq, tensor.GetMultiIdxOffsets());
437 const auto tile_layout =
439 projected_tile_shape, aligned_desc);
441 make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), tile_layout);
460 template <
typename TensorType,
typename BlockShapeTuple,
typename BlockIdxs>
462 const BlockShapeTuple& tile_shape,
463 const BlockIdxs& block_idxs)
465 const auto projection = detail::GenerateDefaultProjection(BlockShapeTuple{});
__host__ constexpr __device__ const auto & shape(const LayoutType &layout)
Get Layout shape.
Definition: layout_utils.hpp:431
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
typename remove_reference< T >::type remove_reference_t
Definition: type.hpp:292
__host__ constexpr __device__ auto IsNestedTuple(const Tuple< Ts... > &)
Definition: tuple_helper.hpp:180
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto to_multi_index(const T &x)
Definition: array_multi_index.hpp:28
__host__ constexpr __device__ auto generate_sequence_v2(F &&f, Number< N >)
Definition: sequence_helper.hpp:25
__host__ constexpr __device__ auto make_slice_transform(const LowLength &low_length, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition: multi_index_transform_helper.hpp:110
constexpr bool is_same_v
Definition: type.hpp:283
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: is_detected.hpp:34
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:298
Array< index_t, N > MultiIndex
Definition: array_multi_index.hpp:12
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
integral_constant< index_t, N > Number
Definition: number.hpp:12
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Layout wrapper that performs the tensor descriptor logic.
Definition: layout.hpp:24
__host__ static constexpr __device__ index_t Size()
Definition: array.hpp:20
__host__ constexpr __device__ auto make_local_partition(TensorType &tensor, [[maybe_unused]] const Layout< ThreadShape, ThreadUnrolledDesc > &thread_layout, const index_t thread_id, const ProjectionTuple &projection)
Create local partition for thread (At now only packed partition is supported).
Definition: tensor_partition.hpp:263
__host__ constexpr __device__ auto make_local_tile(const TensorType &tensor, const BlockShapeTuple &tile_shape, const BlockIdxs &block_idxs, const ProjectionTuple &projection)
Create local tile for thread block. (At now only packed tile is supported).
Definition: tensor_partition.hpp:350
__host__ constexpr __device__ const auto & layout(const Tensor< BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType > &tensor)
Get Tensor Layout.
Definition: tensor_utils.hpp:162