13 namespace tensor_operation {
32 template <index_t NumDim1, index_t NumDim2>
33 auto CalculateMaxRead(
const std::vector<index_t>& lengths,
const std::vector<index_t>& strides)
35 if(lengths.size() != NumDim1 + NumDim2)
37 std::ostringstream err;
38 err <<
"Incorrect number of lengths in " <<
"device_contraction_utils.hpp" <<
":"
39 << __LINE__ <<
", in function: " << __func__;
40 throw std::runtime_error(err.str());
42 if(strides.size() != NumDim1 + NumDim2)
44 std::ostringstream err;
45 err <<
"Incorrect number of strides in " <<
"device_contraction_utils.hpp" <<
":"
46 << __LINE__ <<
", in function: " << __func__;
47 throw std::runtime_error(err.str());
51 index_t begin_idx, end_idx, continous_dim, consecutive_stride = 1;
52 if(strides[NumDim1 - 1] == 1 && strides[NumDim1 + NumDim2 - 1] == 1)
55 bool dims1_are_ones =
true;
56 for(
index_t dim_idx = 0; dim_idx < NumDim1; dim_idx++)
58 if(lengths[dim_idx] != 1)
60 dims1_are_ones =
false;
67 end_idx = NumDim1 + NumDim2 - 1;
73 end_idx = NumDim1 - 1;
77 else if(strides[NumDim1 - 1] == 1)
80 end_idx = NumDim1 - 1;
83 else if(strides[NumDim1 + NumDim2 - 1] == 1)
86 end_idx = NumDim1 + NumDim2 - 1;
93 consecutive_stride = 1;
95 return make_tuple(continous_dim, consecutive_stride);
98 for(
index_t dim_idx = end_idx; dim_idx >= begin_idx; --dim_idx)
100 if(strides[dim_idx] == consecutive_stride)
102 consecutive_stride *= lengths[dim_idx];
109 const index_t max_subsequent_elems = consecutive_stride;
110 return make_tuple(continous_dim, max_subsequent_elems);
auto CalculateMaxRead(const std::vector< index_t > &lengths, const std::vector< index_t > &strides)
Definition: device_contraction_utils.hpp:33
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298