/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp Source File
cshuffle_epilogue.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
10 
11 #include <optional>
12 #include <type_traits>
13 
14 namespace ck_tile {
15 template <typename AsDataType_,
16  typename BsDataType_,
17  typename DsDataType_,
18  typename AccDataType_,
19  typename ODataType_,
20  typename DsLayout_,
21  typename ELayout_,
22  typename CDElementwise_,
23  index_t kM_,
24  index_t kN_,
25  index_t MWave_,
26  index_t NWave_,
27  index_t MPerXdl_,
28  index_t NPerXdl_,
29  index_t KPerXdl_,
30  bool isCTransposed_,
31  memory_operation_enum MemoryOperation_,
32  index_t kNumWaveGroups_ = 1,
33  bool FixedVectorSize_ = false,
34  index_t VectorSizeC_ = 1,
35  bool TiledMMAPermuteN_ = false>
37 {
46  static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size();
47  static constexpr index_t kMPerBlock = kM_;
48  static constexpr index_t kNPerBlock = kN_;
49  static constexpr index_t MWave = MWave_;
50  static constexpr index_t NWave = NWave_;
51  static constexpr index_t MPerXdl = MPerXdl_;
52  static constexpr index_t NPerXdl = NPerXdl_;
53  static constexpr index_t KPerXdl = KPerXdl_;
54  static constexpr index_t isCTransposed = isCTransposed_;
55  static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
56  static constexpr bool FixedVectorSize = FixedVectorSize_;
57  static constexpr index_t VectorSizeC = VectorSizeC_;
58  static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
59  static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
60  static constexpr index_t NumDTensor = DsDataType::size();
61 
62  static_assert(NumDTensor == DsLayout::size(),
63  "The size of DsDataType and DsLayout should be the same");
64 };
65 
66 template <typename Problem_, typename Policy_ = void>
68 {
76 
79 
83 
87 
90 
91  using ATypeToUse =
92  std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
93  // Used for weight-only quantization kernel, B would be dequantized to the same data type as A
94  using BTypeToUse =
95  std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
98  static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;
99  static constexpr index_t kBlockSize = Problem::kBlockSize;
100  static constexpr index_t kMPerBlock = Problem::kMPerBlock;
101  static constexpr index_t kNPerBlock = Problem::kNPerBlock;
102  static constexpr index_t MWave = Problem::MWave;
103  static constexpr index_t NWave = Problem::NWave;
104  static constexpr index_t MPerXdl = Problem::MPerXdl;
105  static constexpr index_t NPerXdl = Problem::NPerXdl;
106  static constexpr index_t KPerXdl = Problem::KPerXdl;
107  static constexpr index_t isCTransposed = Problem::isCTransposed;
108  static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
109  static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
110  static constexpr index_t VectorSizeC = Problem::VectorSizeC;
111  static constexpr index_t MPerIteration = MPerXdl * MWave;
112  static constexpr index_t NPerIteration = NPerXdl * NWave;
113  static constexpr index_t NumDTensor = Problem::NumDTensor;
114  static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave);
115  static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave);
116 
117  static_assert(NumDTensor == DsLayout::size(),
118  "The size of DsDataType and DsLayout should be the same");
130  {
131  if constexpr(FixedVectorSize)
132  {
133  return VectorSizeC;
134  }
135  constexpr index_t max_vector_size = 16;
136  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
137  {
138  return std::min(static_cast<int>(NPerIteration),
139  static_cast<int>(max_vector_size / sizeof(ODataType)));
140  }
141  else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
142  {
143  return std::min(static_cast<int>(MPerIteration),
144  static_cast<int>(max_vector_size / sizeof(ODataType)));
145  }
146  else
147  {
148  static_assert(false, "Unsupported ELayout!");
149  }
150  }
151 
157  template <index_t I>
159  {
160  constexpr index_t max_vector_size = 16;
161  using DiDataType = remove_cvref_t<std::tuple_element_t<index.value, DsDataType>>;
162  using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
163  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
164  {
165  return std::min(static_cast<int>(NPerIteration),
166  static_cast<int>(max_vector_size / sizeof(DiDataType)));
167  }
168  else if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::ColumnMajor>)
169  {
170  return std::min(static_cast<int>(MPerIteration),
171  static_cast<int>(max_vector_size / sizeof(DiDataType)));
172  }
173  else
174  {
175  static_assert(false, "Unsupported DLayout!");
176  }
177  return max_vector_size / sizeof(DiDataType);
178  }
187  static constexpr auto shuffle_tile_tuple = [] {
188  constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size();
189  if constexpr(elem_per_thread >= GetVectorSizeC())
190  {
191  return std::make_tuple(1, 1);
192  }
193  else
194  {
195  constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread;
196  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
197  {
198  static_assert((kMPerBlock % (MPerXdl * MWave) == 0) &&
199  (kMPerBlock % num_xdl_shuffles == 0),
200  "kMPerBlock must be divisible by MPerXdl*MWave and "
201  "num_xdl_shuffles for CShuffleEpilogue");
202  return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (MPerXdl * MWave)), 1);
203  }
204  else
205  {
206  static_assert((kNPerBlock % (NPerXdl * NWave) == 0) &&
207  (kNPerBlock % num_xdl_shuffles == 0),
208  "kNPerBlock must be divisible by NPerXdl*NWave and "
209  "num_xdl_shuffles for CShuffleEpilogue");
210  return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (NPerXdl * NWave)));
211  }
212  }
213  }();
214  static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple);
215  static constexpr index_t NumNXdlPerWavePerShuffle = std::get<1>(shuffle_tile_tuple);
216 
217  static constexpr auto MNPerIterationShuffle = [] {
218  constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle;
219  constexpr index_t n_val = NPerXdl * NWave * NumNXdlPerWavePerShuffle;
220  if constexpr(kMPerBlock % m_val != 0 || kNPerBlock % n_val != 0)
222  else
223  return std::make_tuple(m_val, n_val);
224  }();
225  static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle);
226  static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle);
227 
229  BTypeToUse,
230  AccDataType,
231  MPerXdl,
232  NPerXdl,
233  KPerXdl,
234  isCTransposed>;
235 
236  using CWarpDstr = typename WG::CWarpDstr;
237  using CWarpTensor = typename WG::CWarpTensor;
238  using CWarpDstrEncoding = typename WG::CWarpDstrEncoding;
242 
243  template <typename Problem>
245  {
246  // N is contiguous dimension
247  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
248  {
252  }
253  // M is contiguous dimension
254  else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
255  {
259  }
260  else
261  {
262  static_assert(false, "Unsupported ELayout!");
263  }
264  }
265 
267  {
268  constexpr auto block_outer_dstr_encoding =
275  sequence<0, 0>>{};
276  constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
277  block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{});
278 
279  return block_dstr_encoding;
280  }
281 
283  {
285  }
286 
287  template <index_t iAccess, typename LdsTile, typename ScaleM, typename ScaleN>
288  CK_TILE_DEVICE void
289  scale_tile(LdsTile& lds_tile, ScaleM& scale_m_window, ScaleN& scale_n_window)
290  {
291  // Check if scales are EmptyScale first (no scaling needed)
292  if constexpr(std::is_same_v<ScaleM, EmptyScale> && std::is_same_v<ScaleN, EmptyScale>)
293  {
294  // No scaling needed - this is a no-op
295  }
296  // Check if scales are scalar AccDataType
297  else if constexpr(std::is_same_v<ScaleM, AccDataType> &&
298  std::is_same_v<ScaleN, AccDataType>)
299  {
300  // Handle scalar scales
301  const AccDataType scale_m = scale_m_window;
302  const AccDataType scale_n = scale_n_window;
303  tile_elementwise_inout([&](auto& element) { element = element * scale_m * scale_n; },
304  lds_tile);
305  }
306  // Otherwise, assume they are tile windows that can be loaded
307  else
308  {
309  // Load tiles
310  const auto scale_m_tile = load_tile(scale_m_window);
311  const auto scale_n_tile = load_tile(scale_n_window);
312 
313  // Compute element-wise product in-place i.e. lds_tile = lds_tile * scale_m * scale_n
315  element_wise::MultiDMultiply{}, lds_tile, lds_tile, scale_m_tile, scale_n_tile);
316 
317  // Move scale windows
318  constexpr index_t num_access = SFC::get_num_of_access();
319  if constexpr(iAccess != num_access - 1)
320  {
321  constexpr auto step = SFC::get_forward_step(number<iAccess>{});
322 
323  move_tile_window(scale_m_window, {step.at(number<0>{}), step.at(number<1>{})});
324  move_tile_window(scale_n_window, {step.at(number<0>{}), step.at(number<1>{})});
325  }
326  }
327  }
328 
329  template <index_t iAccess, typename OAccTile, typename LdsTile>
330  CK_TILE_DEVICE void slice_acc_tile(const OAccTile& o_acc_tile, LdsTile& lds_tile)
331  {
332  constexpr auto idx_y_start = SFC::get_index(number<iAccess>{});
333 
334  constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
335  constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
336  constexpr auto c_warp_y_lengths =
337  to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
338  constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
339 
340  lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
343  c_warp_y_index_zeros),
345  c_warp_y_lengths));
346  }
347 
348  template <typename LdsTile, typename InLdsWindow>
349  CK_TILE_DEVICE void cast_lds_tile(LdsTile& lds_tile, InLdsWindow& in_lds_window)
350  {
351  const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
352 
353  store_tile(in_lds_window, c_warptile_in_tensor_casted);
354  }
355 
356  template <typename DramWindows, typename COutTensor>
357  CK_TILE_DEVICE void apply_d_tensors(DramWindows& d_dram_windows, COutTensor& c_out_tensor)
358  {
359  const auto ds_tensor = generate_tuple(
360  [&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
361 
362  const auto c_ds_tiles = concat_tuple_of_reference(
363  tie(c_out_tensor, c_out_tensor),
364  generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; },
365  number<NumDTensor>{}));
366 
367  tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
368  }
369 
370  template <typename OutDramWindow, typename COutTensor>
371  CK_TILE_DEVICE void store_to_dram(OutDramWindow& out_dram_window,
372  const COutTensor& c_out_tensor)
373  {
374  if constexpr(MemoryOperation == memory_operation_enum::set)
375  {
376  store_tile(out_dram_window, c_out_tensor);
377  }
378  else
379  {
380  update_tile(out_dram_window, c_out_tensor);
381  }
382  }
383 
387  template <index_t iAccess, typename OutDramWindow, typename DDramWindows>
388  CK_TILE_DEVICE void move_windows(OutDramWindow& out_dram_window, DDramWindows& d_dram_windows)
389  {
390  constexpr index_t num_access = SFC::get_num_of_access();
391  if constexpr(iAccess != num_access - 1)
392  {
393  constexpr auto step = SFC::get_forward_step(number<iAccess>{});
394 
395  // move the output dram window
396  move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
397 
398  // move windows for each of the D matrices (inputs for element-wise)
399  static_for<0, NumDTensor, 1>{}([&](auto idx) {
400  move_tile_window(d_dram_windows[idx], {step.at(number<0>{}), step.at(number<1>{})});
401  });
402  }
403  }
404 
405  // TODO: Check if there would be nicer ways to overload rather than with EmptyScale or nullptr_t
406  struct EmptyScale
407  {
408  };
409 
410  template <typename, typename = void>
412  {
413  using DataType = float;
414  };
415 
416  template <typename T>
417  struct ScaleDataType<T, std::void_t<typename T::DataType>>
418  {
419  using DataType = typename T::DataType;
420  };
421 
422  template <typename ODramWindow,
423  typename OAccTile,
424  typename DsDramWindows,
425  typename ScaleM = EmptyScale,
426  typename ScaleN = EmptyScale,
427  int EnablePermuateN_ = TiledMMAPermuteN,
428  std::enable_if_t<EnablePermuateN_, int> = 0>
429  CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
430  const OAccTile& o_acc_tile,
431  const DsDramWindows& ds_dram_windows,
432  void* /*p_smem*/,
433  const ScaleM& scale_m = {},
434  const ScaleN& scale_n = {})
435  {
436  constexpr int kM0 = MWave;
437  constexpr int kM2 = 4;
438  constexpr int kM1 = MPerXdl / kM2;
439 
440  constexpr int kN0 = NWave;
441  constexpr int kN1 = NPerXdl;
442  constexpr int kN2 = NRepeat;
443 
444  using IntrThreadShuffleEncode =
445  tile_distribution_encoding<sequence<>,
446  tuple<sequence<kM0, kM1, kM2>, sequence<kN0, kN1, kN2>>,
447  tuple<sequence<1, 2>, sequence<1, 2>>,
448  tuple<sequence<0, 0>, sequence<1, 1>>,
449  sequence<1, 2>,
450  sequence<2, 2>>;
451  constexpr auto dram_tile_distribution =
452  make_static_tile_distribution(IntrThreadShuffleEncode{});
453 
454  auto d_dram_windows = generate_tuple(
455  [&](auto idx) {
456  return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
457  },
458  number<NumDTensor>{});
459 
460  constexpr auto c_warp_y_lengths =
461  to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
462  constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
463 
464  auto shuffle_acc = make_static_distributed_tensor<AccDataType>(dram_tile_distribution);
465  auto c_out_tensor = make_static_distributed_tensor<ODataType>(dram_tile_distribution);
466 
467  // Optional scales (must share the same distribution to match per-thread indexing)
468  constexpr bool has_scales =
470  constexpr bool has_scalar_scales =
471  std::is_same_v<ScaleM, AccDataType> && std::is_same_v<ScaleN, AccDataType>;
472 
473  // Tiles to hold row/col scales when present
474  using SMType = typename ScaleDataType<ScaleM>::DataType;
475  using SNType = typename ScaleDataType<ScaleN>::DataType;
476 
477  auto sm_tile = make_static_distributed_tensor<SMType>(dram_tile_distribution);
478  auto sn_tile = make_static_distributed_tensor<SNType>(dram_tile_distribution);
479 
480  // Build windows only if non-scalar scales are provided
481  auto scale_m_window = [&]() {
482  if constexpr(has_scales && !has_scalar_scales)
483  {
484  return make_tile_window(scale_m, dram_tile_distribution);
485  }
486  else
487  {
488  return EmptyScale{};
489  }
490  }();
491  auto scale_n_window = [&]() {
492  if constexpr(has_scales && !has_scalar_scales)
493  {
494  return make_tile_window(scale_n, dram_tile_distribution);
495  }
496  else
497  {
498  return EmptyScale{};
499  }
500  }();
501 
502  static_for<0, MRepeat, 1>{}([&](auto mIter) {
503  // Slice accumulators for this M repeat into the permuted layout
504  shuffle_acc.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
505  merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
506  merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
507 
508  // If non-scalar scales provided, load them with identical distribution
509  if constexpr(has_scales && !has_scalar_scales)
510  {
511  sm_tile = load_tile(scale_m_window); // row scales in permuted layout
512  sn_tile = load_tile(scale_n_window); // col scales in permuted layout
513  }
514 
515  // Pack 4 “rows per lane” as you already do
516  static_for<0, NRepeat, 1>{}([&](auto n_idx) {
517  // source indices in shuffle_acc: (n_idx * product(Y) + row)
518  const index_t base = n_idx * c_warp_y_lengths.product();
519 
520  // local lambda to fuse scale (if present) and convert
521  auto emit = [&](index_t out_idx, index_t src_row) {
522  AccDataType v = shuffle_acc.get_thread_buffer()[base + src_row];
523 
524  if constexpr(has_scalar_scales)
525  {
526  v = static_cast<AccDataType>(v * scale_m * scale_n);
527  }
528  else if constexpr(has_scales && !has_scalar_scales)
529  {
530  // same linear index mapping on the permuted distribution
531  const auto s_m = static_cast<float>(sm_tile.get_thread_buffer()[out_idx]);
532  const auto s_n = static_cast<float>(sn_tile.get_thread_buffer()[out_idx]);
533  v = static_cast<AccDataType>(v * s_m * s_n);
534  }
535 
536  c_out_tensor.get_thread_buffer()[out_idx] = type_convert<ODataType>(v);
537  };
538 
539  // Your current packing pattern (rows 0..3, spaced by NRepeat)
540  emit(n_idx + 0 * NRepeat, 0);
541  emit(n_idx + 1 * NRepeat, 1);
542  emit(n_idx + 2 * NRepeat, 2);
543  emit(n_idx + 3 * NRepeat, 3);
544  });
545 
546  // store/update
547  if constexpr(MemoryOperation == memory_operation_enum::set)
548  {
549  store_tile(out_dram_window, c_out_tensor);
550  }
551  else
552  {
553  update_tile(out_dram_window, c_out_tensor);
554  }
555 
556  // advance output (and any D-tensors) by one MPerXdl*MWave chunk
557  move_tile_window(out_dram_window, {number<MPerXdl * MWave>{}, number<0>{}});
558  static_for<0, NumDTensor, 1>{}([&](auto idx) {
559  move_tile_window(d_dram_windows[idx], {number<MPerXdl * MWave>{}, number<0>{}});
560  });
561  });
562  }
563 
564  template <typename ODramWindow,
565  typename OAccTile,
566  typename DsDramWindows,
567  typename ScaleM = EmptyScale,
568  typename ScaleN = EmptyScale,
569  int EnablePermuateN_ = TiledMMAPermuteN,
570  std::enable_if_t<!EnablePermuateN_, int> = 0>
571  CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
572  const OAccTile& o_acc_tile,
573  const DsDramWindows& ds_dram_windows,
574  void* p_smem,
575  const ScaleM& scale_m = {},
576  const ScaleN& scale_n = {})
577  {
578  constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
579 
580  auto lds_tile = make_static_distributed_tensor<AccDataType>(LdsTileDistr);
581 
582  constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
583  auto o_lds_block = make_tensor_view<address_space_enum::lds>(
584  static_cast<ODataType*>(p_smem), lds_block_desc);
585 
586  auto in_lds_window = make_tile_window(
587  o_lds_block,
588  make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
589  {0, 0},
590  LdsTileDistr);
591 
592  auto out_lds_window = make_tile_window(
593  o_lds_block,
594  make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
595  {0, 0});
596 
597  constexpr index_t num_access = SFC::get_num_of_access();
598 
599  static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
600  "Currently, the CShuffle Epilogue only supports the Row Major Output layout");
601 
602  using TileEncodingPattern =
603  tile_distribution_encoding_pattern_2d<kBlockSize,
606  GetVectorSizeC(),
608  Problem::kNumWaveGroups>;
609  constexpr auto dram_tile_distribution =
610  TileEncodingPattern::make_2d_static_tile_distribution();
611 
612  auto d_dram_windows = generate_tuple(
613  [&](auto idx) {
614  return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
615  },
616  number<NumDTensor>{});
617 
618  constexpr bool has_scales =
619  !std::is_same_v<ScaleM, EmptyScale> && !std::is_same_v<ScaleN, EmptyScale>;
620  constexpr bool has_scalar_scales =
621  std::is_same_v<ScaleM, AccDataType> && std::is_same_v<ScaleN, AccDataType>;
622  auto scale_m_window = [&]() {
623  if constexpr(has_scalar_scales)
624  {
625  return scale_m;
626  }
627  else if constexpr(has_scales)
628  {
629  return make_tile_window(scale_m, lds_tile.get_tile_distribution());
630  }
631  else
632  {
633  return EmptyScale{};
634  }
635  }();
636  auto scale_n_window = [&]() {
637  if constexpr(has_scalar_scales)
638  {
639  return scale_n;
640  }
641  else if constexpr(has_scales)
642  {
643  return make_tile_window(scale_n, lds_tile.get_tile_distribution());
644  }
645  else
646  {
647  return EmptyScale{};
648  }
649  }();
650 
651  static_for<0, num_access, 1>{}([&](auto iAccess) {
652  block_sync_lds();
653  slice_acc_tile<iAccess>(o_acc_tile, lds_tile);
654 
655  if constexpr(has_scales)
656  {
657  scale_tile<iAccess>(lds_tile, scale_m_window, scale_n_window);
658  }
659 
660  cast_lds_tile(lds_tile, in_lds_window);
661  block_sync_lds();
662 
663  auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
664 
665  apply_d_tensors(d_dram_windows, c_out_tensor);
666  store_to_dram(out_dram_window, c_out_tensor);
667  move_windows<iAccess>(out_dram_window, d_dram_windows);
668  });
669  }
670 };
671 } // namespace ck_tile
CK_TILE_DEVICE void block_sync_lds()
Definition: arch.hpp:192
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:457
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:274
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition: tile_elementwise.hpp:23
constexpr tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:376
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto generate_tie(F &&f, number< N >)
Definition: tuple.hpp:435
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
CK_TILE_DEVICE auto tile_elementwise_inout_unpack(const InElementFunc &in_element_func, const Tuple &t, std::index_sequence< I... >)
Template function that "unpacks" a tuple and applies an element-wise operation.
Definition: tile_elementwise.hpp:71
@ thread_raked
Thread raked pattern.
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1055
constexpr CK_TILE_HOST_DEVICE auto merge_sequences(Seqs...)
Definition: sequence.hpp:826
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto concat_tuple_of_reference(const tuple< X &... > &tx, const tuple< Y &... > &ty)
Definition: tuple.hpp:443
CK_TILE_DEVICE void update_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: update_tile.hpp:22
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition: warp_gemm_dispatcher.hpp:184
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:1026
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Definition: cshuffle_epilogue.hpp:407
typename T::DataType DataType
Definition: cshuffle_epilogue.hpp:419
Definition: cshuffle_epilogue.hpp:412
float DataType
Definition: cshuffle_epilogue.hpp:413
Definition: cshuffle_epilogue.hpp:68
static constexpr index_t kBlockSize
Definition: cshuffle_epilogue.hpp:99
CK_TILE_DEVICE void scale_tile(LdsTile &lds_tile, ScaleM &scale_m_window, ScaleN &scale_n_window)
Definition: cshuffle_epilogue.hpp:289
std::conditional_t< ADataTypeIsTuple, remove_cvref_t< AsDataType >, remove_cvref_t< tuple< AsDataType > >> AsDataTypeTuple
Definition: cshuffle_epilogue.hpp:82
static constexpr index_t NRepeat
Definition: cshuffle_epilogue.hpp:115
CK_TILE_DEVICE void slice_acc_tile(const OAccTile &o_acc_tile, LdsTile &lds_tile)
Definition: cshuffle_epilogue.hpp:330
static constexpr CK_TILE_HOST_DEVICE auto MakeLdsBlockDescriptor()
Definition: cshuffle_epilogue.hpp:244
static constexpr index_t MRepeat
Definition: cshuffle_epilogue.hpp:114
typename WG::CWarpTensor CWarpTensor
Definition: cshuffle_epilogue.hpp:237
typename WG::CWarpDstrEncoding CWarpDstrEncoding
Definition: cshuffle_epilogue.hpp:238
remove_cvref_t< typename Problem::AsDataType > AsDataType
Definition: cshuffle_epilogue.hpp:70
remove_cvref_t< Problem_ > Problem
Definition: cshuffle_epilogue.hpp:69
static constexpr index_t MPerXdl
Definition: cshuffle_epilogue.hpp:104
static constexpr bool FixedVectorSize
Definition: cshuffle_epilogue.hpp:108
static constexpr CK_TILE_HOST_DEVICE index_t GetVectorSizeD(number< I > index)
Get the vector store size for Di tensor.
Definition: cshuffle_epilogue.hpp:158
remove_cvref_t< typename Problem::ODataType > ODataType
Definition: cshuffle_epilogue.hpp:73
CK_TILE_DEVICE void store_to_dram(OutDramWindow &out_dram_window, const COutTensor &c_out_tensor)
Definition: cshuffle_epilogue.hpp:371
static constexpr bool ADataTypeIsTuple
Definition: cshuffle_epilogue.hpp:77
static constexpr index_t kNPerBlock
Definition: cshuffle_epilogue.hpp:101
remove_cvref_t< typename Problem::ELayout > ELayout
Definition: cshuffle_epilogue.hpp:96
static constexpr memory_operation_enum MemoryOperation
Definition: cshuffle_epilogue.hpp:98
static constexpr bool TiledMMAPermuteN
Definition: cshuffle_epilogue.hpp:109
static constexpr bool BDataTypeIsTuple
Definition: cshuffle_epilogue.hpp:78
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsDataTypeTuple > > BDataType
Definition: cshuffle_epilogue.hpp:89
remove_cvref_t< typename Problem::DsLayout > DsLayout
Definition: cshuffle_epilogue.hpp:75
static constexpr CK_TILE_DEVICE auto MakeLdsDistributionEncode()
Definition: cshuffle_epilogue.hpp:266
static constexpr index_t MPerIteration
Definition: cshuffle_epilogue.hpp:111
static constexpr auto MNPerIterationShuffle
Definition: cshuffle_epilogue.hpp:217
static constexpr index_t isCTransposed
Definition: cshuffle_epilogue.hpp:107
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: cshuffle_epilogue.hpp:282
CK_TILE_DEVICE void apply_d_tensors(DramWindows &d_dram_windows, COutTensor &c_out_tensor)
Definition: cshuffle_epilogue.hpp:357
static constexpr index_t MWave
Definition: cshuffle_epilogue.hpp:102
static constexpr CK_TILE_HOST_DEVICE index_t GetVectorSizeC()
Get the vector store size for C tensor.
Definition: cshuffle_epilogue.hpp:129
static constexpr index_t VectorSizeC
Definition: cshuffle_epilogue.hpp:110
remove_cvref_t< typename Problem::DsDataType > DsDataType
Definition: cshuffle_epilogue.hpp:74
std::conditional_t< std::is_same_v< BDataType, pk_int4_t >, ADataType, BDataType > BTypeToUse
Definition: cshuffle_epilogue.hpp:95
CK_TILE_DEVICE void move_windows(OutDramWindow &out_dram_window, DDramWindows &d_dram_windows)
Move both the output and D tensors windows for the next access.
Definition: cshuffle_epilogue.hpp:388
remove_cvref_t< typename Problem::CDElementwise > CDElementwise
Definition: cshuffle_epilogue.hpp:97
static constexpr index_t NPerIterationShuffle
Definition: cshuffle_epilogue.hpp:226
remove_cvref_t< typename Problem::AccDataType > AccDataType
Definition: cshuffle_epilogue.hpp:72
CK_TILE_DEVICE auto operator()(ODramWindow &out_dram_window, const OAccTile &o_acc_tile, const DsDramWindows &ds_dram_windows, void *, const ScaleM &scale_m={}, const ScaleN &scale_n={})
Definition: cshuffle_epilogue.hpp:429
static constexpr index_t NumDTensor
Definition: cshuffle_epilogue.hpp:113
static constexpr index_t KPerXdl
Definition: cshuffle_epilogue.hpp:106
CK_TILE_DEVICE auto operator()(ODramWindow &out_dram_window, const OAccTile &o_acc_tile, const DsDramWindows &ds_dram_windows, void *p_smem, const ScaleM &scale_m={}, const ScaleN &scale_n={})
Definition: cshuffle_epilogue.hpp:571
static constexpr index_t NumMXdlPerWavePerShuffle
Definition: cshuffle_epilogue.hpp:214
std::conditional_t< BDataTypeIsTuple, remove_cvref_t< BsDataType >, remove_cvref_t< tuple< BsDataType > >> BsDataTypeTuple
Definition: cshuffle_epilogue.hpp:86
static constexpr index_t NumNXdlPerWavePerShuffle
Definition: cshuffle_epilogue.hpp:215
remove_cvref_t< typename Problem::BsDataType > BsDataType
Definition: cshuffle_epilogue.hpp:71
WarpGemmDispatcher< ATypeToUse, BTypeToUse, AccDataType, MPerXdl, NPerXdl, KPerXdl, isCTransposed > WG
Definition: cshuffle_epilogue.hpp:234
static constexpr index_t MPerIterationShuffle
Definition: cshuffle_epilogue.hpp:225
CK_TILE_DEVICE void cast_lds_tile(LdsTile &lds_tile, InLdsWindow &in_lds_window)
Definition: cshuffle_epilogue.hpp:349
static constexpr auto shuffle_tile_tuple
Shuffle tile configuration parameters.
Definition: cshuffle_epilogue.hpp:187
static constexpr index_t NWave
Definition: cshuffle_epilogue.hpp:103
static constexpr index_t kMPerBlock
Definition: cshuffle_epilogue.hpp:100
static constexpr index_t NPerIteration
Definition: cshuffle_epilogue.hpp:112
static constexpr index_t NPerXdl
Definition: cshuffle_epilogue.hpp:105
typename WG::CWarpDstr CWarpDstr
Definition: cshuffle_epilogue.hpp:236
std::conditional_t< std::is_same_v< ADataType, pk_int4_t >, BDataType, ADataType > ATypeToUse
Definition: cshuffle_epilogue.hpp:92
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsDataTypeTuple > > ADataType
Definition: cshuffle_epilogue.hpp:88
Definition: cshuffle_epilogue.hpp:37
remove_cvref_t< AccDataType_ > AccDataType
Definition: cshuffle_epilogue.hpp:40
static constexpr index_t isCTransposed
Definition: cshuffle_epilogue.hpp:54
static constexpr index_t MPerXdl
Definition: cshuffle_epilogue.hpp:51
remove_cvref_t< CDElementwise_ > CDElementwise
Definition: cshuffle_epilogue.hpp:45
static constexpr bool TiledMMAPermuteN
Definition: cshuffle_epilogue.hpp:58
static constexpr index_t KPerXdl
Definition: cshuffle_epilogue.hpp:53
remove_cvref_t< AsDataType_ > AsDataType
Definition: cshuffle_epilogue.hpp:38
static constexpr index_t VectorSizeC
Definition: cshuffle_epilogue.hpp:57
static constexpr index_t kMPerBlock
Definition: cshuffle_epilogue.hpp:47
static constexpr index_t MWave
Definition: cshuffle_epilogue.hpp:49
static constexpr bool FixedVectorSize
Definition: cshuffle_epilogue.hpp:56
static constexpr index_t NumDTensor
Definition: cshuffle_epilogue.hpp:60
remove_cvref_t< ODataType_ > ODataType
Definition: cshuffle_epilogue.hpp:41
static constexpr index_t NPerXdl
Definition: cshuffle_epilogue.hpp:52
static constexpr index_t kNPerBlock
Definition: cshuffle_epilogue.hpp:48
remove_cvref_t< ELayout_ > ELayout
Definition: cshuffle_epilogue.hpp:44
remove_cvref_t< DsDataType_ > DsDataType
Definition: cshuffle_epilogue.hpp:42
static constexpr index_t NWave
Definition: cshuffle_epilogue.hpp:50
static constexpr index_t kBlockSize
Definition: cshuffle_epilogue.hpp:46
static constexpr memory_operation_enum MemoryOperation
Definition: cshuffle_epilogue.hpp:55
remove_cvref_t< DsLayout_ > DsLayout
Definition: cshuffle_epilogue.hpp:43
remove_cvref_t< BsDataType_ > BsDataType
Definition: cshuffle_epilogue.hpp:39
static constexpr index_t kNumWaveGroups
Definition: cshuffle_epilogue.hpp:59
Definition: integral_constant.hpp:13
static constexpr value_type value
Definition: integral_constant.hpp:16
Definition: unary_element_wise_operation.hpp:484
Definition: sequence.hpp:49
Definition: space_filling_curve.hpp:20
static constexpr CK_TILE_HOST_DEVICE auto get_forward_step(number< AccessIdx1d >)
Definition: space_filling_curve.hpp:70
static constexpr CK_TILE_HOST_DEVICE auto get_index(number< AccessIdx1d >)
Definition: space_filling_curve.hpp:158
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_access()
Definition: space_filling_curve.hpp:46
Definition: functional.hpp:43
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192