19 template <
typename InOutElementFunc,
20 typename... InOutDstrTensors,
22 std::negation<std::is_same<std::remove_const_t<InOutDstrTensors>, null_tensor>>...>>>
24 InOutDstrTensors&... inout_dstr_tensors)
29 constexpr
index_t thread_buffer_size =
30 __type_pack_element<0, InOutDstrTensors...>::get_thread_buffer_size();
33 [&](
auto i) { inout_element_func(inout_dstr_tensors.get_thread_buffer().at(i)...); });
36 template <
typename InElementFunc,
39 std::conjunction_v<std::negation<std::is_same<InTensor, null_tensor>>...>>>
41 const InTensor&... in_dstr_tensors)
43 using OutDataType = decltype(in_element_func(
typename InTensor::DataType{}...));
47 constexpr
auto in_tile_dstr = __type_pack_element<0, InTensor...>::get_tile_distribution();
49 constexpr
index_t thread_buffer_size =
50 __type_pack_element<0, InTensor...>::get_thread_buffer_size();
52 auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
55 out_dstr_tensor.get_thread_buffer()(i) =
56 in_element_func(in_dstr_tensors.get_thread_buffer()[i]...);
59 return out_dstr_tensor;
62 template <
typename DstrTensors,
typename T>
67 x = type_convert<typename DstrTensors::DataType, remove_cvref_t<T>>(value);
79 template <
typename DstrTensors, index_t v,
bool skip_subdword_opt = false>
83 using elem_type =
typename DstrTensors::DataType;
84 constexpr
index_t elem_size =
sizeof(elem_type);
86 constexpr
index_t tensor_bytes = DstrTensors::get_thread_buffer_size() * elem_size;
89 if constexpr(v == 0 && tensor_bytes % 4 == 0 && !skip_subdword_opt)
91 #if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
92 auto& buffer = dstr_tensor.get_thread_buffer();
94 static_for<0, tensor_bytes / 4, 1>{}([&](
auto i_write) {
95 if constexpr(elem_size == 1)
98 constexpr
auto values = ext_vector_t<elem_type, 4>{0, 0, 0, 0};
100 buffer[i_write * 4 + 0] = values.x;
101 buffer[i_write * 4 + 1] = values.y;
102 buffer[i_write * 4 + 2] = values.z;
103 buffer[i_write * 4 + 3] = values.w;
105 else if constexpr(elem_size == 2)
108 constexpr
auto values = ext_vector_t<elem_type, 2>{0, 0};
110 buffer[i_write * 2 + 0] = values.x;
111 buffer[i_write * 2 + 1] = values.y;
113 else if constexpr(elem_size == 4)
116 constexpr elem_type value = 0;
118 buffer[i_write] = value;
122 static_assert(
false,
"type not supported");
126 using dvec_t = array<
index_t, tensor_bytes / 4>;
127 auto& tensor =
reinterpret_cast<dvec_t&
>(dstr_tensor.get_thread_buffer());
128 for(
auto i = 0; i < tensor.size(); i++)
144 template <
typename DstrTensors>
152 template <
typename OutDataType,
typename InTensor>
155 #if defined(__gfx94__)
157 constexpr
auto in_tile_dstr = InTensor::get_tile_distribution();
159 constexpr
index_t thread_buffer_size = InTensor::get_thread_buffer_size();
160 static_assert(thread_buffer_size % 4 == 0);
161 constexpr
index_t thread_buffer_size_pk = thread_buffer_size / 4;
163 auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
164 #pragma clang diagnostic push
165 #pragma clang diagnostic ignored "-Wuninitialized"
171 uint32_t x = __builtin_amdgcn_cvt_pk_fp8_f32(
177 uint32_t y = __builtin_amdgcn_cvt_pk_fp8_f32(
183 constexpr int32_t m0 = 0x05040100;
186 vec_t d = bit_cast<vec_t>(__builtin_amdgcn_perm(y, x, m0));
187 out_dstr_tensor.get_thread_buffer().template set_as<vec_t>(
number<i>{}, d);
189 #pragma clang diagnostic pop
191 return out_dstr_tensor;
199 template <
typename OutDataType,
typename InTensor>
202 #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)
204 constexpr
auto in_tile_dstr = InTensor::get_tile_distribution();
206 constexpr
index_t thread_buffer_size = InTensor::get_thread_buffer_size();
207 static_assert(thread_buffer_size % 2 == 0);
208 constexpr
index_t thread_buffer_size_pk = thread_buffer_size / 2;
210 auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
213 for(
index_t i = 0; i < thread_buffer_size_pk; i++)
215 auto o = __builtin_amdgcn_cvt_pkrtz(in_dstr_tensors.get_thread_buffer()[2 * i + 0],
216 in_dstr_tensors.get_thread_buffer()[2 * i + 1]);
218 out_dstr_tensor.get_thread_buffer().at(2 * i + 0) = o.x;
219 out_dstr_tensor.get_thread_buffer().at(2 * i + 1) = o.y;
222 return out_dstr_tensor;
230 #if CK_TILE_USE_SUBDWORD_TILE_CAST
233 template <
typename OutDataType,
typename InTensor>
234 CK_TILE_DEVICE auto cast_tile_opt_subdword(
const InTensor& in_dstr_tensors)
236 constexpr
auto in_tile_dstr = InTensor::get_tile_distribution();
238 auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
242 constexpr
index_t i_elem_bytes =
sizeof(i_type);
243 constexpr
index_t o_elem_bytes =
sizeof(o_type);
244 static_assert(i_elem_bytes < 4 || o_elem_bytes < 4);
247 (i_elem_bytes >= o_elem_bytes) ? (4 / o_elem_bytes) : (4 / i_elem_bytes);
248 static_assert(bulk_size != 0);
253 constexpr
index_t thread_buffer_size = InTensor::get_thread_buffer_size();
255 constexpr
index_t iters = thread_buffer_size / bulk_size;
256 constexpr
index_t rems = thread_buffer_size % bulk_size;
263 o_type data[bulk_size];
267 static_for<0, bulk_size, 1>{}([&o_bulk, &in_dstr_tensors, &i](
auto ib) {
268 o_bulk.data[ib.value] =
static_cast<o_type
>(
269 in_dstr_tensors.get_thread_buffer()
270 .template get_as<i_type>()[
number<bulk_size * i.value + ib.value>{}]);
280 out_dstr_tensor.get_thread_buffer().template set_as<o_bulk_type>(i, o_bulk.bulk);
283 static_for<0, rems, 1>{}([&](
auto r) {
285 auto idx = number<iters * bulk_size + r>{};
286 out_dstr_tensor.get_thread_buffer().at(idx) =
287 static_cast<o_type
>(in_dstr_tensors.get_thread_buffer().at(idx));
290 return out_dstr_tensor;
295 template <
typename DstType,
typename SrcTensor>
298 if constexpr((std::is_same_v<DstType, fp8_t> ||
299 std::is_same_v<DstType, bf8_t>)&&
std::is_same_v<
typename SrcTensor::DataType,
301 (SrcTensor::get_thread_buffer_size() % 4 == 0))
303 return impl::cast_tile_pk_fp8_fp32<DstType, SrcTensor>(src_tensor);
305 #if CK_TILE_USE_PK_FP16_TILE_CAST
306 else if constexpr(std::is_same_v<DstType, fp16_t> &&
307 std::is_same_v<typename SrcTensor::DataType, float> &&
308 (SrcTensor::get_thread_buffer_size() % 2 == 0))
310 return impl::cast_tile_pk_fp16_fp32<DstType, SrcTensor>(src_tensor);
313 #if CK_TILE_USE_SUBDWORD_TILE_CAST
314 else if constexpr(
sizeof(DstType) < 4 ||
sizeof(
typename SrcTensor::DataType) < 4)
316 return impl::cast_tile_opt_subdword<DstType, SrcTensor>(src_tensor);
320 return tile_elementwise_in(type_convert<DstType, typename SrcTensor::DataType>, src_tensor);
324 template <
typename InOutElementFunc,
325 typename... MaybeNullTensor,
327 std::disjunction_v<std::is_same<remove_cvref_t<MaybeNullTensor>, null_tensor>...>>>
333 template <
typename InElementFunc,
334 typename... MaybeNullTensor,
336 std::disjunction_v<std::is_same<remove_cvref_t<MaybeNullTensor>, null_tensor>...>>>
#define CK_TILE_DEVICE
Definition: config.hpp:40
CK_TILE_DEVICE auto cast_tile_pk_fp8_fp32(const InTensor &in_dstr_tensors)
Definition: tile_elementwise.hpp:153
CK_TILE_DEVICE auto cast_tile_pk_fp16_fp32(const InTensor &in_dstr_tensors)
Definition: tile_elementwise.hpp:200
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition: tile_elementwise.hpp:40
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:63
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition: tile_elementwise.hpp:23
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
constant< v > number
Definition: integral_constant.hpp:33
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition: tile_elementwise.hpp:296
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition: tile_elementwise.hpp:145
constexpr bool is_same_v
Definition: type.hpp:283
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:13
Definition: integral_constant.hpp:13
Definition: null_tensor.hpp:9
Definition: functional.hpp:43