/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_description/multi_index_transform_helper.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_description/multi_index_transform_helper.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_description/multi_index_transform_helper.hpp Source File
multi_index_transform_helper.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
8 
9 namespace ck {
10 
11 template <typename LowLength>
12 __host__ __device__ constexpr auto make_pass_through_transform(const LowLength& low_length)
13 {
14  return PassThrough<LowLength>{low_length};
15 }
16 
17 template <typename LowLength, typename LeftPad, typename RightPad, bool SkipIsValidCheck = false>
18 __host__ __device__ constexpr auto
19 make_pad_transform(const LowLength& low_length,
20  const LeftPad& left_pad,
21  const RightPad& right_pad,
23 {
24  return Pad<LowLength, LeftPad, RightPad, SkipIsValidCheck>{low_length, left_pad, right_pad};
25 }
26 
27 template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
28 __host__ __device__ constexpr auto make_left_pad_transform(
29  const LowLength& low_length,
30  const LeftPadLength& left_pad,
32 {
33  return LeftPad<LowLength, LeftPadLength, SkipIsValidCheck>{low_length, left_pad};
34 }
35 
36 template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
37 __host__ __device__ constexpr auto make_right_pad_transform(
38  const LowLength& low_length,
39  const RightPadLength& right_pad,
41 {
42  return RightPad<LowLength, RightPadLength, SkipIsValidCheck>{low_length, right_pad};
43 }
44 
45 template <typename UpLengths,
46  typename Coefficients,
47  typename enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
48 __host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_lengths,
49  const Coefficients& coefficients)
50 {
51  return Embed<UpLengths, Coefficients>{up_lengths, coefficients};
52 }
53 
54 template <typename LowLengths>
55 __host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths)
56 {
57 #if CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
58  return make_merge_transform_v2_magic_division(low_lengths);
59 #else
60  return make_merge_transform_v1_carry_check(low_lengths);
61 #endif
62 }
63 
64 template <typename LowLengths>
65 __host__ __device__ constexpr auto
66 make_merge_transform_v1_carry_check(const LowLengths& low_lengths)
67 {
68  return Merge_v1_carry_check<LowLengths>{low_lengths};
69 }
70 
71 template <typename LowLengths>
72 __host__ __device__ constexpr auto
73 make_merge_transform_v2_magic_division(const LowLengths& low_lengths)
74 {
75 #if 1
76  return Merge_v2_magic_division<LowLengths>{low_lengths};
77 #else
78  return Merge_v2r2_magic_division<LowLengths>{low_lengths};
79 #endif
80 }
81 
82 template <typename LowLengths>
83 __host__ __device__ constexpr auto
84 make_merge_transform_v3_division_mod(const LowLengths& low_lengths)
85 {
86  return Merge_v3_division_mod<LowLengths>{low_lengths};
87 }
88 
89 template <typename UpLengths, bool Use24BitIntegerCalculation = false>
90 __host__ __device__ constexpr auto make_unmerge_transform(
91  const UpLengths& up_lengths,
93 {
94  return UnMerge<UpLengths, Use24BitIntegerCalculation>{up_lengths};
95 }
96 
97 __host__ __device__ constexpr auto make_conv_bwd_data_out_transform(index_t N,
98  index_t Ho,
99  index_t Wo,
100  index_t K,
101  [[maybe_unused]] index_t YDot,
102  index_t XDot,
103  index_t HTilde,
104  index_t WTilde,
105  index_t ConvDilationH,
106  index_t ConvDilationW,
107  index_t HTildeSlice,
108  index_t WTildeSlice,
109  index_t YDotSlice,
110  index_t XDotSlice,
111  index_t IHTildeSliceBegin,
112  index_t IWTildeSliceBegin,
113  index_t GcdStrideDilationH,
114  index_t GcdStrideDilationW,
115  index_t K0,
116  index_t K1,
117  index_t MPerBlock,
118  index_t GemmKPerBlock)
119 {
120  // Calculate padding
121  const auto MRaw = N * HTildeSlice * WTildeSlice;
122  const auto MPadded = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
123  const auto MPad = MPadded - MRaw;
124 
125  const auto KRaw = YDotSlice * XDotSlice * K;
126  const auto KPadded = math::integer_divide_ceil(KRaw, GemmKPerBlock) * GemmKPerBlock;
127  const auto KPad = KPadded - KRaw;
128 
130  Ho,
131  Wo,
132  K,
133  XDot,
134  HTilde,
135  WTilde,
136  WTildeSlice,
137  HTildeSlice * WTildeSlice,
138  IHTildeSliceBegin,
139  IWTildeSliceBegin,
140  -ConvDilationH / GcdStrideDilationH,
141  -ConvDilationW / GcdStrideDilationW,
142  XDotSlice * K,
143  K0,
144  MPadded,
145  K1,
146  MPad,
147  KPad};
148 }
149 
150 template <typename LowerIndex>
151 __host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_idx)
152 {
153  return Freeze<LowerIndex>{low_idx};
154 }
155 
156 template <typename UpperIndex>
157 __host__ __device__ constexpr auto make_insert_transform(const UpperIndex& up_idx)
158 {
159  return Insert<UpperIndex>{up_idx};
160 }
161 
162 template <typename LowLength, typename SliceBegin, typename SliceEnd>
163 __host__ __device__ constexpr auto make_slice_transform(const LowLength& low_length,
164  const SliceBegin& slice_begin,
165  const SliceEnd& slice_end)
166 {
167  return Slice<LowLength, SliceBegin, SliceEnd>{low_length, slice_begin, slice_end};
168 }
169 
170 template <typename VectorSize, typename UpLength>
171 __host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& vector_size,
172  const UpLength& up_length)
173 {
174  return Vectorize<VectorSize, UpLength>{vector_size, up_length};
175 }
176 
177 template <typename Modulus, typename UpLength>
178 __host__ __device__ constexpr auto make_modulo_transform(const Modulus& modulus,
179  const UpLength& up_length)
180 {
181  return Modulo<Modulus, UpLength>{modulus, up_length};
182 }
183 
184 template <typename LowLengths>
185 __host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths& low_lengths)
186 {
187  return Xor<LowLengths, true /*ApplyModulo*/>{low_lengths};
188 }
189 
190 template <typename LowLengths>
191 __host__ __device__ constexpr auto make_xor_transform(const LowLengths& low_lengths)
192 {
193  return Xor<LowLengths, false /*ApplyModulo*/>{low_lengths};
194 }
195 } // namespace ck
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
Definition: ck.hpp:268
__host__ constexpr __device__ auto make_left_pad_transform(const LowLength &low_length, const LeftPadLength &left_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:28
__host__ constexpr __device__ auto make_xor_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:191
__host__ constexpr __device__ auto make_merge_transform_v2_magic_division(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:73
__host__ constexpr __device__ auto make_merge_transform_v1_carry_check(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:66
__host__ constexpr __device__ auto make_vectorize_transform(const VectorSize &vector_size, const UpLength &up_length)
Definition: multi_index_transform_helper.hpp:171
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
__host__ constexpr __device__ auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition: multi_index_transform_helper.hpp:48
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:185
__host__ constexpr __device__ auto make_slice_transform(const LowLength &low_length, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition: multi_index_transform_helper.hpp:163
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_insert_transform(const UpperIndex &up_idx)
Definition: multi_index_transform_helper.hpp:157
__host__ constexpr __device__ auto make_modulo_transform(const Modulus &modulus, const UpLength &up_length)
Definition: multi_index_transform_helper.hpp:178
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:299
__host__ constexpr __device__ auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:19
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__host__ constexpr __device__ auto make_conv_bwd_data_out_transform(index_t N, index_t Ho, index_t Wo, index_t K, [[maybe_unused]] index_t YDot, index_t XDot, index_t HTilde, index_t WTilde, index_t ConvDilationH, index_t ConvDilationW, index_t HTildeSlice, index_t WTildeSlice, index_t YDotSlice, index_t XDotSlice, index_t IHTildeSliceBegin, index_t IWTildeSliceBegin, index_t GcdStrideDilationH, index_t GcdStrideDilationW, index_t K0, index_t K1, index_t MPerBlock, index_t GemmKPerBlock)
Definition: multi_index_transform_helper.hpp:97
Transformation struct for convolution backward data output indices to GEMM indices.
Definition: multi_index_transform.hpp:1565
Definition: multi_index_transform.hpp:385
Definition: multi_index_transform.hpp:1750
Definition: multi_index_transform.hpp:1816
Definition: multi_index_transform.hpp:196
Definition: multi_index_transform.hpp:481
Definition: multi_index_transform.hpp:1036
Definition: multi_index_transform.hpp:1188
Definition: multi_index_transform.hpp:1338
Definition: multi_index_transform.hpp:2065
Definition: multi_index_transform.hpp:13
Definition: multi_index_transform.hpp:284
Definition: multi_index_transform.hpp:1968
Definition: multi_index_transform.hpp:1882
Definition: multi_index_transform.hpp:2149
Definition: integral_constant.hpp:20