19 template <
typename InputGridDesc,
20 typename InputDataType,
21 typename OutputGridDesc,
22 typename OutputDataType,
23 typename Block2ETileMap,
24 typename ComputePtrOffsetOfStridedBatch,
25 typename GridwiseTensorRearrangeKernel>
27 #if CK_USE_LAUNCH_BOUNDS
31 const InputDataType* __restrict__ p_in_global,
32 const OutputGridDesc out_grid_desc,
33 OutputDataType* __restrict__ p_out_global,
35 const Block2ETileMap block_2_tile_map,
36 const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
38 #if(defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || \
39 defined(__gfx103__) || defined(__gfx11__) || defined(__gfx12__))
40 GridwiseTensorRearrangeKernel::Run(in_grid_desc,
46 compute_ptr_offset_of_batch);
54 ignore = compute_ptr_offset_of_batch;
58 template <
typename InputGridDesc,
59 typename InputDataType,
60 typename OutputGridDesc,
61 typename OutputDataType,
65 typename ThreadClusterLengths,
68 typename Block2ETileMap,
69 typename ComputePtrOffsetOfStridedBatch>
78 __device__
static void Run(
const InputGridDesc& in_grid_desc,
79 const InputDataType* __restrict__ p_in_global,
80 const OutputGridDesc& out_grid_desc,
81 OutputDataType* __restrict__ p_out_global,
83 const Block2ETileMap& block_2_tile_map,
84 const ComputePtrOffsetOfStridedBatch& compute_ptr_offset_of_batch)
86 const auto block_work_idx =
89 const index_t m_block_data_idx_on_grid =
90 __builtin_amdgcn_readfirstlane(block_work_idx[
I0] * MPerBlock);
92 const index_t k_block_data_idx_on_grid =
93 __builtin_amdgcn_readfirstlane(block_work_idx[
I1] * KPerBlock);
95 auto copy_global_to_global =
99 decltype(
tie(in_grid_desc)),
100 decltype(
tie(out_grid_desc)),
104 ThreadClusterLengths,
117 const index_t num_blocks_per_batch =
118 __builtin_amdgcn_readfirstlane(
get_grid_size() / batch_count);
120 __builtin_amdgcn_readfirstlane(
get_block_1d_id() / num_blocks_per_batch);
123 const index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
124 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
125 const index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
126 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
128 const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
129 p_in_global + a_batch_offset, in_grid_desc.GetElementSpaceSize());
130 auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
131 p_out_global + c_batch_offset, out_grid_desc.GetElementSpaceSize());
133 copy_global_to_global.Run(
134 tie(in_grid_desc),
tie(in_global_buf),
tie(out_grid_desc),
tie(out_global_buf));
137 __host__
static constexpr
bool CheckValidity(
const InputGridDesc& in_grid_desc,
138 const OutputGridDesc& out_grid_desc)
140 if(in_grid_desc.GetLength(
I0) % MPerBlock != 0 ||
141 in_grid_desc.GetLength(
I1) % KPerBlock != 0)
143 if(out_grid_desc.GetLength(
I0) % MPerBlock != 0 ||
144 out_grid_desc.GetLength(
I1) % KPerBlock != 0)
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:30
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_grid_size()
Definition: get_id.hpp:60
InMemoryDataOperationEnum
Definition: ck.hpp:276
int64_t long_index_t
Definition: ck.hpp:299
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__global__ void kernel_tensor_rearrange(const InputGridDesc in_grid_desc, const InputDataType *__restrict__ p_in_global, const OutputGridDesc out_grid_desc, OutputDataType *__restrict__ p_out_global, const index_t batch_count, const Block2ETileMap block_2_tile_map, const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
Definition: gridwise_tensor_rearrange.hpp:30
Definition: gridwise_tensor_rearrange.hpp:71
static constexpr __host__ bool CheckValidity(const InputGridDesc &in_grid_desc, const OutputGridDesc &out_grid_desc)
Definition: gridwise_tensor_rearrange.hpp:137
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_tensor_rearrange.hpp:76
static constexpr auto I0
Definition: gridwise_tensor_rearrange.hpp:73
static constexpr auto I1
Definition: gridwise_tensor_rearrange.hpp:74
static __device__ void Run(const InputGridDesc &in_grid_desc, const InputDataType *__restrict__ p_in_global, const OutputGridDesc &out_grid_desc, OutputDataType *__restrict__ p_out_global, const index_t batch_count, const Block2ETileMap &block_2_tile_map, const ComputePtrOffsetOfStridedBatch &compute_ptr_offset_of_batch)
Definition: gridwise_tensor_rearrange.hpp:78
Definition: sequence.hpp:43
Definition: thread_group_tensor_slice_transfer_v7.hpp:42
Definition: tuple.hpp:117
Definition: integral_constant.hpp:20
Definition: unary_element_wise_operation.hpp:334