/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp Source File
universal_gemm_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 #include "ck_tile/host/concat.hpp"
16 
17 namespace ck_tile {
18 
30 template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
32 {
33  CK_TILE_HOST UniversalGemmHostArgs(const std::array<const void*, NumATensor>& as_ptr_,
34  const std::array<const void*, NumBTensor>& bs_ptr_,
35  const std::array<const void*, NumDTensor>& ds_ptr_,
36  void* e_ptr_,
37  index_t k_batch_,
38  index_t M_,
39  index_t N_,
40  index_t K_,
41  const std::array<index_t, NumATensor>& stride_As_,
42  const std::array<index_t, NumBTensor>& stride_Bs_,
43  const std::array<index_t, NumDTensor>& stride_Ds_,
44  index_t stride_E_)
45  : as_ptr(as_ptr_),
46  bs_ptr(bs_ptr_),
47  ds_ptr(ds_ptr_),
48  e_ptr(e_ptr_),
49  M(M_),
50  N(N_),
51  K(K_),
52  stride_As(stride_As_),
53  stride_Bs(stride_Bs_),
54  stride_Ds(stride_Ds_),
55  stride_E(stride_E_),
56  k_batch(k_batch_)
57  {
58  }
59 
60  const std::array<const void*, NumATensor> as_ptr;
61  const std::array<const void*, NumBTensor> bs_ptr;
62  const std::array<const void*, NumDTensor> ds_ptr;
63  union
64  {
65  void* e_ptr;
66  void* c_ptr;
67  };
71  const std::array<index_t, NumATensor> stride_As;
72  const std::array<index_t, NumBTensor> stride_Bs;
73  const std::array<index_t, NumDTensor> stride_Ds;
74  union
75  {
78  };
79 
81 };
82 
84 template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
86 {
88  const std::array<const void*, NumATensor> as_ptr;
90  const std::array<const void*, NumBTensor> bs_ptr;
92  const std::array<const void*, NumDTensor> ds_ptr;
94  void* e_ptr;
103  std::array<index_t, NumATensor> stride_As;
106  std::array<index_t, NumBTensor> stride_Bs;
109  std::array<index_t, NumDTensor> stride_Ds;
114 };
115 
152 template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
154 {
158 
159  static constexpr bool ADataTypeIsTuple =
161  static constexpr bool BDataTypeIsTuple =
163  static constexpr bool DDataTypeIsTuple =
165  static constexpr bool ALayoutIsTuple =
167  static constexpr bool BLayoutIsTuple =
169  static constexpr bool DLayoutIsTuple =
171 
178 
182 
186 
190 
191  using DsDataType =
195 
198 
199  static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
200 
201  // Get the persistent kernel if the pipeline has it available
203  {
204  template <typename T>
205  using has_persistent_type = decltype(T::UsePersistentKernel);
206 
207  static constexpr bool value = []() {
209  return GemmPipeline::UsePersistentKernel;
210  else
211  return false;
212  }();
213  };
215 
216  // Check if TilePartitioner has GetOutputOffset method with kargs and k_id
218  {
219  template <typename T, typename KernelArgs>
221  decltype(T::GetOutputOffset(std::declval<KernelArgs>(), std::declval<index_t>()));
222 
223  static constexpr bool value = []() {
225  return true;
226  else
227  return false;
228  }();
229  };
230  static constexpr bool has_tile_partitioner_output_offset =
232 
233  static constexpr auto I0 = number<0>();
234  static constexpr auto I1 = number<1>();
235  static constexpr auto I2 = number<2>();
236  static constexpr auto I3 = number<3>{};
237 
238  static constexpr index_t NumATensor = AsDataType::size();
239  static constexpr index_t NumBTensor = BsDataType::size();
240  static constexpr index_t NumDTensor = DsDataType::size();
241 
244 
245  static_assert(AsLayout::size() == AsDataType::size(),
246  "The size of AsLayout and AsDataType should be the same");
247 
248  static_assert(BsLayout::size() == BsDataType::size(),
249  "The size of BsLayout and BsDataType should be the same");
250 
251  static_assert(DsLayout::size() == DsDataType::size(),
252  "The size of DsLayout and DsDataType should be the same");
253 
254  using KernelArgs =
255  UniversalGemmKernelArgs<AsLayout::size(), BsLayout::size(), DsLayout::size()>;
256 
257  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
258  {
259  // clang-format off
260  return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>(), GemmPipeline::GetName());
261  // clang-format on
262  }
263 
264  CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
265  {
266  return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
267  }
268 
275  CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
276  {
278  const auto kernel = kentry<1, Kernel, KernelArgs>;
279  int occupancy;
281  hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize().x, 0));
282 
283  const int grid_size = get_available_compute_units(s) * occupancy;
284  return dim3(grid_size, 1, 1);
285  }
286 
287  CK_TILE_HOST static auto BlockSize()
288  {
289  if(ck_tile::is_wave32())
290  {
291  return dim3(kBlockSize / 2);
292  }
293  else
294  {
295  return dim3(kBlockSize);
296  }
297  }
298 
299  CK_TILE_HOST static constexpr KernelArgs
301  {
302  return KernelArgs{hostArgs.as_ptr,
303  hostArgs.bs_ptr,
304  hostArgs.ds_ptr,
305  hostArgs.e_ptr,
306  hostArgs.M,
307  hostArgs.N,
308  hostArgs.K,
309  hostArgs.stride_As,
310  hostArgs.stride_Bs,
311  hostArgs.stride_Ds,
312  hostArgs.stride_E,
313  hostArgs.k_batch};
314  }
315 
317  {
318  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
319  }
320 
322  {
323  __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
324  {
325  constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
326  const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1);
327  const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.K + K_t - 1) / K_t * K1);
328 
329  static_for<0, NumATensor, 1>{}([&](auto index) {
330  using AiLayout = remove_cvref_t<std::tuple_element_t<index.value, AsLayout>>;
331  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, AiLayout>)
332  {
333  as_k_split_offset[index] = __builtin_amdgcn_readfirstlane(k_id * KRead);
334  }
335  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, AiLayout>)
336  {
337  as_k_split_offset[index] =
338  __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_As[index]);
339  }
340  });
341 
342  static_for<0, NumBTensor, 1>{}([&](auto index) {
343  using BiLayout = remove_cvref_t<std::tuple_element_t<index.value, BsLayout>>;
344  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BiLayout>)
345  {
346  bs_k_split_offset[index] =
347  __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_Bs[index]);
348  }
349  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BiLayout>)
350  {
351  bs_k_split_offset[index] = __builtin_amdgcn_readfirstlane(k_id * KRead);
352  }
353  });
354 
355  if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
356  {
357  splitted_k = __builtin_amdgcn_readfirstlane(KRead);
358  }
359  else
360  {
361  splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1));
362  }
363  }
364 
365  std::array<index_t, NumATensor> as_k_split_offset;
366  std::array<index_t, NumBTensor> bs_k_split_offset;
368  };
369 
370  CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs)
371  {
372  if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
374  {
375  if(kargs.k_batch != 1)
376  {
377  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
378  {
379  CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
380  }
381  return false;
382  }
383  }
384 
385  const auto vectorSizeA = is_wave32() ? GemmPipeline::template GetVectorSizeA<true>()
386  : GemmPipeline::template GetVectorSizeA<false>();
387  bool AsTesnorIsValid = {true};
388  static_for<0, NumATensor, 1>{}([&](auto index) {
389  using AiLayout = remove_cvref_t<std::tuple_element_t<index.value, AsLayout>>;
390  if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
391  {
392  if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
393  GemmPipeline::kPadK == false)
394  {
395  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
396  {
398  "Can't support K that is not a multiple of k_batch * KPerBlock "
399  "without padding!");
400  }
401  AsTesnorIsValid = false;
402  }
403  if(kargs.K % vectorSizeA != 0)
404  {
405  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
406  {
407  CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
408  }
409  AsTesnorIsValid = false;
410  }
411  }
412  else
413  {
414  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
415  {
416  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
417  {
419  "Can't support M that is not a multiple of MPerBlock without padding!");
420  }
421  AsTesnorIsValid = false;
422  }
423  if(kargs.M % vectorSizeA != 0)
424  {
425  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
426  {
427  CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!");
428  }
429  AsTesnorIsValid = false;
430  }
431  }
432  });
433 
434  bool BsTesnorIsValid = {true};
435  const auto vectorSizeB = is_wave32() ? GemmPipeline::template GetVectorSizeB<true>()
436  : GemmPipeline::template GetVectorSizeB<false>();
437  static_for<0, NumBTensor, 1>{}([&](auto index) {
438  using BiLayout = remove_cvref_t<std::tuple_element_t<index.value, BsLayout>>;
439  if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::RowMajor>)
440  {
441  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
442  {
443  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
444  {
446  "Can't support N that is not a multiple of NPerBlock without padding!");
447  }
448  BsTesnorIsValid = false;
449  }
450  if(kargs.N % vectorSizeB != 0)
451  {
452  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
453  {
454  CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!");
455  }
456  BsTesnorIsValid = false;
457  }
458  }
459  else
460  {
461  if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
462  GemmPipeline::kPadK == false)
463  {
464  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
465  {
467  "Can't support K that is not a multiple of k_batch * KPerBlock "
468  "without padding!");
469  }
470  BsTesnorIsValid = false;
471  }
472  if(kargs.K % vectorSizeB != 0)
473  {
474  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
475  {
476  CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!");
477  }
478  BsTesnorIsValid = false;
479  }
480  }
481  });
482 
483  bool DTesnorIsValid = {true};
484  static_for<0, NumDTensor, 1>{}([&](auto index) {
485  using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
486  if(std::is_same_v<DiLayout, ELayout> == false)
487  {
488  DTesnorIsValid = false;
489  }
490  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
491  {
492  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
493  {
494  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
495  {
496  CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of "
497  "NPerBlock without padding!");
498  }
499  DTesnorIsValid = false;
500  }
501  if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0)
502  {
503  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
504  {
505  CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!");
506  }
507  DTesnorIsValid = false;
508  }
509  }
510  else
511  {
512  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
513  {
514  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
515  {
516  CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of "
517  "MPerBlock without padding!");
518  }
519  DTesnorIsValid = false;
520  }
521  if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0)
522  {
523  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
524  {
525  CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!");
526  }
527  DTesnorIsValid = false;
528  }
529  }
530  });
531 
532  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
533  {
534  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
535  {
536  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
537  {
539  "Can't support N that is not a multiple of NPerBlock without padding!");
540  }
541  return false;
542  }
543  if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
544  {
545  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
546  {
547  CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!");
548  }
549  return false;
550  }
551  }
552  else
553  {
554  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
555  {
556  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
557  {
559  "Can't support M that is not a multiple of MPerBlock without padding!");
560  }
561  return false;
562  }
563  if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
564  {
565  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
566  {
567  CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!");
568  }
569  return false;
570  }
571  }
572  return AsTesnorIsValid && BsTesnorIsValid && DTesnorIsValid;
573  }
574 
575  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
576  CK_TILE_DEVICE static auto
577  MakeGemmTensorViews(const std::array<const ADataType*, NumATensor>& as_ptr,
578  const std::array<const BDataType*, NumBTensor>& bs_ptr,
579  const std::array<const void*, NumDTensor>& ds_ptr,
580  EDataType* e_ptr,
581  const KernelArgs& kargs,
582  const SplitKBatchOffset& splitk_batch_offset)
583  {
584  static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
585 
586  const auto& as_tensor_view = generate_tuple(
587  [&](auto i) {
588  using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
589  using AiDataType = remove_cvref_t<std::tuple_element_t<i.value, AsDataType>>;
590  if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
591  {
592  return make_naive_tensor_view<address_space_enum::global>(
593  static_cast<const AiDataType*>(as_ptr[i]),
594  make_tuple(kargs.M, splitk_batch_offset.splitted_k),
595  make_tuple(kargs.stride_As[i], 1),
596  number<GemmPipeline::GetVectorSizeA()>{},
597  number<1>{});
598  }
599  else
600  {
601  return make_naive_tensor_view<address_space_enum::global>(
602  static_cast<const AiDataType*>(as_ptr[i]),
603  make_tuple(splitk_batch_offset.splitted_k, kargs.M),
604  make_tuple(kargs.stride_As[i], 1),
605  number<GemmPipeline::GetVectorSizeA()>{},
606  number<1>{});
607  }
608  },
610 
611  const auto& bs_tensor_view = generate_tuple(
612  [&](auto i) {
613  using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
614  using BiDataType = remove_cvref_t<std::tuple_element_t<i.value, BsDataType>>;
615  if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::RowMajor>)
616  {
617  if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
618  {
619  constexpr index_t K1 = GemmPipeline::GetSmemPackB();
620  const index_t K0 = splitk_batch_offset.splitted_k / K1;
621  constexpr index_t VectorSizeB =
622  std::min(K1, GemmPipeline::GetVectorSizeB());
623  const auto b_k0_n_k1_desc =
625  make_tuple(kargs.N * K1, K1, I1),
627  number<1>{});
628  const auto b_n_k_desc = transform_tensor_descriptor(
629  b_k0_n_k1_desc,
634  return make_tensor_view<address_space_enum::global>(
635  static_cast<const BiDataType*>(bs_ptr[i]), b_n_k_desc);
636  }
637  else
638  {
639  return make_naive_tensor_view<address_space_enum::global>(
640  bs_ptr[i],
641  make_tuple(splitk_batch_offset.splitted_k, kargs.N),
642  make_tuple(kargs.stride_Bs[i], 1),
643  number<GemmPipeline::GetVectorSizeB()>{},
644  number<1>{});
645  }
646  }
647  else
648  {
649  if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
650  {
651  constexpr index_t K1 = GemmPipeline::GetSmemPackB();
652  const index_t K0 = splitk_batch_offset.splitted_k / K1;
653  constexpr index_t VectorSizeB =
654  std::min(K1, GemmPipeline::GetVectorSizeB());
655  const auto b_k0_n_k1_desc =
657  make_tuple(kargs.N * K1, K1, I1),
659  number<1>{});
660  const auto b_n_k_desc = transform_tensor_descriptor(
661  b_k0_n_k1_desc,
666  return make_tensor_view<address_space_enum::global>(
667  static_cast<const BiDataType*>(bs_ptr[i]), b_n_k_desc);
668  }
669  else
670  {
671  if constexpr(GemmPipeline::Preshuffle)
672  {
673  index_t kFlatK =
674  GemmPipeline::BlockGemmShape::flatKPerWarp *
675  (splitk_batch_offset.splitted_k /
676  TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}));
677  index_t kFlatN = kargs.N * kargs.K / kFlatK;
678 
679  return make_naive_tensor_view<address_space_enum::global>(
680  bs_ptr[i],
681  make_tuple(kFlatN, kFlatK),
682  make_tuple(kFlatK, 1),
683  number<GemmPipeline::GetVectorSizeB()>{},
684  number<1>{});
685  }
686  else
687  {
688  return make_naive_tensor_view<address_space_enum::global>(
689  bs_ptr[i],
690  make_tuple(kargs.N, splitk_batch_offset.splitted_k),
691  make_tuple(kargs.stride_Bs[i], 1),
692  number<GemmPipeline::GetVectorSizeB()>{},
693  number<1>{});
694  }
695  }
696  }
697  },
699 
700  const auto& ds_tensor_view = generate_tuple(
701  [&](auto i) {
702  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
703  using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
704  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
705  {
706  return make_naive_tensor_view<address_space_enum::global>(
707  static_cast<const DDataType_*>(ds_ptr[i]),
708  make_tuple(kargs.M, kargs.N),
709  make_tuple(kargs.stride_Ds[i], 1),
710  number<EpiloguePipeline::GetVectorSizeD(i)>{},
711  number<1>{});
712  }
713  else
714  {
715  return make_naive_tensor_view<address_space_enum::global>(
716  static_cast<const DDataType_*>(ds_ptr[i]),
717  make_tuple(kargs.N, kargs.M),
718  make_tuple(kargs.stride_Ds[i], 1),
719  number<EpiloguePipeline::GetVectorSizeD(i)>{},
720  number<1>{});
721  }
722  },
724 
725  // TODO: enable vector write for C in ColMajor
726  const auto& e_tensor_view = [&]() {
727  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
728  {
729  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
730  e_ptr,
731  make_tuple(kargs.M, kargs.N), // arguments not matching with flatmm.
732  make_tuple(kargs.stride_E, 1),
733  number<EpiloguePipeline::GetVectorSizeC()>{},
734  number<1>{});
735  }
736  else
737  {
738  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
739  e_ptr,
740  make_tuple(kargs.M, kargs.N),
741  make_tuple(1, kargs.stride_E),
742  number<1>{},
743  number<1>{});
744  }
745  }();
746 
747  return make_tuple(as_tensor_view, bs_tensor_view, ds_tensor_view, e_tensor_view);
748  }
749 
750  template <typename TensorView>
751  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
752  {
753  const auto& as_pad_view = generate_tuple(
754  [&](auto i) {
755  const auto& a_tensor_view = views.at(I0);
756  using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
757  if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
758  {
759  return pad_tensor_view(a_tensor_view[i],
763  }
764  else
765  {
766  return pad_tensor_view(a_tensor_view[i],
770  }
771  },
773 
774  const auto& b_flat_pad_view = views.at(I1);
775 
776  const auto& bs_pad_view = generate_tuple(
777  [&](auto i) {
778  const auto& b_tensor_view = views.at(I1);
779  using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
780  if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::ColumnMajor>)
781  {
782  return pad_tensor_view(b_tensor_view[i],
786  }
787  else
788  {
789  return pad_tensor_view(b_tensor_view[i],
793  }
794  },
796 
797  const auto& ds_pad_view = generate_tuple(
798  [&](auto i) {
799  const auto& d_tensor_view = views.at(I2);
800  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
801  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
802  {
803  return pad_tensor_view(d_tensor_view[i],
807  }
808  else
809  {
810  return pad_tensor_view(d_tensor_view[i],
814  }
815  },
817 
818  // TODO vector write in for C in ColMajor
819  const auto& e_pad_view = [&]() {
820  const auto& e_tensor_view = views.at(I3);
821  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
822  {
823  return pad_tensor_view(e_tensor_view,
827  }
828  else
829  {
830  return pad_tensor_view(e_tensor_view,
834  }
835  }();
836 
837  if constexpr(GemmPipeline::Preshuffle)
838  {
839  // For flatmm, we need to use the flat B tensor view
840  return make_tuple(as_pad_view, b_flat_pad_view, ds_pad_view, e_pad_view);
841  }
842  else
843  {
844  return make_tuple(as_pad_view, bs_pad_view, ds_pad_view, e_pad_view);
845  }
846  }
847 
848  template <typename PadView>
849  CK_TILE_DEVICE static auto
850  MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
851  {
852  const auto& as_pad_view = views.at(I0);
853  const auto& bs_pad_view = views.at(I1);
854  const auto& ds_pad_view = views.at(I2);
855  const auto& e_pad_view = views.at(I3);
856 
857  const auto& as_block_window = generate_tuple(
858  [&](auto i) {
859  using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
860  if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
861  {
862  return make_tile_window(as_pad_view[i],
865  {i_m, 0});
866  }
867  else
868  {
869  return make_tile_window(as_pad_view[i],
872  {0, i_m});
873  }
874  },
876 
877  const auto& bs_block_window = generate_tuple(
878  [&](auto i) {
879  using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
880  if constexpr(GemmPipeline::Preshuffle)
881  {
882  return make_tile_window(
883  bs_pad_view[i],
886  {static_cast<int>(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)),
887  0});
888  }
889  else
890  {
891  if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::ColumnMajor>)
892  {
893  return make_tile_window(bs_pad_view[i],
896  {i_n, 0});
897  }
898  else
899  {
900  return make_tile_window(bs_pad_view[i],
903  {0, i_n});
904  }
905  }
906  },
908 
909  const auto ds_block_window = generate_tuple(
910  [&](auto i) {
911  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
912  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
913  {
914  return make_tile_window(ds_pad_view[i],
917  {i_m, i_n});
918  }
919  else
920  {
921  return make_tile_window(ds_pad_view[i],
924  {i_n, i_m});
925  }
926  },
928 
929  auto e_block_window = make_tile_window(
930  e_pad_view,
932  {i_m, i_n});
933 
934  return make_tuple(as_block_window, bs_block_window, ds_block_window, e_block_window);
935  }
936 
951  template <bool UseDefaultScheduler = true>
952  CK_TILE_DEVICE static void RunGemm(const std::array<const ADataType*, NumATensor>& as_ptr,
953  const std::array<const BDataType*, NumBTensor>& bs_ptr,
954  const std::array<const void*, NumDTensor>& ds_ptr,
955  EDataType* e_ptr,
956  void* smem_ptr_0,
957  const KernelArgs& kargs,
958  const SplitKBatchOffset& splitk_batch_offset,
959  const index_t block_idx_m,
960  const index_t block_idx_n)
961  {
962  // Create Gemm tensor views, pad views and tile windows
963  const auto& gemm_tensor_views_tuple =
964  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
965  as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
966 
967  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
968  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
969 
970  const index_t num_loop = __builtin_amdgcn_readfirstlane(
971  TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
972 
973  // Run GEMM cooperatively by whole workgroup.
974  const auto& as_block_window = gemm_tile_windows.at(I0);
975  const auto& bs_block_window = gemm_tile_windows.at(I1);
976  const auto& ds_block_window = gemm_tile_windows.at(I2);
977 
978  const auto& c_block_tile =
979  GemmPipeline{}(as_block_window[I0], bs_block_window[I0], num_loop, smem_ptr_0);
980 
981  if(UseDefaultScheduler || (get_warp_id() == 0))
982  {
983  // Run Epilogue Pipeline
984  auto& c_block_window = gemm_tile_windows.at(I3);
985 
986  EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
987  }
988  }
989 
1007  CK_TILE_DEVICE static void RunGemm2LDS(const std::array<const ADataType*, NumATensor>& as_ptr,
1008  const std::array<const BDataType*, NumBTensor>& bs_ptr,
1009  const std::array<const void*, NumDTensor>& ds_ptr,
1010  EDataType* e_ptr,
1011  void* __restrict__ smem_ptr_0,
1012  void* __restrict__ smem_ptr_1,
1013  const KernelArgs& kargs,
1014  const SplitKBatchOffset& splitk_batch_offset,
1015  const index_t block_idx_m,
1016  const index_t block_idx_n)
1017  {
1018  // Create Gemm tensor views, pad views and tile windows
1019  const auto& gemm_tensor_views_tuple =
1020  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
1021  as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
1022 
1023  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
1024  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
1025 
1026  const index_t num_loop = __builtin_amdgcn_readfirstlane(
1027  TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
1028 
1029  // Run GEMM cooperatively by whole workgroup.
1030  const auto& as_block_window = gemm_tile_windows.at(I0);
1031  const auto& bs_block_window = gemm_tile_windows.at(I1);
1032  const auto& ds_block_window = gemm_tile_windows.at(I2);
1033 
1034  const auto& c_block_tile = GemmPipeline{}(
1035  as_block_window[I0], bs_block_window[I0], num_loop, smem_ptr_0, smem_ptr_1);
1036 
1037  // Run Epilogue Pipeline
1038  auto& c_block_window = gemm_tile_windows.at(I3);
1039 
1040  EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
1041  }
1042 
1043  // Non-persistent kernel entry point
1044  template <bool U = !PersistentKernel, typename = std::enable_if_t<U>>
1046  {
1047  const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x);
1048  const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
1049  const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
1050  const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
1051 
1052  const SplitKBatchOffset splitk_batch_offset(kargs);
1053 
1054  // options
1055  std::array<const ADataType*, NumATensor> as_ptr;
1056  static_for<0, NumATensor, 1>{}([&](auto i) {
1057  as_ptr[i] = static_cast<const ADataType*>(kargs.as_ptr[i]) +
1058  splitk_batch_offset.as_k_split_offset[i];
1059  });
1060 
1061  std::array<const BDataType*, NumBTensor> bs_ptr;
1062  static_for<0, NumBTensor, 1>{}([&](auto i) {
1063  bs_ptr[i] = static_cast<const BDataType*>(kargs.bs_ptr[i]) +
1064  splitk_batch_offset.bs_k_split_offset[i];
1065  });
1066 
1067  // Calculate output offset from tile partitioner and apply to output pointer
1068  EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
1070  {
1071  const index_t output_offset = TilePartitioner::GetOutputOffset(kargs, blockIdx.z);
1072  e_ptr += output_offset;
1073  }
1074 
1075  // allocate LDS
1076  __shared__ char smem_ptr_0[GetSmemSize()];
1077 
1078  if constexpr(GemmPipeline::DoubleSmemBuffer == true)
1079  {
1080  __shared__ char smem_ptr_1[GetSmemSize()];
1081  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
1082  EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
1084  {
1085  RunGemm2LDS(as_ptr,
1086  bs_ptr,
1087  kargs.ds_ptr,
1088  e_ptr,
1089  smem_ptr_0,
1090  smem_ptr_1,
1091  kargs,
1092  splitk_batch_offset,
1093  i_m,
1094  i_n);
1095  }
1096  }
1097  else
1098  {
1099  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
1100  EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
1102  {
1103  constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1);
1104  RunGemm<scheduler_type>(as_ptr,
1105  bs_ptr,
1106  kargs.ds_ptr,
1107  e_ptr,
1108  smem_ptr_0,
1109  kargs,
1110  splitk_batch_offset,
1111  i_m,
1112  i_n);
1113  }
1114  }
1115  }
1116 
1117  // Persistent kernel entry point
1118  template <bool U = PersistentKernel, typename = std::enable_if_t<U>, typename = void>
1120  {
1121  const auto grid_size = __builtin_amdgcn_readfirstlane(get_grid_size());
1122  const auto num_tiles =
1123  __builtin_amdgcn_readfirstlane(TilePartitioner::GridSize(kargs.M, kargs.N));
1124  const auto num_work = __builtin_amdgcn_readfirstlane(num_tiles * kargs.k_batch);
1125  auto block_id = __builtin_amdgcn_readfirstlane(get_block_id());
1126 
1127  while(block_id < num_work)
1128  {
1129  // Get the tile index for this block
1130  const auto tile_idx = __builtin_amdgcn_readfirstlane(block_id % num_tiles);
1131  const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx);
1132  const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
1133  const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
1134 
1135  // Get the SplitK offset for this block
1136  const auto k_batch = __builtin_amdgcn_readfirstlane(block_id / num_tiles);
1137  const SplitKBatchOffset splitk_batch_offset(kargs, k_batch);
1138 
1139  std::array<const ADataType*, NumATensor> as_ptr;
1140  static_for<0, NumATensor, 1>{}([&](auto i) {
1141  as_ptr[i] = static_cast<const ADataType*>(kargs.as_ptr[i]) +
1142  splitk_batch_offset.as_k_split_offset[i];
1143  });
1144 
1145  std::array<const BDataType*, NumBTensor> bs_ptr;
1146  static_for<0, NumBTensor, 1>{}([&](auto i) {
1147  bs_ptr[i] = static_cast<const BDataType*>(kargs.bs_ptr[i]) +
1148  splitk_batch_offset.bs_k_split_offset[i];
1149  });
1150 
1151  // Calculate output offset from tile partitioner and apply to output pointer
1152  EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
1154  {
1155  const index_t output_offset = TilePartitioner::GetOutputOffset(kargs, k_batch);
1156  e_ptr += output_offset;
1157  }
1158 
1159  // allocate LDS
1160  __shared__ char smem_ptr_0[GetSmemSize()];
1161  // Run the GEMM
1162  if constexpr(GemmPipeline::DoubleSmemBuffer == true)
1163  {
1164  __shared__ char smem_ptr_1[GetSmemSize()];
1165  if constexpr(!(EpiloguePipeline::MemoryOperation ==
1167  EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
1169  {
1170  RunGemm2LDS(as_ptr,
1171  bs_ptr,
1172  kargs.ds_ptr,
1173  e_ptr,
1174  smem_ptr_0,
1175  smem_ptr_1,
1176  kargs,
1177  splitk_batch_offset,
1178  i_m,
1179  i_n);
1180  }
1181  }
1182  else
1183  {
1184  if constexpr(!(EpiloguePipeline::MemoryOperation ==
1186  EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
1188  {
1189  RunGemm(as_ptr,
1190  bs_ptr,
1191  kargs.ds_ptr,
1192  e_ptr,
1193  smem_ptr_0,
1194  kargs,
1195  splitk_batch_offset,
1196  i_m,
1197  i_n);
1198  }
1199  }
1200  // Advance to the next work item
1201  block_id += grid_size;
1202  if(block_id >= num_work)
1203  {
1204  break;
1205  }
1206  }
1207  }
1208 };
1209 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:268
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1615
CK_TILE_HOST void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:13
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:197
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
__device__ index_t get_grid_size()
Definition: get_id.hpp:60
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__device__ X atomic_add(X *p_dst, const X &x)
unsigned int uint32_t
Definition: stdint.h:126
The Universal GEMM kernel host arguments.
Definition: universal_gemm_kernel.hpp:32
void * c_ptr
Definition: universal_gemm_kernel.hpp:66
const std::array< index_t, NumDTensor > stride_Ds
Definition: universal_gemm_kernel.hpp:73
const std::array< index_t, NumBTensor > stride_Bs
Definition: universal_gemm_kernel.hpp:72
CK_TILE_HOST UniversalGemmHostArgs(const std::array< const void *, NumATensor > &as_ptr_, const std::array< const void *, NumBTensor > &bs_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, const std::array< index_t, NumATensor > &stride_As_, const std::array< index_t, NumBTensor > &stride_Bs_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_E_)
Definition: universal_gemm_kernel.hpp:33
index_t K
Definition: universal_gemm_kernel.hpp:70
void * e_ptr
Definition: universal_gemm_kernel.hpp:65
index_t M
Definition: universal_gemm_kernel.hpp:68
const std::array< const void *, NumDTensor > ds_ptr
Definition: universal_gemm_kernel.hpp:62
const std::array< const void *, NumATensor > as_ptr
Definition: universal_gemm_kernel.hpp:60
const std::array< index_t, NumATensor > stride_As
Definition: universal_gemm_kernel.hpp:71
index_t N
Definition: universal_gemm_kernel.hpp:69
index_t stride_E
Definition: universal_gemm_kernel.hpp:76
const std::array< const void *, NumBTensor > bs_ptr
Definition: universal_gemm_kernel.hpp:61
index_t stride_C
Definition: universal_gemm_kernel.hpp:77
index_t k_batch
Definition: universal_gemm_kernel.hpp:80
Definition: universal_gemm_kernel.hpp:322
std::array< index_t, NumATensor > as_k_split_offset
Definition: universal_gemm_kernel.hpp:365
index_t splitted_k
Definition: universal_gemm_kernel.hpp:367
__device__ SplitKBatchOffset(const KernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: universal_gemm_kernel.hpp:323
std::array< index_t, NumBTensor > bs_k_split_offset
Definition: universal_gemm_kernel.hpp:366
Definition: universal_gemm_kernel.hpp:203
static constexpr bool value
Definition: universal_gemm_kernel.hpp:207
decltype(T::UsePersistentKernel) has_persistent_type
Definition: universal_gemm_kernel.hpp:205
decltype(T::GetOutputOffset(std::declval< KernelArgs >(), std::declval< index_t >())) has_get_output_offset_t
Definition: universal_gemm_kernel.hpp:221
static constexpr bool value
Definition: universal_gemm_kernel.hpp:223
The GEMM kernel device arguments.
Definition: universal_gemm_kernel.hpp:86
void * e_ptr
The E output tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:94
std::array< index_t, NumBTensor > stride_Bs
The distance between consecutive elements of non-contiguous dimension (in memory) of Bs tensor.
Definition: universal_gemm_kernel.hpp:106
const std::array< const void *, NumDTensor > ds_ptr
The Ds input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:92
std::array< index_t, NumATensor > stride_As
The distance between consecutive elements of non-contiguous dimension (in memory) of As tensor.
Definition: universal_gemm_kernel.hpp:103
const std::array< const void *, NumATensor > as_ptr
The As input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:88
index_t k_batch
Definition: universal_gemm_kernel.hpp:113
index_t N
GEMM's N dimension size.
Definition: universal_gemm_kernel.hpp:98
index_t stride_E
The distance between consecutive elements of non-contiguous dimension (in memory) of E tensor.
Definition: universal_gemm_kernel.hpp:112
index_t K
GEMM's K dimension size.
Definition: universal_gemm_kernel.hpp:100
const std::array< const void *, NumBTensor > bs_ptr
The Bs input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:90
std::array< index_t, NumDTensor > stride_Ds
The distance between consecutive elements of non-contiguous dimension (in memory) of Ds tensor.
Definition: universal_gemm_kernel.hpp:109
index_t M
GEMM's M dimension size.
Definition: universal_gemm_kernel.hpp:96
The Universal GEMM kernel template.
Definition: universal_gemm_kernel.hpp:154
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
Definition: universal_gemm_kernel.hpp:1045
std::conditional_t< BDataTypeIsTuple, remove_cvref_t< typename GemmPipeline::BDataType >, remove_cvref_t< tuple< typename GemmPipeline::BDataType > >> BsDataType
Definition: universal_gemm_kernel.hpp:189
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: universal_gemm_kernel.hpp:156
static CK_TILE_HOST const std::string GetName()
Definition: universal_gemm_kernel.hpp:257
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: universal_gemm_kernel.hpp:155
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
Definition: universal_gemm_kernel.hpp:1119
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_0, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: universal_gemm_kernel.hpp:952
static constexpr bool BDataTypeIsTuple
Definition: universal_gemm_kernel.hpp:161
static constexpr auto I2
Definition: universal_gemm_kernel.hpp:235
static constexpr bool BLayoutIsTuple
Definition: universal_gemm_kernel.hpp:167
static CK_TILE_DEVICE auto MakeGemmTensorViews(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: universal_gemm_kernel.hpp:577
std::conditional_t< BLayoutIsTuple, remove_cvref_t< typename GemmPipeline::BLayout >, remove_cvref_t< tuple< typename GemmPipeline::BLayout > >> BsLayout
Definition: universal_gemm_kernel.hpp:177
static constexpr index_t NumATensor
Definition: universal_gemm_kernel.hpp:238
static constexpr bool ALayoutIsTuple
Definition: universal_gemm_kernel.hpp:165
remove_cvref_t< std::tuple_element_t< I0, AsDataType > > ADataType
Definition: universal_gemm_kernel.hpp:242
static CK_TILE_DEVICE void RunGemm2LDS(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: universal_gemm_kernel.hpp:1007
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: universal_gemm_kernel.hpp:850
static constexpr auto I3
Definition: universal_gemm_kernel.hpp:236
std::conditional_t< DDataTypeIsTuple, remove_cvref_t< typename EpiloguePipeline::DsDataType >, remove_cvref_t< tuple< typename EpiloguePipeline::DsDataType > >> DsDataType
Definition: universal_gemm_kernel.hpp:194
static constexpr bool ADataTypeIsTuple
Definition: universal_gemm_kernel.hpp:159
static constexpr bool has_tile_partitioner_output_offset
Definition: universal_gemm_kernel.hpp:230
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: universal_gemm_kernel.hpp:751
remove_cvref_t< typename GemmPipeline::CLayout > ELayout
Definition: universal_gemm_kernel.hpp:196
static constexpr index_t NumDTensor
Definition: universal_gemm_kernel.hpp:240
UniversalGemmKernelArgs< AsLayout::size(), BsLayout::size(), DsLayout::size()> KernelArgs
Definition: universal_gemm_kernel.hpp:255
static constexpr bool DDataTypeIsTuple
Definition: universal_gemm_kernel.hpp:163
static constexpr bool PersistentKernel
Definition: universal_gemm_kernel.hpp:214
static constexpr auto I1
Definition: universal_gemm_kernel.hpp:234
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: universal_gemm_kernel.hpp:264
static CK_TILE_HOST auto BlockSize()
Definition: universal_gemm_kernel.hpp:287
std::conditional_t< ADataTypeIsTuple, remove_cvref_t< typename GemmPipeline::ADataType >, remove_cvref_t< tuple< typename GemmPipeline::ADataType > >> AsDataType
Definition: universal_gemm_kernel.hpp:185
remove_cvref_t< std::tuple_element_t< I0, BsDataType > > BDataType
Definition: universal_gemm_kernel.hpp:243
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: universal_gemm_kernel.hpp:275
static constexpr index_t NumBTensor
Definition: universal_gemm_kernel.hpp:239
static constexpr auto I0
Definition: universal_gemm_kernel.hpp:233
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: universal_gemm_kernel.hpp:370
std::conditional_t< ALayoutIsTuple, remove_cvref_t< typename GemmPipeline::ALayout >, remove_cvref_t< tuple< typename GemmPipeline::ALayout > >> AsLayout
Definition: universal_gemm_kernel.hpp:174
std::conditional_t< DLayoutIsTuple, remove_cvref_t< typename EpiloguePipeline::DsLayout >, remove_cvref_t< tuple< typename EpiloguePipeline::DsLayout > >> DsLayout
Definition: universal_gemm_kernel.hpp:181
static constexpr bool DLayoutIsTuple
Definition: universal_gemm_kernel.hpp:169
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: universal_gemm_kernel.hpp:157
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: universal_gemm_kernel.hpp:316
static constexpr CK_TILE_HOST KernelArgs MakeKernelArgs(const UniversalGemmHostArgs< NumATensor, NumBTensor, NumDTensor > &hostArgs)
Definition: universal_gemm_kernel.hpp:300
static constexpr index_t kBlockSize
Definition: universal_gemm_kernel.hpp:199
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: universal_gemm_kernel.hpp:197
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:115
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: stream_config.hpp:30
#define CK_TILE_ENV(name)
Definition: env.hpp:145