/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp Source File
gridwise_tensor_rearrange.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
16 
17 namespace ck {
18 
19 template <typename InputGridDesc,
20  typename InputDataType,
21  typename OutputGridDesc,
22  typename OutputDataType,
23  typename Block2ETileMap,
24  typename ComputePtrOffsetOfStridedBatch,
25  typename GridwiseTensorRearrangeKernel>
26 __global__ void
27 #if CK_USE_LAUNCH_BOUNDS
29 #endif
30  kernel_tensor_rearrange(const InputGridDesc in_grid_desc,
31  const InputDataType* __restrict__ p_in_global,
32  const OutputGridDesc out_grid_desc,
33  OutputDataType* __restrict__ p_out_global,
34  const index_t batch_count,
35  const Block2ETileMap block_2_tile_map,
36  const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
37 {
38 #if(defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || \
39  defined(__gfx103__) || defined(__gfx11__) || defined(__gfx12__))
40  GridwiseTensorRearrangeKernel::Run(in_grid_desc,
41  p_in_global,
42  out_grid_desc,
43  p_out_global,
44  batch_count,
45  block_2_tile_map,
46  compute_ptr_offset_of_batch);
47 #else
48  ignore = in_grid_desc;
49  ignore = p_in_global;
50  ignore = out_grid_desc;
51  ignore = p_out_global;
52  ignore = batch_count;
53  ignore = block_2_tile_map;
54  ignore = compute_ptr_offset_of_batch;
55 #endif
56 }
57 
58 template <typename InputGridDesc,
59  typename InputDataType,
60  typename OutputGridDesc,
61  typename OutputDataType,
62  index_t BlockSize,
63  index_t MPerBlock,
64  index_t KPerBlock,
65  typename ThreadClusterLengths,
66  index_t ScalarPerVector,
67  InMemoryDataOperationEnum DstInMemOp,
68  typename Block2ETileMap,
69  typename ComputePtrOffsetOfStridedBatch>
71 {
72 
73  static constexpr auto I0 = Number<0>{};
74  static constexpr auto I1 = Number<1>{};
75 
77 
78  __device__ static void Run(const InputGridDesc& in_grid_desc,
79  const InputDataType* __restrict__ p_in_global,
80  const OutputGridDesc& out_grid_desc,
81  OutputDataType* __restrict__ p_out_global,
82  const index_t batch_count,
83  const Block2ETileMap& block_2_tile_map,
84  const ComputePtrOffsetOfStridedBatch& compute_ptr_offset_of_batch)
85  {
86  const auto block_work_idx =
87  block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
88 
89  const index_t m_block_data_idx_on_grid =
90  __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
91 
92  const index_t k_block_data_idx_on_grid =
93  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * KPerBlock);
94 
95  auto copy_global_to_global =
99  decltype(tie(in_grid_desc)),
100  decltype(tie(out_grid_desc)),
102  Sequence<static_cast<index_t>(DstInMemOp)>,
104  ThreadClusterLengths,
107  I1,
108  ScalarPerVector,
111  in_grid_desc,
112  make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)),
113  out_grid_desc,
114  make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)),
116 
117  const index_t num_blocks_per_batch =
118  __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
119  const index_t g_idx =
120  __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
121 
122  // Global Memory
123  const index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
124  static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
125  const index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
126  static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
127 
128  const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
129  p_in_global + a_batch_offset, in_grid_desc.GetElementSpaceSize());
130  auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
131  p_out_global + c_batch_offset, out_grid_desc.GetElementSpaceSize());
132 
133  copy_global_to_global.Run(
134  tie(in_grid_desc), tie(in_global_buf), tie(out_grid_desc), tie(out_global_buf));
135  }
136 
137  __host__ static constexpr bool CheckValidity(const InputGridDesc& in_grid_desc,
138  const OutputGridDesc& out_grid_desc)
139  {
140  if(in_grid_desc.GetLength(I0) % MPerBlock != 0 ||
141  in_grid_desc.GetLength(I1) % KPerBlock != 0)
142  return false;
143  if(out_grid_desc.GetLength(I0) % MPerBlock != 0 ||
144  out_grid_desc.GetLength(I1) % KPerBlock != 0)
145  return false;
146  return true;
147  }
148 };
149 
150 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:30
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_grid_size()
Definition: get_id.hpp:60
InMemoryDataOperationEnum
Definition: ck.hpp:276
int64_t long_index_t
Definition: ck.hpp:299
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__global__ void kernel_tensor_rearrange(const InputGridDesc in_grid_desc, const InputDataType *__restrict__ p_in_global, const OutputGridDesc out_grid_desc, OutputDataType *__restrict__ p_out_global, const index_t batch_count, const Block2ETileMap block_2_tile_map, const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
Definition: gridwise_tensor_rearrange.hpp:30
Definition: gridwise_tensor_rearrange.hpp:71
static constexpr __host__ bool CheckValidity(const InputGridDesc &in_grid_desc, const OutputGridDesc &out_grid_desc)
Definition: gridwise_tensor_rearrange.hpp:137
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_tensor_rearrange.hpp:76
static constexpr auto I0
Definition: gridwise_tensor_rearrange.hpp:73
static constexpr auto I1
Definition: gridwise_tensor_rearrange.hpp:74
static __device__ void Run(const InputGridDesc &in_grid_desc, const InputDataType *__restrict__ p_in_global, const OutputGridDesc &out_grid_desc, OutputDataType *__restrict__ p_out_global, const index_t batch_count, const Block2ETileMap &block_2_tile_map, const ComputePtrOffsetOfStridedBatch &compute_ptr_offset_of_batch)
Definition: gridwise_tensor_rearrange.hpp:78
Definition: sequence.hpp:43
Definition: thread_group_tensor_slice_transfer_v7.hpp:42
Definition: tuple.hpp:117
Definition: integral_constant.hpp:20
Definition: unary_element_wise_operation.hpp:334