6 #include <hip/hip_runtime.h> 
   13 namespace tensor_operation {
 
   20         hipDeviceProp_t dev_prop;
 
   25         num_cu_ = dev_prop.multiProcessorCount;
 
   33     const int max_capacity = max_occupancy * device_properties.
num_cu_;
 
   36     const auto optimal_split =
 
   40         k_batch = optimal_split;
 
   45         std::cout << 
"[SPLIT-K AUTODEDUCE] Max active thread blocks per CU for GEMM kernel:  " 
   46                   << max_occupancy << std::endl;
 
   47         std::cout << 
"[SPLIT-K AUTODEDUCE] Output grid size:  " << grid_size << std::endl;
 
   48         std::cout << 
"[SPLIT-K AUTODEDUCE] Optimal split-k value " << k_batch << std::endl;
 
   53 template <ck::index_t NDimSpatial>
 
   56                           const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths)
 
   63     constexpr 
index_t spatial_offset = 3;
 
   64     const index_t DoHoWo             = std::accumulate(begin(a_g_n_k_wos_lengths) + spatial_offset,
 
   65                                            end(a_g_n_k_wos_lengths),
 
   68     const auto gemmK                 = a_g_n_k_wos_lengths[I1] * DoHoWo;
 
   71     const auto gemmM = e_g_k_c_xs_lengths[I1];
 
   75     const index_t XYZ = std::accumulate(begin(e_g_k_c_xs_lengths) + spatial_offset,
 
   76                                         end(e_g_k_c_xs_lengths),
 
   79     const auto gemmN  = e_g_k_c_xs_lengths[I2] * XYZ;
 
   83 template <ck::index_t MPerBlock, ck::index_t NPerBlock>
 
void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:10
 
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
 
__host__ T floor(T x)
Definition: math_v2.hpp:367
 
auto get_bwd_weight_gemm_sizes(const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_lengths)
Definition: split_k_utils.hpp:55
 
ck::index_t get_best_occupancy_k_batch_value(int max_occupancy, ck::index_t grid_size)
Definition: split_k_utils.hpp:30
 
ck::index_t calculate_mn_grid_size(ck::index_t gemmM, ck::index_t gemmN)
Definition: split_k_utils.hpp:84
 
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
int32_t index_t
Definition: ck.hpp:299
 
Definition: integral_constant.hpp:20
 
Definition: split_k_utils.hpp:17
 
DeviceProperties()
Definition: split_k_utils.hpp:18
 
int num_cu_
Definition: split_k_utils.hpp:26
 
#define CK_ENV(name)
Definition: env.hpp:129