/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp Source File
split_k_utils.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 #include <numeric>
6 #include <hip/hip_runtime.h>
7 #include "ck/utility/env.hpp"
8 #include "ck/utility/number.hpp"
10 #include "ck/ck.hpp"
11 
12 namespace ck {
13 namespace tensor_operation {
14 namespace device {
15 
17 {
19  {
20  hipDeviceProp_t dev_prop;
21  hipDevice_t dev;
22  hip_check_error(hipGetDevice(&dev));
23  hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
24 
25  num_cu_ = dev_prop.multiProcessorCount;
26  };
27  int num_cu_;
28 };
29 
30 inline ck::index_t get_best_occupancy_k_batch_value(int max_occupancy, ck::index_t grid_size)
31 {
32  static DeviceProperties device_properties;
33  const int max_capacity = max_occupancy * device_properties.num_cu_;
34 
35  ck::index_t k_batch = 1;
36  const auto optimal_split =
37  static_cast<ck::index_t>(std::floor((1.0 * max_capacity) / grid_size));
38  if(optimal_split > 1)
39  {
40  k_batch = optimal_split;
41  }
42 
43  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
44  {
45  std::cout << "[SPLIT-K AUTODEDUCE] Max active thread blocks per CU for GEMM kernel: "
46  << max_occupancy << std::endl;
47  std::cout << "[SPLIT-K AUTODEDUCE] Output grid size: " << grid_size << std::endl;
48  std::cout << "[SPLIT-K AUTODEDUCE] Optimal split-k value " << k_batch << std::endl;
49  }
50  return k_batch;
51 }
52 
53 template <ck::index_t NDimSpatial>
54 inline auto
55 get_bwd_weight_gemm_sizes(const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths,
56  const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths)
57 {
58  static constexpr auto I1 = Number<1>{};
59  static constexpr auto I2 = Number<2>{};
60 
61  // The input array has elements in the order: G, N, K, Do, Ho, Wo
62  // GemmK = N * Do * Ho * Wo for the BWD weight pass.
63  constexpr index_t spatial_offset = 3;
64  const index_t DoHoWo = std::accumulate(begin(a_g_n_k_wos_lengths) + spatial_offset,
65  end(a_g_n_k_wos_lengths),
66  index_t{1},
67  std::multiplies<>{});
68  const auto gemmK = a_g_n_k_wos_lengths[I1] * DoHoWo;
69 
70  // The GEMM M dimension is the number of output channels.
71  const auto gemmM = e_g_k_c_xs_lengths[I1];
72 
73  // The output array has elements in the order: G, K, C, X, Y, Z
74  // GemmN = C * X * Y * Z for the BWD weight pass.
75  const index_t XYZ = std::accumulate(begin(e_g_k_c_xs_lengths) + spatial_offset,
76  end(e_g_k_c_xs_lengths),
77  index_t{1},
78  std::multiplies<>{});
79  const auto gemmN = e_g_k_c_xs_lengths[I2] * XYZ;
80  return std::make_tuple(gemmM, gemmN, gemmK);
81 }
82 
83 template <ck::index_t MPerBlock, ck::index_t NPerBlock>
85 {
86  const auto M0 = math::integer_divide_ceil(gemmM, MPerBlock);
87  const auto N0 = math::integer_divide_ceil(gemmN, NPerBlock);
88  return M0 * N0;
89 }
90 
91 } // namespace device
92 } // namespace tensor_operation
93 } // namespace ck
void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:10
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ T floor(T x)
Definition: math_v2.hpp:367
auto get_bwd_weight_gemm_sizes(const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_lengths)
Definition: split_k_utils.hpp:55
ck::index_t get_best_occupancy_k_batch_value(int max_occupancy, ck::index_t grid_size)
Definition: split_k_utils.hpp:30
ck::index_t calculate_mn_grid_size(ck::index_t gemmM, ck::index_t gemmN)
Definition: split_k_utils.hpp:84
Definition: ck.hpp:267
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
Definition: integral_constant.hpp:20
Definition: split_k_utils.hpp:17
DeviceProperties()
Definition: split_k_utils.hpp:18
int num_cu_
Definition: split_k_utils.hpp:26
#define CK_ENV(name)
Definition: env.hpp:129