/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp Source File
image_to_column_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
8 
9 namespace ck_tile {
10 
11 template <typename Problem_>
13 {
14  static constexpr auto I0 = number<0>{};
15  static constexpr auto I1 = number<1>{};
16  static constexpr auto I2 = number<2>{};
17  static constexpr auto I3 = number<3>{};
18  static constexpr auto I4 = number<4>{};
19 
21 
24 
25  static constexpr index_t NDimSpatial = Problem::NDimSpatial;
26 
27  static constexpr index_t AligmentIn = Problem::AligmentIn;
28  static constexpr index_t AligmentOut = Problem::AligmentOut;
29 
30  static_assert(NDimSpatial == 2, "Not supported.");
31 
32  static constexpr index_t kMPerBlock = Problem::BlockShape::kMPerBlock;
33  static constexpr index_t kKPerBlock = Problem::BlockShape::kKPerBlock;
34  static constexpr index_t kBlockSize = Problem::BlockShape::kBlockSize;
35 
36  struct Kargs
37  {
38  const void* p_in;
39  void* p_out;
40 
41  const long_index_t G;
42  const long_index_t N;
43  const long_index_t C;
44 
54  };
55 
56  CK_TILE_HOST static constexpr Kargs
57  MakeKargs(const void* p_in,
58  void* p_out,
59  const long_index_t G,
60  const long_index_t N,
61  const long_index_t C,
62  const array<long_index_t, NDimSpatial> input_spatial_lengths,
63  const array<long_index_t, NDimSpatial> filter_spatial_lengths,
64  const array<long_index_t, NDimSpatial> output_spatial_lengths,
65  const array<long_index_t, NDimSpatial + 3> image_g_n_c_wis_strides,
66  const array<long_index_t, 3> gemm_g_m_k_strides,
67  const array<long_index_t, NDimSpatial> conv_filter_strides,
68  const array<long_index_t, NDimSpatial> conv_filter_dilations,
69  const array<long_index_t, NDimSpatial> input_left_pads,
70  const array<long_index_t, NDimSpatial> input_right_pads)
71  {
72  return Kargs{p_in,
73  p_out,
74  G,
75  N,
76  C,
77  input_spatial_lengths,
78  filter_spatial_lengths,
79  output_spatial_lengths,
80  image_g_n_c_wis_strides,
81  gemm_g_m_k_strides,
82  conv_filter_strides,
83  conv_filter_dilations,
84  input_left_pads,
85  input_right_pads};
86  }
87 
88  CK_TILE_HOST static constexpr auto GridSize(index_t GemmM, index_t GemmK, index_t Batch)
89  {
90  return dim3(
92  }
93 
94  CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::kBlockSize; }
95 
96  CK_TILE_DEVICE auto MakeImageMKDesc(const Kargs& kargs) const
97  {
98  static_assert(NDimSpatial == 2, "Not supported.");
99 
100  const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
101  make_tuple(
102  kargs.N, kargs.input_spatial_lengths[I0], kargs.input_spatial_lengths[I1], kargs.C),
106  kargs.image_g_n_c_wis_strides[I2]),
108  I1);
109 
110  const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
111  in_n_hi_wi_c_desc,
114  kargs.input_left_pads[I0],
115  kargs.input_right_pads[I0]),
117  kargs.input_left_pads[I1],
118  kargs.input_right_pads[I1]),
122 
123  const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor(
124  in_n_hip_wip_c_desc,
125  make_tuple(
136 
138  in_n_y_ho_x_wo_c_desc,
139  make_tuple(
141  kargs.N, kargs.output_spatial_lengths[I0], kargs.output_spatial_lengths[I1])),
143  kargs.filter_spatial_lengths[I0], kargs.filter_spatial_lengths[I1], kargs.C))),
146  }
147 
148  CK_TILE_DEVICE auto CalculateMKDims(const Kargs& kargs) const
149  {
150  static_assert(NDimSpatial == 2, "Not supported.");
151  const index_t M = kargs.N * static_cast<index_t>(kargs.output_spatial_lengths[I0] *
152  kargs.output_spatial_lengths[I1]);
153  const index_t K = kargs.C * static_cast<index_t>(kargs.filter_spatial_lengths[I0] *
154  kargs.filter_spatial_lengths[I1]);
155  return make_tuple(M, K);
156  }
157 
159  {
160  using P = typename Problem::BlockShape;
161  // P: {kMWarpPerBlock * kKWarpPerBlock, kMThreadPerWarp * kKThreadPerWarp}
162  // Y: {kMPerThread, kKPerThread}
165  sequence<1>,
171  sequence<2, 2>>{});
172  }
173 
174  CK_TILE_DEVICE void ConvTensorRearrange(const Kargs& kargs) const
175  {
176  const auto [M, K] = CalculateMKDims(kargs);
177 
178  const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock);
179  const index_t iK = __builtin_amdgcn_readfirstlane(blockIdx.y * kKPerBlock);
180  const index_t iBatch = __builtin_amdgcn_readfirstlane(blockIdx.z);
181 
182  const auto in_offset = iBatch * kargs.image_g_n_c_wis_strides[I0];
183  const auto out_offset = iBatch * kargs.gemm_g_m_k_strides[I0];
184 
185  const auto image_m_k = make_tensor_view<address_space_enum::global>(
186  static_cast<const InDataType*>(kargs.p_in) + in_offset, MakeImageMKDesc(kargs));
187  const auto gemm_m_k = make_naive_tensor_view<address_space_enum::global>(
188  static_cast<OutDataType*>(kargs.p_out) + out_offset,
189  make_tuple(M, K),
192  I1);
193 
194  const auto image_m_k_padded =
195  pad_tensor_view(image_m_k,
198  const auto gemm_m_k_padded =
199  pad_tensor_view(gemm_m_k,
202 
203  constexpr auto dstr = MakeBlockTileDistribution();
204 
205  const auto image_tile =
206  make_tile_window(image_m_k_padded,
208  {iM, iK},
209  dstr);
210 
211  auto gemm_tile = make_tile_window(gemm_m_k_padded,
213  {iM, iK},
214  dstr);
215 
216  // load from Global
217  const auto loaded_tile = load_tile(image_tile);
218  // save to Global
219  store_tile(gemm_tile, loaded_tile);
220  }
221 
222  CK_TILE_DEVICE void operator()(Kargs& kargs) const { ConvTensorRearrange(kargs); }
223 };
224 
225 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:268
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1615
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1565
int64_t long_index_t
Definition: integer.hpp:11
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:197
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
constexpr CK_TILE_HOST_DEVICE auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition: coordinate_transform.hpp:1594
Definition: image_to_column_kernel.hpp:37
const array< long_index_t, NDimSpatial > input_right_pads
Definition: image_to_column_kernel.hpp:53
const array< long_index_t, NDimSpatial > conv_filter_strides
Definition: image_to_column_kernel.hpp:50
const array< long_index_t, NDimSpatial > conv_filter_dilations
Definition: image_to_column_kernel.hpp:51
const long_index_t C
Definition: image_to_column_kernel.hpp:43
const long_index_t N
Definition: image_to_column_kernel.hpp:42
const long_index_t G
Definition: image_to_column_kernel.hpp:41
const array< long_index_t, NDimSpatial+3 > image_g_n_c_wis_strides
Definition: image_to_column_kernel.hpp:48
const void * p_in
Definition: image_to_column_kernel.hpp:38
const array< long_index_t, NDimSpatial > filter_spatial_lengths
Definition: image_to_column_kernel.hpp:46
void * p_out
Definition: image_to_column_kernel.hpp:39
const array< long_index_t, NDimSpatial > input_left_pads
Definition: image_to_column_kernel.hpp:52
const array< long_index_t, 3 > gemm_g_m_k_strides
Definition: image_to_column_kernel.hpp:49
const array< long_index_t, NDimSpatial > input_spatial_lengths
Definition: image_to_column_kernel.hpp:45
const array< long_index_t, NDimSpatial > output_spatial_lengths
Definition: image_to_column_kernel.hpp:47
Definition: image_to_column_kernel.hpp:13
CK_TILE_DEVICE auto MakeImageMKDesc(const Kargs &kargs) const
Definition: image_to_column_kernel.hpp:96
static constexpr CK_TILE_HOST Kargs MakeKargs(const void *p_in, void *p_out, const long_index_t G, const long_index_t N, const long_index_t C, const array< long_index_t, NDimSpatial > input_spatial_lengths, const array< long_index_t, NDimSpatial > filter_spatial_lengths, const array< long_index_t, NDimSpatial > output_spatial_lengths, const array< long_index_t, NDimSpatial+3 > image_g_n_c_wis_strides, const array< long_index_t, 3 > gemm_g_m_k_strides, const array< long_index_t, NDimSpatial > conv_filter_strides, const array< long_index_t, NDimSpatial > conv_filter_dilations, const array< long_index_t, NDimSpatial > input_left_pads, const array< long_index_t, NDimSpatial > input_right_pads)
Definition: image_to_column_kernel.hpp:57
static constexpr auto I2
Definition: image_to_column_kernel.hpp:16
CK_TILE_DEVICE void ConvTensorRearrange(const Kargs &kargs) const
Definition: image_to_column_kernel.hpp:174
static constexpr index_t kBlockSize
Definition: image_to_column_kernel.hpp:34
static constexpr auto I4
Definition: image_to_column_kernel.hpp:18
static constexpr auto I1
Definition: image_to_column_kernel.hpp:15
static constexpr CK_TILE_DEVICE auto MakeBlockTileDistribution()
Definition: image_to_column_kernel.hpp:158
remove_cvref_t< typename Problem::InDataType > InDataType
Definition: image_to_column_kernel.hpp:22
remove_cvref_t< Problem_ > Problem
Definition: image_to_column_kernel.hpp:20
static constexpr index_t AligmentOut
Definition: image_to_column_kernel.hpp:28
static constexpr CK_TILE_HOST auto BlockSize()
Definition: image_to_column_kernel.hpp:94
static constexpr index_t AligmentIn
Definition: image_to_column_kernel.hpp:27
static constexpr CK_TILE_HOST auto GridSize(index_t GemmM, index_t GemmK, index_t Batch)
Definition: image_to_column_kernel.hpp:88
CK_TILE_DEVICE void operator()(Kargs &kargs) const
Definition: image_to_column_kernel.hpp:222
static constexpr auto I3
Definition: image_to_column_kernel.hpp:17
static constexpr index_t kKPerBlock
Definition: image_to_column_kernel.hpp:33
remove_cvref_t< typename Problem::OutDataType > OutDataType
Definition: image_to_column_kernel.hpp:23
CK_TILE_DEVICE auto CalculateMKDims(const Kargs &kargs) const
Definition: image_to_column_kernel.hpp:148
static constexpr index_t kMPerBlock
Definition: image_to_column_kernel.hpp:32
static constexpr index_t NDimSpatial
Definition: image_to_column_kernel.hpp:25
static constexpr auto I0
Definition: image_to_column_kernel.hpp:14
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192