31 template <index_t K1, 
typename TileLayout>
 
   32 __device__ constexpr 
auto GetBlockDescriptor()
 
   34     using TileLayoutShape      = 
typename TileLayout::LayoutShape;
 
   35     using TileLayoutDescriptor = 
typename TileLayout::LayoutUnrolledDescriptorType;
 
   37     constexpr 
auto K0PerBlock = Number<size<1>(TileLayoutShape{})>{} / Number<K1>{};
 
   39     constexpr 
auto Dim0 = Number<size<0>(TileLayoutShape{})>{};
 
   42         TileLayoutDescriptor{},
 
   48     return a_block_desc_k0_m_k1;
 
   86 template <
typename DataType,
 
   93                                    const BTensorType& b_local_tile_tensor,
 
   94                                    CTensorType& c_reg_tensor)
 
   96     constexpr 
auto I3 = Number<3>{};
 
   98     static_assert(ATensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds);
 
   99     static_assert(BTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds);
 
  100     static_assert(CTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Vgpr);
 
  101     static_assert(is_same_v<DataType, typename ATensorType::TensorElementType>);
 
  102     static_assert(is_same_v<DataType, typename BTensorType::TensorElementType>);
 
  104     constexpr 
bool is_integer =
 
  105         is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
 
  106     using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
 
  111     static_assert(
typename ATileLayout::LayoutShape{}.Size() ==
 
  112                   typename BTileLayout::LayoutShape{}.Size());
 
  113     constexpr 
bool is_3d_desc = 
typename ATileLayout::LayoutShape{}.Size() == I3;
 
  115     using ABlockDesc_K0_M_K1_Type =
 
  117                       typename ATileLayout::LayoutUnrolledDescriptorType,
 
  118                       decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
 
  119     using BBlockDesc_K0_N_K1_Type =
 
  121                       typename BTileLayout::LayoutUnrolledDescriptorType,
 
  122                       decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
 
  124     BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
 
  128                                                         ABlockDesc_K0_M_K1_Type,
 
  129                                                         BBlockDesc_K0_N_K1_Type,
 
  132                                                         GemmTraits::MXdlPerWave,
 
  133                                                         GemmTraits::NXdlPerWave,
 
  135         blockwise_gemm_xdl_op{};
 
  137     blockwise_gemm_xdl_op.Run(
 
  138         a_local_tile_tensor.GetBuffer(), b_local_tile_tensor.GetBuffer(), c_reg_tensor.GetBuffer());
 
  170 template <
typename DataType,
 
  171           typename ATileLayout,
 
  172           typename BTileLayout,
 
  175           typename CTensorType>
 
  176 __host__ __device__ constexpr 
auto 
  179     constexpr 
auto I0 = Number<0>{};
 
  180     constexpr 
auto I1 = Number<1>{};
 
  181     constexpr 
auto I2 = Number<2>{};
 
  182     constexpr 
auto I3 = Number<3>{};
 
  183     constexpr 
auto I4 = Number<4>{};
 
  184     constexpr 
auto I5 = Number<5>{};
 
  185     constexpr 
auto I6 = Number<6>{};
 
  186     constexpr 
auto I7 = Number<7>{};
 
  188     static_assert(
typename ATileLayout::LayoutShape{}.Size() ==
 
  189                   typename BTileLayout::LayoutShape{}.Size());
 
  191     constexpr 
bool is_integer =
 
  192         is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
 
  193     using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
 
  195     constexpr 
bool is_3d_desc = 
typename ATileLayout::LayoutShape{}.Size() == I3;
 
  196     using ABlockDesc_K0_M_K1_Type =
 
  198                       typename ATileLayout::LayoutUnrolledDescriptorType,
 
  199                       decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
 
  200     using BBlockDesc_K0_N_K1_Type =
 
  202                       typename BTileLayout::LayoutUnrolledDescriptorType,
 
  203                       decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
 
  205     using BlockwiseGemmXdlops =
 
  206         BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
 
  210                                                             ABlockDesc_K0_M_K1_Type,
 
  211                                                             BBlockDesc_K0_N_K1_Type,
 
  214                                                             GemmTraits::MXdlPerWave,
 
  215                                                             GemmTraits::NXdlPerWave,
 
  218     constexpr 
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
 
  219         BlockwiseGemmXdlops::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
 
  220     constexpr 
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0);
 
  221     constexpr 
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
 
  222     constexpr 
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2);
 
  223     constexpr 
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
 
  224     constexpr 
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4);
 
  225     constexpr 
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
 
  226     constexpr 
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
 
  227     constexpr 
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
 
  230     const auto c_thread_mtx_on_block =
 
  231         BlockwiseGemmXdlops::CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
 
  233     const index_t m_thread_data_on_grid =
 
  234         c_local_tile_tensor.GetMultiIdxOffsets()[I0] + c_thread_mtx_on_block[I0];
 
  236     const index_t n_thread_data_on_grid =
 
  237         c_local_tile_tensor.GetMultiIdxOffsets()[I1] + c_thread_mtx_on_block[I1];
 
  244     const auto m_thread_data_on_grid_idx =
 
  245         m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
 
  248     const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor =
 
  253     const auto n_thread_data_on_grid_idx =
 
  254         n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
 
  257     const auto partition_shape = 
make_tuple(M0, N0, I1, I1, M2, I1, M4, I1);
 
  259     const auto partition_desc = BlockwiseGemmXdlops::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(
 
  260         layout(c_local_tile_tensor).GetUnrolledDescriptor());
 
  262     const auto lower_upper_dims =
 
  263         generate_tuple([&](
auto i) { 
return Sequence<i.value>{}; }, Number<8>{});
 
  269                                  m_thread_data_on_grid_idx[I0],
 
  270                                  partition_shape.At(Number<0>{}) + m_thread_data_on_grid_idx[I0]),
 
  272                                  n_thread_data_on_grid_idx[I0],
 
  273                                  partition_shape.At(Number<1>{}) + n_thread_data_on_grid_idx[I0]),
 
  275                                  m_thread_data_on_grid_idx[I1],
 
  276                                  partition_shape.At(Number<2>{}) + m_thread_data_on_grid_idx[I1]),
 
  278                                  n_thread_data_on_grid_idx[I1],
 
  279                                  partition_shape.At(Number<3>{}) + n_thread_data_on_grid_idx[I1]),
 
  281                                  m_thread_data_on_grid_idx[I2],
 
  282                                  partition_shape.At(Number<4>{}) + m_thread_data_on_grid_idx[I2]),
 
  284                                  m_thread_data_on_grid_idx[I3],
 
  285                                  partition_shape.At(Number<5>{}) + m_thread_data_on_grid_idx[I3]),
 
  287                                  m_thread_data_on_grid_idx[I4],
 
  288                                  partition_shape.At(Number<6>{}) + m_thread_data_on_grid_idx[I4]),
 
  290                                  n_thread_data_on_grid_idx[I2],
 
  291                                  partition_shape.At(Number<7>{}) + n_thread_data_on_grid_idx[I2])),
 
  295     const auto partition_layout =
 
  297             partition_shape, sliced_desc);
 
  298     auto partition_tensor = make_tensor<CTensorType::TensorBufferAddressSpace>(
 
  299         c_local_tile_tensor.GetPointer(), partition_layout);
 
  300     return partition_tensor;
 
  330 template <
typename DataType,
 
  331           typename ATileLayout,
 
  332           typename BTileLayout,
 
  337     constexpr 
auto I0 = Number<0>{};
 
  338     constexpr 
auto I1 = Number<1>{};
 
  339     constexpr 
auto I2 = Number<2>{};
 
  340     constexpr 
auto I3 = Number<3>{};
 
  341     constexpr 
auto I4 = Number<4>{};
 
  342     constexpr 
auto I5 = Number<5>{};
 
  343     constexpr 
auto I6 = Number<6>{};
 
  344     constexpr 
auto I7 = Number<7>{};
 
  346     static_assert(
typename ATileLayout::LayoutShape{}.Size() ==
 
  347                   typename BTileLayout::LayoutShape{}.Size());
 
  349     constexpr 
bool is_integer =
 
  350         is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
 
  351     using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
 
  353     constexpr 
bool is_3d_desc = 
typename ATileLayout::LayoutShape{}.Size() == I3;
 
  354     using ABlockDesc_K0_M_K1_Type =
 
  356                       typename ATileLayout::LayoutUnrolledDescriptorType,
 
  357                       decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
 
  358     using BBlockDesc_K0_N_K1_Type =
 
  360                       typename BTileLayout::LayoutUnrolledDescriptorType,
 
  361                       decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
 
  363     using BlockwiseGemmXdlops =
 
  364         BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
 
  368                                                             ABlockDesc_K0_M_K1_Type,
 
  369                                                             BBlockDesc_K0_N_K1_Type,
 
  372                                                             GemmTraits::MXdlPerWave,
 
  373                                                             GemmTraits::NXdlPerWave,
 
  376     constexpr 
auto vgpr_desc = BlockwiseGemmXdlops::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
 
  377     const auto vgpr_shape    = 
make_tuple(vgpr_desc.GetLengths()[I0],
 
  378                                        vgpr_desc.GetLengths()[I1],
 
  379                                        vgpr_desc.GetLengths()[I2],
 
  380                                        vgpr_desc.GetLengths()[I3],
 
  381                                        vgpr_desc.GetLengths()[I4],
 
  382                                        vgpr_desc.GetLengths()[I5],
 
  383                                        vgpr_desc.GetLengths()[I6],
 
  384                                        vgpr_desc.GetLengths()[I7]);
 
  386         vgpr_shape, vgpr_desc);
 
  388     constexpr 
index_t ScalarPerVector = BlockwiseGemmXdlops::xdlops_gemm.GetRegSizePerXdlops();
 
  389     using VgprVectorType = 
typename vector_type<GemmAccDataType, ScalarPerVector>::type;
 
  390     return ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, VgprVectorType>(
 
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
 
typename remove_reference< T >::type remove_reference_t
Definition: type.hpp:292
 
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
 
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
 
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
 
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
 
__host__ constexpr __device__ auto make_slice_transform(const LowLength &low_length, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition: multi_index_transform_helper.hpp:163
 
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
 
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
 
int32_t index_t
Definition: ck.hpp:299
 
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
 
Layout wrapper that performs the tensor descriptor logic.
Definition: layout.hpp:24
 
__host__ constexpr __device__ const auto & layout(const Tensor< BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType > &tensor)
Get Tensor Layout.
Definition: tensor_utils.hpp:162
 
__device__ void blockwise_gemm_xdl(const ATensorType &a_local_tile_tensor, const BTensorType &b_local_tile_tensor, CTensorType &c_reg_tensor)
Perform blockwise gemm xdl on tensors stored in lds. Result will be stored in Vgpr register....
Definition: gemm.hpp:92
 
__host__ constexpr __device__ auto make_blockwise_gemm_xdl_c_local_partition(CTensorType &c_local_tile_tensor)
Create local partition per thread for C tensor.
Definition: gemm.hpp:177
 
__host__ constexpr __device__ auto make_blockwise_gemm_xdl_c_vgpr()
Create local partition per thread for C tensor.
Definition: gemm.hpp:335