/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/masking_specialization.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/masking_specialization.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/masking_specialization.hpp Source File
masking_specialization.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 namespace ck {
7 namespace tensor_operation {
8 namespace device {
9 
11 {
14 };
15 
16 #ifndef __HIPCC_RTC__
18 {
19  switch(s)
20  {
21  case MaskingSpecialization::MaskDisabled: return "MaskDisabled";
22  case MaskingSpecialization::MaskOutUpperTriangle: return "MaskOutUpperTriangle";
23  default: return "Unrecognized specialization!";
24  }
25 }
26 #endif
27 
29 {
30  __host__ __device__ constexpr bool operator()(index_t /*m*/, index_t /*n*/) const
31  {
32  return false;
33  };
34 
35  __host__ __device__ constexpr bool
36  IsTileSkippable(index_t /*m*/, index_t /*n*/, index_t /*m_tile*/, index_t /*n_tile*/) const
37  {
38  return false;
39  }
40 };
41 
43 {
44  __host__ __device__ constexpr bool operator()(index_t m, index_t n) const { return n > m; }
45 
46  __host__ __device__ constexpr bool
47  IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t /*n_tile*/) const
48  {
49  return operator()(m + m_tile - 1, n);
50  }
51 };
52 
53 // to track the points which need to be set to -inf on C0
54 // Note: no need to reset M padding value, because they will not be stored out.
55 template <typename MaskOutPredicate>
57 {
58  __host__ __device__ constexpr C0MatrixMask_impl(index_t NRaw)
59  : NRaw_(NRaw), predicate_(MaskOutPredicate{})
60  {
61  }
62 
63  __host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ index_t n) const
64  {
65  return n >= NRaw_;
66  }
67 
68  __host__ __device__ constexpr bool IsMaskedElement(index_t m, index_t n) const
69  {
70  return predicate_(m, n) || IsNOutOfBound(n);
71  }
72 
73  __host__ __device__ constexpr bool
74  IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t n_tile) const
75  {
76  return predicate_.IsTileSkippable(m, n, m_tile, n_tile);
77  }
78 
79  private:
80  // index_t MRaw_;
81  index_t NRaw_;
82  MaskOutPredicate predicate_;
83 };
84 
85 } // namespace device
86 } // namespace tensor_operation
87 } // namespace ck
std::string getMaskingSpecializationString(const MaskingSpecialization &s)
Definition: masking_specialization.hpp:17
MaskingSpecialization
Definition: masking_specialization.hpp:11
Definition: ck.hpp:267
int32_t index_t
Definition: ck.hpp:298
Definition: masking_specialization.hpp:57
__host__ constexpr __device__ C0MatrixMask_impl(index_t NRaw)
Definition: masking_specialization.hpp:58
__host__ constexpr __device__ bool IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t n_tile) const
Definition: masking_specialization.hpp:74
__host__ constexpr __device__ bool IsNOutOfBound(index_t n) const
Definition: masking_specialization.hpp:63
__host__ constexpr __device__ bool IsMaskedElement(index_t m, index_t n) const
Definition: masking_specialization.hpp:68
Definition: masking_specialization.hpp:29
__host__ constexpr __device__ bool IsTileSkippable(index_t, index_t, index_t, index_t) const
Definition: masking_specialization.hpp:36
__host__ constexpr __device__ bool operator()(index_t, index_t) const
Definition: masking_specialization.hpp:30
Definition: masking_specialization.hpp:43
__host__ constexpr __device__ bool IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t) const
Definition: masking_specialization.hpp:47
__host__ constexpr __device__ bool operator()(index_t m, index_t n) const
Definition: masking_specialization.hpp:44