/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp Source File
gridwise_gemm_xdlops_streamk.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
20 
21 namespace ck {
22 
23 template <typename GridwiseGemm>
24 __global__ void
25 #if CK_USE_LAUNCH_BOUNDS
27 #endif
28  kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB* p_a_grid,
29  const typename GridwiseGemm::FloatAB* p_b_grid,
30  typename GridwiseGemm::FloatC* p_c_grid,
31  void* p_workspace,
32  index_t M,
33  index_t N,
34  index_t K,
35  index_t StrideA,
36  index_t StrideB,
37  index_t StrideC,
38  typename GridwiseGemm::Block2CTileMap block_mapping)
39 {
40 #if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__))
41  constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
42 
43  __shared__ uint8_t p_shared[shared_size];
44 
45  GridwiseGemm::Run(p_a_grid,
46  p_b_grid,
47  p_c_grid,
48  p_workspace,
49  M,
50  N,
51  K,
52  StrideA,
53  StrideB,
54  StrideC,
55  block_mapping,
56  static_cast<void*>(p_shared));
57 #else
58  ignore = p_a_grid;
59  ignore = p_b_grid;
60  ignore = p_c_grid;
61  ignore = p_workspace;
62  ignore = M;
63  ignore = N;
64  ignore = K;
65  ignore = StrideA;
66  ignore = StrideB;
67  ignore = StrideC;
68  ignore = block_mapping;
69 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
70 }
71 
72 template <index_t BlockSize,
73  typename Block2CTileMap_,
74  typename FloatAB_,
75  typename FloatAcc_,
76  typename FloatC_,
77  typename ALayout,
78  typename BLayout,
79  typename CLayout,
80  typename AElementwiseOperation,
81  typename BElementwiseOperation,
82  typename CElementwiseOperation,
83  index_t MPerBlock,
84  index_t NPerBlock,
85  index_t K0PerBlock,
86  index_t MPerXDL,
87  index_t NPerXDL,
88  index_t K1Value,
89  index_t MRepeat,
90  index_t NRepeat,
91  typename ABlockTransferThreadClusterLengths_K0_M_K1,
92  typename ABlockTransferThreadClusterArrangeOrder,
93  typename ABlockTransferSrcAccessOrder,
94  index_t ABlockTransferSrcVectorDim,
95  index_t ABlockTransferSrcScalarPerVector,
96  index_t ABlockTransferDstScalarPerVector_K1,
97  bool AThreadTransferSrcResetCoordinateAfterRun,
98  index_t ABlockLdsExtraM,
99  typename BBlockTransferThreadClusterLengths_K0_N_K1,
100  typename BBlockTransferThreadClusterArrangeOrder,
101  typename BBlockTransferSrcAccessOrder,
102  index_t BBlockTransferSrcVectorDim,
103  index_t BBlockTransferSrcScalarPerVector,
104  index_t BBlockTransferDstScalarPerVector_K1,
105  bool BThreadTransferSrcResetCoordinateAfterRun,
106  index_t BBlockLdsExtraN,
107  index_t CShuffleMRepeatPerShuffle,
108  index_t CShuffleNRepeatPerShuffle,
109  index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
110  typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>
112 {
113  static constexpr auto I0 = Number<0>{};
114  static constexpr auto I1 = Number<1>{};
115  static constexpr auto I2 = Number<2>{};
116  static constexpr auto I3 = Number<3>{};
117  static constexpr auto I4 = Number<4>{};
118  static constexpr auto I5 = Number<5>{};
119  static constexpr auto I6 = Number<6>{};
120  static constexpr auto I7 = Number<7>{};
121 
122  // K1 should be Number<...>
123  static constexpr auto K1 = Number<K1Value>{};
124  static constexpr auto M01 = 1;
125  static constexpr auto N01 = 1;
126  static constexpr auto KPerBlock = K0PerBlock * K1;
127 
129  using FloatAcc = FloatAcc_;
131 
132  using Block2CTileMap = Block2CTileMap_;
133  using FloatAB = FloatAB_;
134  using FloatC = FloatC_;
135 
137  {
148 
149  Argument(const FloatAB* p_a_grid_,
150  const FloatAB* p_b_grid_,
151  FloatC* p_c_grid_,
152  index_t M_,
153  index_t N_,
154  index_t K_,
155  index_t StrideA_,
156  index_t StrideB_,
157  index_t StrideC_,
158  uint32_t num_cu,
159  uint32_t occupancy,
160  uint32_t num_sk_blocks_)
161  : p_a_grid(p_a_grid_),
162  p_b_grid(p_b_grid_),
163  p_c_grid(p_c_grid_),
164  M(M_),
165  N(N_),
166  K(K_),
167  StrideA(StrideA_),
168  StrideB(StrideB_),
169  StrideC(StrideC_),
170  block_mapping(M, N, K, num_cu, occupancy, num_sk_blocks_)
171  {
172  }
173 
174  void Print() const
175  {
176  std::cout << "arg {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
177  << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
178  << std::endl;
179  }
180  };
181 
182  __host__ __device__ static auto CalculateGridSize(const Argument& karg)
183  {
184  return std::make_tuple(math::integer_divide_ceil(karg.N, NPerBlock),
185  math::integer_divide_ceil(karg.M, MPerBlock),
186  karg.k_batch);
187  }
188 
189  __host__ __device__ static auto CalculateK0(index_t KPad) { return KPad / K1; }
190 
191  __host__ __device__ static auto
193  {
194  const index_t K0 = CalculateK0(KPad);
195 
196  const auto a_grid_desc_m_k = [&]() {
198  {
199  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
200  }
202  {
203  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
204  }
205  }();
206 
207  const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
208  a_grid_desc_m_k,
212 
213  return transform_tensor_descriptor(a_grid_desc_m_kpad,
215  make_right_pad_transform(M, MPad - M)),
218  }
219 
220  __host__ __device__ static auto
222  {
223  const index_t K0 = CalculateK0(KPad);
224 
225  const auto b_grid_desc_k_n = [&]() {
227  {
228  return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
229  }
231  {
232  return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
233  }
234  }();
235 
236  const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
237  b_grid_desc_k_n,
241 
242  return transform_tensor_descriptor(b_grid_desc_kpad_n,
244  make_right_pad_transform(N, NPad - N)),
247  }
248 
249  __host__ __device__ static auto
251  {
252  const auto c_grid_desc_m_n = [&]() {
254  {
255  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
256  }
258  {
259  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
260  }
261  }();
262 
263  return transform_tensor_descriptor(c_grid_desc_m_n,
265  make_right_pad_transform(N, NPad - N)),
268  }
269 
270  __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
271  {
272  // A matrix in LDS memory, dst of blockwise copy
276  }
277 
278  __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
279  {
280  // B matrix in LDS memory, dst of blockwise copy
284  }
285 
286  __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
287  {
288  constexpr auto max_lds_align = K1;
289 
290  // LDS allocation for A and B: be careful of alignment
291  constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
292  constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
293 
294  constexpr auto a_block_space_size_aligned =
295  math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
296 
297  constexpr auto b_block_space_size_aligned =
298  math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
299 
300  constexpr auto c_block_size =
302 
303  return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
304  sizeof(FloatAB),
305  c_block_size * sizeof(FloatCShuffle));
306  }
307 
308  __host__ __device__ static constexpr bool CheckValidity(const Argument& karg)
309  {
311  {
312  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
313  return false;
314  }
315  else
316  {
317  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
318  return false;
319  }
320 
322  {
323  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
324  return false;
325  }
326  else
327  {
328  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
329  return false;
330  }
331 
333  {
334  if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
335  return false;
336  }
337  else
338  {
339  if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
340  return false;
341  }
342 
343  return true;
344  }
345 
346  __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
347  {
348  const bool has_main_k0_block_loop = K0 > K0PerBlock;
349 
350  return has_main_k0_block_loop;
351  }
352 
353  template <typename CGridDesc>
354  __host__ __device__ static constexpr auto
355  MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_m_n_grid_desc)
356  {
357  const auto M = c_m_n_grid_desc.GetLength(I0);
358  const auto N = c_m_n_grid_desc.GetLength(I1);
359 
360  const auto MBlock = M / MPerBlock;
361  const auto NBlock = N / NPerBlock;
362 
364  c_m_n_grid_desc,
369  }
370 
371  // return block_id to C matrix tile idx (m0, n0) mapping
372  template <typename CGridDesc>
373  __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
374  const CGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch)
375  {
377  c_m_n_grid_desc, 8, KBatch);
378  }
379 
380  __host__ __device__ static constexpr auto
382  {
383  constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL);
384  constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL);
385 
387  make_tuple(I1,
389  I1,
391  }
392 
393  __host__ __device__ static constexpr auto
395  {
396  constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL);
397  constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL);
398 
402  Number<NRepeat / CShuffleNRepeatPerShuffle>{},
404  }
405 
406  __host__ __device__ static constexpr auto GetClusterLengthReduction()
407  {
408  // TODO: assume C is row major
409  // TODO: we always first loop over N, then M
410  constexpr auto NPerBlockPow2 = math::next_power_of_two<NPerBlock>();
411  constexpr auto NPerBlockReduction =
412  NPerBlockPow2 / CBlockTransferScalarPerVector_NWaveNPerXDL;
413  constexpr auto MPerBlockReduction =
414  (BlockSize + NPerBlockReduction - 1) / NPerBlockReduction;
416  }
417 
418  __host__ __device__ static constexpr auto GetPartialAccBlockDescriptor()
419  {
420  const auto c_partial_acc_block_m_n = [&]() {
422  {
423  return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock),
424  make_tuple(NPerBlock, I1));
425  }
427  {
428  return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock),
429  make_tuple(I1, MPerBlock));
430  }
431  }();
432  return c_partial_acc_block_m_n;
433  }
434 
435  using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1))>;
436 
437  __device__ static void Run(const FloatAB* p_a_grid,
438  const FloatAB* p_b_grid,
439  FloatC* p_c_grid,
440  void* p_workspace,
441  index_t M,
442  index_t N,
443  index_t K,
444  index_t StrideA,
445  index_t StrideB,
446  index_t StrideC,
447  Block2CTileMap block_mapping,
448  void* __restrict__ p_shared_block)
449  {
450  uint32_t m = M;
451  uint32_t n = N;
452  uint32_t k = K;
453  uint32_t pad_m = (m + MPerBlock - 1) / MPerBlock * MPerBlock;
454  uint32_t pad_n = (n + NPerBlock - 1) / NPerBlock * NPerBlock;
455  uint32_t pad_k = (k + KPerBlock - 1) / KPerBlock * KPerBlock;
456  uint32_t stride_a = StrideA;
457  uint32_t stride_b = StrideB;
458  uint32_t stride_c = StrideC;
459 
460  const auto a_k0_m_k1_grid_desc = MakeAGridDescriptor_K0_M_K1(m, pad_m, k, pad_k, stride_a);
461  const auto b_k0_n_k1_grid_desc = MakeBGridDescriptor_K0_N_K1(k, pad_k, n, pad_n, stride_b);
462  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(m, pad_m, n, pad_n, stride_c);
463 
464  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
466  const AElementwiseOperation a_element_op = AElementwiseOperation{};
467  const BElementwiseOperation b_element_op = BElementwiseOperation{};
468  const CElementwiseOperation c_element_op = CElementwiseOperation{};
469 
470  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
471  p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize());
472  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
473  p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize());
474  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
475  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
476 
477  // lds max alignment
478  constexpr auto max_lds_align = K1;
479 
480  // A matrix in LDS memory, dst of blockwise copy
481  constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
482 
483  // B matrix in LDS memory, dst of blockwise copy
484  constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
485 
486  auto blockwise_gemm =
488  FloatAB,
489  FloatAB,
490  FloatAcc,
491  decltype(a_block_desc_k0_m_k1),
492  decltype(b_block_desc_k0_n_k1),
493  MPerXDL,
494  NPerXDL,
495  MRepeat,
496  NRepeat,
497  K1>{};
498 
499  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
500 
501  // LDS allocation for A and B: be careful of alignment
502  constexpr auto a_block_space_size =
503  math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
504 
505  FloatAB* p_a_block = static_cast<FloatAB*>(p_shared_block);
506  FloatAB* p_b_block = static_cast<FloatAB*>(p_shared_block) + a_block_space_size;
507 
508  constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
509  constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
510 
511  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
512  p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize());
513  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
514  p_b_block, b_block_desc_k0_n_k1.GetElementSpaceSize());
515 
516  // gridwise GEMM pipeline
517  const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_v3();
518 
519  uint32_t block_idx = block_mapping.get_block_idx();
520  bool is_sk_block = block_idx < block_mapping.sk_num_blocks;
521  bool is_dp_block = block_idx >= block_mapping.dp_start_block_idx &&
522  block_idx < block_mapping.reduction_start_block_idx;
523  bool is_reduction_block = block_idx >= block_mapping.reduction_start_block_idx;
524  bool is_padding_block = block_idx >= block_mapping.sk_num_blocks &&
525  block_idx < block_mapping.dp_start_block_idx;
526  uint32_t iter_start, iter_end;
527  block_mapping.get_block_itr(block_idx, iter_start, iter_end);
528  uint32_t total_iter_length = iter_end - iter_start;
529 
530  if(is_padding_block)
531  return;
532 
533  uint32_t* p_semaphore =
534  reinterpret_cast<uint32_t*>(reinterpret_cast<char*>(p_workspace) +
535  block_mapping.get_workspace_size_for_acc(sizeof(FloatAcc)));
536 
537  if constexpr(Block2CTileMap::ReductionStrategy == StreamKReductionStrategy::Reduction)
538  {
539  if(is_reduction_block)
540  {
541  // descriptors
542  constexpr auto cluster_length_reduce = GetClusterLengthReduction();
543  constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce);
544  const auto reduce_thread_cluster_idx =
545  reduce_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
546  const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0];
547  const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1];
548 
549  constexpr auto MReduceIters =
550  math::integer_divide_ceil(Number<MPerBlock>{}, cluster_length_reduce.At(I0));
551  constexpr auto NReduceIters = math::integer_divide_ceil(
553  cluster_length_reduce.At(I1) *
555 
556  constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed(
558  constexpr auto acc_thread_buf_store_desc = make_naive_tensor_descriptor_packed(
560 
561  constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor();
562 
563  constexpr auto partial_acc_load_step_n = make_multi_index(
564  0, cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
565  constexpr auto partial_acc_load_step_n_reverse =
567  -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
568  CBlockTransferScalarPerVector_NWaveNPerXDL);
569  constexpr auto partial_acc_load_step_m =
570  make_multi_index(cluster_length_reduce.At(I0), 0);
571 
572  constexpr auto partial_acc_store_step_n = make_multi_index(
573  0,
574  0,
575  0,
576  cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
577  constexpr auto partial_acc_store_step_n_reverse =
579  0,
580  0,
581  -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
582  CBlockTransferScalarPerVector_NWaveNPerXDL);
583  constexpr auto partial_acc_store_step_m =
584  make_multi_index(0, cluster_length_reduce.At(I0), 0, 0);
585 
587  FloatAcc,
588  CBlockTransferScalarPerVector_NWaveNPerXDL,
589  true>
590  parcial_acc_buf;
592  FloatAcc,
593  CBlockTransferScalarPerVector_NWaveNPerXDL,
594  true>
595  acc_buf;
596 
597  // start to compute
598  auto reduction_idx = blockIdx.x - block_mapping.reduction_start_block_idx;
599  auto spatial_idx = block_mapping.tile_to_spatial(reduction_idx, m, n);
600 
601  workgroup_barrier wg_barrier(p_semaphore);
602 
603  uint32_t tile_acc_offset_start =
604  block_mapping.get_acc_buffer_offset_from_tile(reduction_idx);
605  uint32_t tile_acc_offset_end =
606  block_mapping.get_acc_buffer_offset_from_tile(reduction_idx + 1);
607 
608  auto acc_load = ThreadwiseTensorSliceTransfer_v2<
609  FloatAcc, // SrcData,
610  FloatAcc, // DstData,
611  decltype(c_partial_acc_block_m_n), // SrcDesc,
612  decltype(acc_thread_buf_load_desc), // DstDesc,
614  Sequence<0, 1>, // DimAccessOrder,
615  1, // SrcVectorDim,
616  CBlockTransferScalarPerVector_NWaveNPerXDL, // SrcScalarPerVector,
617  1, // SrcScalarStrideInVector,
618  false // SrcResetCoordinateAfterRun,
619  >{c_partial_acc_block_m_n,
620  make_multi_index(thread_m_cluster_id,
621  thread_n_cluster_id *
622  CBlockTransferScalarPerVector_NWaveNPerXDL)};
623 
624  auto acc_store = ThreadwiseTensorSliceTransfer_v1r3<
625  FloatAcc, // SrcData,
626  FloatC, // DstData,
627  decltype(acc_thread_buf_store_desc), // SrcDesc,
628  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc,
629  CElementwiseOperation, // ElementwiseOperation,
631  Sequence<0, 1, 2, 3>, // DimAccessOrder,
632  3, // DstVectorDim,
633  CBlockTransferScalarPerVector_NWaveNPerXDL, // DstScalarPerVector,
634  InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp,
635  1, // DstScalarStrideInVector,
636  false // DstResetCoordinateAfterRun,
637  >{c_grid_desc_mblock_mperblock_nblock_nperblock,
638  make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
639  thread_m_cluster_id,
640  __builtin_amdgcn_readfirstlane(spatial_idx[I1]),
641  thread_n_cluster_id *
642  CBlockTransferScalarPerVector_NWaveNPerXDL),
643  CElementwiseOperation{}};
644 
645  // block synchronization
646  wg_barrier.wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start);
647 
648 #if 0
649  if(threadIdx.x == 0) {
650  printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast<int>(blockIdx.x),
651  reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end),
652  __builtin_amdgcn_readfirstlane(spatial_idx[I0]),
653  __builtin_amdgcn_readfirstlane(spatial_idx[I1]));
654  }
655 #endif
656 
657  using Accumulation = ck::detail::
658  AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, FloatAcc>;
659 
660  for(int i_m = 0; i_m < MReduceIters; i_m++)
661  {
662  static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) {
663  acc_buf.Clear();
664  for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++)
665  {
666  auto c_partial_acc_buf =
669  reinterpret_cast<FloatAcc*>(p_workspace) +
670  i * c_partial_acc_block_m_n.GetElementSpaceSize(),
671  c_partial_acc_block_m_n.GetElementSpaceSize());
672 
673  acc_load.Run(c_partial_acc_block_m_n,
674  c_partial_acc_buf,
675  acc_thread_buf_load_desc,
676  make_tuple(I0, I0),
677  parcial_acc_buf);
678 
680  [&](auto i_vec) {
681  constexpr auto offset =
682  acc_thread_buf_load_desc.CalculateOffset(
683  make_tuple(0, i_vec));
684  Accumulation::Calculate(acc_buf(Number<offset>{}),
685  parcial_acc_buf[Number<offset>{}]);
686  });
687  }
688 
689  if(thread_n_cluster_id * CBlockTransferScalarPerVector_NWaveNPerXDL <
690  NPerBlock)
691  {
692  acc_store.Run(acc_thread_buf_store_desc,
693  make_tuple(I0, I0, I0, I0),
694  acc_buf,
695  c_grid_desc_mblock_mperblock_nblock_nperblock,
696  c_grid_buf);
697  }
698  if constexpr(NReduceIters != 1)
699  {
700  if constexpr(i_n_reduce != (NReduceIters - 1))
701  {
702  acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
703  partial_acc_load_step_n);
704  acc_store.MoveDstSliceWindow(
705  c_grid_desc_mblock_mperblock_nblock_nperblock,
706  partial_acc_store_step_n);
707  }
708  else
709  {
710  acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
711  partial_acc_load_step_n_reverse);
712  acc_store.MoveDstSliceWindow(
713  c_grid_desc_mblock_mperblock_nblock_nperblock,
714  partial_acc_store_step_n_reverse);
715  }
716  }
717  });
718  {
719  acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
720  partial_acc_load_step_m);
721  acc_store.MoveDstSliceWindow(c_grid_desc_mblock_mperblock_nblock_nperblock,
722  partial_acc_store_step_m);
723  }
724  }
725  return;
726  }
727  }
728 
729  // offset for last acc buffer of this block
730  uint32_t block_acc_offset =
731  (block_mapping.get_acc_buffer_offset_from_block(block_idx + 1) - 1) * MPerBlock *
732  NPerBlock;
733 
734  while(true)
735  {
736  uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
737  block_mapping.get_current_iter_length(iter_start, iter_end, total_iter_length));
738  uint32_t tile_idx, iter_offset;
739  block_mapping.get_tile_idx_with_offset(iter_end - 1, tile_idx, iter_offset);
740  iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1);
741  auto spatial_idx = block_mapping.tile_to_spatial(tile_idx, m, n);
742 
743  const index_t m_block_data_idx_on_grid =
744  __builtin_amdgcn_readfirstlane(spatial_idx[I0] * MPerBlock);
745 
746  const index_t n_block_data_idx_on_grid =
747  __builtin_amdgcn_readfirstlane(spatial_idx[I1] * NPerBlock);
748 
749  const index_t k0_block_data_idx_on_grid =
750  __builtin_amdgcn_readfirstlane(iter_offset * K0PerBlock);
751 
752  // A matrix blockwise copy
753  auto a_blockwise_copy =
755  AElementwiseOperation,
759  ABlockTransferThreadClusterLengths_K0_M_K1,
760  ABlockTransferThreadClusterArrangeOrder,
761  FloatAB,
762  FloatAB,
763  decltype(a_k0_m_k1_grid_desc),
764  decltype(a_block_desc_k0_m_k1),
765  ABlockTransferSrcAccessOrder,
767  ABlockTransferSrcVectorDim,
768  2,
769  ABlockTransferSrcScalarPerVector,
770  ABlockTransferDstScalarPerVector_K1,
771  1,
772  1,
773  AThreadTransferSrcResetCoordinateAfterRun,
774  true>(
775  a_k0_m_k1_grid_desc,
776  make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0),
777  a_element_op,
778  a_block_desc_k0_m_k1,
779  make_multi_index(0, 0, 0),
781 
782  // B matrix blockwise copy
783  auto b_blockwise_copy =
785  BElementwiseOperation,
789  BBlockTransferThreadClusterLengths_K0_N_K1,
790  BBlockTransferThreadClusterArrangeOrder,
791  FloatAB,
792  FloatAB,
793  decltype(b_k0_n_k1_grid_desc),
794  decltype(b_block_desc_k0_n_k1),
795  BBlockTransferSrcAccessOrder,
797  BBlockTransferSrcVectorDim,
798  2,
799  BBlockTransferSrcScalarPerVector,
800  BBlockTransferDstScalarPerVector_K1,
801  1,
802  1,
803  BThreadTransferSrcResetCoordinateAfterRun,
804  true>(
805  b_k0_n_k1_grid_desc,
806  make_multi_index(k0_block_data_idx_on_grid, n_block_data_idx_on_grid, 0),
807  b_element_op,
808  b_block_desc_k0_n_k1,
809  make_multi_index(0, 0, 0),
811 
812  const index_t num_k_block_main_loop = current_iter_length;
813 
814  gridwise_gemm_pipeline.Run(a_k0_m_k1_grid_desc,
815  a_block_desc_k0_m_k1,
816  a_blockwise_copy,
817  a_grid_buf,
818  a_block_buf,
819  a_block_slice_copy_step,
820  b_k0_n_k1_grid_desc,
821  b_block_desc_k0_n_k1,
822  b_blockwise_copy,
823  b_grid_buf,
824  b_block_buf,
825  b_block_slice_copy_step,
826  blockwise_gemm,
827  c_thread_buf,
828  num_k_block_main_loop);
829 
830  // output: register to global memory
831  {
832  constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL);
833  constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL);
834 
835  constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
836  blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
837 
838  constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
839  blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
840 
841  constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
842  constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
843  constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
844  constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
845  constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
846  constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
847  constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
848  constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
849 
850  constexpr auto c_block_desc_mblock_mpershuffle_nblock_npershuffle =
852 
853  constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle =
855 
856  auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
857  reinterpret_cast<FloatCShuffle*>(p_shared_block),
858  c_block_desc_mblock_mpershuffle_nblock_npershuffle.GetElementSpaceSize());
859 
860  auto c_partial_acc_buf =
861  make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
862  reinterpret_cast<FloatAcc*>(p_workspace) + block_acc_offset,
863  c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
864  .GetElementSpaceSize());
865 
866  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
867  c_block_desc_mblock_mpershuffle_nblock_npershuffle,
868  make_tuple(make_freeze_transform(I0), // freeze mblock
870  make_tuple(CShuffleMRepeatPerShuffle,
871  M1,
872  M2,
873  M3,
874  M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL
875  make_freeze_transform(I0), // freeze nblock
877  make_tuple(CShuffleNRepeatPerShuffle,
878  N1,
879  N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL
883  Sequence<>{},
884  Sequence<1, 3, 7>{}));
885 
886  // calculate origin of thread output tensor on global memory
887  // blockwise GEMM c matrix starting index
888  const auto c_thread_mtx_on_block =
889  blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
890 
891  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
892  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
893 
894  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
896  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
899 
900  const auto m_thread_data_on_block_idx =
901  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
902  make_multi_index(m_thread_data_on_block));
903 
904  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
909 
910  const auto n_thread_data_on_block_idx =
911  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
912  make_multi_index(n_thread_data_on_block));
913 
914  // VGPR to LDS
915  auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
916  FloatAcc,
918  decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
919  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
921  Sequence<CShuffleMRepeatPerShuffle,
922  CShuffleNRepeatPerShuffle,
923  I1,
924  I1,
925  M2,
926  I1,
927  M4,
928  I1>,
930  7,
931  1,
933  1,
934  true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
936  0,
937  m_thread_data_on_block_idx[I1],
938  n_thread_data_on_block_idx[I1],
939  m_thread_data_on_block_idx[I2],
940  m_thread_data_on_block_idx[I3],
941  m_thread_data_on_block_idx[I4],
942  n_thread_data_on_block_idx[I2]),
944 
945  // LDS to global
946  auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1r2<
947  ThisThreadBlock, // index_t BlockSize,
948  CElementwiseOperation, // ElementwiseOperation,
949  // InMemoryDataOperationEnum::Set, // DstInMemOp,
950  Sequence<1,
951  CShuffleMRepeatPerShuffle * MWave * MPerXDL,
952  1,
953  CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths,
954  CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
955  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
956  FloatCShuffle, // typename SrcData,
957  FloatC, // typename DstData,
958  decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle),
959  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
960  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
961  3, // index_t VectorDim,
962  CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
963  false, // bool ThreadTransferSrcResetCoordinateAfterRun,
964  false> // bool ThreadTransferDstResetCoordinateAfterRun
965  {c_block_desc_mblock_mpershuffle_nblock_npershuffle,
966  make_multi_index(0, 0, 0, 0),
967  c_grid_desc_mblock_mperblock_nblock_nperblock,
968  make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
969  0,
970  __builtin_amdgcn_readfirstlane(spatial_idx[I1]),
971  0),
972  c_element_op};
973 
974  // LDS to global partial acc
975  auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2<
976  ThisThreadBlock, // index_t BlockSize,
977  CElementwiseOperation, // ElementwiseOperation,
978  // InMemoryDataOperationEnum::Set, // DstInMemOp,
979  Sequence<1,
980  CShuffleMRepeatPerShuffle * MWave * MPerXDL,
981  1,
982  CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths,
983  CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
984  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
985  FloatCShuffle, // typename SrcData,
986  FloatCShuffle, // typename DstData,
987  decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle),
988  decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
989  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
990  3, // index_t VectorDim,
991  CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
992  false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be false,
993  // othre wise has scratch
994  false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be false,
995  // othre wise has scratch
996  {c_block_desc_mblock_mpershuffle_nblock_npershuffle,
997  make_multi_index(0, 0, 0, 0),
998  c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
999  make_multi_index(0, 0, 0, 0),
1000  c_element_op};
1001 
1002  constexpr auto mxdlperwave_forward_step =
1003  make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0);
1004  constexpr auto nxdlperwave_forward_step =
1005  make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL);
1006  constexpr auto nxdlperwave_backward_step =
1007  make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL);
1008 
1009  static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) {
1010  constexpr auto mxdlperwave = mxdlperwave_iter;
1011 
1012  static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) {
1013  constexpr bool nxdlperwave_forward_sweep =
1014  (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
1015 
1016  constexpr index_t nxdlperwave_value =
1017  nxdlperwave_forward_sweep
1018  ? nxdlperwave_iter
1019  : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle);
1020 
1021  constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
1022 
1023  // make sure it's safe to do ds_write
1024  block_sync_lds();
1025 
1026  // VGPR to LDS
1027  c_thread_copy_vgpr_to_lds.Run(
1028  c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
1029  make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
1030  c_thread_buf,
1031  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1032  c_block_buf);
1033 
1034  // make sure it's safe to do ds_read
1035  block_sync_lds();
1036 
1037  c_block_copy_lds_to_global.SetSrcSliceOrigin(
1038  c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1039  make_tuple(0, 0, 0, 0));
1040 
1041  // LDS to global
1042  if(is_dp_block)
1043  c_block_copy_lds_to_global.template Run<decltype(c_block_buf),
1044  decltype(c_grid_buf),
1046  c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1047  c_block_buf,
1048  c_grid_desc_mblock_mperblock_nblock_nperblock,
1049  c_grid_buf);
1050  else if(is_sk_block)
1051  {
1052  if constexpr(Block2CTileMap::ReductionStrategy ==
1054  {
1055  // constexpr offset
1056  c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
1057  c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1058  make_tuple(0, 0, 0, 0));
1059 
1060  c_block_copy_lds_to_partial_acc.SetDstSliceOrigin(
1061  c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
1062  make_tuple(mxdlperwave.value, 0, nxdlperwave.value, 0));
1063 
1064  c_block_copy_lds_to_partial_acc
1065  .template Run<decltype(c_block_buf),
1066  decltype(c_partial_acc_buf),
1068  c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1069  c_block_buf,
1070  c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
1071  c_partial_acc_buf);
1072  }
1073  else if constexpr(Block2CTileMap::ReductionStrategy ==
1075  {
1076  c_block_copy_lds_to_global
1077  .template Run<decltype(c_block_buf),
1078  decltype(c_grid_buf),
1080  c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1081  c_block_buf,
1082  c_grid_desc_mblock_mperblock_nblock_nperblock,
1083  c_grid_buf);
1084  }
1085  }
1086 
1087  // move on nxdlperwave dimension
1088  if constexpr(nxdlperwave_forward_sweep &&
1089  (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle))
1090  {
1091  c_block_copy_lds_to_global.MoveDstSliceWindow(
1092  c_grid_desc_mblock_mperblock_nblock_nperblock,
1093  nxdlperwave_forward_step);
1094  }
1095  else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
1096  {
1097  c_block_copy_lds_to_global.MoveDstSliceWindow(
1098  c_grid_desc_mblock_mperblock_nblock_nperblock,
1099  nxdlperwave_backward_step);
1100  }
1101  });
1102 
1103  // move on mxdlperwave dimension
1104  if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle)
1105  {
1106  c_block_copy_lds_to_global.MoveDstSliceWindow(
1107  c_grid_desc_mblock_mperblock_nblock_nperblock,
1108  mxdlperwave_forward_step);
1109  }
1110  });
1111 
1112  if constexpr(Block2CTileMap::ReductionStrategy ==
1114  {
1115  if(is_sk_block)
1116  {
1117  // increase the counter for this tile
1118  workgroup_barrier wg_barrier(p_semaphore);
1119  wg_barrier.inc(tile_idx);
1120  }
1121  }
1122  }
1123 
1124  // exit condition
1125  iter_end -= current_iter_length;
1126  if(iter_end <= iter_start)
1127  break;
1128 
1129  if constexpr(Block2CTileMap::ReductionStrategy == StreamKReductionStrategy::Reduction)
1130  {
1131  block_acc_offset -= MPerBlock * NPerBlock;
1132  }
1133  // make sure next loop LDS is ready for use
1134  block_sync_lds();
1135  }
1136  }
1137 
1138  template <typename Layout>
1139  struct LStr
1140  {
1141  static std::string Get() { return ""; }
1142  };
1143 
1144  template <>
1146  {
1147  static std::string Get() { return "R"; }
1148  };
1149 
1150  template <>
1152  {
1153  static std::string Get() { return "C"; }
1154  };
1155 
1156  static std::string GetTypeString()
1157  {
1158  auto str = std::stringstream();
1159 
1160  // clang-format off
1161  str << "GemmXdlStreamK_"
1162  << std::string(ALayout::name)[0]
1163  << std::string(BLayout::name)[0]
1164  << std::string(CLayout::name)[0]
1165  << "_"
1166  << "B" << BlockSize << "_"
1167  << "Vec" << ABlockTransferSrcScalarPerVector << "x"
1168  << BBlockTransferSrcScalarPerVector << "x"
1169  << CBlockTransferScalarPerVector_NWaveNPerXDL << "_"
1170  << MPerBlock << "x"
1171  << NPerBlock << "x"
1172  << K0PerBlock << "x"
1173  << K1 ;
1174  // clang-format on
1175 
1176  return str.str();
1177  }
1178 };
1179 
1180 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:30
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
@ Atomic
Definition: block_to_ctile_map.hpp:1011
@ Reduction
Definition: block_to_ctile_map.hpp:1012
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__global__ void kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB *p_a_grid, const typename GridwiseGemm::FloatAB *p_b_grid, typename GridwiseGemm::FloatC *p_c_grid, void *p_workspace, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, typename GridwiseGemm::Block2CTileMap block_mapping)
Definition: gridwise_gemm_xdlops_streamk.hpp:28
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
__host__ constexpr __device__ auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition: dynamic_buffer.hpp:461
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
unsigned int uint32_t
Definition: stdint.h:126
unsigned char uint8_t
Definition: stdint.h:124
Definition: block_to_ctile_map.hpp:540
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
Definition: gridwise_gemm_xdlops_streamk.hpp:137
index_t K
Definition: gridwise_gemm_xdlops_streamk.hpp:143
const FloatAB * p_b_grid
Definition: gridwise_gemm_xdlops_streamk.hpp:139
void Print() const
Definition: gridwise_gemm_xdlops_streamk.hpp:174
index_t M
Definition: gridwise_gemm_xdlops_streamk.hpp:141
Argument(const FloatAB *p_a_grid_, const FloatAB *p_b_grid_, FloatC *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, uint32_t num_cu, uint32_t occupancy, uint32_t num_sk_blocks_)
Definition: gridwise_gemm_xdlops_streamk.hpp:149
FloatC * p_c_grid
Definition: gridwise_gemm_xdlops_streamk.hpp:140
const FloatAB * p_a_grid
Definition: gridwise_gemm_xdlops_streamk.hpp:138
index_t StrideC
Definition: gridwise_gemm_xdlops_streamk.hpp:146
index_t StrideB
Definition: gridwise_gemm_xdlops_streamk.hpp:145
index_t StrideA
Definition: gridwise_gemm_xdlops_streamk.hpp:144
index_t N
Definition: gridwise_gemm_xdlops_streamk.hpp:142
Block2CTileMap block_mapping
Definition: gridwise_gemm_xdlops_streamk.hpp:147
static std::string Get()
Definition: gridwise_gemm_xdlops_streamk.hpp:1153
static std::string Get()
Definition: gridwise_gemm_xdlops_streamk.hpp:1147
Definition: gridwise_gemm_xdlops_streamk.hpp:1140
static std::string Get()
Definition: gridwise_gemm_xdlops_streamk.hpp:1141
Definition: gridwise_gemm_xdlops_streamk.hpp:112
static constexpr auto I5
Definition: gridwise_gemm_xdlops_streamk.hpp:118
static __device__ void Run(const FloatAB *p_a_grid, const FloatAB *p_b_grid, FloatC *p_c_grid, void *p_workspace, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, Block2CTileMap block_mapping, void *__restrict__ p_shared_block)
Definition: gridwise_gemm_xdlops_streamk.hpp:437
__host__ static constexpr __device__ auto MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_m_n_grid_desc)
Definition: gridwise_gemm_xdlops_streamk.hpp:355
__host__ static __device__ auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA)
Definition: gridwise_gemm_xdlops_streamk.hpp:192
__host__ static __device__ auto CalculateK0(index_t KPad)
Definition: gridwise_gemm_xdlops_streamk.hpp:189
static constexpr auto I0
Definition: gridwise_gemm_xdlops_streamk.hpp:113
Block2CTileMap_ Block2CTileMap
Definition: gridwise_gemm_xdlops_streamk.hpp:132
__host__ static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdlops_streamk.hpp:278
__host__ static __device__ auto CalculateGridSize(const Argument &karg)
Definition: gridwise_gemm_xdlops_streamk.hpp:182
FloatAcc FloatCShuffle
Definition: gridwise_gemm_xdlops_streamk.hpp:130
__host__ static constexpr __device__ auto GetClusterLengthReduction()
Definition: gridwise_gemm_xdlops_streamk.hpp:406
__host__ static constexpr __device__ bool CalculateHasMainK0BlockLoop(index_t K0)
Definition: gridwise_gemm_xdlops_streamk.hpp:346
static constexpr auto N01
Definition: gridwise_gemm_xdlops_streamk.hpp:125
static constexpr auto I6
Definition: gridwise_gemm_xdlops_streamk.hpp:119
__host__ static constexpr __device__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdlops_streamk.hpp:308
static constexpr auto M01
Definition: gridwise_gemm_xdlops_streamk.hpp:124
__host__ static constexpr __device__ auto GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle()
Definition: gridwise_gemm_xdlops_streamk.hpp:381
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdlops_streamk.hpp:250
static std::string GetTypeString()
Definition: gridwise_gemm_xdlops_streamk.hpp:1156
__host__ static constexpr __device__ auto GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle()
Definition: gridwise_gemm_xdlops_streamk.hpp:394
__host__ static constexpr __device__ auto GetPartialAccBlockDescriptor()
Definition: gridwise_gemm_xdlops_streamk.hpp:418
static constexpr auto I2
Definition: gridwise_gemm_xdlops_streamk.hpp:115
static constexpr auto I1
Definition: gridwise_gemm_xdlops_streamk.hpp:114
FloatAB_ FloatAB
Definition: gridwise_gemm_xdlops_streamk.hpp:133
static constexpr auto K1
Definition: gridwise_gemm_xdlops_streamk.hpp:123
static constexpr auto KPerBlock
Definition: gridwise_gemm_xdlops_streamk.hpp:126
__host__ static constexpr __device__ auto MakeCBlockClusterAdaptor(const CGridDesc &c_m_n_grid_desc, index_t, index_t, index_t KBatch)
Definition: gridwise_gemm_xdlops_streamk.hpp:373
FloatAcc_ FloatAcc
Definition: gridwise_gemm_xdlops_streamk.hpp:129
remove_cvref_t< decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1))> CGridDesc_M_N
Definition: gridwise_gemm_xdlops_streamk.hpp:435
static constexpr auto I3
Definition: gridwise_gemm_xdlops_streamk.hpp:116
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdlops_streamk.hpp:128
static constexpr auto I7
Definition: gridwise_gemm_xdlops_streamk.hpp:120
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_streamk.hpp:286
static constexpr auto I4
Definition: gridwise_gemm_xdlops_streamk.hpp:117
FloatC_ FloatC
Definition: gridwise_gemm_xdlops_streamk.hpp:134
__host__ static __device__ auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB)
Definition: gridwise_gemm_xdlops_streamk.hpp:221
__host__ static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdlops_streamk.hpp:270
Definition: gridwise_gemm_pipeline_v3.hpp:11
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
__host__ __device__ void Clear()
Definition: static_buffer.hpp:63
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1r2.hpp:33
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: reduction_operator.hpp:37
Definition: functional2.hpp:33
Definition: tensor_layout.hpp:21
Definition: tensor_layout.hpp:16
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:334
Definition: workgroup_barrier.hpp:7
__device__ void inc(uint32_t offset)
Definition: workgroup_barrier.hpp:62
__device__ void wait_eq(uint32_t offset, uint32_t value)
Definition: workgroup_barrier.hpp:29