12 namespace tensor_operation {
18 template <
typename Descriptor>
23 constexpr
index_t num_dims = Descriptor::GetNumOfDimension();
27 [&](
auto i) { dims_product *=
static_cast<long_index_t>(desc.GetLength(i)); });
29 return desc.GetElementSpaceSize() == dims_product;
34 template <index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout>
37 template <
typename ADescriptor,
typename BDescriptor>
39 Check(
const ADescriptor& a_desc,
40 const BDescriptor& b_desc,
43 const std::array<index_t, NDimSpatial>& output_spatial_lengths,
53 const index_t output_spatial_acum = std::accumulate(output_spatial_lengths.begin(),
54 output_spatial_lengths.end(),
56 std::multiplies<index_t>());
59 const bool is_k_not_paded = (Conv_N * output_spatial_acum) % (k_block_size * k_batch) == 0;
61 const bool can_divide_n_spatial_by_k_batch = (Conv_N * output_spatial_acum) % k_batch == 0;
63 const bool can_divide_n_by_k_batch = Conv_N % k_batch == 0;
65 const bool is_correct_layout =
66 is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
68 const bool is_a_stride_divisible = a_desc.GetElementSpaceSize() % k_batch == 0;
70 const bool is_b_stride_divisible = b_desc.GetElementSpaceSize() % k_batch == 0;
79 const bool eligible = can_divide_n_spatial_by_k_batch && can_divide_n_by_k_batch &&
80 is_k_not_paded && is_correct_layout && is_a_stride_divisible &&
81 is_b_stride_divisible && is_a_compact && is_b_compact;
89 template <
typename GridwiseGemm,
90 typename AGridDesc_AK0_M_K1,
91 typename BGridDesc_BK0_N_K1,
92 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
93 bool HasMainKBlockLoop,
100 const BDataType* p_b_grid,
103 const typename GridwiseGemm::Argument& karg,
104 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
105 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
106 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
107 c_grid_desc_mblock_mperblock_nblock_nperblock,
110 bool split_k_offset_hack)
112 if(split_k_offset_hack)
114 GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
116 CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
118 CGlobalMemoryDataOperation,
125 a_grid_desc_ak0_m_ak1,
126 b_grid_desc_bk0_n_bk1,
127 c_grid_desc_mblock_mperblock_nblock_nperblock,
133 GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
135 CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
137 CGlobalMemoryDataOperation,
144 a_grid_desc_ak0_m_ak1,
145 b_grid_desc_bk0_n_bk1,
146 c_grid_desc_mblock_mperblock_nblock_nperblock,
154 template <
typename GridwiseGemm,
155 typename AGridDesc_AK0_M_K1,
156 typename BGridDesc_BK0_N_K1,
157 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
158 bool HasMainKBlockLoop,
165 const BDataType* p_b_grid,
169 const typename GridwiseGemm::Argument& karg,
170 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
171 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
172 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
173 c_grid_desc_mblock_mperblock_nblock_nperblock,
176 bool split_k_offset_hack)
178 if(split_k_offset_hack)
180 GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
182 CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
184 CGlobalMemoryDataOperation,
192 a_grid_desc_ak0_m_ak1,
193 b_grid_desc_bk0_n_bk1,
194 c_grid_desc_mblock_mperblock_nblock_nperblock,
200 GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
202 CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
204 CGlobalMemoryDataOperation,
212 a_grid_desc_ak0_m_ak1,
213 b_grid_desc_bk0_n_bk1,
214 c_grid_desc_mblock_mperblock_nblock_nperblock,
__device__ void DispatchSplitKHack(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const typename GridwiseGemm::Argument &karg, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock, index_t k_id, index_t k_batch, bool split_k_offset_hack)
Definition: split_k_offset_utils.hpp:99
__device__ void DispatchSplitKHack_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const typename GridwiseGemm::Argument &karg, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock, index_t k_id, index_t k_batch, bool split_k_offset_hack)
Definition: split_k_offset_utils.hpp:164
bool IsDescriptorCompact(const Descriptor &desc)
Definition: split_k_offset_utils.hpp:19
InMemoryDataOperationEnum
Definition: ck.hpp:279
int64_t long_index_t
Definition: ck.hpp:302
TailNumber
Tail number enumeration for pipeline buffering.
Definition: scheduler_enum.hpp:49
int32_t index_t
Definition: ck.hpp:301
Definition: functional2.hpp:33
Definition: split_k_offset_utils.hpp:36
static bool Check(const ADescriptor &a_desc, const BDescriptor &b_desc, index_t k_batch, index_t Conv_N, const std::array< index_t, NDimSpatial > &output_spatial_lengths, index_t k_block_size)
Definition: split_k_offset_utils.hpp:39