include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp Source File

include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp Source File#

Composable Kernel: include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp Source File
block_fmha_bwd_pipeline_problem.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 
8 namespace ck_tile {
9 
10 template <typename QDataType_,
11  typename KDataType_,
12  typename VDataType_,
13  typename GemmDataType_,
14  typename LSEDataType_,
15  typename AccDataType_,
16  typename DDataType_,
17  typename BiasDataType_,
18  typename RandValOutputDataType_,
19  typename ODataType_,
20  typename OGradDataType_,
21  typename QGradDataType_,
22  typename KGradDataType_,
23  typename VGradDataType_,
24  typename BiasGradDataType_,
25  typename BlockFmhaShape_,
26  bool kIsGroupMode_,
27  bool kIsDeterministic_,
28  typename FmhaMask_,
29  typename FmhaDropout_,
30  typename Traits_>
32 {
52 
53  static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
54  static constexpr bool kIsGroupMode = kIsGroupMode_;
55  static constexpr bool kIsDeterministic = kIsDeterministic_;
56 
57  // attributes from traits
58  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
59  static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
60  static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
61  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
62  static constexpr auto BiasEnum = Traits::BiasEnum;
63  static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
64  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
65 };
66 
67 template <typename ODataType_,
68  typename OGradDataType_,
69  typename DDataType_,
70  index_t kBlockSize_,
71  index_t kVHeaddim_,
72  bool kIsGroupMode_,
73  typename Traits_>
75 {
80 
81  static_assert(0 < kBlockSize_ && kBlockSize_ % get_warp_size() == 0,
82  "kBlockSize should be divisible by get_warp_size()");
83 
84  static constexpr index_t kBlockSize = kBlockSize_;
85  static constexpr index_t kVHeaddim = kVHeaddim_;
86  static constexpr bool kIsGroupMode = kIsGroupMode_;
87 
88  // attributes from traits
89  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
90  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
91  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
92 };
93 
94 template <typename AccDataType_,
95  typename QGradDataType_,
96  index_t kBlockSize_,
97  index_t kM0_,
98  index_t kN0_,
99  index_t kQKHeaddim_,
100  bool kIsGroupMode_,
101  bool kIsDeterministic_,
102  typename Traits_>
104 {
108 
109  static_assert(0 < kBlockSize_ && kBlockSize_ % get_warp_size() == 0,
110  "kBlockSize should be divisible by get_warp_size()");
111 
112  static constexpr index_t kBlockSize = kBlockSize_;
113  static constexpr index_t kM0 = kM0_;
114  static constexpr index_t kN0 = kN0_;
115  static constexpr index_t kQKHeaddim = kQKHeaddim_;
116  static constexpr bool kIsGroupMode = kIsGroupMode_;
117  static constexpr bool kIsDeterministic = kIsDeterministic_;
118 
119  // attributes from traits
120  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
121  static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
122  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
123 };
124 
125 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE index_t get_warp_size()
Definition: arch.hpp:51
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
Definition: block_fmha_bwd_pipeline_problem.hpp:104
static constexpr index_t kBlockPerCu
Definition: block_fmha_bwd_pipeline_problem.hpp:122
static constexpr index_t kM0
Definition: block_fmha_bwd_pipeline_problem.hpp:113
static constexpr index_t kQKHeaddim
Definition: block_fmha_bwd_pipeline_problem.hpp:115
static constexpr bool kIsGroupMode
Definition: block_fmha_bwd_pipeline_problem.hpp:116
static constexpr bool kPadHeadDimQ
Definition: block_fmha_bwd_pipeline_problem.hpp:121
remove_cvref_t< QGradDataType_ > QGradDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:106
static constexpr index_t kBlockSize
Definition: block_fmha_bwd_pipeline_problem.hpp:112
remove_cvref_t< AccDataType_ > AccDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:105
static constexpr bool kPadSeqLenQ
Definition: block_fmha_bwd_pipeline_problem.hpp:120
static constexpr bool kIsDeterministic
Definition: block_fmha_bwd_pipeline_problem.hpp:117
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_bwd_pipeline_problem.hpp:107
static constexpr index_t kN0
Definition: block_fmha_bwd_pipeline_problem.hpp:114
Definition: block_fmha_bwd_pipeline_problem.hpp:75
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_bwd_pipeline_problem.hpp:76
remove_cvref_t< OGradDataType_ > OGradDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:77
static constexpr bool kPadHeadDimV
Definition: block_fmha_bwd_pipeline_problem.hpp:90
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_bwd_pipeline_problem.hpp:79
static constexpr index_t kVHeaddim
Definition: block_fmha_bwd_pipeline_problem.hpp:85
static constexpr index_t kBlockPerCu
Definition: block_fmha_bwd_pipeline_problem.hpp:91
static constexpr bool kPadSeqLenQ
Definition: block_fmha_bwd_pipeline_problem.hpp:89
static constexpr bool kIsGroupMode
Definition: block_fmha_bwd_pipeline_problem.hpp:86
remove_cvref_t< DDataType_ > DDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:78
static constexpr index_t kBlockSize
Definition: block_fmha_bwd_pipeline_problem.hpp:84
Definition: block_fmha_bwd_pipeline_problem.hpp:32
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:35
remove_cvref_t< VGradDataType_ > VGradDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:46
static constexpr index_t kBlockSize
Definition: block_fmha_bwd_pipeline_problem.hpp:53
remove_cvref_t< OGradDataType_ > OGradDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:43
static constexpr bool kIsDeterministic
Definition: block_fmha_bwd_pipeline_problem.hpp:55
remove_cvref_t< BiasGradDataType_ > BiasGradDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:47
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_bwd_pipeline_problem.hpp:49
static constexpr bool kPadHeadDimQ
Definition: block_fmha_bwd_pipeline_problem.hpp:60
static constexpr bool kPadSeqLenQ
Definition: block_fmha_bwd_pipeline_problem.hpp:58
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_bwd_pipeline_problem.hpp:48
remove_cvref_t< GemmDataType_ > GemmDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:36
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_bwd_pipeline_problem.hpp:42
static constexpr bool kHasBiasGrad
Definition: block_fmha_bwd_pipeline_problem.hpp:63
static constexpr bool kPadHeadDimV
Definition: block_fmha_bwd_pipeline_problem.hpp:61
remove_cvref_t< AccDataType_ > AccDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:38
remove_cvref_t< DDataType_ > DDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:39
remove_cvref_t< RandValOutputDataType_ > RandValOutputDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:41
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:34
static constexpr bool kIsGroupMode
Definition: block_fmha_bwd_pipeline_problem.hpp:54
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_bwd_pipeline_problem.hpp:51
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:33
static constexpr auto BiasEnum
Definition: block_fmha_bwd_pipeline_problem.hpp:62
static constexpr bool kPadSeqLenK
Definition: block_fmha_bwd_pipeline_problem.hpp:59
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:37
static constexpr index_t kBlockPerCu
Definition: block_fmha_bwd_pipeline_problem.hpp:64
remove_cvref_t< QGradDataType_ > QGradDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:44
remove_cvref_t< BiasDataType_ > BiasDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:40
remove_cvref_t< FmhaDropout_ > FmhaDropout
Definition: block_fmha_bwd_pipeline_problem.hpp:50
remove_cvref_t< KGradDataType_ > KGradDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:45