31 template <index_t K1,
typename TileLayout>
32 __device__ constexpr
auto GetBlockDescriptor()
34 using TileLayoutShape =
typename TileLayout::LayoutShape;
35 using TileLayoutDescriptor =
typename TileLayout::LayoutUnrolledDescriptorType;
37 constexpr
auto K0PerBlock = Number<size<1>(TileLayoutShape{})>{} / Number<K1>{};
39 constexpr
auto Dim0 = Number<size<0>(TileLayoutShape{})>{};
42 TileLayoutDescriptor{},
48 return a_block_desc_k0_m_k1;
86 template <
typename DataType,
93 const BTensorType& b_local_tile_tensor,
94 CTensorType& c_reg_tensor)
96 constexpr
auto I3 = Number<3>{};
98 static_assert(ATensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds);
99 static_assert(BTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds);
100 static_assert(CTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Vgpr);
101 static_assert(is_same_v<DataType, typename ATensorType::TensorElementType>);
102 static_assert(is_same_v<DataType, typename BTensorType::TensorElementType>);
104 constexpr
bool is_integer =
105 is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
106 using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
111 static_assert(
typename ATileLayout::LayoutShape{}.Size() ==
112 typename BTileLayout::LayoutShape{}.Size());
113 constexpr
bool is_3d_desc =
typename ATileLayout::LayoutShape{}.Size() == I3;
115 using ABlockDesc_K0_M_K1_Type =
117 typename ATileLayout::LayoutUnrolledDescriptorType,
118 decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
119 using BBlockDesc_K0_N_K1_Type =
121 typename BTileLayout::LayoutUnrolledDescriptorType,
122 decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
124 BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
128 ABlockDesc_K0_M_K1_Type,
129 BBlockDesc_K0_N_K1_Type,
132 GemmTraits::MXdlPerWave,
133 GemmTraits::NXdlPerWave,
135 blockwise_gemm_xdl_op{};
137 blockwise_gemm_xdl_op.Run(
138 a_local_tile_tensor.GetBuffer(), b_local_tile_tensor.GetBuffer(), c_reg_tensor.GetBuffer());
170 template <
typename DataType,
171 typename ATileLayout,
172 typename BTileLayout,
175 typename CTensorType>
176 __host__ __device__ constexpr
auto
179 constexpr
auto I0 = Number<0>{};
180 constexpr
auto I1 = Number<1>{};
181 constexpr
auto I2 = Number<2>{};
182 constexpr
auto I3 = Number<3>{};
183 constexpr
auto I4 = Number<4>{};
184 constexpr
auto I5 = Number<5>{};
185 constexpr
auto I6 = Number<6>{};
186 constexpr
auto I7 = Number<7>{};
188 static_assert(
typename ATileLayout::LayoutShape{}.Size() ==
189 typename BTileLayout::LayoutShape{}.Size());
191 constexpr
bool is_integer =
192 is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
193 using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
195 constexpr
bool is_3d_desc =
typename ATileLayout::LayoutShape{}.Size() == I3;
196 using ABlockDesc_K0_M_K1_Type =
198 typename ATileLayout::LayoutUnrolledDescriptorType,
199 decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
200 using BBlockDesc_K0_N_K1_Type =
202 typename BTileLayout::LayoutUnrolledDescriptorType,
203 decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
205 using BlockwiseGemmXdlops =
206 BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
210 ABlockDesc_K0_M_K1_Type,
211 BBlockDesc_K0_N_K1_Type,
214 GemmTraits::MXdlPerWave,
215 GemmTraits::NXdlPerWave,
218 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
219 BlockwiseGemmXdlops::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
220 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0);
221 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
222 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2);
223 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
224 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4);
225 constexpr
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
226 constexpr
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
227 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
230 const auto c_thread_mtx_on_block =
231 BlockwiseGemmXdlops::CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
233 const index_t m_thread_data_on_grid =
234 c_local_tile_tensor.GetMultiIdxOffsets()[I0] + c_thread_mtx_on_block[I0];
236 const index_t n_thread_data_on_grid =
237 c_local_tile_tensor.GetMultiIdxOffsets()[I1] + c_thread_mtx_on_block[I1];
244 const auto m_thread_data_on_grid_idx =
245 m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
248 const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor =
253 const auto n_thread_data_on_grid_idx =
254 n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
257 const auto partition_shape =
make_tuple(M0, N0, I1, I1, M2, I1, M4, I1);
259 const auto partition_desc = BlockwiseGemmXdlops::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(
260 layout(c_local_tile_tensor).GetUnrolledDescriptor());
262 const auto lower_upper_dims =
263 generate_tuple([&](
auto i) {
return Sequence<i.value>{}; }, Number<8>{});
269 m_thread_data_on_grid_idx[I0],
270 partition_shape.At(Number<0>{}) + m_thread_data_on_grid_idx[I0]),
272 n_thread_data_on_grid_idx[I0],
273 partition_shape.At(Number<1>{}) + n_thread_data_on_grid_idx[I0]),
275 m_thread_data_on_grid_idx[I1],
276 partition_shape.At(Number<2>{}) + m_thread_data_on_grid_idx[I1]),
278 n_thread_data_on_grid_idx[I1],
279 partition_shape.At(Number<3>{}) + n_thread_data_on_grid_idx[I1]),
281 m_thread_data_on_grid_idx[I2],
282 partition_shape.At(Number<4>{}) + m_thread_data_on_grid_idx[I2]),
284 m_thread_data_on_grid_idx[I3],
285 partition_shape.At(Number<5>{}) + m_thread_data_on_grid_idx[I3]),
287 m_thread_data_on_grid_idx[I4],
288 partition_shape.At(Number<6>{}) + m_thread_data_on_grid_idx[I4]),
290 n_thread_data_on_grid_idx[I2],
291 partition_shape.At(Number<7>{}) + n_thread_data_on_grid_idx[I2])),
295 const auto partition_layout =
297 partition_shape, sliced_desc);
298 auto partition_tensor = make_tensor<CTensorType::TensorBufferAddressSpace>(
299 c_local_tile_tensor.GetPointer(), partition_layout);
300 return partition_tensor;
330 template <
typename DataType,
331 typename ATileLayout,
332 typename BTileLayout,
337 constexpr
auto I0 = Number<0>{};
338 constexpr
auto I1 = Number<1>{};
339 constexpr
auto I2 = Number<2>{};
340 constexpr
auto I3 = Number<3>{};
341 constexpr
auto I4 = Number<4>{};
342 constexpr
auto I5 = Number<5>{};
343 constexpr
auto I6 = Number<6>{};
344 constexpr
auto I7 = Number<7>{};
346 static_assert(
typename ATileLayout::LayoutShape{}.Size() ==
347 typename BTileLayout::LayoutShape{}.Size());
349 constexpr
bool is_integer =
350 is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
351 using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
353 constexpr
bool is_3d_desc =
typename ATileLayout::LayoutShape{}.Size() == I3;
354 using ABlockDesc_K0_M_K1_Type =
356 typename ATileLayout::LayoutUnrolledDescriptorType,
357 decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
358 using BBlockDesc_K0_N_K1_Type =
360 typename BTileLayout::LayoutUnrolledDescriptorType,
361 decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
363 using BlockwiseGemmXdlops =
364 BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
368 ABlockDesc_K0_M_K1_Type,
369 BBlockDesc_K0_N_K1_Type,
372 GemmTraits::MXdlPerWave,
373 GemmTraits::NXdlPerWave,
376 constexpr
auto vgpr_desc = BlockwiseGemmXdlops::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
377 const auto vgpr_shape =
make_tuple(vgpr_desc.GetLengths()[I0],
378 vgpr_desc.GetLengths()[I1],
379 vgpr_desc.GetLengths()[I2],
380 vgpr_desc.GetLengths()[I3],
381 vgpr_desc.GetLengths()[I4],
382 vgpr_desc.GetLengths()[I5],
383 vgpr_desc.GetLengths()[I6],
384 vgpr_desc.GetLengths()[I7]);
386 vgpr_shape, vgpr_desc);
388 constexpr
index_t ScalarPerVector = BlockwiseGemmXdlops::xdlops_gemm.GetRegSizePerXdlops();
389 using VgprVectorType =
typename vector_type<GemmAccDataType, ScalarPerVector>::type;
390 return ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, VgprVectorType>(
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
typename remove_reference< T >::type remove_reference_t
Definition: type.hpp:292
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__host__ constexpr __device__ auto make_slice_transform(const LowLength &low_length, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition: multi_index_transform_helper.hpp:110
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
Layout wrapper that performs the tensor descriptor logic.
Definition: layout.hpp:24
__host__ constexpr __device__ const auto & layout(const Tensor< BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType > &tensor)
Get Tensor Layout.
Definition: tensor_utils.hpp:162
__device__ void blockwise_gemm_xdl(const ATensorType &a_local_tile_tensor, const BTensorType &b_local_tile_tensor, CTensorType &c_reg_tensor)
Perform blockwise gemm xdl on tensors stored in lds. Result will be stored in Vgpr register....
Definition: gemm.hpp:92
__host__ constexpr __device__ auto make_blockwise_gemm_xdl_c_local_partition(CTensorType &c_local_tile_tensor)
Create local partition per thread for C tensor.
Definition: gemm.hpp:177
__host__ constexpr __device__ auto make_blockwise_gemm_xdl_c_vgpr()
Create local partition per thread for C tensor.
Definition: gemm.hpp:335