/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp Source File
block_fmha_bwd_pipeline_problem.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 
8 namespace ck_tile {
9 
10 template <typename QDataType_,
11  typename KDataType_,
12  typename VDataType_,
13  typename GemmDataType_,
14  typename LSEDataType_,
15  typename AccDataType_,
16  typename DDataType_,
17  typename BiasDataType_,
18  typename RandValOutputDataType_,
19  typename ODataType_,
20  typename OGradDataType_,
21  typename QGradDataType_,
22  typename KGradDataType_,
23  typename VGradDataType_,
24  typename BiasGradDataType_,
25  typename BlockFmhaShape_,
26  bool kIsGroupMode_,
27  bool kIsDeterministic_,
28  typename FmhaMask_,
29  typename FmhaDropout_,
30  bool kUseTrLoad_,
31  typename Traits_>
33 {
53 
54  static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
55  static constexpr bool kIsGroupMode = kIsGroupMode_;
56  static constexpr bool kIsDeterministic = kIsDeterministic_;
57  static constexpr bool kUseTrLoad = kUseTrLoad_;
58 
59  // attributes from traits
60  static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
61  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
62  static constexpr auto BiasEnum = Traits::BiasEnum;
63  static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
64  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
65  static_assert(!Traits::kPadSeqLenQ, "BlockFmhaBwdPipelineProblem does not need kPadSeqLenQ");
66  static_assert(!Traits::kPadSeqLenK, "BlockFmhaBwdPipelineProblem does not need kPadSeqLenQ");
67 };
68 
69 template <typename ODataType_,
70  typename OGradDataType_,
71  typename DDataType_,
72  index_t kBlockSize_,
73  index_t kVHeaddim_,
74  bool kIsGroupMode_,
75  typename Traits_>
77 {
82 
83  static_assert(0 < kBlockSize_ && kBlockSize_ % get_warp_size() == 0,
84  "kBlockSize should be divisible by get_warp_size()");
85 
86  static constexpr index_t kBlockSize = kBlockSize_;
87  static constexpr index_t kVHeaddim = kVHeaddim_;
88  static constexpr bool kIsGroupMode = kIsGroupMode_;
89 
90  // attributes from traits
91  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
92  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
93  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
94 };
95 
96 template <typename AccDataType_,
97  typename QGradDataType_,
98  index_t kBlockSize_,
99  index_t kM0_,
100  index_t kN0_,
101  index_t kQKHeaddim_,
102  bool kIsGroupMode_,
103  bool kIsDeterministic_,
104  typename Traits_>
106 {
110 
111  static_assert(0 < kBlockSize_ && kBlockSize_ % get_warp_size() == 0,
112  "kBlockSize should be divisible by get_warp_size()");
113 
114  static constexpr index_t kBlockSize = kBlockSize_;
115  static constexpr index_t kM0 = kM0_;
116  static constexpr index_t kN0 = kN0_;
117  static constexpr index_t kQKHeaddim = kQKHeaddim_;
118  static constexpr bool kIsGroupMode = kIsGroupMode_;
119  static constexpr bool kIsDeterministic = kIsDeterministic_;
120 
121  // attributes from traits
122  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
123  static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
124  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
125 };
126 
127 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
Definition: block_fmha_bwd_pipeline_problem.hpp:106
static constexpr index_t kBlockPerCu
Definition: block_fmha_bwd_pipeline_problem.hpp:124
static constexpr index_t kM0
Definition: block_fmha_bwd_pipeline_problem.hpp:115
static constexpr index_t kQKHeaddim
Definition: block_fmha_bwd_pipeline_problem.hpp:117
static constexpr bool kIsGroupMode
Definition: block_fmha_bwd_pipeline_problem.hpp:118
static constexpr bool kPadHeadDimQ
Definition: block_fmha_bwd_pipeline_problem.hpp:123
remove_cvref_t< QGradDataType_ > QGradDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:108
static constexpr index_t kBlockSize
Definition: block_fmha_bwd_pipeline_problem.hpp:114
remove_cvref_t< AccDataType_ > AccDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:107
static constexpr bool kPadSeqLenQ
Definition: block_fmha_bwd_pipeline_problem.hpp:122
static constexpr bool kIsDeterministic
Definition: block_fmha_bwd_pipeline_problem.hpp:119
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_bwd_pipeline_problem.hpp:109
static constexpr index_t kN0
Definition: block_fmha_bwd_pipeline_problem.hpp:116
Definition: block_fmha_bwd_pipeline_problem.hpp:77
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_bwd_pipeline_problem.hpp:78
remove_cvref_t< OGradDataType_ > OGradDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:79
static constexpr bool kPadHeadDimV
Definition: block_fmha_bwd_pipeline_problem.hpp:92
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_bwd_pipeline_problem.hpp:81
static constexpr index_t kVHeaddim
Definition: block_fmha_bwd_pipeline_problem.hpp:87
static constexpr index_t kBlockPerCu
Definition: block_fmha_bwd_pipeline_problem.hpp:93
static constexpr bool kPadSeqLenQ
Definition: block_fmha_bwd_pipeline_problem.hpp:91
static constexpr bool kIsGroupMode
Definition: block_fmha_bwd_pipeline_problem.hpp:88
remove_cvref_t< DDataType_ > DDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:80
static constexpr index_t kBlockSize
Definition: block_fmha_bwd_pipeline_problem.hpp:86
Definition: block_fmha_bwd_pipeline_problem.hpp:33
remove_cvref_t< BiasGradDataType_ > BiasGradDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:48
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_bwd_pipeline_problem.hpp:52
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_bwd_pipeline_problem.hpp:50
remove_cvref_t< GemmDataType_ > GemmDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:37
remove_cvref_t< KGradDataType_ > KGradDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:46
remove_cvref_t< DDataType_ > DDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:40
remove_cvref_t< BiasDataType_ > BiasDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:41
remove_cvref_t< FmhaDropout_ > FmhaDropout
Definition: block_fmha_bwd_pipeline_problem.hpp:51
static constexpr auto BiasEnum
Definition: block_fmha_bwd_pipeline_problem.hpp:62
remove_cvref_t< QGradDataType_ > QGradDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:45
static constexpr bool kIsGroupMode
Definition: block_fmha_bwd_pipeline_problem.hpp:55
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:34
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:35
static constexpr bool kHasBiasGrad
Definition: block_fmha_bwd_pipeline_problem.hpp:63
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:38
static constexpr index_t kBlockSize
Definition: block_fmha_bwd_pipeline_problem.hpp:54
static constexpr bool kPadHeadDimQ
Definition: block_fmha_bwd_pipeline_problem.hpp:60
static constexpr bool kUseTrLoad
Definition: block_fmha_bwd_pipeline_problem.hpp:57
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:36
remove_cvref_t< OGradDataType_ > OGradDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:44
remove_cvref_t< AccDataType_ > AccDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:39
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_bwd_pipeline_problem.hpp:49
static constexpr index_t kBlockPerCu
Definition: block_fmha_bwd_pipeline_problem.hpp:64
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_bwd_pipeline_problem.hpp:43
remove_cvref_t< VGradDataType_ > VGradDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:47
static constexpr bool kPadHeadDimV
Definition: block_fmha_bwd_pipeline_problem.hpp:61
remove_cvref_t< RandValOutputDataType_ > RandValOutputDataType
Definition: block_fmha_bwd_pipeline_problem.hpp:42
static constexpr bool kIsDeterministic
Definition: block_fmha_bwd_pipeline_problem.hpp:56