/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp Source File
flatmm_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
12 
13 namespace ck_tile {
14 
15 template <index_t NumDTensor = 0>
17 {
19  CK_TILE_HOST FlatmmHostArgs(const void* a_ptr_,
20  const void* b_ptr_,
21  const std::array<const void*, NumDTensor>& ds_ptr_,
22  void* e_ptr_,
23  index_t k_batch_,
24  index_t M_,
25  index_t N_,
26  index_t K_,
27  index_t stride_A_,
28  index_t stride_B_,
29  const std::array<index_t, NumDTensor>& stride_Ds_,
30  index_t stride_E_)
31  : a_ptr(a_ptr_),
32  b_ptr(b_ptr_),
33  ds_ptr(ds_ptr_),
34  e_ptr(e_ptr_),
35  M(M_),
36  N(N_),
37  K(K_),
38  stride_A(stride_A_),
39  stride_B(stride_B_),
40  stride_Ds(stride_Ds_),
41  stride_E(stride_E_),
42  k_batch(k_batch_)
43  {
44  }
45 
46  const void* a_ptr;
47  const void* b_ptr;
48  const std::array<const void*, NumDTensor> ds_ptr;
49  union
50  {
51  void* e_ptr;
52  void* c_ptr;
53  };
59  const std::array<index_t, NumDTensor> stride_Ds;
60  union
61  {
64  };
65 
67 };
68 
69 template <index_t NumDTensor = 0>
71 {
72  const void* a_ptr;
73  // const void* b_shuffle_ptr;
74  const void* b_ptr;
75  const std::array<const void*, NumDTensor> ds_ptr;
76  void* e_ptr;
82  std::array<index_t, NumDTensor> stride_Ds;
85 };
86 
87 template <typename TilePartitioner_, typename FlatmmPipeline_, typename EpiloguePipeline_>
89 {
100  static constexpr index_t kBlockSize = FlatmmPipeline::BlockSize;
101 
104  // Below type is actually accumulation data type - the output of block GEMM.
106 
107  static constexpr index_t NumDTensor = DsDataType::size();
108 
109  static constexpr auto I0 = number<0>();
110  static constexpr auto I1 = number<1>();
111  static constexpr auto I2 = number<2>();
112  static constexpr auto I3 = number<3>();
113 
114  static_assert(DsLayout::size() == DsDataType::size(),
115  "The size of DsLayout and DsDataType should be the same");
116  using KernelArgs = FlatmmKernelArgs<DsLayout::size()>;
117 
118  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
119  {
120  // clang-format off
121  return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
122  // clang-format on
123  }
124 
125  CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
126  {
127  return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
128  }
129 
130  CK_TILE_HOST static constexpr auto BlockSize()
131  {
132  return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
133  }
134 
135  CK_TILE_HOST static constexpr KernelArgs
137  {
138  return KernelArgs{hostArgs.a_ptr,
139  hostArgs.b_ptr,
140  hostArgs.ds_ptr,
141  hostArgs.e_ptr,
142  hostArgs.M,
143  hostArgs.N,
144  hostArgs.K,
145  hostArgs.stride_A,
146  hostArgs.stride_B,
147  hostArgs.stride_Ds,
148  hostArgs.stride_E,
149  hostArgs.k_batch};
150  }
151 
153  {
154  return max(FlatmmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
155  }
156 
158  {
159  __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
160  {
161  constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
162  const index_t K_t = kargs.k_batch * K1;
163  const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
164 
165  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
166  {
167  a_k_split_offset = k_id * KRead;
168  }
169  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
170  {
171  a_k_split_offset = k_id * KRead * kargs.stride_A;
172  }
173 
174  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
175  {
176  b_k_split_offset = k_id * KRead * kargs.stride_B;
177  }
178  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
179  {
180  b_k_split_offset = k_id * KRead;
181  }
182 
183  if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
184  {
185  splitted_k = KRead;
186  }
187  else
188  {
189  splitted_k = kargs.K - KRead * (kargs.k_batch - 1);
190  }
191  }
192 
196  };
197 
198  CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs)
199  {
200  if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
202  {
203  if(kargs.k_batch != 1)
204  {
205  std::cerr << "Conditions not met for Kbatch >1 !" << std::endl;
206  return false;
207  }
208  }
209 
210  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
211  {
212  if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false)
213  {
214  std::cerr << "Can't support K that is not a multiple of KPerBlock"
215  " without padding!"
216  << std::endl;
217  return false;
218  }
219  if(kargs.K % FlatmmPipeline::GetVectorSizeA() != 0)
220  {
221  std::cerr << "K is not a multiple of vector load size for A tensor!" << std::endl;
222  return false;
223  }
224  }
225  else
226  {
227  if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
228  {
229  std::cerr << "Can't support M that is not a multiple of MPerBlock"
230  " without padding!"
231  << std::endl;
232  return false;
233  }
234  if(kargs.M % FlatmmPipeline::GetVectorSizeA() != 0)
235  {
236  std::cerr << "M is not a multiple of vector load size for A tensor!" << std::endl;
237  return false;
238  }
239  }
240 
241  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
242  {
243  if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
244  {
245  std::cerr << "Can't support N that is not a multiple of NPerBlock"
246  " without padding!"
247  << std::endl;
248  return false;
249  }
250  if(kargs.N % FlatmmPipeline::GetVectorSizeB() != 0)
251  {
252  std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl;
253  return false;
254  }
255  }
256  else
257  {
258  if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false)
259  {
260  std::cerr << "Can't support K that is not a multiple of KPerBlock"
261  " without padding!"
262  << std::endl;
263  return false;
264  }
265  if(kargs.K % FlatmmPipeline::GetVectorSizeB() != 0)
266  {
267  std::cerr << "K is not a multiple of vector load size for B tensor!" << std::endl;
268  return false;
269  }
270  }
271 
272  bool DTesnorIsValid = {true};
273  static_for<0, NumDTensor, 1>{}([&](auto index) {
274  using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
275  if(std::is_same_v<DiLayout, ELayout> == false)
276  {
277  DTesnorIsValid = false;
278  }
279  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
280  {
281  if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
282  {
283  CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of "
284  "NPerBlock without padding!");
285  DTesnorIsValid = false;
286  }
287  if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0)
288  {
289  CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!");
290  DTesnorIsValid = false;
291  }
292  }
293  else
294  {
295  if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
296  {
297  CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of "
298  "MPerBlock without padding!");
299 
300  DTesnorIsValid = false;
301  }
302  if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0)
303  {
304  CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!");
305  DTesnorIsValid = false;
306  }
307  }
308  });
309 
310  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
311  {
312  if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
313  {
314  std::cerr << "Can't support N that is not a multiple of NPerBlock"
315  " without padding!"
316  << std::endl;
317  return false;
318  }
319  if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
320  {
321  std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl;
322  return false;
323  }
324  }
325  else
326  {
327  if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
328  {
329  std::cerr << "Can't support M that is not a multiple of MPerBlock"
330  " without padding!"
331  << std::endl;
332  return false;
333  }
334  if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
335  {
336  std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl;
337  return false;
338  }
339  }
340  return DTesnorIsValid;
341  }
342 
343  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
344  CK_TILE_DEVICE static auto
346  const BDataType* b_flat_ptr,
347  const std::array<const void*, NumDTensor>& ds_ptr,
348  EDataType* e_ptr,
349  const KernelArgs& kargs,
350  const SplitKBatchOffset& splitk_batch_offset)
351  {
352  const auto& a_tensor_view = [&]() {
353  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
354  {
355  return make_naive_tensor_view<address_space_enum::global>(
356  a_ptr,
357  make_tuple(kargs.M, splitk_batch_offset.splitted_k),
358  make_tuple(kargs.stride_A, 1),
359  number<FlatmmPipeline::GetVectorSizeA()>{},
360  number<1>{});
361  }
362  else
363  {
364  return make_naive_tensor_view<address_space_enum::global>(
365  a_ptr,
366  make_tuple(splitk_batch_offset.splitted_k, kargs.M),
367  make_tuple(kargs.stride_A, 1),
368  number<FlatmmPipeline::GetVectorSizeA()>{},
369  number<1>{});
370  }
371  }();
372 
373  index_t kFlatK = FlatmmPipeline::flatKPerWarp * (splitk_batch_offset.splitted_k /
374  BlockGemmShape::WarpTile::at(number<2>{}));
375  index_t kFlatN = kargs.N * kargs.K / kFlatK;
376  const auto& b_flat_tensor_view = [&]() {
377  return make_naive_tensor_view<address_space_enum::global>(
378  b_flat_ptr,
379  make_tuple(kFlatN, kFlatK),
380  make_tuple(kFlatK, 1),
381  number<FlatmmPipeline::GetVectorSizeB()>{},
382  number<1>{});
383  }();
384 
385  const auto& ds_tensor_view = generate_tuple(
386  [&](auto i) {
387  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
388  using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
389  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
390  {
391  return make_naive_tensor_view<address_space_enum::global>(
392  static_cast<const DDataType_*>(ds_ptr[i]),
393  make_tuple(kargs.M, kargs.N),
394  make_tuple(kargs.stride_Ds[i], 1),
395  number<EpiloguePipeline::GetVectorSizeD(i)>{},
396  number<1>{});
397  }
398  else
399  {
400  return make_naive_tensor_view<address_space_enum::global>(
401  static_cast<const DDataType_*>(ds_ptr[i]),
402  make_tuple(kargs.N, kargs.M),
403  make_tuple(kargs.stride_Ds[i], 1),
404  number<EpiloguePipeline::GetVectorSizeD(i)>{},
405  number<1>{});
406  }
407  },
409 
410  // TODO: enable vector write for C in ColMajor
411  const auto& e_tensor_view = [&]() {
412  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
413  {
414  return make_naive_tensor_view<address_space_enum::global>(
415  e_ptr,
416  make_tuple(kargs.M, kargs.N),
417  make_tuple(kargs.stride_E, 1),
418  number<EpiloguePipeline::GetVectorSizeC()>{},
419  number<1>{});
420  }
421  else
422  {
423  return make_naive_tensor_view<address_space_enum::global>(
424  e_ptr,
425  make_tuple(kargs.N, kargs.M),
426  make_tuple(kargs.stride_E, 1),
427  number<1>{},
428  number<1>{});
429  }
430  }();
431 
432  return make_tuple(a_tensor_view, b_flat_tensor_view, ds_tensor_view, e_tensor_view);
433  }
434 
435  template <typename TensorView>
436  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
437  {
438  const auto& a_pad_view = [&]() {
439  const auto& a_tensor_view = views.at(I0);
440  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
441  {
442  return pad_tensor_view(a_tensor_view,
446  }
447  else
448  {
449  return pad_tensor_view(a_tensor_view,
453  }
454  }();
455 
456  const auto& b_flat_tensor_view = views.at(I1);
457 
458  const auto& ds_pad_view = generate_tuple(
459  [&](auto i) {
460  const auto& d_tensor_view = views.at(I2);
461  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
462  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
463  {
464  return pad_tensor_view(d_tensor_view[i],
468  }
469  else
470  {
471  return pad_tensor_view(d_tensor_view[i],
475  }
476  },
478 
479  // TODO vector write in for C in ColMajor
480  const auto& e_pad_view = [&]() {
481  const auto& e_tensor_view = views.at(I3);
482  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
483  {
484  return pad_tensor_view(e_tensor_view,
488  }
489  else
490  {
491  return pad_tensor_view(e_tensor_view,
495  }
496  }();
497 
498  return make_tuple(a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view);
499  }
500 
501  template <typename PadView>
502  CK_TILE_DEVICE static auto
503  MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
504  {
505  const auto& a_pad_view = views.at(I0);
506  const auto& b_flat_pad_view = views.at(I1);
507  const auto& ds_pad_view = views.at(I2);
508  const auto& e_pad_view = views.at(I3);
509 
510  const auto& a_block_window = [&]() {
511  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
512  {
513  return make_tile_window(a_pad_view,
516  {i_m, 0});
517  }
518  else
519  {
520  return make_tile_window(a_pad_view,
523  {0, i_m});
524  }
525  }();
526 
527  const auto& b_flat_block_window =
528  make_tile_window(b_flat_pad_view,
531  {static_cast<int>(i_n / BlockGemmShape::WarpTile::at(I1)), 0});
532 
533  const auto ds_block_window = generate_tuple(
534  [&](auto i) {
535  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
536  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
537  {
538  return make_tile_window(ds_pad_view[i],
541  {i_m, i_n});
542  }
543  else
544  {
545  return make_tile_window(ds_pad_view[i],
548  {i_n, i_m});
549  }
550  },
552 
553  auto e_block_window = make_tile_window(
554  e_pad_view,
556  {i_m, i_n});
557 
558  return make_tuple(a_block_window, b_flat_block_window, ds_block_window, e_block_window);
559  }
560 
561  template <bool UseDefaultScheduler = true>
562  CK_TILE_DEVICE static void RunFlatmm(const ADataType* a_ptr,
563  const BDataType* b_flat_ptr,
564  const std::array<const void*, NumDTensor>& ds_ptr,
565  EDataType* e_ptr,
566  void* smem_ptr,
567  const KernelArgs& kargs,
568  const SplitKBatchOffset& splitk_batch_offset,
569  const index_t block_idx_m,
570  const index_t block_idx_n)
571  {
572  // Create Gemm tensor views, pad views and tile windows
573  const auto& gemm_tensor_views_tuple =
574  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
575  a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
576  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
577  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
578 
579  const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
580 
581  // Run GEMM cooperatively by whole workgroup.
582  const auto& a_block_window = gemm_tile_windows.at(I0);
583  const auto& b_flat_block_window = gemm_tile_windows.at(I1);
584  const auto& d_block_window = gemm_tile_windows.at(I2);
585  const auto& c_block_tile = FlatmmPipeline{}.template operator()(
586  a_block_window, b_flat_block_window, num_loop, smem_ptr);
587  if(UseDefaultScheduler || (get_warp_id() == 0))
588  {
589  // Run Epilogue Pipeline
590  auto& c_block_window = gemm_tile_windows.at(I3);
591 
592  EpiloguePipeline{}.template
593  operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
594  c_block_window, c_block_tile, d_block_window, smem_ptr);
595  }
596  }
597 
599  {
600  const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
601  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
602  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
603 
604  const SplitKBatchOffset splitk_batch_offset(kargs);
605  // options
606  const ADataType* a_ptr =
607  static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
608  const BDataType* b_flat_ptr =
609  static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
610  EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
611 
612  // allocate LDS
613  __shared__ char smem_ptr[GetSmemSize()];
614 
615  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
616  EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
618  {
619  constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
620  RunFlatmm<scheduler_type>(a_ptr,
621  b_flat_ptr,
622  kargs.ds_ptr,
623  e_ptr,
624  smem_ptr,
625  kargs,
626  splitk_batch_offset,
627  i_m,
628  i_n);
629  }
630  }
631 };
632 
633 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:2834
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__device__ X atomic_add(X *p_dst, const X &x)
unsigned int uint32_t
Definition: stdint.h:126
Definition: flatmm_kernel.hpp:17
index_t stride_C
Definition: flatmm_kernel.hpp:63
index_t stride_A
Definition: flatmm_kernel.hpp:57
CK_TILE_HOST FlatmmHostArgs(const void *a_ptr_, const void *b_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_E_)
Definition: flatmm_kernel.hpp:19
index_t K
Definition: flatmm_kernel.hpp:56
index_t stride_E
Definition: flatmm_kernel.hpp:62
const void * b_ptr
Definition: flatmm_kernel.hpp:47
void * c_ptr
Definition: flatmm_kernel.hpp:52
CK_TILE_HOST FlatmmHostArgs()=default
void * e_ptr
Definition: flatmm_kernel.hpp:51
const std::array< index_t, NumDTensor > stride_Ds
Definition: flatmm_kernel.hpp:59
const void * a_ptr
Definition: flatmm_kernel.hpp:46
index_t N
Definition: flatmm_kernel.hpp:55
index_t stride_B
Definition: flatmm_kernel.hpp:58
index_t k_batch
Definition: flatmm_kernel.hpp:66
index_t M
Definition: flatmm_kernel.hpp:54
const std::array< const void *, NumDTensor > ds_ptr
Definition: flatmm_kernel.hpp:48
Definition: flatmm_kernel.hpp:158
index_t b_k_split_offset
Definition: flatmm_kernel.hpp:194
index_t a_k_split_offset
Definition: flatmm_kernel.hpp:193
index_t splitted_k
Definition: flatmm_kernel.hpp:195
__device__ SplitKBatchOffset(const KernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: flatmm_kernel.hpp:159
Definition: flatmm_kernel.hpp:71
index_t N
Definition: flatmm_kernel.hpp:78
index_t K
Definition: flatmm_kernel.hpp:79
void * e_ptr
Definition: flatmm_kernel.hpp:76
index_t k_batch
Definition: flatmm_kernel.hpp:84
const std::array< const void *, NumDTensor > ds_ptr
Definition: flatmm_kernel.hpp:75
index_t M
Definition: flatmm_kernel.hpp:77
const void * a_ptr
Definition: flatmm_kernel.hpp:72
index_t stride_A
Definition: flatmm_kernel.hpp:80
index_t stride_E
Definition: flatmm_kernel.hpp:83
index_t stride_B
Definition: flatmm_kernel.hpp:81
const void * b_ptr
Definition: flatmm_kernel.hpp:74
std::array< index_t, NumDTensor > stride_Ds
Definition: flatmm_kernel.hpp:82
Definition: flatmm_kernel.hpp:89
FlatmmKernelArgs< DsLayout::size()> KernelArgs
Definition: flatmm_kernel.hpp:116
static constexpr CK_TILE_HOST auto BlockSize()
Definition: flatmm_kernel.hpp:130
remove_cvref_t< typename FlatmmPipeline::BlockGemmShape > BlockGemmShape
Definition: flatmm_kernel.hpp:93
static constexpr auto I0
Definition: flatmm_kernel.hpp:109
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: flatmm_kernel.hpp:90
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition: flatmm_kernel.hpp:98
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: flatmm_kernel.hpp:99
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: flatmm_kernel.hpp:105
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: flatmm_kernel.hpp:345
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: flatmm_kernel.hpp:94
static constexpr auto I2
Definition: flatmm_kernel.hpp:111
static CK_TILE_DEVICE void RunFlatmm(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Definition: flatmm_kernel.hpp:562
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: flatmm_kernel.hpp:436
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: flatmm_kernel.hpp:198
remove_cvref_t< FlatmmPipeline_ > FlatmmPipeline
Definition: flatmm_kernel.hpp:91
remove_cvref_t< typename FlatmmPipeline::CLayout > ELayout
Definition: flatmm_kernel.hpp:97
remove_cvref_t< typename FlatmmPipeline::ADataType > ADataType
Definition: flatmm_kernel.hpp:102
remove_cvref_t< typename FlatmmPipeline::BLayout > BLayout
Definition: flatmm_kernel.hpp:96
static constexpr index_t NumDTensor
Definition: flatmm_kernel.hpp:107
static CK_TILE_HOST const std::string GetName()
Definition: flatmm_kernel.hpp:118
static constexpr CK_TILE_HOST KernelArgs MakeKernelArgs(const FlatmmHostArgs< NumDTensor > &hostArgs)
Definition: flatmm_kernel.hpp:136
static constexpr index_t kBlockSize
Definition: flatmm_kernel.hpp:100
remove_cvref_t< typename FlatmmPipeline::ALayout > ALayout
Definition: flatmm_kernel.hpp:95
static constexpr auto I3
Definition: flatmm_kernel.hpp:112
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: flatmm_kernel.hpp:152
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: flatmm_kernel.hpp:125
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: flatmm_kernel.hpp:503
static constexpr auto I1
Definition: flatmm_kernel.hpp:110
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
Definition: flatmm_kernel.hpp:598
remove_cvref_t< typename FlatmmPipeline::BDataType > BDataType
Definition: flatmm_kernel.hpp:103
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:115
Definition: sequence.hpp:49
Definition: functional.hpp:43