/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp Source File
universal_gemm_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 #include "ck_tile/host/concat.hpp"
16 
17 namespace ck_tile {
18 
30 template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
32 {
33  CK_TILE_HOST UniversalGemmHostArgs(const std::array<const void*, NumATensor>& as_ptr_,
34  const std::array<const void*, NumBTensor>& bs_ptr_,
35  const std::array<const void*, NumDTensor>& ds_ptr_,
36  void* e_ptr_,
37  index_t k_batch_,
38  index_t M_,
39  index_t N_,
40  index_t K_,
41  const std::array<index_t, NumATensor>& stride_As_,
42  const std::array<index_t, NumBTensor>& stride_Bs_,
43  const std::array<index_t, NumDTensor>& stride_Ds_,
44  index_t stride_E_)
45  : as_ptr(as_ptr_),
46  bs_ptr(bs_ptr_),
47  ds_ptr(ds_ptr_),
48  e_ptr(e_ptr_),
49  M(M_),
50  N(N_),
51  K(K_),
52  stride_As(stride_As_),
53  stride_Bs(stride_Bs_),
54  stride_Ds(stride_Ds_),
55  stride_E(stride_E_),
56  k_batch(k_batch_)
57  {
58  }
59 
60  const std::array<const void*, NumATensor> as_ptr;
61  const std::array<const void*, NumBTensor> bs_ptr;
62  const std::array<const void*, NumDTensor> ds_ptr;
63  union
64  {
65  void* e_ptr;
66  void* c_ptr;
67  };
71  const std::array<index_t, NumATensor> stride_As;
72  const std::array<index_t, NumBTensor> stride_Bs;
73  const std::array<index_t, NumDTensor> stride_Ds;
74  union
75  {
78  };
79 
81 };
82 
84 template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
86 {
88  const std::array<const void*, NumATensor> as_ptr;
90  const std::array<const void*, NumBTensor> bs_ptr;
92  const std::array<const void*, NumDTensor> ds_ptr;
94  void* e_ptr;
103  std::array<index_t, NumATensor> stride_As;
106  std::array<index_t, NumBTensor> stride_Bs;
109  std::array<index_t, NumDTensor> stride_Ds;
114 };
115 
152 template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
154 {
158 
159  static constexpr bool ADataTypeIsTuple =
161  static constexpr bool BDataTypeIsTuple =
163  static constexpr bool DDataTypeIsTuple =
165  static constexpr bool ALayoutIsTuple =
167  static constexpr bool BLayoutIsTuple =
169  static constexpr bool DLayoutIsTuple =
171 
178 
182 
186 
190 
191  using DsDataType =
195 
198 
201 
202  static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
203 
204  // Get the persistent kernel if the pipeline has it available
206  {
207  template <typename T>
208  using has_persistent_type = decltype(T::UsePersistentKernel);
209 
210  static constexpr bool value = []() {
212  return GemmPipeline::UsePersistentKernel;
213  else
214  return false;
215  }();
216  };
218 
219  // Check if TilePartitioner has GetOutputOffset method with kargs and k_id
221  {
222  template <typename T, typename KernelArgs>
224  decltype(T::GetOutputOffset(std::declval<KernelArgs>(), std::declval<index_t>()));
225 
226  static constexpr bool value = []() {
228  return true;
229  else
230  return false;
231  }();
232  };
233  static constexpr bool has_tile_partitioner_output_offset =
235 
236  static constexpr auto I0 = number<0>();
237  static constexpr auto I1 = number<1>();
238  static constexpr auto I2 = number<2>();
239  static constexpr auto I3 = number<3>{};
240 
241  static constexpr index_t NumATensor = AsDataType::size();
242  static constexpr index_t NumBTensor = BsDataType::size();
243  static constexpr index_t NumDTensor = DsDataType::size();
244 
247 
248  static_assert(AsLayout::size() == AsDataType::size(),
249  "The size of AsLayout and AsDataType should be the same");
250 
251  static_assert(BsLayout::size() == BsDataType::size(),
252  "The size of BsLayout and BsDataType should be the same");
253 
254  static_assert(DsLayout::size() == DsDataType::size(),
255  "The size of DsLayout and DsDataType should be the same");
256 
257  static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
258 
259  using KernelArgs =
260  UniversalGemmKernelArgs<AsLayout::size(), BsLayout::size(), DsLayout::size()>;
261 
262  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
263  {
264  // clang-format off
265  return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>(), GemmPipeline::GetName());
266  // clang-format on
267  }
268 
269  CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
270  {
271  return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
272  }
273 
280  CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
281  {
283  const auto kernel = kentry<1, Kernel, KernelArgs>;
284  int occupancy;
286  hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize().x, 0));
287 
288  const int grid_size = get_available_compute_units(s) * occupancy;
289  return dim3(grid_size, 1, 1);
290  }
291 
292  CK_TILE_HOST static auto BlockSize()
293  {
294  if(ck_tile::is_wave32())
295  {
296  return dim3(kBlockSize / 2);
297  }
298  else
299  {
300  return dim3(kBlockSize);
301  }
302  }
303 
304  CK_TILE_HOST static constexpr KernelArgs
306  {
307  return KernelArgs{hostArgs.as_ptr,
308  hostArgs.bs_ptr,
309  hostArgs.ds_ptr,
310  hostArgs.e_ptr,
311  hostArgs.M,
312  hostArgs.N,
313  hostArgs.K,
314  hostArgs.stride_As,
315  hostArgs.stride_Bs,
316  hostArgs.stride_Ds,
317  hostArgs.stride_E,
318  hostArgs.k_batch};
319  }
320 
322  {
323  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
324  }
325 
327  {
328  // This structure distributes work evenly among splitkk workgroups
329  // It's based on a principle that if there is enough work to fill all workgroups,
330  // then we can distribute the (K / K1) parts among k_batch workgroups in such a way
331  // that each workgroup will be doing ceil((K / K1) / splitk) or ceil((K / K1) / splitk) - 1
332  // and leave the potential tail for last(splitk - 1) indexed workgroup.
333  __device__ SplitKBatchOffset(const KernelArgs& kargs, const index_t k_id = blockIdx.z)
334  {
335  constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
336  const index_t num_all = amd_wave_read_first_lane(
337  kargs.K / K1); // num of all loops not including potential tail
338  index_t num_full = amd_wave_read_first_lane(num_all % kargs.k_batch);
339  num_full = num_full == 0 ? kargs.k_batch : num_full;
340 
341  const index_t num_full_iters =
343  const index_t full_k_read = num_full_iters * K1;
344  const index_t partial_k_read = (num_full_iters - 1) * K1;
345 
346  static_for<0, NumATensor, 1>{}([&](auto index) {
347  using AiLayout = remove_cvref_t<std::tuple_element_t<index.value, AsLayout>>;
348  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, AiLayout>)
349  {
350  as_k_split_offset[index] =
351  amd_wave_read_first_lane(std::min(k_id, num_full) * full_k_read +
352  std::max(k_id - num_full, 0) * partial_k_read);
353  }
354  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, AiLayout>)
355  {
356  as_k_split_offset[index] =
357  amd_wave_read_first_lane((std::min(k_id, num_full) * full_k_read +
358  std::max(k_id - num_full, 0) * partial_k_read) *
359  kargs.stride_As[index]);
360  }
361  });
362 
363  static_for<0, NumBTensor, 1>{}([&](auto index) {
364  using BiLayout = remove_cvref_t<std::tuple_element_t<index.value, BsLayout>>;
365  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BiLayout>)
366  {
367  bs_k_split_offset[index] =
368  amd_wave_read_first_lane((std::min(k_id, num_full) * full_k_read +
369  std::max(k_id - num_full, 0) * partial_k_read) *
370  kargs.stride_Bs[index]);
371  }
372  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BiLayout>)
373  {
374  bs_k_split_offset[index] =
375  amd_wave_read_first_lane(std::min(k_id, num_full) * full_k_read +
376  std::max(k_id - num_full, 0) * partial_k_read);
377  }
378  });
379 
380  if(k_id == kargs.k_batch - 1)
381  {
382  splitted_k = kargs.K - std::min(k_id, num_full) * full_k_read -
383  std::max(k_id - num_full, 0) * partial_k_read;
384  }
385  else if(k_id < num_full)
386  {
387  splitted_k = full_k_read;
388  }
389  else
390  {
391  splitted_k = partial_k_read;
392  }
393  }
394 
395  std::array<index_t, NumATensor> as_k_split_offset;
396  std::array<index_t, NumBTensor> bs_k_split_offset;
398  };
399 
400  CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs)
401  {
402  if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
404  {
405  if(kargs.k_batch != 1)
406  {
407  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
408  {
409  CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
410  }
411  return false;
412  }
413  }
414 
415  if(kargs.K < GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}) * kargs.k_batch)
416  {
417  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
418  {
419  CK_TILE_ERROR("KBatch is too large, part of GPU wouldn't be utilized!");
420  }
421  return false;
422  }
423 
424  const auto vectorSizeA = is_wave32() ? GemmPipeline::template GetVectorSizeA<true>()
425  : GemmPipeline::template GetVectorSizeA<false>();
426  bool AsTensorIsValid = {true};
427  static_for<0, NumATensor, 1>{}([&](auto index) {
428  using AiLayout = remove_cvref_t<std::tuple_element_t<index.value, AsLayout>>;
429  if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
430  {
431  if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
432  GemmPipeline::kPadK == false)
433  {
434  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
435  {
437  "Can't support K that is not a multiple of k_batch * KPerBlock "
438  "without padding!");
439  }
440  AsTensorIsValid = false;
441  }
442  if(kargs.K % vectorSizeA != 0)
443  {
444  const auto remainder = kargs.K % vectorSizeA;
445  constexpr ck_tile::index_t APackedSize =
447  const auto remainder_in_bytes = remainder * sizeof(ADataType) / APackedSize;
448  // oob can support to dword level
449  if(remainder_in_bytes % 4 == 0)
450  {
451  AsTensorIsValid = true;
452  }
453  else
454  {
455  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
456  {
457  CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
458  }
459  AsTensorIsValid = false;
460  }
461  }
462  }
463  else
464  {
465  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
466  {
467  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
468  {
470  "Can't support M that is not a multiple of MPerBlock without padding!");
471  }
472  AsTensorIsValid = false;
473  }
474  if(kargs.M % vectorSizeA != 0)
475  {
476  const auto remainder = kargs.M % vectorSizeA;
477  constexpr ck_tile::index_t APackedSize =
479  const auto remainder_in_bytes = remainder * sizeof(ADataType) / APackedSize;
480  // oob can support to dword level
481  if(remainder_in_bytes % 4 == 0)
482  {
483 
484  AsTensorIsValid = true;
485  }
486  else
487  {
488  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
489  {
490  CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!");
491  }
492  AsTensorIsValid = false;
493  }
494  }
495  }
496  });
497 
498  bool BsTensorIsValid = {true};
499  const auto vectorSizeB = is_wave32() ? GemmPipeline::template GetVectorSizeB<true>()
500  : GemmPipeline::template GetVectorSizeB<false>();
501  static_for<0, NumBTensor, 1>{}([&](auto index) {
502  using BiLayout = remove_cvref_t<std::tuple_element_t<index.value, BsLayout>>;
503  if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::RowMajor>)
504  {
505  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
506  {
507  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
508  {
510  "Can't support N that is not a multiple of NPerBlock without padding!");
511  }
512  BsTensorIsValid = false;
513  }
514  if(kargs.N % vectorSizeB != 0)
515  {
516  const auto remainder = kargs.N % vectorSizeB;
517  constexpr ck_tile::index_t BPackedSize =
519  const auto remainder_in_bytes = remainder * sizeof(BDataType) / BPackedSize;
520  // oob can support to dword level
521  if(remainder_in_bytes % 4 == 0)
522  {
523  BsTensorIsValid = true;
524  }
525  else
526  {
527  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
528  {
529  CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!");
530  }
531  BsTensorIsValid = false;
532  }
533  }
534  else
535  {
536  if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
537  GemmPipeline::kPadK == false)
538  {
539  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
540  {
542  "Can't support K that is not a multiple of k_batch * KPerBlock "
543  "without padding!");
544  }
545  BsTensorIsValid = false;
546  }
547  if(kargs.K % vectorSizeB != 0)
548  {
549  const auto remainder = kargs.K % vectorSizeB;
550  constexpr ck_tile::index_t BPackedSize =
552  const auto remainder_in_bytes = remainder * sizeof(BDataType) / BPackedSize;
553  // oob can support to dword level
554  if(remainder_in_bytes % 4 == 0)
555  {
556  BsTensorIsValid = true;
557  }
558  else
559  {
560  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
561  {
563  "K is not a multiple of vector load size for B tensor!");
564  }
565  BsTensorIsValid = false;
566  }
567  }
568  }
569  }
570  });
571 
572  bool DTensorIsValid = {true};
573  static_for<0, NumDTensor, 1>{}([&](auto index) {
574  using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
575  if(std::is_same_v<DiLayout, CLayout> == false)
576  {
577  DTensorIsValid = false;
578  }
579  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
580  {
581  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
582  {
583  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
584  {
585  CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of "
586  "NPerBlock without padding!");
587  }
588  DTensorIsValid = false;
589  }
590  if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0)
591  {
592  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
593  {
594  CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!");
595  }
596  DTensorIsValid = false;
597  }
598  }
599  else
600  {
601  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
602  {
603  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
604  {
605  CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of "
606  "MPerBlock without padding!");
607  }
608  DTensorIsValid = false;
609  }
610  if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0)
611  {
612  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
613  {
614  CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!");
615  }
616  DTensorIsValid = false;
617  }
618  }
619  });
620 
621  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
622  {
623  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
624  {
625  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
626  {
628  "Can't support N that is not a multiple of NPerBlock without padding!");
629  }
630  return false;
631  }
632  if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
633  {
634  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
635  {
636  CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!");
637  }
638  return false;
639  }
640  }
641  else
642  {
643  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
644  {
645  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
646  {
648  "Can't support M that is not a multiple of MPerBlock without padding!");
649  }
650  return false;
651  }
652  if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
653  {
654  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
655  {
656  CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!");
657  }
658  return false;
659  }
660  }
661  return AsTensorIsValid && BsTensorIsValid && DTensorIsValid;
662  }
663 
664  CK_TILE_DEVICE static auto
665  MakeABlockWindows(const std::array<const ADataType*, NumATensor>& as_ptr,
666  const KernelArgs& kargs,
667  const index_t k_size,
668  const index_t i_m)
669  {
670  // Step 1: Create tensor views for A tensors (from MakeGemmTensorViews)
671  const auto& as_tensor_view = generate_tuple(
672  [&](auto i) {
673  using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
674  using AiDataType = remove_cvref_t<std::tuple_element_t<i.value, AsDataType>>;
675  if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
676  {
677  return make_naive_tensor_view<address_space_enum::global>(
678  static_cast<const AiDataType*>(as_ptr[i]),
679  make_tuple(kargs.M, k_size),
680  make_tuple(kargs.stride_As[i], 1),
681  number<GemmPipeline::GetVectorSizeA()>{},
682  number<1>{});
683  }
684  else
685  {
686  return make_naive_tensor_view<address_space_enum::global>(
687  static_cast<const AiDataType*>(as_ptr[i]),
688  make_tuple(k_size, kargs.M),
689  make_tuple(kargs.stride_As[i], 1),
690  number<GemmPipeline::GetVectorSizeA()>{},
691  number<1>{});
692  }
693  },
695 
696  // Step 2: Create padded views (from MakeGemmPadViews)
697  const auto& as_pad_view = generate_tuple(
698  [&](auto i) {
699  using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
700  if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
701  {
702  return pad_tensor_view(as_tensor_view[i],
706  }
707  else
708  {
709  return pad_tensor_view(as_tensor_view[i],
713  }
714  },
716 
717  // Step 3: Create tile windows (from MakeGemmTileWindows)
718  const auto& as_block_window = generate_tuple(
719  [&](auto i) {
720  using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
721  if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
722  {
723  return make_tile_window(as_pad_view[i],
726  {i_m, 0});
727  }
728  else
729  {
730  return make_tile_window(as_pad_view[i],
733  {0, i_m});
734  }
735  },
737 
738  return as_block_window;
739  }
740 
741  CK_TILE_DEVICE static auto
742  MakeBBlockWindows(const std::array<const BDataType*, NumBTensor>& bs_ptr,
743  const KernelArgs& kargs,
744  const index_t k_size,
745  const index_t i_n)
746  {
747  // Step 1: Create tensor views for B tensors (from MakeGemmTensorViews)
748  const auto& bs_tensor_view = generate_tuple(
749  [&](auto i) {
750  using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
751  using BiDataType = remove_cvref_t<std::tuple_element_t<i.value, BsDataType>>;
752  if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::RowMajor>)
753  {
754  if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
755  {
756  constexpr index_t K1 = GemmPipeline::GetSmemPackB();
757  const index_t K0 = k_size / K1;
758  constexpr index_t VectorSizeB =
759  std::min(K1, GemmPipeline::GetVectorSizeB());
760  const auto b_k0_n_k1_desc =
762  make_tuple(kargs.N * K1, K1, I1),
764  number<1>{});
765  const auto b_n_k_desc = transform_tensor_descriptor(
766  b_k0_n_k1_desc,
771  return make_tensor_view<address_space_enum::global>(
772  static_cast<const BiDataType*>(bs_ptr[i]), b_n_k_desc);
773  }
774  else
775  {
776  return make_naive_tensor_view<address_space_enum::global>(
777  bs_ptr[i],
778  make_tuple(k_size, kargs.N),
779  make_tuple(kargs.stride_Bs[i], 1),
780  number<GemmPipeline::GetVectorSizeB()>{},
781  number<1>{});
782  }
783  }
784  else
785  {
786  if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
787  {
788  constexpr index_t K1 = GemmPipeline::GetSmemPackB();
789  const index_t K0 = k_size / K1;
790  constexpr index_t VectorSizeB =
791  std::min(K1, GemmPipeline::GetVectorSizeB());
792  const auto b_k0_n_k1_desc =
794  make_tuple(kargs.N * K1, K1, I1),
796  number<1>{});
797  const auto b_n_k_desc = transform_tensor_descriptor(
798  b_k0_n_k1_desc,
803  return make_tensor_view<address_space_enum::global>(
804  static_cast<const BiDataType*>(bs_ptr[i]), b_n_k_desc);
805  }
806  else
807  {
808  if constexpr(GemmPipeline::Preshuffle)
809  {
810  index_t kFlatK =
811  GemmPipeline::BlockGemmShape::flatKPerWarp *
812  (k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
813  index_t kFlatN = kargs.N * kargs.K / kFlatK;
814 
815  return make_naive_tensor_view<address_space_enum::global>(
816  bs_ptr[i],
817  make_tuple(kFlatN, kFlatK),
818  make_tuple(kFlatK, 1),
819  number<GemmPipeline::GetVectorSizeB()>{},
820  number<1>{});
821  }
822  else
823  {
824  return make_naive_tensor_view<address_space_enum::global>(
825  bs_ptr[i],
826  make_tuple(kargs.N, k_size),
827  make_tuple(kargs.stride_Bs[i], 1),
828  number<GemmPipeline::GetVectorSizeB()>{},
829  number<1>{});
830  }
831  }
832  }
833  },
835 
836  // Step 2: Create padded views (from MakeGemmPadViews)
837  const auto& bs_pad_view = generate_tuple(
838  [&](auto i) {
839  using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
840  if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::ColumnMajor>)
841  {
842  return pad_tensor_view(bs_tensor_view[i],
846  }
847  else
848  {
849  return pad_tensor_view(bs_tensor_view[i],
853  }
854  },
856 
857  // Step 3: Create tile windows (from MakeGemmTileWindows)
858  const auto& bs_block_window = generate_tuple(
859  [&](auto i) {
860  using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
861  if constexpr(GemmPipeline::Preshuffle)
862  {
863  return make_tile_window(
864  bs_pad_view[i],
867  {static_cast<int>(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)),
868  0});
869  }
870  else
871  {
872  if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::ColumnMajor>)
873  {
874  return make_tile_window(bs_pad_view[i],
877  {i_n, 0});
878  }
879  else
880  {
881  return make_tile_window(bs_pad_view[i],
884  {0, i_n});
885  }
886  }
887  },
889 
890  return bs_block_window;
891  }
892 
893  CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
894  const KernelArgs& kargs,
895  const index_t i_m,
896  const index_t i_n)
897  {
898  // Step 1: Create tensor views for D tensors (from MakeGemmTensorViews)
899  const auto& ds_tensor_view = generate_tuple(
900  [&](auto i) {
901  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
902  using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
903  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
904  {
905  return make_naive_tensor_view<address_space_enum::global>(
906  static_cast<const DDataType_*>(ds_ptr[i]),
907  make_tuple(kargs.M, kargs.N),
908  make_tuple(kargs.stride_Ds[i], 1),
909  number<EpiloguePipeline::GetVectorSizeD(i)>{},
910  number<1>{});
911  }
912  else
913  {
914  return make_naive_tensor_view<address_space_enum::global>(
915  static_cast<const DDataType_*>(ds_ptr[i]),
916  make_tuple(kargs.N, kargs.M),
917  make_tuple(kargs.stride_Ds[i], 1),
918  number<EpiloguePipeline::GetVectorSizeD(i)>{},
919  number<1>{});
920  }
921  },
923 
924  // Step 2: Create padded views (from MakeGemmPadViews)
925  const auto& ds_pad_view = generate_tuple(
926  [&](auto i) {
927  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
928  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
929  {
930  return pad_tensor_view(ds_tensor_view[i],
934  }
935  else
936  {
937  return pad_tensor_view(ds_tensor_view[i],
941  }
942  },
944 
945  // Step 3: Create tile windows (from MakeGemmTileWindows)
946  const auto& ds_block_window = generate_tuple(
947  [&](auto i) {
948  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
949  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
950  {
951  return make_tile_window(ds_pad_view[i],
954  {i_m, i_n});
955  }
956  else
957  {
958  return make_tile_window(ds_pad_view[i],
961  {i_n, i_m});
962  }
963  },
965 
966  return ds_block_window;
967  }
968 
969  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
971  const KernelArgs& kargs,
972  const index_t i_m,
973  const index_t i_n)
974  {
975  // Step 1: Create tensor view for E/C tensor (from MakeGemmTensorViews)
976  const auto& e_tensor_view = [&]() {
977  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
978  {
979  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
980  e_ptr,
981  make_tuple(kargs.M, kargs.N),
982  make_tuple(kargs.stride_E, 1),
983  number<EpiloguePipeline::GetVectorSizeC()>{},
984  number<1>{});
985  }
986  else
987  {
988  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
989  e_ptr,
990  make_tuple(kargs.M, kargs.N),
991  make_tuple(1, kargs.stride_E),
992  number<1>{},
993  number<1>{});
994  }
995  }();
996 
997  // Step 2: Create padded view (from MakeGemmPadViews)
998  const auto& e_pad_view = [&]() {
999  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
1000  {
1001  return pad_tensor_view(e_tensor_view,
1005  }
1006  else
1007  {
1008  return pad_tensor_view(e_tensor_view,
1012  }
1013  }();
1014 
1015  // Step 3: Create tile window (from MakeGemmTileWindows)
1016  auto e_block_window = make_tile_window(
1017  e_pad_view,
1019  {i_m, i_n});
1020 
1021  return e_block_window;
1022  }
1023 
1038  template <bool UseDefaultScheduler = true>
1039  CK_TILE_DEVICE static void RunGemm(const std::array<const ADataType*, NumATensor>& as_ptr,
1040  const std::array<const BDataType*, NumBTensor>& bs_ptr,
1041  const std::array<const void*, NumDTensor>& ds_ptr,
1042  EDataType* e_ptr,
1043  void* smem_ptr,
1044  const KernelArgs& kargs,
1045  const SplitKBatchOffset& splitk_batch_offset,
1046  const index_t block_idx_m,
1047  const index_t block_idx_n)
1048  {
1049  // Create block windows using specialized methods
1050  const auto& as_block_window =
1051  MakeABlockWindows(as_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
1052  const auto& bs_block_window =
1053  MakeBBlockWindows(bs_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
1054  const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
1055 
1056  const index_t num_loop =
1057  amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
1058 
1059  // Run GEMM cooperatively by whole workgroup.
1060  const auto& c_block_tile = GemmPipeline{}.template operator()(
1061  as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr);
1062 
1063  const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch);
1064  // Run Epilogue Pipeline
1065  if(k_batch == 1)
1066  {
1067  auto c_block_window = MakeCBlockWindows<memory_operation_enum::set>(
1068  e_ptr, kargs, block_idx_m, block_idx_n);
1069  EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr);
1070  }
1071  else
1072  {
1073  auto c_block_window = MakeCBlockWindows<memory_operation_enum::atomic_add>(
1074  e_ptr, kargs, block_idx_m, block_idx_n);
1075  EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr);
1076  }
1077  }
1078 
1079  CK_TILE_DEVICE static auto
1081  {
1082  index_t iM, iN;
1083 
1084  // Regular launch: use 1D block indexing
1085  const auto blockId = amd_wave_read_first_lane(blockIdx.x);
1086  const auto [tile_m, tile_n] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
1087  iM = tile_m;
1088  iN = tile_n;
1089 
1090  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
1091  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
1092 
1093  return make_tuple(i_m, i_n);
1094  }
1095 
1096  // Helper functions
1098  {
1099  // For 1D regular launch
1100  return amd_wave_read_first_lane(get_block_id());
1101  }
1102 
1104  {
1105  // For 1D regular launch
1107  }
1108 
1109  // Helper to get total number of tiles, handling both dim3 and index_t return types
1110  template <typename... Args>
1111  CK_TILE_HOST_DEVICE static auto GetNumTiles(Args&&... args) -> index_t
1112  {
1113  auto grid_size = TilePartitioner::GridSize(std::forward<Args>(args)...);
1114 
1115  using GridSizeType = decltype(grid_size);
1116 
1117  if constexpr(std::is_same_v<GridSizeType, dim3>)
1118  {
1119  // GridSize returns dim3: compute total tiles as x * y * z
1120  return amd_wave_read_first_lane(grid_size.x * grid_size.y * grid_size.z);
1121  }
1122  else
1123  {
1124  // GridSize returns scalar (index_t): use directly
1125  return amd_wave_read_first_lane(grid_size);
1126  }
1127  }
1128 
1129  // Non-persistent kernel entry point
1130  template <bool U = !PersistentKernel, typename = std::enable_if_t<U>>
1132  {
1133  const auto blockId = amd_wave_read_first_lane(blockIdx.x);
1134  const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
1135  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
1136  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
1137 
1138  const SplitKBatchOffset splitk_batch_offset(kargs);
1139 
1140  // options
1141  std::array<const ADataType*, NumATensor> as_ptr;
1142  static_for<0, NumATensor, 1>{}([&](auto i) {
1143  as_ptr[i] = static_cast<const ADataType*>(kargs.as_ptr[i]) +
1144  splitk_batch_offset.as_k_split_offset[i];
1145  });
1146 
1147  std::array<const BDataType*, NumBTensor> bs_ptr;
1148  static_for<0, NumBTensor, 1>{}([&](auto i) {
1149  bs_ptr[i] = static_cast<const BDataType*>(kargs.bs_ptr[i]) +
1150  splitk_batch_offset.bs_k_split_offset[i];
1151  });
1152 
1153  // Calculate output offset from tile partitioner and apply to output pointer
1154  EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
1156  {
1157  const index_t output_offset = TilePartitioner::GetOutputOffset(kargs, blockIdx.z);
1158  e_ptr += output_offset;
1159  }
1160 
1161  // allocate LDS
1162  __shared__ char smem_ptr[GetSmemSize()];
1163 
1164  constexpr auto scheduler_type =
1165  GemmPipeline::DoubleSmemBuffer || (GemmPipeline::NumWaveGroups == 1);
1166  RunGemm<scheduler_type>(
1167  as_ptr, bs_ptr, kargs.ds_ptr, e_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
1168  }
1169 
1170  // Persistent kernel entry point
1171  template <bool U = PersistentKernel, typename = std::enable_if_t<U>, typename = void>
1173  {
1174  const auto grid_size = amd_wave_read_first_lane(get_grid_size());
1175  const auto num_tiles =
1176  amd_wave_read_first_lane(TilePartitioner::GridSize(kargs.M, kargs.N));
1177  const auto num_work = amd_wave_read_first_lane(num_tiles * kargs.k_batch);
1178  auto block_id = amd_wave_read_first_lane(get_block_id());
1179 
1180  while(block_id < num_work)
1181  {
1182  s_waitcnt_barrier();
1183  // Get the tile index for this block
1184  const auto tile_idx = amd_wave_read_first_lane(block_id % num_tiles);
1185  const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx);
1186  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
1187  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
1188 
1189  // Get the SplitK offset for this block
1190  const auto k_batch = amd_wave_read_first_lane(block_id / num_tiles);
1191  const SplitKBatchOffset splitk_batch_offset(kargs, k_batch);
1192 
1193  std::array<const ADataType*, NumATensor> as_ptr;
1194  static_for<0, NumATensor, 1>{}([&](auto i) {
1195  as_ptr[i] = static_cast<const ADataType*>(kargs.as_ptr[i]) +
1196  splitk_batch_offset.as_k_split_offset[i];
1197  });
1198 
1199  std::array<const BDataType*, NumBTensor> bs_ptr;
1200  static_for<0, NumBTensor, 1>{}([&](auto i) {
1201  bs_ptr[i] = static_cast<const BDataType*>(kargs.bs_ptr[i]) +
1202  splitk_batch_offset.bs_k_split_offset[i];
1203  });
1204 
1205  // Calculate output offset from tile partitioner and apply to output pointer
1206  EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
1208  {
1209  const index_t output_offset = TilePartitioner::GetOutputOffset(kargs, k_batch);
1210  e_ptr += output_offset;
1211  }
1212 
1213  // allocate LDS
1214  __shared__ char smem_ptr[GetSmemSize()];
1215  // Run the GEMM
1216 
1217  RunGemm(as_ptr,
1218  bs_ptr,
1219  kargs.ds_ptr,
1220  e_ptr,
1221  smem_ptr,
1222  kargs,
1223  splitk_batch_offset,
1224  i_m,
1225  i_n);
1226 
1227  // Advance to the next work item
1228  block_id += grid_size;
1229  if(block_id >= num_work)
1230  {
1231  break;
1232  }
1233  }
1234  }
1235 };
1236 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:274
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:145
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:36
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1690
CK_TILE_HOST void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:13
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1633
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:203
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:157
__device__ index_t get_grid_size()
Definition: get_id.hpp:49
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
The Universal GEMM kernel host arguments.
Definition: universal_gemm_kernel.hpp:32
void * c_ptr
Definition: universal_gemm_kernel.hpp:66
const std::array< index_t, NumDTensor > stride_Ds
Definition: universal_gemm_kernel.hpp:73
const std::array< index_t, NumBTensor > stride_Bs
Definition: universal_gemm_kernel.hpp:72
CK_TILE_HOST UniversalGemmHostArgs(const std::array< const void *, NumATensor > &as_ptr_, const std::array< const void *, NumBTensor > &bs_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, const std::array< index_t, NumATensor > &stride_As_, const std::array< index_t, NumBTensor > &stride_Bs_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_E_)
Definition: universal_gemm_kernel.hpp:33
index_t K
Definition: universal_gemm_kernel.hpp:70
void * e_ptr
Definition: universal_gemm_kernel.hpp:65
index_t M
Definition: universal_gemm_kernel.hpp:68
const std::array< const void *, NumDTensor > ds_ptr
Definition: universal_gemm_kernel.hpp:62
const std::array< const void *, NumATensor > as_ptr
Definition: universal_gemm_kernel.hpp:60
const std::array< index_t, NumATensor > stride_As
Definition: universal_gemm_kernel.hpp:71
index_t N
Definition: universal_gemm_kernel.hpp:69
index_t stride_E
Definition: universal_gemm_kernel.hpp:76
const std::array< const void *, NumBTensor > bs_ptr
Definition: universal_gemm_kernel.hpp:61
index_t stride_C
Definition: universal_gemm_kernel.hpp:77
index_t k_batch
Definition: universal_gemm_kernel.hpp:80
Definition: universal_gemm_kernel.hpp:327
std::array< index_t, NumATensor > as_k_split_offset
Definition: universal_gemm_kernel.hpp:395
index_t splitted_k
Definition: universal_gemm_kernel.hpp:397
__device__ SplitKBatchOffset(const KernelArgs &kargs, const index_t k_id=blockIdx.z)
Definition: universal_gemm_kernel.hpp:333
std::array< index_t, NumBTensor > bs_k_split_offset
Definition: universal_gemm_kernel.hpp:396
Definition: universal_gemm_kernel.hpp:206
static constexpr bool value
Definition: universal_gemm_kernel.hpp:210
decltype(T::UsePersistentKernel) has_persistent_type
Definition: universal_gemm_kernel.hpp:208
decltype(T::GetOutputOffset(std::declval< KernelArgs >(), std::declval< index_t >())) has_get_output_offset_t
Definition: universal_gemm_kernel.hpp:224
static constexpr bool value
Definition: universal_gemm_kernel.hpp:226
The GEMM kernel device arguments.
Definition: universal_gemm_kernel.hpp:86
void * e_ptr
The E output tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:94
std::array< index_t, NumBTensor > stride_Bs
The distance between consecutive elements of non-contiguous dimension (in memory) of Bs tensor.
Definition: universal_gemm_kernel.hpp:106
const std::array< const void *, NumDTensor > ds_ptr
The Ds input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:92
std::array< index_t, NumATensor > stride_As
The distance between consecutive elements of non-contiguous dimension (in memory) of As tensor.
Definition: universal_gemm_kernel.hpp:103
const std::array< const void *, NumATensor > as_ptr
The As input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:88
index_t k_batch
Definition: universal_gemm_kernel.hpp:113
index_t N
GEMM's N dimension size.
Definition: universal_gemm_kernel.hpp:98
index_t stride_E
The distance between consecutive elements of non-contiguous dimension (in memory) of E tensor.
Definition: universal_gemm_kernel.hpp:112
index_t K
GEMM's K dimension size.
Definition: universal_gemm_kernel.hpp:100
const std::array< const void *, NumBTensor > bs_ptr
The Bs input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:90
std::array< index_t, NumDTensor > stride_Ds
The distance between consecutive elements of non-contiguous dimension (in memory) of Ds tensor.
Definition: universal_gemm_kernel.hpp:109
index_t M
GEMM's M dimension size.
Definition: universal_gemm_kernel.hpp:96
The Universal GEMM kernel template.
Definition: universal_gemm_kernel.hpp:154
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
Definition: universal_gemm_kernel.hpp:1131
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: universal_gemm_kernel.hpp:156
static CK_TILE_HOST const std::string GetName()
Definition: universal_gemm_kernel.hpp:262
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: universal_gemm_kernel.hpp:155
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
Definition: universal_gemm_kernel.hpp:1172
static CK_TILE_DEVICE auto GetGridSize() -> index_t
Definition: universal_gemm_kernel.hpp:1103
static constexpr bool BDataTypeIsTuple
Definition: universal_gemm_kernel.hpp:161
static CK_TILE_DEVICE auto MakeBBlockWindows(const std::array< const BDataType *, NumBTensor > &bs_ptr, const KernelArgs &kargs, const index_t k_size, const index_t i_n)
Definition: universal_gemm_kernel.hpp:742
static constexpr auto I2
Definition: universal_gemm_kernel.hpp:238
static constexpr bool BLayoutIsTuple
Definition: universal_gemm_kernel.hpp:167
remove_cvref_t< typename GemmPipeline::BElementWise > BElementWise
Definition: universal_gemm_kernel.hpp:200
static constexpr index_t NumATensor
Definition: universal_gemm_kernel.hpp:241
static constexpr bool ALayoutIsTuple
Definition: universal_gemm_kernel.hpp:165
remove_cvref_t< std::tuple_element_t< I0, AsDataType > > ADataType
Definition: universal_gemm_kernel.hpp:245
std::conditional_t< ALayoutIsTuple, remove_cvref_t< typename GemmPipeline::AsLayout >, remove_cvref_t< tuple< typename GemmPipeline::ALayout > >> AsLayout
Definition: universal_gemm_kernel.hpp:174
static constexpr auto I3
Definition: universal_gemm_kernel.hpp:239
std::conditional_t< DDataTypeIsTuple, remove_cvref_t< typename EpiloguePipeline::DsDataType >, remove_cvref_t< tuple< typename EpiloguePipeline::DsDataType > >> DsDataType
Definition: universal_gemm_kernel.hpp:194
static constexpr bool ADataTypeIsTuple
Definition: universal_gemm_kernel.hpp:159
static constexpr bool has_tile_partitioner_output_offset
Definition: universal_gemm_kernel.hpp:233
std::conditional_t< ADataTypeIsTuple, remove_cvref_t< typename GemmPipeline::AsDataType >, remove_cvref_t< tuple< typename GemmPipeline::ADataType > >> AsDataType
Definition: universal_gemm_kernel.hpp:185
static constexpr index_t NumDTensor
Definition: universal_gemm_kernel.hpp:243
UniversalGemmKernelArgs< AsLayout::size(), BsLayout::size(), DsLayout::size()> KernelArgs
Definition: universal_gemm_kernel.hpp:260
static constexpr bool DDataTypeIsTuple
Definition: universal_gemm_kernel.hpp:163
static CK_TILE_DEVICE auto GetTileCoordinates(const KernelArgs &kargs) -> tuple< index_t, index_t >
Definition: universal_gemm_kernel.hpp:1080
static CK_TILE_DEVICE auto MakeCBlockWindows(EDataType *e_ptr, const KernelArgs &kargs, const index_t i_m, const index_t i_n)
Definition: universal_gemm_kernel.hpp:970
static constexpr bool PersistentKernel
Definition: universal_gemm_kernel.hpp:217
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition: universal_gemm_kernel.hpp:196
static constexpr auto I1
Definition: universal_gemm_kernel.hpp:237
static CK_TILE_DEVICE auto MakeDBlockWindows(const std::array< const void *, NumDTensor > &ds_ptr, const KernelArgs &kargs, const index_t i_m, const index_t i_n)
Definition: universal_gemm_kernel.hpp:893
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: universal_gemm_kernel.hpp:269
static CK_TILE_HOST auto BlockSize()
Definition: universal_gemm_kernel.hpp:292
remove_cvref_t< std::tuple_element_t< I0, BsDataType > > BDataType
Definition: universal_gemm_kernel.hpp:246
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: universal_gemm_kernel.hpp:280
static constexpr index_t NumBTensor
Definition: universal_gemm_kernel.hpp:242
static constexpr auto I0
Definition: universal_gemm_kernel.hpp:236
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: universal_gemm_kernel.hpp:400
std::conditional_t< DLayoutIsTuple, remove_cvref_t< typename EpiloguePipeline::DsLayout >, remove_cvref_t< tuple< typename EpiloguePipeline::DsLayout > >> DsLayout
Definition: universal_gemm_kernel.hpp:181
static CK_TILE_HOST_DEVICE auto GetNumTiles(Args &&... args) -> index_t
Definition: universal_gemm_kernel.hpp:1111
static constexpr bool DLayoutIsTuple
Definition: universal_gemm_kernel.hpp:169
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: universal_gemm_kernel.hpp:157
std::conditional_t< BDataTypeIsTuple, remove_cvref_t< typename GemmPipeline::BsDataType >, remove_cvref_t< tuple< typename GemmPipeline::BDataType > >> BsDataType
Definition: universal_gemm_kernel.hpp:189
static CK_TILE_DEVICE auto MakeABlockWindows(const std::array< const ADataType *, NumATensor > &as_ptr, const KernelArgs &kargs, const index_t k_size, const index_t i_m)
Definition: universal_gemm_kernel.hpp:665
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: universal_gemm_kernel.hpp:321
remove_cvref_t< typename GemmPipeline::AElementWise > AElementWise
Definition: universal_gemm_kernel.hpp:199
std::conditional_t< BLayoutIsTuple, remove_cvref_t< typename GemmPipeline::BsLayout >, remove_cvref_t< tuple< typename GemmPipeline::BLayout > >> BsLayout
Definition: universal_gemm_kernel.hpp:177
static constexpr CK_TILE_HOST KernelArgs MakeKernelArgs(const UniversalGemmHostArgs< NumATensor, NumBTensor, NumDTensor > &hostArgs)
Definition: universal_gemm_kernel.hpp:305
static CK_TILE_DEVICE auto GetBlockId() -> index_t
Definition: universal_gemm_kernel.hpp:1097
static constexpr index_t kBlockSize
Definition: universal_gemm_kernel.hpp:202
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: universal_gemm_kernel.hpp:1039
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: universal_gemm_kernel.hpp:197
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:115
Definition: numeric.hpp:81
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: stream_config.hpp:30
Definition: tuple.hpp:192
#define CK_TILE_ENV(name)
Definition: env.hpp:145