6 #include <hip/hip_runtime.h>
13 namespace tensor_operation {
20 hipDeviceProp_t dev_prop;
25 num_cu_ = dev_prop.multiProcessorCount;
33 const int max_capacity = max_occupancy * device_properties.
num_cu_;
36 const auto optimal_split =
40 k_batch = optimal_split;
45 std::cout <<
"[SPLIT-K AUTODEDUCE] Max active thread blocks per CU for GEMM kernel: "
46 << max_occupancy << std::endl;
47 std::cout <<
"[SPLIT-K AUTODEDUCE] Output grid size: " << grid_size << std::endl;
48 std::cout <<
"[SPLIT-K AUTODEDUCE] Optimal split-k value " << k_batch << std::endl;
53 template <ck::index_t NDimSpatial>
56 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths)
63 constexpr
index_t spatial_offset = 3;
64 const index_t DoHoWo = std::accumulate(begin(a_g_n_k_wos_lengths) + spatial_offset,
65 end(a_g_n_k_wos_lengths),
68 const auto gemmK = a_g_n_k_wos_lengths[I1] * DoHoWo;
71 const auto gemmM = e_g_k_c_xs_lengths[I1];
75 const index_t XYZ = std::accumulate(begin(e_g_k_c_xs_lengths) + spatial_offset,
76 end(e_g_k_c_xs_lengths),
79 const auto gemmN = e_g_k_c_xs_lengths[I2] * XYZ;
83 template <ck::index_t MPerBlock, ck::index_t NPerBlock>
void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:10
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ T floor(T x)
Definition: math_v2.hpp:367
auto get_bwd_weight_gemm_sizes(const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_lengths)
Definition: split_k_utils.hpp:55
ck::index_t get_best_occupancy_k_batch_value(int max_occupancy, ck::index_t grid_size)
Definition: split_k_utils.hpp:30
ck::index_t calculate_mn_grid_size(ck::index_t gemmM, ck::index_t gemmN)
Definition: split_k_utils.hpp:84
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
Definition: integral_constant.hpp:20
Definition: split_k_utils.hpp:17
DeviceProperties()
Definition: split_k_utils.hpp:18
int num_cu_
Definition: split_k_utils.hpp:26
#define CK_ENV(name)
Definition: env.hpp:129