/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/library/utility/validation_common.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/library/utility/validation_common.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/library/utility/validation_common.hpp Source File
validation_common.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <stdexcept>
7 #include <string>
8 #include <type_traits>
9 #include "ck/ck.hpp"
10 #include "ck/utility/type.hpp"
12 
13 namespace ck {
14 namespace utils {
15 
16 template <typename Layout>
17 inline void
18 validate_gemm_stride(int M, int N, int stride, const std::string& stride_name = "Stride")
19 {
20  if(ck::is_same_v<Layout, ck::tensor_layout::gemm::ColumnMajor>)
21  {
22  if(stride < M)
23  {
24  throw std::runtime_error(
25  "Error: For ColumnMajor layout, " + stride_name + " (" + std::to_string(stride) +
26  ") must be greater than or equal to dim (" + std::to_string(M) + ")");
27  }
28  }
29  else // RowMajor
30  {
31  if(stride < N)
32  {
33  throw std::runtime_error(
34  "Error: For RowMajor layout, " + stride_name + " (" + std::to_string(stride) +
35  ") must be greater than or equal to dim (" + std::to_string(N) + ")");
36  }
37  }
38 }
39 
40 // Convenience functions for common GEMM patterns
41 template <typename ALayout, typename BLayout, typename CLayout>
42 inline void validate_gemm_strides_abc(int M, int N, int K, int StrideA, int StrideB, int StrideC)
43 {
44  validate_gemm_stride<ALayout>(M, K, StrideA, "StrideA");
45  validate_gemm_stride<BLayout>(K, N, StrideB, "StrideB");
46  validate_gemm_stride<CLayout>(M, N, StrideC, "StrideC");
47 }
48 
49 } // namespace utils
50 } // namespace ck
void validate_gemm_stride(int M, int N, int stride, const std::string &stride_name="Stride")
Definition: validation_common.hpp:18
void validate_gemm_strides_abc(int M, int N, int K, int StrideA, int StrideB, int StrideC)
Definition: validation_common.hpp:42
Definition: ck.hpp:267