/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp Source File
device_contraction_utils.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <cassert>
7 #include <sstream>
8 #include <vector>
9 
10 #include "ck/ck.hpp"
11 
12 namespace ck {
13 namespace tensor_operation {
14 namespace device {
15 
32 template <index_t NumDim1, index_t NumDim2>
33 auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<index_t>& strides)
34 {
35  if(lengths.size() != NumDim1 + NumDim2)
36  {
37  std::ostringstream err;
38  err << "Incorrect number of lengths in " << "device_contraction_utils.hpp" << ":"
39  << __LINE__ << ", in function: " << __func__;
40  throw std::runtime_error(err.str());
41  }
42  if(strides.size() != NumDim1 + NumDim2)
43  {
44  std::ostringstream err;
45  err << "Incorrect number of strides in " << "device_contraction_utils.hpp" << ":"
46  << __LINE__ << ", in function: " << __func__;
47  throw std::runtime_error(err.str());
48  }
49 
50  // Determine the beginning and end idx of the group representing the FCD.
51  index_t begin_idx, end_idx, continous_dim, consecutive_stride = 1;
52  if(strides[NumDim1 - 1] == 1 && strides[NumDim1 + NumDim2 - 1] == 1)
53  {
54  // MZ or KZ are ones
55  bool dims1_are_ones = true;
56  for(index_t dim_idx = 0; dim_idx < NumDim1; dim_idx++)
57  {
58  if(lengths[dim_idx] != 1)
59  {
60  dims1_are_ones = false;
61  }
62  }
63 
64  if(dims1_are_ones)
65  {
66  begin_idx = NumDim1;
67  end_idx = NumDim1 + NumDim2 - 1;
68  continous_dim = 1;
69  }
70  else
71  {
72  begin_idx = 0;
73  end_idx = NumDim1 - 1;
74  continous_dim = 0;
75  }
76  }
77  else if(strides[NumDim1 - 1] == 1)
78  {
79  begin_idx = 0;
80  end_idx = NumDim1 - 1;
81  continous_dim = 0;
82  }
83  else if(strides[NumDim1 + NumDim2 - 1] == 1)
84  {
85  begin_idx = NumDim1;
86  end_idx = NumDim1 + NumDim2 - 1;
87  continous_dim = 1;
88  }
89  else
90  {
91  // The dimension consecutive in memory is not the last dimension of any group, so only
92  // one element can be read/written at once.
93  consecutive_stride = 1;
94  continous_dim = 0;
95  return make_tuple(continous_dim, consecutive_stride);
96  }
97 
98  for(index_t dim_idx = end_idx; dim_idx >= begin_idx; --dim_idx)
99  {
100  if(strides[dim_idx] == consecutive_stride)
101  {
102  consecutive_stride *= lengths[dim_idx];
103  }
104  else
105  {
106  break;
107  }
108  }
109  const index_t max_subsequent_elems = consecutive_stride;
110  return make_tuple(continous_dim, max_subsequent_elems);
111 }
112 
113 } // namespace device
114 } // namespace tensor_operation
115 } // namespace ck
auto CalculateMaxRead(const std::vector< index_t > &lengths, const std::vector< index_t > &strides)
Definition: device_contraction_utils.hpp:33
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298