/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp Source File
block_fmha_bwd_convert_dq.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 
9 namespace ck_tile {
10 
11 template <typename Problem, typename Policy = BlockFmhaBwdPipelineDefaultPolicy>
13 {
16 
17  static constexpr index_t kM0 = Problem::kM0;
18  static constexpr index_t kN0 = Problem::kN0;
19 
20  static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
21  static constexpr index_t kBlockSize = Problem::kBlockSize;
22  static constexpr index_t kQKHeaddim = Problem::kQKHeaddim;
23 
24  static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
25  static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
26  static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
27  static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
28 
29  static constexpr index_t kAlignmentQGradAcc =
30  kPadHeadDimQ ? 1 : Policy::template GetAlignmentPostQGradAcc<Problem>();
31  static constexpr index_t kAlignmentQGrad =
32  kPadHeadDimQ ? 1 : Policy::template GetAlignmentPostQGrad<Problem>();
33 
34  CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
35 
36  // Convert only
37  template <typename QGradAccDramBlockWindowTmp, typename QGradDramBlockWindowTmp>
39  operator()(const QGradAccDramBlockWindowTmp& dq_acc_dram_block_window_tmp,
40  QGradDramBlockWindowTmp& dq_dram_block_window_tmp) const
41  {
42  static_assert(
47  "wrong!");
48 
49  static_assert(kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!");
50 
51  auto dq_acc_dram_window =
52  make_tile_window(dq_acc_dram_block_window_tmp.get_bottom_tensor_view(),
53  dq_acc_dram_block_window_tmp.get_window_lengths(),
54  dq_acc_dram_block_window_tmp.get_window_origin(),
55  Policy::template MakePostQGradDramTileDistribution<Problem>());
56 
57  auto dq_acc = load_tile(dq_acc_dram_window);
58  const auto dq = cast_tile<QGradDataType>(dq_acc);
59 
60  store_tile(dq_dram_block_window_tmp, dq);
61  }
62 
63  // Reduce + Convert
64  template <typename QGradAccDramBlockWindowTmp, typename QGradDramBlockWindowTmp>
66  operator()(const QGradAccDramBlockWindowTmp& dq_acc_dram_block_window_tmp,
67  QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
68  index_t nsplits) const
69  {
70  static_assert(
75  "wrong!");
76 
77  static_assert(kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!");
78 
79  auto dq_acc_dram_window =
80  make_tile_window(dq_acc_dram_block_window_tmp.get_bottom_tensor_view(),
81  dq_acc_dram_block_window_tmp.get_window_lengths(),
82  dq_acc_dram_block_window_tmp.get_window_origin(),
83  Policy::template MakePostQGradAccDramTileDistribution<Problem>());
84 
85  auto dq_acc = decltype(load_tile(dq_acc_dram_window)){};
86  clear_tile(dq_acc);
87 
88  constexpr auto dq_acc_spans = decltype(dq_acc)::get_distributed_spans();
89  index_t i_total_loops = 0;
90  auto dq_acc_buf = load_tile(dq_acc_dram_window);
91  move_tile_window(dq_acc_dram_window, {1, 0, 0});
92 
93  do
94  {
95  sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) {
96  sweep_tile_span(dq_acc_spans[number<1>{}], [&](auto idx1) {
97  sweep_tile_span(dq_acc_spans[number<2>{}], [&](auto idx2) {
98  constexpr auto n_i_j_idx = make_tuple(idx0, idx1, idx2);
99  dq_acc(n_i_j_idx) += dq_acc_buf(n_i_j_idx);
100  });
101  });
102  });
103 
104  dq_acc_buf = load_tile(dq_acc_dram_window);
105  move_tile_window(dq_acc_dram_window, {1, 0, 0});
106 
107  i_total_loops += 1;
108  } while(i_total_loops < (nsplits - 1));
109 
110  sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) {
111  sweep_tile_span(dq_acc_spans[number<1>{}], [&](auto idx1) {
112  sweep_tile_span(dq_acc_spans[number<2>{}], [&](auto idx2) {
113  constexpr auto n_i_j_idx = make_tuple(idx0, idx1, idx2);
114  dq_acc(n_i_j_idx) += dq_acc_buf(n_i_j_idx);
115  });
116  });
117  });
118 
119  // declare dq
120  constexpr auto dq_converted_dstr =
121  Policy::template MakePostQGradAccDramTileDistribution<Problem>();
122  auto dq_converted = make_static_distributed_tensor<QGradDataType>(dq_converted_dstr);
123 
124  sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) {
125  sweep_tile_span(dq_acc_spans[number<1>{}], [&](auto idx1) {
126  sweep_tile_span(dq_acc_spans[number<2>{}], [&](auto idx2) {
127  constexpr auto n_i_j_idx = make_tuple(idx0, idx1, idx2);
128  dq_converted(n_i_j_idx) = type_convert<QGradDataType>(dq_acc[n_i_j_idx]);
129  });
130  });
131  });
132 
133  constexpr auto dq_dstr = Policy::template MakePostQGradDramTileDistribution<Problem>();
134  auto dq = make_static_distributed_tensor<QGradDataType>(dq_dstr);
135  dq.get_thread_buffer() = dq_converted.get_thread_buffer();
136 
137  store_tile(dq_dram_block_window_tmp, dq);
138  }
139 };
140 
141 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition: tile_elementwise.hpp:177
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
constexpr bool is_same_v
Definition: type.hpp:283
Definition: block_fmha_bwd_convert_dq.hpp:13
static constexpr bool kPadSeqLenQ
Definition: block_fmha_bwd_convert_dq.hpp:25
static constexpr bool kIsDeterministic
Definition: block_fmha_bwd_convert_dq.hpp:27
remove_cvref_t< typename Problem::AccDataType > AccDataType
Definition: block_fmha_bwd_convert_dq.hpp:14
static constexpr index_t kAlignmentQGradAcc
Definition: block_fmha_bwd_convert_dq.hpp:29
remove_cvref_t< typename Problem::QGradDataType > QGradDataType
Definition: block_fmha_bwd_convert_dq.hpp:15
static constexpr index_t kBlockSize
Definition: block_fmha_bwd_convert_dq.hpp:21
static constexpr index_t kM0
Definition: block_fmha_bwd_convert_dq.hpp:17
static constexpr bool kPadHeadDimQ
Definition: block_fmha_bwd_convert_dq.hpp:26
static constexpr index_t kBlockPerCu
Definition: block_fmha_bwd_convert_dq.hpp:20
static constexpr index_t kAlignmentQGrad
Definition: block_fmha_bwd_convert_dq.hpp:31
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: block_fmha_bwd_convert_dq.hpp:34
static constexpr index_t kN0
Definition: block_fmha_bwd_convert_dq.hpp:18
CK_TILE_HOST_DEVICE void operator()(const QGradAccDramBlockWindowTmp &dq_acc_dram_block_window_tmp, QGradDramBlockWindowTmp &dq_dram_block_window_tmp, index_t nsplits) const
Definition: block_fmha_bwd_convert_dq.hpp:66
static constexpr index_t kQKHeaddim
Definition: block_fmha_bwd_convert_dq.hpp:22
CK_TILE_HOST_DEVICE void operator()(const QGradAccDramBlockWindowTmp &dq_acc_dram_block_window_tmp, QGradDramBlockWindowTmp &dq_dram_block_window_tmp) const
Definition: block_fmha_bwd_convert_dq.hpp:39
static constexpr bool kIsGroupMode
Definition: block_fmha_bwd_convert_dq.hpp:24
Definition: integral_constant.hpp:13