/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp Source File
universal_gemm_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 #include "ck_tile/host/concat.hpp"
16 
17 namespace ck_tile {
18 
30 template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
32 {
33  CK_TILE_HOST UniversalGemmHostArgs(const std::array<const void*, NumATensor>& as_ptr_,
34  const std::array<const void*, NumBTensor>& bs_ptr_,
35  const std::array<const void*, NumDTensor>& ds_ptr_,
36  void* e_ptr_,
37  index_t k_batch_,
38  index_t M_,
39  index_t N_,
40  index_t K_,
41  const std::array<index_t, NumATensor>& stride_As_,
42  const std::array<index_t, NumBTensor>& stride_Bs_,
43  const std::array<index_t, NumDTensor>& stride_Ds_,
44  index_t stride_E_)
45  : as_ptr(as_ptr_),
46  bs_ptr(bs_ptr_),
47  ds_ptr(ds_ptr_),
48  e_ptr(e_ptr_),
49  M(M_),
50  N(N_),
51  K(K_),
52  stride_As(stride_As_),
53  stride_Bs(stride_Bs_),
54  stride_Ds(stride_Ds_),
55  stride_E(stride_E_),
56  k_batch(k_batch_)
57  {
58  }
59 
60  const std::array<const void*, NumATensor> as_ptr;
61  const std::array<const void*, NumBTensor> bs_ptr;
62  const std::array<const void*, NumDTensor> ds_ptr;
63  union
64  {
65  void* e_ptr;
66  void* c_ptr;
67  };
71  const std::array<index_t, NumATensor> stride_As;
72  const std::array<index_t, NumBTensor> stride_Bs;
73  const std::array<index_t, NumDTensor> stride_Ds;
74  union
75  {
78  };
79 
81 };
82 
84 template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
86 {
88  const std::array<const void*, NumATensor> as_ptr;
90  const std::array<const void*, NumBTensor> bs_ptr;
92  const std::array<const void*, NumDTensor> ds_ptr;
94  void* e_ptr;
103  std::array<index_t, NumATensor> stride_As;
106  std::array<index_t, NumBTensor> stride_Bs;
109  std::array<index_t, NumDTensor> stride_Ds;
114 };
115 
152 template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
154 {
158 
159  static constexpr bool ADataTypeIsTuple =
161  static constexpr bool BDataTypeIsTuple =
163  static constexpr bool DDataTypeIsTuple =
165  static constexpr bool ALayoutIsTuple =
167  static constexpr bool BLayoutIsTuple =
169  static constexpr bool DLayoutIsTuple =
171 
178 
182 
186 
190 
191  using DsDataType =
195 
198 
201 
202  static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
203 
204  // Get the persistent kernel if the pipeline has it available
206  {
207  template <typename T>
208  using has_persistent_type = decltype(T::UsePersistentKernel);
209 
210  static constexpr bool value = []() {
212  return GemmPipeline::UsePersistentKernel;
213  else
214  return false;
215  }();
216  };
218 
219  // Check if TilePartitioner has GetOutputOffset method with kargs and k_id
221  {
222  template <typename T, typename KernelArgs>
224  decltype(T::GetOutputOffset(std::declval<KernelArgs>(), std::declval<index_t>()));
225 
226  static constexpr bool value = []() {
228  return true;
229  else
230  return false;
231  }();
232  };
233  static constexpr bool has_tile_partitioner_output_offset =
235 
236  static constexpr auto I0 = number<0>();
237  static constexpr auto I1 = number<1>();
238  static constexpr auto I2 = number<2>();
239  static constexpr auto I3 = number<3>{};
240 
241  static constexpr index_t NumATensor = AsDataType::size();
242  static constexpr index_t NumBTensor = BsDataType::size();
243  static constexpr index_t NumDTensor = DsDataType::size();
244 
247 
248  static_assert(AsLayout::size() == AsDataType::size(),
249  "The size of AsLayout and AsDataType should be the same");
250 
251  static_assert(BsLayout::size() == BsDataType::size(),
252  "The size of BsLayout and BsDataType should be the same");
253 
254  static_assert(DsLayout::size() == DsDataType::size(),
255  "The size of DsLayout and DsDataType should be the same");
256 
257  using KernelArgs =
258  UniversalGemmKernelArgs<AsLayout::size(), BsLayout::size(), DsLayout::size()>;
259 
260  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
261  {
262  // clang-format off
263  return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>(), GemmPipeline::GetName());
264  // clang-format on
265  }
266 
267  CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
268  {
269  return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
270  }
271 
278  CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
279  {
281  const auto kernel = kentry<1, Kernel, KernelArgs>;
282  int occupancy;
284  hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize().x, 0));
285 
286  const int grid_size = get_available_compute_units(s) * occupancy;
287  return dim3(grid_size, 1, 1);
288  }
289 
290  CK_TILE_HOST static auto BlockSize()
291  {
292  if(ck_tile::is_wave32())
293  {
294  return dim3(kBlockSize / 2);
295  }
296  else
297  {
298  return dim3(kBlockSize);
299  }
300  }
301 
302  CK_TILE_HOST static constexpr KernelArgs
304  {
305  return KernelArgs{hostArgs.as_ptr,
306  hostArgs.bs_ptr,
307  hostArgs.ds_ptr,
308  hostArgs.e_ptr,
309  hostArgs.M,
310  hostArgs.N,
311  hostArgs.K,
312  hostArgs.stride_As,
313  hostArgs.stride_Bs,
314  hostArgs.stride_Ds,
315  hostArgs.stride_E,
316  hostArgs.k_batch};
317  }
318 
320  {
321  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
322  }
323 
325  {
326  __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
327  {
328  constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
329  const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
330  const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1);
331 
332  static_for<0, NumATensor, 1>{}([&](auto index) {
333  using AiLayout = remove_cvref_t<std::tuple_element_t<index.value, AsLayout>>;
334  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, AiLayout>)
335  {
336  as_k_split_offset[index] = amd_wave_read_first_lane(k_id * KRead);
337  }
338  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, AiLayout>)
339  {
340  as_k_split_offset[index] =
341  amd_wave_read_first_lane(k_id * KRead * kargs.stride_As[index]);
342  }
343  });
344 
345  static_for<0, NumBTensor, 1>{}([&](auto index) {
346  using BiLayout = remove_cvref_t<std::tuple_element_t<index.value, BsLayout>>;
347  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BiLayout>)
348  {
349  bs_k_split_offset[index] =
350  amd_wave_read_first_lane(k_id * KRead * kargs.stride_Bs[index]);
351  }
352  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BiLayout>)
353  {
354  bs_k_split_offset[index] = amd_wave_read_first_lane(k_id * KRead);
355  }
356  });
357 
358  if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
359  {
361  }
362  else
363  {
364  splitted_k = amd_wave_read_first_lane(kargs.K - KRead * (kargs.k_batch - 1));
365  }
366  }
367 
368  std::array<index_t, NumATensor> as_k_split_offset;
369  std::array<index_t, NumBTensor> bs_k_split_offset;
371  };
372 
373  CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs)
374  {
375  if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
377  {
378  if(kargs.k_batch != 1)
379  {
380  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
381  {
382  CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
383  }
384  return false;
385  }
386  }
387 
388  const auto vectorSizeA = is_wave32() ? GemmPipeline::template GetVectorSizeA<true>()
389  : GemmPipeline::template GetVectorSizeA<false>();
390  bool AsTesnorIsValid = {true};
391  static_for<0, NumATensor, 1>{}([&](auto index) {
392  using AiLayout = remove_cvref_t<std::tuple_element_t<index.value, AsLayout>>;
393  if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
394  {
395  if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
396  GemmPipeline::kPadK == false)
397  {
398  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
399  {
401  "Can't support K that is not a multiple of k_batch * KPerBlock "
402  "without padding!");
403  }
404  AsTesnorIsValid = false;
405  }
406  if(kargs.K % vectorSizeA != 0)
407  {
408  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
409  {
410  CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
411  }
412  AsTesnorIsValid = false;
413  }
414  }
415  else
416  {
417  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
418  {
419  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
420  {
422  "Can't support M that is not a multiple of MPerBlock without padding!");
423  }
424  AsTesnorIsValid = false;
425  }
426  if(kargs.M % vectorSizeA != 0)
427  {
428  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
429  {
430  CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!");
431  }
432  AsTesnorIsValid = false;
433  }
434  }
435  });
436 
437  bool BsTesnorIsValid = {true};
438  const auto vectorSizeB = is_wave32() ? GemmPipeline::template GetVectorSizeB<true>()
439  : GemmPipeline::template GetVectorSizeB<false>();
440  static_for<0, NumBTensor, 1>{}([&](auto index) {
441  using BiLayout = remove_cvref_t<std::tuple_element_t<index.value, BsLayout>>;
442  if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::RowMajor>)
443  {
444  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
445  {
446  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
447  {
449  "Can't support N that is not a multiple of NPerBlock without padding!");
450  }
451  BsTesnorIsValid = false;
452  }
453  if(kargs.N % vectorSizeB != 0)
454  {
455  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
456  {
457  CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!");
458  }
459  BsTesnorIsValid = false;
460  }
461  }
462  else
463  {
464  if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
465  GemmPipeline::kPadK == false)
466  {
467  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
468  {
470  "Can't support K that is not a multiple of k_batch * KPerBlock "
471  "without padding!");
472  }
473  BsTesnorIsValid = false;
474  }
475  if(kargs.K % vectorSizeB != 0)
476  {
477  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
478  {
479  CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!");
480  }
481  BsTesnorIsValid = false;
482  }
483  }
484  });
485 
486  bool DTesnorIsValid = {true};
487  static_for<0, NumDTensor, 1>{}([&](auto index) {
488  using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
489  if(std::is_same_v<DiLayout, CLayout> == false)
490  {
491  DTesnorIsValid = false;
492  }
493  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
494  {
495  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
496  {
497  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
498  {
499  CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of "
500  "NPerBlock without padding!");
501  }
502  DTesnorIsValid = false;
503  }
504  if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0)
505  {
506  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
507  {
508  CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!");
509  }
510  DTesnorIsValid = false;
511  }
512  }
513  else
514  {
515  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
516  {
517  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
518  {
519  CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of "
520  "MPerBlock without padding!");
521  }
522  DTesnorIsValid = false;
523  }
524  if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0)
525  {
526  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
527  {
528  CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!");
529  }
530  DTesnorIsValid = false;
531  }
532  }
533  });
534 
535  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
536  {
537  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
538  {
539  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
540  {
542  "Can't support N that is not a multiple of NPerBlock without padding!");
543  }
544  return false;
545  }
546  if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
547  {
548  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
549  {
550  CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!");
551  }
552  return false;
553  }
554  }
555  else
556  {
557  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
558  {
559  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
560  {
562  "Can't support M that is not a multiple of MPerBlock without padding!");
563  }
564  return false;
565  }
566  if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
567  {
568  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
569  {
570  CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!");
571  }
572  return false;
573  }
574  }
575  return AsTesnorIsValid && BsTesnorIsValid && DTesnorIsValid;
576  }
577 
578  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
579  CK_TILE_DEVICE static auto
580  MakeGemmTensorViews(const std::array<const ADataType*, NumATensor>& as_ptr,
581  const std::array<const BDataType*, NumBTensor>& bs_ptr,
582  const std::array<const void*, NumDTensor>& ds_ptr,
583  EDataType* e_ptr,
584  const KernelArgs& kargs,
585  const index_t k_size)
586  {
587  static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
588 
589  const auto& as_tensor_view = generate_tuple(
590  [&](auto i) {
591  using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
592  using AiDataType = remove_cvref_t<std::tuple_element_t<i.value, AsDataType>>;
593  if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
594  {
595  return make_naive_tensor_view<address_space_enum::global>(
596  static_cast<const AiDataType*>(as_ptr[i]),
597  make_tuple(kargs.M, k_size),
598  make_tuple(kargs.stride_As[i], 1),
599  number<GemmPipeline::GetVectorSizeA()>{},
600  number<1>{});
601  }
602  else
603  {
604  return make_naive_tensor_view<address_space_enum::global>(
605  static_cast<const AiDataType*>(as_ptr[i]),
606  make_tuple(k_size, kargs.M),
607  make_tuple(kargs.stride_As[i], 1),
608  number<GemmPipeline::GetVectorSizeA()>{},
609  number<1>{});
610  }
611  },
613 
614  const auto& bs_tensor_view = generate_tuple(
615  [&](auto i) {
616  using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
617  using BiDataType = remove_cvref_t<std::tuple_element_t<i.value, BsDataType>>;
618  if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::RowMajor>)
619  {
620  if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
621  {
622  constexpr index_t K1 = GemmPipeline::GetSmemPackB();
623  const index_t K0 = k_size / K1;
624  constexpr index_t VectorSizeB =
625  std::min(K1, GemmPipeline::GetVectorSizeB());
626  const auto b_k0_n_k1_desc =
628  make_tuple(kargs.N * K1, K1, I1),
630  number<1>{});
631  const auto b_n_k_desc = transform_tensor_descriptor(
632  b_k0_n_k1_desc,
637  return make_tensor_view<address_space_enum::global>(
638  static_cast<const BiDataType*>(bs_ptr[i]), b_n_k_desc);
639  }
640  else
641  {
642  return make_naive_tensor_view<address_space_enum::global>(
643  bs_ptr[i],
644  make_tuple(k_size, kargs.N),
645  make_tuple(kargs.stride_Bs[i], 1),
646  number<GemmPipeline::GetVectorSizeB()>{},
647  number<1>{});
648  }
649  }
650  else
651  {
652  if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
653  {
654  constexpr index_t K1 = GemmPipeline::GetSmemPackB();
655  const index_t K0 = k_size / K1;
656  constexpr index_t VectorSizeB =
657  std::min(K1, GemmPipeline::GetVectorSizeB());
658  const auto b_k0_n_k1_desc =
660  make_tuple(kargs.N * K1, K1, I1),
662  number<1>{});
663  const auto b_n_k_desc = transform_tensor_descriptor(
664  b_k0_n_k1_desc,
669  return make_tensor_view<address_space_enum::global>(
670  static_cast<const BiDataType*>(bs_ptr[i]), b_n_k_desc);
671  }
672  else
673  {
674  if constexpr(GemmPipeline::Preshuffle)
675  {
676  index_t kFlatK =
677  GemmPipeline::BlockGemmShape::flatKPerWarp *
678  (k_size /
679  TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}));
680  index_t kFlatN = kargs.N * kargs.K / kFlatK;
681 
682  return make_naive_tensor_view<address_space_enum::global>(
683  bs_ptr[i],
684  make_tuple(kFlatN, kFlatK),
685  make_tuple(kFlatK, 1),
686  number<GemmPipeline::GetVectorSizeB()>{},
687  number<1>{});
688  }
689  else
690  {
691  return make_naive_tensor_view<address_space_enum::global>(
692  bs_ptr[i],
693  make_tuple(kargs.N, k_size),
694  make_tuple(kargs.stride_Bs[i], 1),
695  number<GemmPipeline::GetVectorSizeB()>{},
696  number<1>{});
697  }
698  }
699  }
700  },
702 
703  const auto& ds_tensor_view = generate_tuple(
704  [&](auto i) {
705  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
706  using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
707  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
708  {
709  return make_naive_tensor_view<address_space_enum::global>(
710  static_cast<const DDataType_*>(ds_ptr[i]),
711  make_tuple(kargs.M, kargs.N),
712  make_tuple(kargs.stride_Ds[i], 1),
713  number<EpiloguePipeline::GetVectorSizeD(i)>{},
714  number<1>{});
715  }
716  else
717  {
718  return make_naive_tensor_view<address_space_enum::global>(
719  static_cast<const DDataType_*>(ds_ptr[i]),
720  make_tuple(kargs.N, kargs.M),
721  make_tuple(kargs.stride_Ds[i], 1),
722  number<EpiloguePipeline::GetVectorSizeD(i)>{},
723  number<1>{});
724  }
725  },
727 
728  // TODO: enable vector write for C in ColMajor
729  const auto& e_tensor_view = [&]() {
730  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
731  {
732  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
733  e_ptr,
734  make_tuple(kargs.M, kargs.N), // arguments not matching with flatmm.
735  make_tuple(kargs.stride_E, 1),
736  number<EpiloguePipeline::GetVectorSizeC()>{},
737  number<1>{});
738  }
739  else
740  {
741  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
742  e_ptr,
743  make_tuple(kargs.M, kargs.N),
744  make_tuple(1, kargs.stride_E),
745  number<1>{},
746  number<1>{});
747  }
748  }();
749 
750  return make_tuple(as_tensor_view, bs_tensor_view, ds_tensor_view, e_tensor_view);
751  }
752 
753  template <typename TensorView>
754  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
755  {
756  const auto& as_pad_view = generate_tuple(
757  [&](auto i) {
758  const auto& a_tensor_view = views.at(I0);
759  using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
760  if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
761  {
762  return pad_tensor_view(a_tensor_view[i],
766  }
767  else
768  {
769  return pad_tensor_view(a_tensor_view[i],
773  }
774  },
776 
777  const auto& b_flat_pad_view = views.at(I1);
778 
779  const auto& bs_pad_view = generate_tuple(
780  [&](auto i) {
781  const auto& b_tensor_view = views.at(I1);
782  using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
783  if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::ColumnMajor>)
784  {
785  return pad_tensor_view(b_tensor_view[i],
789  }
790  else
791  {
792  return pad_tensor_view(b_tensor_view[i],
796  }
797  },
799 
800  const auto& ds_pad_view = generate_tuple(
801  [&](auto i) {
802  const auto& d_tensor_view = views.at(I2);
803  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
804  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
805  {
806  return pad_tensor_view(d_tensor_view[i],
810  }
811  else
812  {
813  return pad_tensor_view(d_tensor_view[i],
817  }
818  },
820 
821  // TODO vector write in for C in ColMajor
822  const auto& e_pad_view = [&]() {
823  const auto& e_tensor_view = views.at(I3);
824  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
825  {
826  return pad_tensor_view(e_tensor_view,
830  }
831  else
832  {
833  return pad_tensor_view(e_tensor_view,
837  }
838  }();
839 
840  if constexpr(GemmPipeline::Preshuffle)
841  {
842  // For flatmm, we need to use the flat B tensor view
843  return make_tuple(as_pad_view, b_flat_pad_view, ds_pad_view, e_pad_view);
844  }
845  else
846  {
847  return make_tuple(as_pad_view, bs_pad_view, ds_pad_view, e_pad_view);
848  }
849  }
850 
851  template <typename PadView>
852  CK_TILE_DEVICE static auto
853  MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
854  {
855  const auto& as_pad_view = views.at(I0);
856  const auto& bs_pad_view = views.at(I1);
857  const auto& ds_pad_view = views.at(I2);
858  const auto& e_pad_view = views.at(I3);
859 
860  const auto& as_block_window = generate_tuple(
861  [&](auto i) {
862  using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
863  if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
864  {
865  return make_tile_window(as_pad_view[i],
868  {i_m, 0});
869  }
870  else
871  {
872  return make_tile_window(as_pad_view[i],
875  {0, i_m});
876  }
877  },
879 
880  const auto& bs_block_window = generate_tuple(
881  [&](auto i) {
882  using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
883  if constexpr(GemmPipeline::Preshuffle)
884  {
885  return make_tile_window(
886  bs_pad_view[i],
889  {static_cast<int>(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)),
890  0});
891  }
892  else
893  {
894  if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::ColumnMajor>)
895  {
896  return make_tile_window(bs_pad_view[i],
899  {i_n, 0});
900  }
901  else
902  {
903  return make_tile_window(bs_pad_view[i],
906  {0, i_n});
907  }
908  }
909  },
911 
912  const auto ds_block_window = generate_tuple(
913  [&](auto i) {
914  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
915  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
916  {
917  return make_tile_window(ds_pad_view[i],
920  {i_m, i_n});
921  }
922  else
923  {
924  return make_tile_window(ds_pad_view[i],
927  {i_n, i_m});
928  }
929  },
931 
932  auto e_block_window = make_tile_window(
933  e_pad_view,
935  {i_m, i_n});
936 
937  return make_tuple(as_block_window, bs_block_window, ds_block_window, e_block_window);
938  }
939 
954  template <bool UseDefaultScheduler = true>
955  CK_TILE_DEVICE static void RunGemm(const std::array<const ADataType*, NumATensor>& as_ptr,
956  const std::array<const BDataType*, NumBTensor>& bs_ptr,
957  const std::array<const void*, NumDTensor>& ds_ptr,
958  EDataType* e_ptr,
959  void* smem_ptr_0,
960  const KernelArgs& kargs,
961  const SplitKBatchOffset& splitk_batch_offset,
962  const index_t block_idx_m,
963  const index_t block_idx_n)
964  {
965  // Create Gemm tensor views, pad views and tile windows
966  const auto& gemm_tensor_views_tuple =
967  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
968  as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k);
969 
970  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
971  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
972 
973  const index_t num_loop =
974  amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
975 
976  // Run GEMM cooperatively by whole workgroup.
977  const auto& as_block_window = gemm_tile_windows.at(I0);
978  const auto& bs_block_window = gemm_tile_windows.at(I1);
979  const auto& ds_block_window = gemm_tile_windows.at(I2);
980 
981  const auto& c_block_tile = GemmPipeline{}.template operator()(
982  as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0);
983 
984  if(UseDefaultScheduler || (get_warp_id() == 0))
985  {
986  // Run Epilogue Pipeline
987  auto& c_block_window = gemm_tile_windows.at(I3);
988 
989  EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
990  }
991  }
992 
1010  CK_TILE_DEVICE static void RunGemm2LDS(const std::array<const ADataType*, NumATensor>& as_ptr,
1011  const std::array<const BDataType*, NumBTensor>& bs_ptr,
1012  const std::array<const void*, NumDTensor>& ds_ptr,
1013  EDataType* e_ptr,
1014  void* __restrict__ smem_ptr_0,
1015  void* __restrict__ smem_ptr_1,
1016  const KernelArgs& kargs,
1017  const SplitKBatchOffset& splitk_batch_offset,
1018  const index_t block_idx_m,
1019  const index_t block_idx_n)
1020  {
1021  // Create Gemm tensor views, pad views and tile windows
1022  const auto& gemm_tensor_views_tuple =
1023  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
1024  as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k);
1025 
1026  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
1027  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
1028 
1029  const index_t num_loop =
1030  amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
1031 
1032  // Run GEMM cooperatively by whole workgroup.
1033  const auto& as_block_window = gemm_tile_windows.at(I0);
1034  const auto& bs_block_window = gemm_tile_windows.at(I1);
1035  const auto& ds_block_window = gemm_tile_windows.at(I2);
1036 
1037  const auto& c_block_tile = GemmPipeline{}.template operator()(as_block_window,
1038  AElementWise{},
1039  bs_block_window,
1040  BElementWise{},
1041  num_loop,
1042  smem_ptr_0,
1043  smem_ptr_1);
1044 
1045  // Run Epilogue Pipeline
1046  auto& c_block_window = gemm_tile_windows.at(I3);
1047 
1048  EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
1049  }
1050 
1051  // Non-persistent kernel entry point
1052  template <bool U = !PersistentKernel, typename = std::enable_if_t<U>>
1054  {
1055  const auto blockId = amd_wave_read_first_lane(blockIdx.x);
1056  const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
1057  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
1058  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
1059 
1060  const SplitKBatchOffset splitk_batch_offset(kargs);
1061 
1062  // options
1063  std::array<const ADataType*, NumATensor> as_ptr;
1064  static_for<0, NumATensor, 1>{}([&](auto i) {
1065  as_ptr[i] = static_cast<const ADataType*>(kargs.as_ptr[i]) +
1066  splitk_batch_offset.as_k_split_offset[i];
1067  });
1068 
1069  std::array<const BDataType*, NumBTensor> bs_ptr;
1070  static_for<0, NumBTensor, 1>{}([&](auto i) {
1071  bs_ptr[i] = static_cast<const BDataType*>(kargs.bs_ptr[i]) +
1072  splitk_batch_offset.bs_k_split_offset[i];
1073  });
1074 
1075  // Calculate output offset from tile partitioner and apply to output pointer
1076  EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
1078  {
1079  const index_t output_offset = TilePartitioner::GetOutputOffset(kargs, blockIdx.z);
1080  e_ptr += output_offset;
1081  }
1082 
1083  // allocate LDS
1084  __shared__ char smem_ptr_0[GetSmemSize()];
1085 
1086  if constexpr(GemmPipeline::DoubleSmemBuffer == true)
1087  {
1088  __shared__ char smem_ptr_1[GetSmemSize()];
1089  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
1090  EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
1092  {
1093  RunGemm2LDS(as_ptr,
1094  bs_ptr,
1095  kargs.ds_ptr,
1096  e_ptr,
1097  smem_ptr_0,
1098  smem_ptr_1,
1099  kargs,
1100  splitk_batch_offset,
1101  i_m,
1102  i_n);
1103  }
1104  }
1105  else
1106  {
1107  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
1108  EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
1110  {
1111  constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1);
1112  RunGemm<scheduler_type>(as_ptr,
1113  bs_ptr,
1114  kargs.ds_ptr,
1115  e_ptr,
1116  smem_ptr_0,
1117  kargs,
1118  splitk_batch_offset,
1119  i_m,
1120  i_n);
1121  }
1122  }
1123  }
1124 
1125  // Persistent kernel entry point
1126  template <bool U = PersistentKernel, typename = std::enable_if_t<U>, typename = void>
1128  {
1129  const auto grid_size = amd_wave_read_first_lane(get_grid_size());
1130  const auto num_tiles =
1131  amd_wave_read_first_lane(TilePartitioner::GridSize(kargs.M, kargs.N));
1132  const auto num_work = amd_wave_read_first_lane(num_tiles * kargs.k_batch);
1133  auto block_id = amd_wave_read_first_lane(get_block_id());
1134 
1135  while(block_id < num_work)
1136  {
1137  // Get the tile index for this block
1138  const auto tile_idx = amd_wave_read_first_lane(block_id % num_tiles);
1139  const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx);
1140  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
1141  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
1142 
1143  // Get the SplitK offset for this block
1144  const auto k_batch = amd_wave_read_first_lane(block_id / num_tiles);
1145  const SplitKBatchOffset splitk_batch_offset(kargs, k_batch);
1146 
1147  std::array<const ADataType*, NumATensor> as_ptr;
1148  static_for<0, NumATensor, 1>{}([&](auto i) {
1149  as_ptr[i] = static_cast<const ADataType*>(kargs.as_ptr[i]) +
1150  splitk_batch_offset.as_k_split_offset[i];
1151  });
1152 
1153  std::array<const BDataType*, NumBTensor> bs_ptr;
1154  static_for<0, NumBTensor, 1>{}([&](auto i) {
1155  bs_ptr[i] = static_cast<const BDataType*>(kargs.bs_ptr[i]) +
1156  splitk_batch_offset.bs_k_split_offset[i];
1157  });
1158 
1159  // Calculate output offset from tile partitioner and apply to output pointer
1160  EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
1162  {
1163  const index_t output_offset = TilePartitioner::GetOutputOffset(kargs, k_batch);
1164  e_ptr += output_offset;
1165  }
1166 
1167  // allocate LDS
1168  __shared__ char smem_ptr_0[GetSmemSize()];
1169  // Run the GEMM
1170  if constexpr(GemmPipeline::DoubleSmemBuffer == true)
1171  {
1172  __shared__ char smem_ptr_1[GetSmemSize()];
1173  if constexpr(!(EpiloguePipeline::MemoryOperation ==
1175  EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
1177  {
1178  RunGemm2LDS(as_ptr,
1179  bs_ptr,
1180  kargs.ds_ptr,
1181  e_ptr,
1182  smem_ptr_0,
1183  smem_ptr_1,
1184  kargs,
1185  splitk_batch_offset,
1186  i_m,
1187  i_n);
1188  }
1189  }
1190  else
1191  {
1192  if constexpr(!(EpiloguePipeline::MemoryOperation ==
1194  EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
1196  {
1197  RunGemm(as_ptr,
1198  bs_ptr,
1199  kargs.ds_ptr,
1200  e_ptr,
1201  smem_ptr_0,
1202  kargs,
1203  splitk_batch_offset,
1204  i_m,
1205  i_n);
1206  }
1207  }
1208  // Advance to the next work item
1209  block_id += grid_size;
1210  if(block_id >= num_work)
1211  {
1212  break;
1213  }
1214  }
1215  }
1216 };
1217 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:274
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:2834
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1615
CK_TILE_HOST void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:13
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:203
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
__device__ index_t get_grid_size()
Definition: get_id.hpp:49
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__device__ X atomic_add(X *p_dst, const X &x)
unsigned int uint32_t
Definition: stdint.h:126
The Universal GEMM kernel host arguments.
Definition: universal_gemm_kernel.hpp:32
void * c_ptr
Definition: universal_gemm_kernel.hpp:66
const std::array< index_t, NumDTensor > stride_Ds
Definition: universal_gemm_kernel.hpp:73
const std::array< index_t, NumBTensor > stride_Bs
Definition: universal_gemm_kernel.hpp:72
CK_TILE_HOST UniversalGemmHostArgs(const std::array< const void *, NumATensor > &as_ptr_, const std::array< const void *, NumBTensor > &bs_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, const std::array< index_t, NumATensor > &stride_As_, const std::array< index_t, NumBTensor > &stride_Bs_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_E_)
Definition: universal_gemm_kernel.hpp:33
index_t K
Definition: universal_gemm_kernel.hpp:70
void * e_ptr
Definition: universal_gemm_kernel.hpp:65
index_t M
Definition: universal_gemm_kernel.hpp:68
const std::array< const void *, NumDTensor > ds_ptr
Definition: universal_gemm_kernel.hpp:62
const std::array< const void *, NumATensor > as_ptr
Definition: universal_gemm_kernel.hpp:60
const std::array< index_t, NumATensor > stride_As
Definition: universal_gemm_kernel.hpp:71
index_t N
Definition: universal_gemm_kernel.hpp:69
index_t stride_E
Definition: universal_gemm_kernel.hpp:76
const std::array< const void *, NumBTensor > bs_ptr
Definition: universal_gemm_kernel.hpp:61
index_t stride_C
Definition: universal_gemm_kernel.hpp:77
index_t k_batch
Definition: universal_gemm_kernel.hpp:80
Definition: universal_gemm_kernel.hpp:325
std::array< index_t, NumATensor > as_k_split_offset
Definition: universal_gemm_kernel.hpp:368
index_t splitted_k
Definition: universal_gemm_kernel.hpp:370
__device__ SplitKBatchOffset(const KernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: universal_gemm_kernel.hpp:326
std::array< index_t, NumBTensor > bs_k_split_offset
Definition: universal_gemm_kernel.hpp:369
Definition: universal_gemm_kernel.hpp:206
static constexpr bool value
Definition: universal_gemm_kernel.hpp:210
decltype(T::UsePersistentKernel) has_persistent_type
Definition: universal_gemm_kernel.hpp:208
decltype(T::GetOutputOffset(std::declval< KernelArgs >(), std::declval< index_t >())) has_get_output_offset_t
Definition: universal_gemm_kernel.hpp:224
static constexpr bool value
Definition: universal_gemm_kernel.hpp:226
The GEMM kernel device arguments.
Definition: universal_gemm_kernel.hpp:86
void * e_ptr
The E output tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:94
std::array< index_t, NumBTensor > stride_Bs
The distance between consecutive elements of non-contiguous dimension (in memory) of Bs tensor.
Definition: universal_gemm_kernel.hpp:106
const std::array< const void *, NumDTensor > ds_ptr
The Ds input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:92
std::array< index_t, NumATensor > stride_As
The distance between consecutive elements of non-contiguous dimension (in memory) of As tensor.
Definition: universal_gemm_kernel.hpp:103
const std::array< const void *, NumATensor > as_ptr
The As input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:88
index_t k_batch
Definition: universal_gemm_kernel.hpp:113
index_t N
GEMM's N dimension size.
Definition: universal_gemm_kernel.hpp:98
index_t stride_E
The distance between consecutive elements of non-contiguous dimension (in memory) of E tensor.
Definition: universal_gemm_kernel.hpp:112
index_t K
GEMM's K dimension size.
Definition: universal_gemm_kernel.hpp:100
const std::array< const void *, NumBTensor > bs_ptr
The Bs input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:90
std::array< index_t, NumDTensor > stride_Ds
The distance between consecutive elements of non-contiguous dimension (in memory) of Ds tensor.
Definition: universal_gemm_kernel.hpp:109
index_t M
GEMM's M dimension size.
Definition: universal_gemm_kernel.hpp:96
The Universal GEMM kernel template.
Definition: universal_gemm_kernel.hpp:154
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
Definition: universal_gemm_kernel.hpp:1053
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: universal_gemm_kernel.hpp:156
static CK_TILE_HOST const std::string GetName()
Definition: universal_gemm_kernel.hpp:260
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: universal_gemm_kernel.hpp:155
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
Definition: universal_gemm_kernel.hpp:1127
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_0, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: universal_gemm_kernel.hpp:955
static constexpr bool BDataTypeIsTuple
Definition: universal_gemm_kernel.hpp:161
static constexpr auto I2
Definition: universal_gemm_kernel.hpp:238
static constexpr bool BLayoutIsTuple
Definition: universal_gemm_kernel.hpp:167
remove_cvref_t< typename GemmPipeline::BElementWise > BElementWise
Definition: universal_gemm_kernel.hpp:200
static constexpr index_t NumATensor
Definition: universal_gemm_kernel.hpp:241
static constexpr bool ALayoutIsTuple
Definition: universal_gemm_kernel.hpp:165
remove_cvref_t< std::tuple_element_t< I0, AsDataType > > ADataType
Definition: universal_gemm_kernel.hpp:245
static CK_TILE_DEVICE void RunGemm2LDS(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: universal_gemm_kernel.hpp:1010
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: universal_gemm_kernel.hpp:853
std::conditional_t< ALayoutIsTuple, remove_cvref_t< typename GemmPipeline::AsLayout >, remove_cvref_t< tuple< typename GemmPipeline::ALayout > >> AsLayout
Definition: universal_gemm_kernel.hpp:174
static constexpr auto I3
Definition: universal_gemm_kernel.hpp:239
std::conditional_t< DDataTypeIsTuple, remove_cvref_t< typename EpiloguePipeline::DsDataType >, remove_cvref_t< tuple< typename EpiloguePipeline::DsDataType > >> DsDataType
Definition: universal_gemm_kernel.hpp:194
static constexpr bool ADataTypeIsTuple
Definition: universal_gemm_kernel.hpp:159
static constexpr bool has_tile_partitioner_output_offset
Definition: universal_gemm_kernel.hpp:233
std::conditional_t< ADataTypeIsTuple, remove_cvref_t< typename GemmPipeline::AsDataType >, remove_cvref_t< tuple< typename GemmPipeline::ADataType > >> AsDataType
Definition: universal_gemm_kernel.hpp:185
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: universal_gemm_kernel.hpp:754
static constexpr index_t NumDTensor
Definition: universal_gemm_kernel.hpp:243
UniversalGemmKernelArgs< AsLayout::size(), BsLayout::size(), DsLayout::size()> KernelArgs
Definition: universal_gemm_kernel.hpp:258
static constexpr bool DDataTypeIsTuple
Definition: universal_gemm_kernel.hpp:163
static constexpr bool PersistentKernel
Definition: universal_gemm_kernel.hpp:217
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition: universal_gemm_kernel.hpp:196
static constexpr auto I1
Definition: universal_gemm_kernel.hpp:237
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: universal_gemm_kernel.hpp:267
static CK_TILE_HOST auto BlockSize()
Definition: universal_gemm_kernel.hpp:290
remove_cvref_t< std::tuple_element_t< I0, BsDataType > > BDataType
Definition: universal_gemm_kernel.hpp:246
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: universal_gemm_kernel.hpp:278
static constexpr index_t NumBTensor
Definition: universal_gemm_kernel.hpp:242
static CK_TILE_DEVICE auto MakeGemmTensorViews(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, const KernelArgs &kargs, const index_t k_size)
Definition: universal_gemm_kernel.hpp:580
static constexpr auto I0
Definition: universal_gemm_kernel.hpp:236
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: universal_gemm_kernel.hpp:373
std::conditional_t< DLayoutIsTuple, remove_cvref_t< typename EpiloguePipeline::DsLayout >, remove_cvref_t< tuple< typename EpiloguePipeline::DsLayout > >> DsLayout
Definition: universal_gemm_kernel.hpp:181
static constexpr bool DLayoutIsTuple
Definition: universal_gemm_kernel.hpp:169
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: universal_gemm_kernel.hpp:157
std::conditional_t< BDataTypeIsTuple, remove_cvref_t< typename GemmPipeline::BsDataType >, remove_cvref_t< tuple< typename GemmPipeline::BDataType > >> BsDataType
Definition: universal_gemm_kernel.hpp:189
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: universal_gemm_kernel.hpp:319
remove_cvref_t< typename GemmPipeline::AElementWise > AElementWise
Definition: universal_gemm_kernel.hpp:199
std::conditional_t< BLayoutIsTuple, remove_cvref_t< typename GemmPipeline::BsLayout >, remove_cvref_t< tuple< typename GemmPipeline::BLayout > >> BsLayout
Definition: universal_gemm_kernel.hpp:177
static constexpr CK_TILE_HOST KernelArgs MakeKernelArgs(const UniversalGemmHostArgs< NumATensor, NumBTensor, NumDTensor > &hostArgs)
Definition: universal_gemm_kernel.hpp:303
static constexpr index_t kBlockSize
Definition: universal_gemm_kernel.hpp:202
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: universal_gemm_kernel.hpp:197
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:115
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: stream_config.hpp:30
#define CK_TILE_ENV(name)
Definition: env.hpp:145