/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/epilogue/chainer/cshuffle_epilogue_chainer_ops.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/epilogue/chainer/cshuffle_epilogue_chainer_ops.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/epilogue/chainer/cshuffle_epilogue_chainer_ops.hpp Source File
cshuffle_epilogue_chainer_ops.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
10 
11 #include <optional>
12 
13 namespace ck_tile {
14 
15 //------------------------------------------------------------------------------
16 // CShuffle-specific epilogue operations
17 // These operations are specific to CShuffle epilogue due to its unique.
18 //------------------------------------------------------------------------------
19 
32 template <typename SFC,
33  typename CWarpDstr,
34  index_t NumMXdlPerWavePerShuffle,
35  index_t NumNXdlPerWavePerShuffle,
36  index_t MPerIterShuffle,
37  index_t NPerIterShuffle>
39 {
40  template <typename OutWindow,
41  typename AccTile,
42  typename AuxWindows,
43  typename IAccess,
44  typename Context>
45  CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window,
46  const AccTile& acc_tile,
47  [[maybe_unused]] const AuxWindows& aux_windows,
48  [[maybe_unused]] void* p_smem,
49  IAccess iAccess,
50  Context& context)
51  {
52  constexpr auto idx_start = SFC::get_index(iAccess);
53  constexpr auto m_iter = number<idx_start.at(number<0>{}) / MPerIterShuffle>{};
54  constexpr auto n_iter = number<idx_start.at(number<1>{}) / NPerIterShuffle>{};
55 
56  constexpr auto warp_y_lengths =
57  to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
58  constexpr auto warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
59 
60  context.working_tile.get_thread_buffer() = acc_tile.get_y_sliced_thread_data(
63  warp_y_index_zeros),
65  warp_y_lengths));
66  }
67 };
68 
77 template <typename SFC>
79 {
80  template <typename OutWindow,
81  typename AccTile,
82  typename AuxWindows,
83  typename IAccess,
84  typename Context,
85  typename ScaleRowTensor,
86  typename ScaleColTensor>
87  CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window,
88  [[maybe_unused]] const AccTile& acc_tile,
89  [[maybe_unused]] const AuxWindows& aux_windows,
90  [[maybe_unused]] void* p_smem,
91  IAccess iAccess,
92  Context& context,
93  const ScaleRowTensor& scale_row_tensor,
94  const ScaleColTensor& scale_col_tensor)
95  {
96  auto scale_row_window =
97  make_tile_window(scale_row_tensor, context.working_tile.get_tile_distribution());
98  auto scale_col_window =
99  make_tile_window(scale_col_tensor, context.working_tile.get_tile_distribution());
100 
101  const auto scale_row_tile = load_tile(scale_row_window);
102  const auto scale_col_tile = load_tile(scale_col_window);
103 
105  context.working_tile,
106  context.working_tile,
107  scale_row_tile,
108  scale_col_tile);
109 
110  constexpr index_t num_access = SFC::get_num_of_access();
111  if constexpr(iAccess != num_access - 1)
112  {
113  constexpr auto step = SFC::get_forward_step(number<iAccess>{});
114  move_tile_window(scale_row_window, {step.at(number<0>{}), step.at(number<1>{})});
115  move_tile_window(scale_col_window, {step.at(number<0>{}), step.at(number<1>{})});
116  }
117  }
118 };
119 
120 //------------------------------------------------------------------------------
121 // CShuffle problem and base operation definitions
122 //------------------------------------------------------------------------------
123 
126 template <typename AsDataType_,
127  typename BsDataType_,
128  typename DsDataType_,
129  typename AccDataType_,
130  typename ODataType_,
131  typename DsLayout_,
132  typename ELayout_,
133  typename CDElementwise_,
134  index_t kM_,
135  index_t kN_,
136  index_t MWave_,
137  index_t NWave_,
138  index_t MPerXdl_,
139  index_t NPerXdl_,
140  index_t KPerXdl_,
141  bool isCTransposed_,
142  memory_operation_enum MemoryOperation_,
143  index_t kNumWaveGroups_ = 1,
144  bool FixedVectorSize_ = false,
145  index_t VectorSizeC_ = 1,
146  bool TiledMMAPermuteN_ = false,
147  index_t BlockedXDLN_PerWarp_ = 1>
149 {
158  static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size();
159  static constexpr index_t kMPerBlock = kM_;
160  static constexpr index_t kNPerBlock = kN_;
161  static constexpr index_t MWave = MWave_;
162  static constexpr index_t NWave = NWave_;
163  static constexpr index_t MPerXdl = MPerXdl_;
164  static constexpr index_t NPerXdl = NPerXdl_;
165  static constexpr index_t KPerXdl = KPerXdl_;
166  static constexpr index_t isCTransposed = isCTransposed_;
167  static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
168  static constexpr bool FixedVectorSize = FixedVectorSize_;
169  static constexpr index_t VectorSizeC = VectorSizeC_;
170  static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
171  static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
172  static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
173  static constexpr index_t NumDTensor = DsDataType::size();
174 
175  static_assert(NumDTensor == DsLayout::size(),
176  "The size of DsDataType and DsLayout should be the same");
177 };
178 
179 template <typename Problem_, typename Policy_ = void>
181 {
189 
192 
196 
200 
203 
204  using ATypeToUse =
205  std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
206  // Used for weight-only quantization kernel, B would be dequantized to the same data type as A
207  using BTypeToUse =
208  std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
211  static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;
212  static constexpr index_t kBlockSize = Problem::kBlockSize;
213  static constexpr index_t kMPerBlock = Problem::kMPerBlock;
214  static constexpr index_t kNPerBlock = Problem::kNPerBlock;
215  static constexpr index_t MWave = Problem::MWave;
216  static constexpr index_t NWave = Problem::NWave;
217  static constexpr index_t MPerXdl = Problem::MPerXdl;
218  static constexpr index_t NPerXdl = Problem::NPerXdl;
219  static constexpr index_t KPerXdl = Problem::KPerXdl;
220  static constexpr index_t isCTransposed = Problem::isCTransposed;
221  static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
222  static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
223  static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp;
224  static constexpr index_t VectorSizeC = Problem::VectorSizeC;
225  static constexpr index_t MPerIteration = MPerXdl * MWave;
226  static constexpr index_t NPerIteration = NPerXdl * NWave;
227  static constexpr index_t NumDTensor = Problem::NumDTensor;
228 
229  static_assert(NumDTensor == DsLayout::size(),
230  "The size of DsDataType and DsLayout should be the same");
231 
243  {
244  if constexpr(FixedVectorSize)
245  {
246  return VectorSizeC;
247  }
248  constexpr index_t max_vector_size = 16;
249  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
250  {
251  return std::min(static_cast<int>(NPerIteration),
252  static_cast<int>(max_vector_size / sizeof(ODataType)));
253  }
254  else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
255  {
256  return std::min(static_cast<int>(MPerIteration),
257  static_cast<int>(max_vector_size / sizeof(ODataType)));
258  }
259  else
260  {
261  static_assert(false, "Unsupported ELayout!");
262  }
263  }
264 
270  template <index_t I>
272  {
273  constexpr index_t max_vector_size = 16;
274  using DiDataType = remove_cvref_t<std::tuple_element_t<index.value, DsDataType>>;
275  using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
276  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
277  {
278  return std::min(static_cast<int>(NPerIteration),
279  static_cast<int>(max_vector_size / sizeof(DiDataType)));
280  }
281  else if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::ColumnMajor>)
282  {
283  return std::min(static_cast<int>(MPerIteration),
284  static_cast<int>(max_vector_size / sizeof(DiDataType)));
285  }
286  else
287  {
288  static_assert(false, "Unsupported DLayout!");
289  }
290  }
299  static constexpr auto shuffle_tile_tuple = [] {
300  constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size();
301  if constexpr(elem_per_thread >= GetVectorSizeC())
302  {
303  return std::make_tuple(1, 1);
304  }
305  else
306  {
307  constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread;
308  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
309  {
310  static_assert((kMPerBlock % (MPerXdl * MWave) == 0) &&
311  (kMPerBlock % num_xdl_shuffles == 0),
312  "kMPerBlock must be divisible by MPerXdl*MWave and "
313  "num_xdl_shuffles for CShuffleEpilogueStageBase");
314  return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (MPerXdl * MWave)), 1);
315  }
316  else
317  {
318  static_assert((kNPerBlock % (NPerXdl * NWave) == 0) &&
319  (kNPerBlock % num_xdl_shuffles == 0),
320  "kNPerBlock must be divisible by NPerXdl*NWave and "
321  "num_xdl_shuffles for CShuffleEpilogueStageBase");
322  return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (NPerXdl * NWave)));
323  }
324  }
325  }();
326  static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple);
329 
330  static constexpr auto MNPerIterationShuffle = [] {
331  constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle;
332  constexpr index_t n_val = NPerXdl * NWave * NumNXdlPerWavePerShuffle;
333  if constexpr(kMPerBlock % m_val != 0 || kNPerBlock % n_val != 0)
335  else
336  return std::make_tuple(m_val, n_val);
337  }();
338  static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle);
339  static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle);
340 
342  BTypeToUse,
343  AccDataType,
344  MPerXdl,
345  NPerXdl,
346  KPerXdl,
347  isCTransposed>;
348 
349  using CWarpDstr = typename WG::CWarpDstr;
350  using CWarpTensor = typename WG::CWarpTensor;
351  using CWarpDstrEncoding = typename WG::CWarpDstrEncoding;
355 
356  template <typename Problem>
358  {
359  // N is contiguous dimension
360  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
361  {
365  }
366  // M is contiguous dimension
367  else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
368  {
372  }
373  else
374  {
375  static_assert(false, "Unsupported ELayout!");
376  }
377  }
378 
380  {
381  constexpr auto block_outer_dstr_encoding = [] {
382  if constexpr(BlockedXDLN_PerWarp == 1)
383  {
390  sequence<0, 0>>{};
391  }
392  else
393  {
394  constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp;
395  // BlockedLayout
397  sequence<>,
404  }
405  }();
406  constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
407  block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{});
408 
409  return block_dstr_encoding;
410  }
411 
413  {
415  }
416 
421  GetVectorSizeC(),
423  Problem::kNumWaveGroups>;
424 
440  template <typename WorkingTileType,
441  typename LdsBlockType,
442  typename LdsWriteWindowType,
443  typename LdsReadWindowType,
444  typename AuxWindowsType,
445  typename OutTileType>
447  {
448  WorkingTileType working_tile; // Working tile for shuffle operations
449  LdsBlockType lds_block; // LDS block view
450  LdsWriteWindowType lds_write_window; // Window for writing to LDS
451  LdsReadWindowType lds_read_window; // Window for reading from LDS
452  AuxWindowsType aux_windows; // Auxiliary tensor windows (D tensors)
453  OutTileType out_tile; // Output tile
454  };
455 
456  template <typename OutDramWindow, typename AccTile, typename DsDramWindows>
457  CK_TILE_DEVICE auto operator()([[maybe_unused]] OutDramWindow& out_window,
458  [[maybe_unused]] const AccTile& acc_tile,
459  const DsDramWindows& ds_windows,
460  void* p_smem)
461  {
462  static_assert(
463  std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
464  "Currently, the CShuffleEpilogueStageBase only supports the Row Major Output layout");
465 
466  constexpr auto working_tile_distr =
468  auto working_tile = make_static_distributed_tensor<AccDataType>(working_tile_distr);
469 
470  constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
471  auto lds_block = make_tensor_view<address_space_enum::lds>(static_cast<ODataType*>(p_smem),
472  lds_block_desc);
473 
474  auto lds_write_window = make_tile_window(
475  lds_block,
477  {0, 0},
478  working_tile_distr);
479 
480  auto lds_read_window = make_tile_window(
481  lds_block,
483  {0, 0});
484 
485  constexpr auto dram_tile_distribution =
486  TileEncodingPattern::make_2d_static_tile_distribution();
487  auto aux_windows = generate_tuple(
488  [&](auto idx) { return make_tile_window(ds_windows[idx], dram_tile_distribution); },
490 
491  auto out_tile = load_tile(make_tile_window(lds_read_window, dram_tile_distribution));
492 
493  using ContextType = CShuffleContext<decltype(working_tile),
494  decltype(lds_block),
495  decltype(lds_write_window),
496  decltype(lds_read_window),
497  decltype(aux_windows),
498  decltype(out_tile)>;
499 
500  ContextType context;
501  context.working_tile = working_tile;
502  context.lds_block = lds_block;
503  context.lds_write_window = lds_write_window;
504  context.lds_read_window = lds_read_window;
505  context.aux_windows = aux_windows;
506  context.out_tile = out_tile;
507 
508  return context;
509  }
510 };
511 
512 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:457
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:274
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition: tile_elementwise.hpp:23
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename impl::warp_gemm_dispatcher::Dispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition: warp_gemm_dispatcher.hpp:176
@ thread_raked
Thread raked pattern.
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1066
constexpr CK_TILE_HOST_DEVICE auto merge_sequences(Seqs...)
Definition: sequence.hpp:837
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:206
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:157
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:36
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:495
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:1037
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
Context structure for CShuffle epilogue operations.
Definition: cshuffle_epilogue_chainer_ops.hpp:447
OutTileType out_tile
Definition: cshuffle_epilogue_chainer_ops.hpp:453
LdsBlockType lds_block
Definition: cshuffle_epilogue_chainer_ops.hpp:449
LdsWriteWindowType lds_write_window
Definition: cshuffle_epilogue_chainer_ops.hpp:450
LdsReadWindowType lds_read_window
Definition: cshuffle_epilogue_chainer_ops.hpp:451
WorkingTileType working_tile
Definition: cshuffle_epilogue_chainer_ops.hpp:448
AuxWindowsType aux_windows
Definition: cshuffle_epilogue_chainer_ops.hpp:452
Definition: cshuffle_epilogue_chainer_ops.hpp:181
static constexpr index_t NumDTensor
Definition: cshuffle_epilogue_chainer_ops.hpp:227
static constexpr index_t kBlockSize
Definition: cshuffle_epilogue_chainer_ops.hpp:212
static constexpr index_t kNPerBlock
Definition: cshuffle_epilogue_chainer_ops.hpp:214
remove_cvref_t< typename Problem::ODataType > ODataType
Definition: cshuffle_epilogue_chainer_ops.hpp:186
std::conditional_t< ADataTypeIsTuple, remove_cvref_t< AsDataType >, remove_cvref_t< tuple< AsDataType > >> AsDataTypeTuple
Definition: cshuffle_epilogue_chainer_ops.hpp:195
std::conditional_t< BDataTypeIsTuple, remove_cvref_t< BsDataType >, remove_cvref_t< tuple< BsDataType > >> BsDataTypeTuple
Definition: cshuffle_epilogue_chainer_ops.hpp:199
static constexpr index_t BlockedXDLN_PerWarp
Definition: cshuffle_epilogue_chainer_ops.hpp:223
static constexpr index_t NumMXdlPerWavePerShuffle
Definition: cshuffle_epilogue_chainer_ops.hpp:326
static constexpr index_t NWave
Definition: cshuffle_epilogue_chainer_ops.hpp:216
CK_TILE_DEVICE auto operator()([[maybe_unused]] OutDramWindow &out_window, [[maybe_unused]] const AccTile &acc_tile, const DsDramWindows &ds_windows, void *p_smem)
Definition: cshuffle_epilogue_chainer_ops.hpp:457
static constexpr index_t MWave
Definition: cshuffle_epilogue_chainer_ops.hpp:215
static constexpr CK_TILE_HOST_DEVICE auto MakeLdsBlockDescriptor()
Definition: cshuffle_epilogue_chainer_ops.hpp:357
static constexpr auto shuffle_tile_tuple
Shuffle tile configuration parameters.
Definition: cshuffle_epilogue_chainer_ops.hpp:299
static constexpr index_t isCTransposed
Definition: cshuffle_epilogue_chainer_ops.hpp:220
remove_cvref_t< typename Problem::ELayout > ELayout
Definition: cshuffle_epilogue_chainer_ops.hpp:209
std::conditional_t< std::is_same_v< BDataType, pk_int4_t >, ADataType, BDataType > BTypeToUse
Definition: cshuffle_epilogue_chainer_ops.hpp:208
remove_cvref_t< typename Problem::AccDataType > AccDataType
Definition: cshuffle_epilogue_chainer_ops.hpp:185
static constexpr bool FixedVectorSize
Definition: cshuffle_epilogue_chainer_ops.hpp:221
static constexpr auto MNPerIterationShuffle
Definition: cshuffle_epilogue_chainer_ops.hpp:330
static constexpr CK_TILE_DEVICE auto MakeLdsDistributionEncode()
Definition: cshuffle_epilogue_chainer_ops.hpp:379
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: cshuffle_epilogue_chainer_ops.hpp:412
static constexpr index_t VectorSizeC
Definition: cshuffle_epilogue_chainer_ops.hpp:224
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsDataTypeTuple > > ADataType
Definition: cshuffle_epilogue_chainer_ops.hpp:201
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsDataTypeTuple > > BDataType
Definition: cshuffle_epilogue_chainer_ops.hpp:202
static constexpr index_t NPerXdl
Definition: cshuffle_epilogue_chainer_ops.hpp:218
static constexpr bool ADataTypeIsTuple
Definition: cshuffle_epilogue_chainer_ops.hpp:190
static constexpr memory_operation_enum MemoryOperation
Definition: cshuffle_epilogue_chainer_ops.hpp:211
static constexpr index_t MPerIterationShuffle
Definition: cshuffle_epilogue_chainer_ops.hpp:338
remove_cvref_t< typename Problem::DsDataType > DsDataType
Definition: cshuffle_epilogue_chainer_ops.hpp:187
typename WG::CWarpDstr CWarpDstr
Definition: cshuffle_epilogue_chainer_ops.hpp:349
static constexpr index_t NumNXdlPerWavePerShuffle
Definition: cshuffle_epilogue_chainer_ops.hpp:327
static constexpr index_t MPerIteration
Definition: cshuffle_epilogue_chainer_ops.hpp:225
remove_cvref_t< typename Problem::AsDataType > AsDataType
Definition: cshuffle_epilogue_chainer_ops.hpp:183
remove_cvref_t< Problem_ > Problem
Definition: cshuffle_epilogue_chainer_ops.hpp:182
std::conditional_t< std::is_same_v< ADataType, pk_int4_t >, BDataType, ADataType > ATypeToUse
Definition: cshuffle_epilogue_chainer_ops.hpp:205
remove_cvref_t< typename Problem::CDElementwise > CDElementwise
Definition: cshuffle_epilogue_chainer_ops.hpp:210
static constexpr index_t MPerXdl
Definition: cshuffle_epilogue_chainer_ops.hpp:217
static constexpr index_t KPerXdl
Definition: cshuffle_epilogue_chainer_ops.hpp:219
static constexpr bool TiledMMAPermuteN
Definition: cshuffle_epilogue_chainer_ops.hpp:222
remove_cvref_t< typename Problem::DsLayout > DsLayout
Definition: cshuffle_epilogue_chainer_ops.hpp:188
static constexpr index_t NPerIterationShuffle
Definition: cshuffle_epilogue_chainer_ops.hpp:339
static constexpr index_t kMPerBlock
Definition: cshuffle_epilogue_chainer_ops.hpp:213
WarpGemmDispatcher< ATypeToUse, BTypeToUse, AccDataType, MPerXdl, NPerXdl, KPerXdl, isCTransposed > WG
Definition: cshuffle_epilogue_chainer_ops.hpp:347
typename WG::CWarpDstrEncoding CWarpDstrEncoding
Definition: cshuffle_epilogue_chainer_ops.hpp:351
static constexpr index_t NPerIteration
Definition: cshuffle_epilogue_chainer_ops.hpp:226
static constexpr bool BDataTypeIsTuple
Definition: cshuffle_epilogue_chainer_ops.hpp:191
typename WG::CWarpTensor CWarpTensor
Definition: cshuffle_epilogue_chainer_ops.hpp:350
remove_cvref_t< typename Problem::BsDataType > BsDataType
Definition: cshuffle_epilogue_chainer_ops.hpp:184
static constexpr CK_TILE_HOST_DEVICE index_t GetVectorSizeD(number< I > index)
Get the vector store size for Di tensor.
Definition: cshuffle_epilogue_chainer_ops.hpp:271
static constexpr CK_TILE_HOST_DEVICE index_t GetVectorSizeC()
Get the vector store size for C tensor.
Definition: cshuffle_epilogue_chainer_ops.hpp:242
Problem configuration for CShuffle epilogue chainer operations.
Definition: cshuffle_epilogue_chainer_ops.hpp:149
static constexpr index_t VectorSizeC
Definition: cshuffle_epilogue_chainer_ops.hpp:169
static constexpr index_t isCTransposed
Definition: cshuffle_epilogue_chainer_ops.hpp:166
remove_cvref_t< ODataType_ > ODataType
Definition: cshuffle_epilogue_chainer_ops.hpp:153
static constexpr index_t MWave
Definition: cshuffle_epilogue_chainer_ops.hpp:161
remove_cvref_t< AsDataType_ > AsDataType
Definition: cshuffle_epilogue_chainer_ops.hpp:150
remove_cvref_t< CDElementwise_ > CDElementwise
Definition: cshuffle_epilogue_chainer_ops.hpp:157
static constexpr index_t MPerXdl
Definition: cshuffle_epilogue_chainer_ops.hpp:163
static constexpr index_t KPerXdl
Definition: cshuffle_epilogue_chainer_ops.hpp:165
static constexpr index_t NWave
Definition: cshuffle_epilogue_chainer_ops.hpp:162
static constexpr memory_operation_enum MemoryOperation
Definition: cshuffle_epilogue_chainer_ops.hpp:167
remove_cvref_t< AccDataType_ > AccDataType
Definition: cshuffle_epilogue_chainer_ops.hpp:152
static constexpr index_t BlockedXDLN_PerWarp
Definition: cshuffle_epilogue_chainer_ops.hpp:170
static constexpr index_t kNumWaveGroups
Definition: cshuffle_epilogue_chainer_ops.hpp:172
static constexpr bool FixedVectorSize
Definition: cshuffle_epilogue_chainer_ops.hpp:168
static constexpr index_t NPerXdl
Definition: cshuffle_epilogue_chainer_ops.hpp:164
static constexpr index_t kNPerBlock
Definition: cshuffle_epilogue_chainer_ops.hpp:160
remove_cvref_t< DsLayout_ > DsLayout
Definition: cshuffle_epilogue_chainer_ops.hpp:155
remove_cvref_t< BsDataType_ > BsDataType
Definition: cshuffle_epilogue_chainer_ops.hpp:151
static constexpr index_t kBlockSize
Definition: cshuffle_epilogue_chainer_ops.hpp:158
remove_cvref_t< DsDataType_ > DsDataType
Definition: cshuffle_epilogue_chainer_ops.hpp:154
static constexpr index_t kMPerBlock
Definition: cshuffle_epilogue_chainer_ops.hpp:159
static constexpr index_t NumDTensor
Definition: cshuffle_epilogue_chainer_ops.hpp:173
static constexpr bool TiledMMAPermuteN
Definition: cshuffle_epilogue_chainer_ops.hpp:171
remove_cvref_t< ELayout_ > ELayout
Definition: cshuffle_epilogue_chainer_ops.hpp:156
Scale working tile using tensor windows (CShuffle-specific)
Definition: cshuffle_epilogue_chainer_ops.hpp:79
CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow &out_window, [[maybe_unused]] const AccTile &acc_tile, [[maybe_unused]] const AuxWindows &aux_windows, [[maybe_unused]] void *p_smem, IAccess iAccess, Context &context, const ScaleRowTensor &scale_row_tensor, const ScaleColTensor &scale_col_tensor)
Definition: cshuffle_epilogue_chainer_ops.hpp:87
Slice accumulator tile for CShuffle epilogue.
Definition: cshuffle_epilogue_chainer_ops.hpp:39
CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow &out_window, const AccTile &acc_tile, [[maybe_unused]] const AuxWindows &aux_windows, [[maybe_unused]] void *p_smem, IAccess iAccess, Context &context)
Definition: cshuffle_epilogue_chainer_ops.hpp:45
Definition: integral_constant.hpp:13
static constexpr value_type value
Definition: integral_constant.hpp:16
Definition: unary_element_wise_operation.hpp:509
Definition: sequence.hpp:49
Definition: space_filling_curve.hpp:20
Class creating 2D static tile distribution with different load/store patterns.
Definition: static_encoding_pattern.hpp:130
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192