/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/split_k_offset_utils.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/split_k_offset_utils.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/split_k_offset_utils.hpp Source File
split_k_offset_utils.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include <numeric>
10 
11 namespace ck {
12 namespace tensor_operation {
13 namespace device {
14 
15 // Check if a tensor descriptor has compact layout
16 // Compact means: GetElementSpaceSize() == product of all dimension lengths
17 // Non-compact descriptors have complex transform pipelines that may not support split-k hack
18 template <typename Descriptor>
19 bool IsDescriptorCompact(const Descriptor& desc)
20 {
21  // Calculate product of all dimensions
22  long_index_t dims_product = 1;
23  constexpr index_t num_dims = Descriptor::GetNumOfDimension();
24 
25  // Use template recursion to multiply all dimension lengths
27  [&](auto i) { dims_product *= static_cast<long_index_t>(desc.GetLength(i)); });
28 
29  return desc.GetElementSpaceSize() == dims_product;
30 }
31 
32 // Determine split-k hack eligibility for descriptor pair
33 // This checks all the conditions required for safely using the split-k offset hack
34 template <index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
36 {
37  template <typename ADescriptor, typename BDescriptor>
38  static bool
39  Check(const ADescriptor& a_desc,
40  const BDescriptor& b_desc,
41  index_t k_batch,
42  index_t Conv_N,
43  const std::array<index_t, NDimSpatial>& output_spatial_lengths,
44  index_t k_block_size) // K0PerBlock*K1 for v1, K0PerBlock for v3, KPerBlock for two-stage
45  {
46  // Only enable hack if k_batch > 1
47  if(k_batch <= 1)
48  {
49  return false;
50  }
51 
52  // Calculate output spatial product
53  const index_t output_spatial_acum = std::accumulate(output_spatial_lengths.begin(),
54  output_spatial_lengths.end(),
55  index_t{1},
56  std::multiplies<index_t>());
57 
58  // Check various divisibility and layout requirements
59  const bool is_k_not_paded = (Conv_N * output_spatial_acum) % (k_block_size * k_batch) == 0;
60 
61  const bool can_divide_n_spatial_by_k_batch = (Conv_N * output_spatial_acum) % k_batch == 0;
62 
63  const bool can_divide_n_by_k_batch = Conv_N % k_batch == 0;
64 
65  const bool is_correct_layout =
66  is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
67 
68  const bool is_a_stride_divisible = a_desc.GetElementSpaceSize() % k_batch == 0;
69 
70  const bool is_b_stride_divisible = b_desc.GetElementSpaceSize() % k_batch == 0;
71 
72  // Check descriptor compactness
73  const bool is_a_compact = IsDescriptorCompact(a_desc);
74  const bool is_b_compact = IsDescriptorCompact(b_desc);
75 
76  // Require BOTH A and B to be eligible for the hack to avoid KBatch dimension mismatch
77  // The gridwise kernel's CheckValidity requires A.KBatch == B.KBatch, so we must
78  // apply the hack uniformly to both tensors to maintain kernel applicability
79  const bool eligible = can_divide_n_spatial_by_k_batch && can_divide_n_by_k_batch &&
80  is_k_not_paded && is_correct_layout && is_a_stride_divisible &&
81  is_b_stride_divisible && is_a_compact && is_b_compact;
82 
83  return eligible;
84  }
85 };
86 
87 // Helper function to dispatch split-K hack for standard kernel (single LDS)
88 // Reduces code duplication in device layer implementations
89 template <typename GridwiseGemm,
90  typename AGridDesc_AK0_M_K1,
91  typename BGridDesc_BK0_N_K1,
92  typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
93  bool HasMainKBlockLoop,
94  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
95  TailNumber TailNum,
96  typename ADataType,
97  typename BDataType,
98  typename CDataType>
99 __device__ void DispatchSplitKHack(const ADataType* p_a_grid,
100  const BDataType* p_b_grid,
101  CDataType* p_c_grid,
102  void* p_shared,
103  const typename GridwiseGemm::Argument& karg,
104  const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
105  const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
106  const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
107  c_grid_desc_mblock_mperblock_nblock_nperblock,
108  index_t k_id,
109  index_t k_batch,
110  bool split_k_offset_hack)
111 {
112  if(split_k_offset_hack)
113  {
114  GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
115  BGridDesc_BK0_N_K1,
116  CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
117  HasMainKBlockLoop,
118  CGlobalMemoryDataOperation,
119  TailNum,
120  true>(p_a_grid,
121  p_b_grid,
122  p_c_grid,
123  p_shared,
124  karg,
125  a_grid_desc_ak0_m_ak1,
126  b_grid_desc_bk0_n_bk1,
127  c_grid_desc_mblock_mperblock_nblock_nperblock,
128  k_id,
129  k_batch);
130  }
131  else
132  {
133  GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
134  BGridDesc_BK0_N_K1,
135  CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
136  HasMainKBlockLoop,
137  CGlobalMemoryDataOperation,
138  TailNum,
139  false>(p_a_grid,
140  p_b_grid,
141  p_c_grid,
142  p_shared,
143  karg,
144  a_grid_desc_ak0_m_ak1,
145  b_grid_desc_bk0_n_bk1,
146  c_grid_desc_mblock_mperblock_nblock_nperblock,
147  k_id,
148  k_batch);
149  }
150 }
151 
152 // Helper function to dispatch split-K hack for 2lds kernel
153 // Reduces code duplication in device layer implementations
154 template <typename GridwiseGemm,
155  typename AGridDesc_AK0_M_K1,
156  typename BGridDesc_BK0_N_K1,
157  typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
158  bool HasMainKBlockLoop,
159  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
160  TailNumber TailNum,
161  typename ADataType,
162  typename BDataType,
163  typename CDataType>
164 __device__ void DispatchSplitKHack_2Lds(const ADataType* p_a_grid,
165  const BDataType* p_b_grid,
166  CDataType* p_c_grid,
167  void* p_shared_0,
168  void* p_shared_1,
169  const typename GridwiseGemm::Argument& karg,
170  const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
171  const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
172  const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
173  c_grid_desc_mblock_mperblock_nblock_nperblock,
174  index_t k_id,
175  index_t k_batch,
176  bool split_k_offset_hack)
177 {
178  if(split_k_offset_hack)
179  {
180  GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
181  BGridDesc_BK0_N_K1,
182  CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
183  HasMainKBlockLoop,
184  CGlobalMemoryDataOperation,
185  TailNum,
186  true>(p_a_grid,
187  p_b_grid,
188  p_c_grid,
189  p_shared_0,
190  p_shared_1,
191  karg,
192  a_grid_desc_ak0_m_ak1,
193  b_grid_desc_bk0_n_bk1,
194  c_grid_desc_mblock_mperblock_nblock_nperblock,
195  k_id,
196  k_batch);
197  }
198  else
199  {
200  GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
201  BGridDesc_BK0_N_K1,
202  CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
203  HasMainKBlockLoop,
204  CGlobalMemoryDataOperation,
205  TailNum,
206  false>(p_a_grid,
207  p_b_grid,
208  p_c_grid,
209  p_shared_0,
210  p_shared_1,
211  karg,
212  a_grid_desc_ak0_m_ak1,
213  b_grid_desc_bk0_n_bk1,
214  c_grid_desc_mblock_mperblock_nblock_nperblock,
215  k_id,
216  k_batch);
217  }
218 }
219 
220 } // namespace device
221 } // namespace tensor_operation
222 } // namespace ck
__device__ void DispatchSplitKHack(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const typename GridwiseGemm::Argument &karg, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock, index_t k_id, index_t k_batch, bool split_k_offset_hack)
Definition: split_k_offset_utils.hpp:99
__device__ void DispatchSplitKHack_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const typename GridwiseGemm::Argument &karg, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock, index_t k_id, index_t k_batch, bool split_k_offset_hack)
Definition: split_k_offset_utils.hpp:164
bool IsDescriptorCompact(const Descriptor &desc)
Definition: split_k_offset_utils.hpp:19
Definition: ck.hpp:270
InMemoryDataOperationEnum
Definition: ck.hpp:279
int64_t long_index_t
Definition: ck.hpp:302
TailNumber
Tail number enumeration for pipeline buffering.
Definition: scheduler_enum.hpp:49
int32_t index_t
Definition: ck.hpp:301
Definition: functional2.hpp:33
Definition: split_k_offset_utils.hpp:36
static bool Check(const ADescriptor &a_desc, const BDescriptor &b_desc, index_t k_batch, index_t Conv_N, const std::array< index_t, NDimSpatial > &output_spatial_lengths, index_t k_block_size)
Definition: split_k_offset_utils.hpp:39