/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/utility/gemm_validation.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/utility/gemm_validation.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/utility/gemm_validation.hpp Source File
gemm_validation.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <string>
7 #include <stdexcept>
9 
10 namespace ck_tile {
11 
12 inline void
13 validate_stride(std::string Layout, int M, int N, int stride, const std::string& stride_name)
14 {
15  if(Layout == "C" && stride < M)
16  {
17  throw std::runtime_error("For ColumnMajor layout, " + stride_name + "(" +
18  std::to_string(stride) + ") must be greater or equal to dim " +
19  std::to_string(M));
20  }
21  if(Layout == "R" && stride < N)
22  {
23  throw std::runtime_error("For RowMajor layout, " + stride_name + "(" +
24  std::to_string(stride) + ") must be greater or equal to dim " +
25  std::to_string(N));
26  }
27 }
28 
29 inline void validate_gemm_stride(std::string a_layout,
30  std::string b_layout,
31  std::string c_layout,
32  int M,
33  int N,
34  int K,
35  int Stride_A,
36  int Stride_B,
37  int Stride_C)
38 {
39  // set default stride
40  if(Stride_A <= 0)
41  Stride_A = (a_layout == "R") ? K : M;
42  if(Stride_B <= 0)
43  Stride_B = (b_layout == "R") ? N : K;
44  if(Stride_C <= 0)
45  Stride_C = (c_layout == "R") ? N : M;
46 
47  validate_stride(a_layout, M, K, Stride_A, "Stride_A");
48  validate_stride(b_layout, K, N, Stride_B, "Stride_B");
49  validate_stride(c_layout, M, N, Stride_C, "Stride_C");
50 }
51 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
void validate_stride(std::string Layout, int M, int N, int stride, const std::string &stride_name)
Definition: gemm_validation.hpp:13
void validate_gemm_stride(std::string a_layout, std::string b_layout, std::string c_layout, int M, int N, int K, int Stride_A, int Stride_B, int Stride_C)
Definition: gemm_validation.hpp:29
Layout wrapper that performs the tensor descriptor logic.
Definition: layout.hpp:24