/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp Source File
gridwise_gemm_xdlops_streamk.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
20 
21 namespace ck {
22 
23 template <typename GridwiseGemm>
24 __global__ void
25 #if CK_USE_LAUNCH_BOUNDS
27 #endif
28  kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB* p_a_grid,
29  const typename GridwiseGemm::FloatAB* p_b_grid,
30  typename GridwiseGemm::FloatC* p_c_grid,
31  void* p_workspace,
32  index_t M,
33  index_t N,
34  index_t K,
35  index_t StrideA,
36  index_t StrideB,
37  index_t StrideC,
38  typename GridwiseGemm::Block2CTileMap block_mapping)
39 {
40 #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx12__)
41  if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
42  {
43  constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
44 
45  __shared__ uint8_t p_shared[shared_size];
46 
47  GridwiseGemm::Run(p_a_grid,
48  p_b_grid,
49  p_c_grid,
50  p_workspace,
51  M,
52  N,
53  K,
54  StrideA,
55  StrideB,
56  StrideC,
57  block_mapping,
58  static_cast<void*>(p_shared));
59  }
60 #else
61  ignore = p_a_grid;
62  ignore = p_b_grid;
63  ignore = p_c_grid;
64  ignore = p_workspace;
65  ignore = M;
66  ignore = N;
67  ignore = K;
68  ignore = StrideA;
69  ignore = StrideB;
70  ignore = StrideC;
71  ignore = block_mapping;
72 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
73 }
74 
75 template <index_t BlockSize,
76  typename Block2CTileMap_,
77  typename FloatAB_,
78  typename FloatAcc_,
79  typename FloatC_,
80  typename ALayout,
81  typename BLayout,
82  typename CLayout,
83  typename AElementwiseOperation,
84  typename BElementwiseOperation,
85  typename CElementwiseOperation,
86  index_t MPerBlock,
87  index_t NPerBlock,
88  index_t K0PerBlock,
89  index_t MPerXdl,
90  index_t NPerXdl,
91  index_t K1Value,
92  index_t MRepeat,
93  index_t NRepeat,
94  typename ABlockTransferThreadClusterLengths_K0_M_K1,
95  typename ABlockTransferThreadClusterArrangeOrder,
96  typename ABlockTransferSrcAccessOrder,
97  index_t ABlockTransferSrcVectorDim,
98  index_t ABlockTransferSrcScalarPerVector,
99  index_t ABlockTransferDstScalarPerVector_K1,
100  bool AThreadTransferSrcResetCoordinateAfterRun,
101  index_t ABlockLdsExtraM,
102  typename BBlockTransferThreadClusterLengths_K0_N_K1,
103  typename BBlockTransferThreadClusterArrangeOrder,
104  typename BBlockTransferSrcAccessOrder,
105  index_t BBlockTransferSrcVectorDim,
106  index_t BBlockTransferSrcScalarPerVector,
107  index_t BBlockTransferDstScalarPerVector_K1,
108  bool BThreadTransferSrcResetCoordinateAfterRun,
109  index_t BBlockLdsExtraN,
110  index_t CShuffleMRepeatPerShuffle,
111  index_t CShuffleNRepeatPerShuffle,
112  index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
113  typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>
115 {
116  static constexpr auto I0 = Number<0>{};
117  static constexpr auto I1 = Number<1>{};
118  static constexpr auto I2 = Number<2>{};
119  static constexpr auto I3 = Number<3>{};
120  static constexpr auto I4 = Number<4>{};
121  static constexpr auto I5 = Number<5>{};
122  static constexpr auto I6 = Number<6>{};
123  static constexpr auto I7 = Number<7>{};
124 
125  // K1 should be Number<...>
126  static constexpr auto K1 = Number<K1Value>{};
127  static constexpr auto M01 = 1;
128  static constexpr auto N01 = 1;
129  static constexpr auto KPerBlock = K0PerBlock * K1;
130 
132  using FloatAcc = FloatAcc_;
134 
135  using Block2CTileMap = Block2CTileMap_;
136  using FloatAB = FloatAB_;
137  using FloatC = FloatC_;
138 
140  {
151 
152  Argument(const FloatAB* p_a_grid_,
153  const FloatAB* p_b_grid_,
154  FloatC* p_c_grid_,
155  index_t M_,
156  index_t N_,
157  index_t K_,
158  index_t StrideA_,
159  index_t StrideB_,
160  index_t StrideC_,
161  uint32_t num_cu,
162  uint32_t occupancy,
163  uint32_t num_sk_blocks_)
164  : p_a_grid(p_a_grid_),
165  p_b_grid(p_b_grid_),
166  p_c_grid(p_c_grid_),
167  M(M_),
168  N(N_),
169  K(K_),
170  StrideA(StrideA_),
171  StrideB(StrideB_),
172  StrideC(StrideC_),
173  block_mapping(M, N, K, num_cu, occupancy, num_sk_blocks_)
174  {
175  }
176 
177  void Print() const
178  {
179  std::cout << "arg {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
180  << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
181  << std::endl;
182  }
183  };
184 
185  __host__ __device__ static auto CalculateGridSize(const Argument& karg)
186  {
187  return std::make_tuple(math::integer_divide_ceil(karg.N, NPerBlock),
188  math::integer_divide_ceil(karg.M, MPerBlock),
189  karg.k_batch);
190  }
191 
192  __host__ __device__ static auto CalculateK0(index_t KPad) { return KPad / K1; }
193 
194  __host__ __device__ static auto
196  {
197  const index_t K0 = CalculateK0(KPad);
198 
199  const auto a_grid_desc_m_k = [&]() {
201  {
202  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
203  }
205  {
206  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
207  }
208  }();
209 
210  const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
211  a_grid_desc_m_k,
215 
216  return transform_tensor_descriptor(a_grid_desc_m_kpad,
218  make_right_pad_transform(M, MPad - M)),
221  }
222 
223  __host__ __device__ static auto
225  {
226  const index_t K0 = CalculateK0(KPad);
227 
228  const auto b_grid_desc_k_n = [&]() {
230  {
231  return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
232  }
234  {
235  return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
236  }
237  }();
238 
239  const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
240  b_grid_desc_k_n,
244 
245  return transform_tensor_descriptor(b_grid_desc_kpad_n,
247  make_right_pad_transform(N, NPad - N)),
250  }
251 
252  __host__ __device__ static auto
254  {
255  const auto c_grid_desc_m_n = [&]() {
257  {
258  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
259  }
261  {
262  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
263  }
264  }();
265 
266  return transform_tensor_descriptor(c_grid_desc_m_n,
268  make_right_pad_transform(N, NPad - N)),
271  }
272 
273  __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
274  {
275  // A matrix in LDS memory, dst of blockwise copy
279  }
280 
281  __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
282  {
283  // B matrix in LDS memory, dst of blockwise copy
287  }
288 
289  __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
290  {
291  constexpr auto max_lds_align = K1;
292 
293  // LDS allocation for A and B: be careful of alignment
294  constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
295  constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
296 
297  constexpr auto a_block_space_size_aligned =
298  math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
299 
300  constexpr auto b_block_space_size_aligned =
301  math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
302 
303  constexpr auto c_block_size =
305 
306  return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
307  sizeof(FloatAB),
308  c_block_size * sizeof(FloatCShuffle));
309  }
310 
311  static constexpr index_t MXdlPerWave = MRepeat;
312  static constexpr index_t NXdlPerWave = NRepeat;
314 
315  __host__ __device__ static constexpr bool CheckValidity(const Argument& karg)
316  {
318  {
319  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
320  return false;
321  }
322  else
323  {
324  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
325  return false;
326  }
327 
329  {
330  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
331  return false;
332  }
333  else
334  {
335  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
336  return false;
337  }
338 
340  {
341  if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
342  return false;
343  }
344  else
345  {
346  if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
347  return false;
348  }
349 
350  return true;
351  }
352 
353  __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
354  {
355  const bool has_main_k0_block_loop = K0 > K0PerBlock;
356 
357  return has_main_k0_block_loop;
358  }
359 
360  template <typename CGridDesc>
361  __host__ __device__ static constexpr auto
362  MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_m_n_grid_desc)
363  {
364  const auto M = c_m_n_grid_desc.GetLength(I0);
365  const auto N = c_m_n_grid_desc.GetLength(I1);
366 
367  const auto MBlock = M / MPerBlock;
368  const auto NBlock = N / NPerBlock;
369 
371  c_m_n_grid_desc,
376  }
377 
378  // return block_id to C matrix tile idx (m0, n0) mapping
379  template <typename CGridDesc>
380  __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
381  const CGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch)
382  {
384  c_m_n_grid_desc, 8, KBatch);
385  }
386 
387  __host__ __device__ static constexpr auto
389  {
390  constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl);
391  constexpr index_t NWave = NRepeat * NPerXdl == 0 ? 1 : NPerBlock / (NRepeat * NPerXdl);
392 
394  make_tuple(I1,
396  I1,
398  }
399 
400  __host__ __device__ static constexpr auto
402  {
403  constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl);
404  constexpr index_t NWave = NRepeat * NPerXdl == 0 ? 1 : NPerBlock / (NRepeat * NPerXdl);
405 
409  Number<NRepeat / CShuffleNRepeatPerShuffle>{},
411  }
412 
413  __host__ __device__ static constexpr auto GetClusterLengthReduction()
414  {
415  // TODO: assume C is row major
416  // TODO: we always first loop over N, then M
417  constexpr auto NPerBlockPow2 = math::next_power_of_two<NPerBlock>();
418  constexpr auto NPerBlockReduction =
419  NPerBlockPow2 / CBlockTransferScalarPerVector_NWaveNPerXDL;
420  constexpr auto MPerBlockReduction =
421  (BlockSize + NPerBlockReduction - 1) / NPerBlockReduction;
423  }
424 
425  __host__ __device__ static constexpr auto GetPartialAccBlockDescriptor()
426  {
427  const auto c_partial_acc_block_m_n = [&]() {
429  {
430  return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock),
431  make_tuple(NPerBlock, I1));
432  }
434  {
435  return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock),
436  make_tuple(I1, MPerBlock));
437  }
438  }();
439  return c_partial_acc_block_m_n;
440  }
441 
442  using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1))>;
443 
444  __device__ static void Run(const FloatAB* p_a_grid,
445  const FloatAB* p_b_grid,
446  FloatC* p_c_grid,
447  void* p_workspace,
448  index_t M,
449  index_t N,
450  index_t K,
451  index_t StrideA,
452  index_t StrideB,
453  index_t StrideC,
454  Block2CTileMap block_mapping,
455  void* __restrict__ p_shared_block)
456  {
457  uint32_t m = M;
458  uint32_t n = N;
459  uint32_t k = K;
460  uint32_t pad_m = (m + MPerBlock - 1) / MPerBlock * MPerBlock;
461  uint32_t pad_n = (n + NPerBlock - 1) / NPerBlock * NPerBlock;
462  uint32_t pad_k = (k + KPerBlock - 1) / KPerBlock * KPerBlock;
463  uint32_t stride_a = StrideA;
464  uint32_t stride_b = StrideB;
465  uint32_t stride_c = StrideC;
466 
467  const auto a_k0_m_k1_grid_desc = MakeAGridDescriptor_K0_M_K1(m, pad_m, k, pad_k, stride_a);
468  const auto b_k0_n_k1_grid_desc = MakeBGridDescriptor_K0_N_K1(k, pad_k, n, pad_n, stride_b);
469  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(m, pad_m, n, pad_n, stride_c);
470 
471  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
473  const AElementwiseOperation a_element_op = AElementwiseOperation{};
474  const BElementwiseOperation b_element_op = BElementwiseOperation{};
475  const CElementwiseOperation c_element_op = CElementwiseOperation{};
476 
477  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
478  p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize());
479  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
480  p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize());
481  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
482  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
483 
484  // lds max alignment
485  constexpr auto max_lds_align = K1;
486 
487  // A matrix in LDS memory, dst of blockwise copy
488  constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
489 
490  // B matrix in LDS memory, dst of blockwise copy
491  constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
492 
493  auto blockwise_gemm =
495  FloatAB,
496  FloatAB,
497  FloatAcc,
498  decltype(a_block_desc_k0_m_k1),
499  decltype(b_block_desc_k0_n_k1),
500  MPerXdl,
501  NPerXdl,
502  MRepeat,
503  NRepeat,
504  K1>{};
505 
506  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
507 
508  // LDS allocation for A and B: be careful of alignment
509  constexpr auto a_block_space_size =
510  math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
511 
512  FloatAB* p_a_block = static_cast<FloatAB*>(p_shared_block);
513  FloatAB* p_b_block = static_cast<FloatAB*>(p_shared_block) + a_block_space_size;
514 
515  constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
516  constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
517 
518  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
519  p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize());
520  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
521  p_b_block, b_block_desc_k0_n_k1.GetElementSpaceSize());
522 
523  // gridwise GEMM pipeline
524  const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_v3();
525 
526  uint32_t block_idx = block_mapping.get_block_idx();
527  bool is_sk_block = block_idx < block_mapping.sk_num_blocks;
528  bool is_dp_block = block_idx >= block_mapping.dp_start_block_idx &&
529  block_idx < block_mapping.reduction_start_block_idx;
530  bool is_reduction_block = block_idx >= block_mapping.reduction_start_block_idx;
531  bool is_padding_block = block_idx >= block_mapping.sk_num_blocks &&
532  block_idx < block_mapping.dp_start_block_idx;
533  uint32_t iter_start, iter_end;
534  block_mapping.get_block_itr(block_idx, iter_start, iter_end);
535  uint32_t total_iter_length = iter_end - iter_start;
536 
537  if(is_padding_block)
538  return;
539 
540  uint32_t* p_semaphore =
541  reinterpret_cast<uint32_t*>(reinterpret_cast<char*>(p_workspace) +
542  block_mapping.get_workspace_size_for_acc(sizeof(FloatAcc)));
543 
544  if constexpr(Block2CTileMap::ReductionStrategy == StreamKReductionStrategy::Reduction)
545  {
546  if(is_reduction_block)
547  {
548  // descriptors
549  constexpr auto cluster_length_reduce = GetClusterLengthReduction();
550  constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce);
551  const auto reduce_thread_cluster_idx =
552  reduce_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
553  const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0];
554  const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1];
555 
556  constexpr auto MReduceIters =
557  math::integer_divide_ceil(Number<MPerBlock>{}, cluster_length_reduce.At(I0));
558  constexpr auto NReduceIters = math::integer_divide_ceil(
560  cluster_length_reduce.At(I1) *
562 
563  constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed(
565  constexpr auto acc_thread_buf_store_desc = make_naive_tensor_descriptor_packed(
567 
568  constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor();
569 
570  constexpr auto partial_acc_load_step_n = make_multi_index(
571  0, cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
572  constexpr auto partial_acc_load_step_n_reverse =
574  -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
575  CBlockTransferScalarPerVector_NWaveNPerXDL);
576  constexpr auto partial_acc_load_step_m =
577  make_multi_index(cluster_length_reduce.At(I0), 0);
578 
579  constexpr auto partial_acc_store_step_n = make_multi_index(
580  0,
581  0,
582  0,
583  cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
584  constexpr auto partial_acc_store_step_n_reverse =
586  0,
587  0,
588  -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
589  CBlockTransferScalarPerVector_NWaveNPerXDL);
590  constexpr auto partial_acc_store_step_m =
591  make_multi_index(0, cluster_length_reduce.At(I0), 0, 0);
592 
594  FloatAcc,
595  CBlockTransferScalarPerVector_NWaveNPerXDL,
596  true>
597  parcial_acc_buf;
599  FloatAcc,
600  CBlockTransferScalarPerVector_NWaveNPerXDL,
601  true>
602  acc_buf;
603 
604  // start to compute
605  auto reduction_idx = blockIdx.x - block_mapping.reduction_start_block_idx;
606  auto spatial_idx = block_mapping.tile_to_spatial(reduction_idx, m, n);
607 
608  workgroup_barrier wg_barrier(p_semaphore);
609 
610  uint32_t tile_acc_offset_start =
611  block_mapping.get_acc_buffer_offset_from_tile(reduction_idx);
612  uint32_t tile_acc_offset_end =
613  block_mapping.get_acc_buffer_offset_from_tile(reduction_idx + 1);
614 
615  auto acc_load = ThreadwiseTensorSliceTransfer_v2<
616  FloatAcc, // SrcData,
617  FloatAcc, // DstData,
618  decltype(c_partial_acc_block_m_n), // SrcDesc,
619  decltype(acc_thread_buf_load_desc), // DstDesc,
621  Sequence<0, 1>, // DimAccessOrder,
622  1, // SrcVectorDim,
623  CBlockTransferScalarPerVector_NWaveNPerXDL, // SrcScalarPerVector,
624  1, // SrcScalarStrideInVector,
625  false // SrcResetCoordinateAfterRun,
626  >{c_partial_acc_block_m_n,
627  make_multi_index(thread_m_cluster_id,
628  thread_n_cluster_id *
629  CBlockTransferScalarPerVector_NWaveNPerXDL)};
630 
631  auto acc_store = ThreadwiseTensorSliceTransfer_v1r3<
632  FloatAcc, // SrcData,
633  FloatC, // DstData,
634  decltype(acc_thread_buf_store_desc), // SrcDesc,
635  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc,
636  CElementwiseOperation, // ElementwiseOperation,
638  Sequence<0, 1, 2, 3>, // DimAccessOrder,
639  3, // DstVectorDim,
640  CBlockTransferScalarPerVector_NWaveNPerXDL, // DstScalarPerVector,
641  InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp,
642  1, // DstScalarStrideInVector,
643  false // DstResetCoordinateAfterRun,
644  >{c_grid_desc_mblock_mperblock_nblock_nperblock,
645  make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
646  thread_m_cluster_id,
647  __builtin_amdgcn_readfirstlane(spatial_idx[I1]),
648  thread_n_cluster_id *
649  CBlockTransferScalarPerVector_NWaveNPerXDL),
650  CElementwiseOperation{}};
651 
652  // block synchronization
653  wg_barrier.wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start);
654 
655 #if 0
656  if(threadIdx.x == 0) {
657  printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast<int>(blockIdx.x),
658  reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end),
659  __builtin_amdgcn_readfirstlane(spatial_idx[I0]),
660  __builtin_amdgcn_readfirstlane(spatial_idx[I1]));
661  }
662 #endif
663 
664  using Accumulation = ck::detail::
665  AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, FloatAcc>;
666 
667  for(int i_m = 0; i_m < MReduceIters; i_m++)
668  {
669  static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) {
670  acc_buf.Clear();
671  for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++)
672  {
673  auto c_partial_acc_buf =
676  reinterpret_cast<FloatAcc*>(p_workspace) +
677  i * c_partial_acc_block_m_n.GetElementSpaceSize(),
678  c_partial_acc_block_m_n.GetElementSpaceSize());
679 
680  acc_load.Run(c_partial_acc_block_m_n,
681  c_partial_acc_buf,
682  acc_thread_buf_load_desc,
683  make_tuple(I0, I0),
684  parcial_acc_buf);
685 
687  [&](auto i_vec) {
688  constexpr auto offset =
689  acc_thread_buf_load_desc.CalculateOffset(
690  make_tuple(0, i_vec));
691  Accumulation::Calculate(acc_buf(Number<offset>{}),
692  parcial_acc_buf[Number<offset>{}]);
693  });
694  }
695 
696  if(thread_n_cluster_id * CBlockTransferScalarPerVector_NWaveNPerXDL <
697  NPerBlock)
698  {
699  acc_store.Run(acc_thread_buf_store_desc,
700  make_tuple(I0, I0, I0, I0),
701  acc_buf,
702  c_grid_desc_mblock_mperblock_nblock_nperblock,
703  c_grid_buf);
704  }
705  if constexpr(NReduceIters != 1)
706  {
707  if constexpr(i_n_reduce != (NReduceIters - 1))
708  {
709  acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
710  partial_acc_load_step_n);
711  acc_store.MoveDstSliceWindow(
712  c_grid_desc_mblock_mperblock_nblock_nperblock,
713  partial_acc_store_step_n);
714  }
715  else
716  {
717  acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
718  partial_acc_load_step_n_reverse);
719  acc_store.MoveDstSliceWindow(
720  c_grid_desc_mblock_mperblock_nblock_nperblock,
721  partial_acc_store_step_n_reverse);
722  }
723  }
724  });
725  {
726  acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
727  partial_acc_load_step_m);
728  acc_store.MoveDstSliceWindow(c_grid_desc_mblock_mperblock_nblock_nperblock,
729  partial_acc_store_step_m);
730  }
731  }
732  return;
733  }
734  }
735 
736  // offset for last acc buffer of this block
737  uint32_t block_acc_offset =
738  (block_mapping.get_acc_buffer_offset_from_block(block_idx + 1) - 1) * MPerBlock *
739  NPerBlock;
740 
741  while(true)
742  {
743  uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
744  block_mapping.get_current_iter_length(iter_start, iter_end, total_iter_length));
745  uint32_t tile_idx, iter_offset;
746  block_mapping.get_tile_idx_with_offset(iter_end - 1, tile_idx, iter_offset);
747  iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1);
748  auto spatial_idx = block_mapping.tile_to_spatial(tile_idx, m, n);
749 
750  const index_t m_block_data_idx_on_grid =
751  __builtin_amdgcn_readfirstlane(spatial_idx[I0] * MPerBlock);
752 
753  const index_t n_block_data_idx_on_grid =
754  __builtin_amdgcn_readfirstlane(spatial_idx[I1] * NPerBlock);
755 
756  const index_t k0_block_data_idx_on_grid =
757  __builtin_amdgcn_readfirstlane(iter_offset * K0PerBlock);
758 
759  // A matrix blockwise copy
760  auto a_blockwise_copy =
762  AElementwiseOperation,
766  ABlockTransferThreadClusterLengths_K0_M_K1,
767  ABlockTransferThreadClusterArrangeOrder,
768  FloatAB,
769  FloatAB,
770  decltype(a_k0_m_k1_grid_desc),
771  decltype(a_block_desc_k0_m_k1),
772  ABlockTransferSrcAccessOrder,
774  ABlockTransferSrcVectorDim,
775  2,
776  ABlockTransferSrcScalarPerVector,
777  ABlockTransferDstScalarPerVector_K1,
778  1,
779  1,
780  AThreadTransferSrcResetCoordinateAfterRun,
781  true>(
782  a_k0_m_k1_grid_desc,
783  make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0),
784  a_element_op,
785  a_block_desc_k0_m_k1,
786  make_multi_index(0, 0, 0),
788 
789  // B matrix blockwise copy
790  auto b_blockwise_copy =
792  BElementwiseOperation,
796  BBlockTransferThreadClusterLengths_K0_N_K1,
797  BBlockTransferThreadClusterArrangeOrder,
798  FloatAB,
799  FloatAB,
800  decltype(b_k0_n_k1_grid_desc),
801  decltype(b_block_desc_k0_n_k1),
802  BBlockTransferSrcAccessOrder,
804  BBlockTransferSrcVectorDim,
805  2,
806  BBlockTransferSrcScalarPerVector,
807  BBlockTransferDstScalarPerVector_K1,
808  1,
809  1,
810  BThreadTransferSrcResetCoordinateAfterRun,
811  true>(
812  b_k0_n_k1_grid_desc,
813  make_multi_index(k0_block_data_idx_on_grid, n_block_data_idx_on_grid, 0),
814  b_element_op,
815  b_block_desc_k0_n_k1,
816  make_multi_index(0, 0, 0),
818 
819  const index_t num_k_block_main_loop = current_iter_length;
820 
821  gridwise_gemm_pipeline.Run(a_k0_m_k1_grid_desc,
822  a_block_desc_k0_m_k1,
823  a_blockwise_copy,
824  a_grid_buf,
825  a_block_buf,
826  a_block_slice_copy_step,
827  b_k0_n_k1_grid_desc,
828  b_block_desc_k0_n_k1,
829  b_blockwise_copy,
830  b_grid_buf,
831  b_block_buf,
832  b_block_slice_copy_step,
833  blockwise_gemm,
834  c_thread_buf,
835  num_k_block_main_loop);
836 
837  // output: register to global memory
838  {
839  constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl);
840  constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl);
841 
842  constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
843  blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
844 
845  constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
846  blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
847 
848  constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
849  constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
850  constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
851  constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
852  constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
853  constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
854  constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
855  constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
856 
857  constexpr auto c_block_desc_mblock_mpershuffle_nblock_npershuffle =
859 
860  constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle =
862 
863  auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
864  reinterpret_cast<FloatCShuffle*>(p_shared_block),
865  c_block_desc_mblock_mpershuffle_nblock_npershuffle.GetElementSpaceSize());
866 
867  auto c_partial_acc_buf =
868  make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
869  reinterpret_cast<FloatAcc*>(p_workspace) + block_acc_offset,
870  c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
871  .GetElementSpaceSize());
872 
873  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
874  c_block_desc_mblock_mpershuffle_nblock_npershuffle,
875  make_tuple(make_freeze_transform(I0), // freeze mblock
877  make_tuple(CShuffleMRepeatPerShuffle,
878  M1,
879  M2,
880  M3,
881  M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl
882  make_freeze_transform(I0), // freeze nblock
884  make_tuple(CShuffleNRepeatPerShuffle,
885  N1,
886  N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl
890  Sequence<>{},
891  Sequence<1, 3, 7>{}));
892 
893  // calculate origin of thread output tensor on global memory
894  // blockwise GEMM c matrix starting index
895  const auto c_thread_mtx_on_block =
896  blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
897 
898  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
899  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
900 
901  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
903  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
906 
907  const auto m_thread_data_on_block_idx =
908  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
909  make_multi_index(m_thread_data_on_block));
910 
911  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
916 
917  const auto n_thread_data_on_block_idx =
918  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
919  make_multi_index(n_thread_data_on_block));
920 
921  // VGPR to LDS
922  auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
923  FloatAcc,
925  decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
926  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
928  Sequence<CShuffleMRepeatPerShuffle,
929  CShuffleNRepeatPerShuffle,
930  I1,
931  I1,
932  M2,
933  I1,
934  M4,
935  I1>,
937  7,
938  1,
940  1,
941  true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
943  0,
944  m_thread_data_on_block_idx[I1],
945  n_thread_data_on_block_idx[I1],
946  m_thread_data_on_block_idx[I2],
947  m_thread_data_on_block_idx[I3],
948  m_thread_data_on_block_idx[I4],
949  n_thread_data_on_block_idx[I2]),
951 
952  // LDS to global
953  auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1r2<
954  ThisThreadBlock, // index_t BlockSize,
955  CElementwiseOperation, // ElementwiseOperation,
956  // InMemoryDataOperationEnum::Set, // DstInMemOp,
957  Sequence<1,
958  CShuffleMRepeatPerShuffle * MWave * MPerXdl,
959  1,
960  CShuffleNRepeatPerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
961  CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
962  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
963  FloatCShuffle, // typename SrcData,
964  FloatC, // typename DstData,
965  decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle),
966  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
967  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
968  3, // index_t VectorDim,
969  CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
970  false, // bool ThreadTransferSrcResetCoordinateAfterRun,
971  false> // bool ThreadTransferDstResetCoordinateAfterRun
972  {c_block_desc_mblock_mpershuffle_nblock_npershuffle,
973  make_multi_index(0, 0, 0, 0),
974  c_grid_desc_mblock_mperblock_nblock_nperblock,
975  make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
976  0,
977  __builtin_amdgcn_readfirstlane(spatial_idx[I1]),
978  0),
979  c_element_op};
980 
981  // LDS to global partial acc
982  auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2<
983  ThisThreadBlock, // index_t BlockSize,
984  CElementwiseOperation, // ElementwiseOperation,
985  // InMemoryDataOperationEnum::Set, // DstInMemOp,
986  Sequence<1,
987  CShuffleMRepeatPerShuffle * MWave * MPerXdl,
988  1,
989  CShuffleNRepeatPerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
990  CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
991  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
992  FloatCShuffle, // typename SrcData,
993  FloatCShuffle, // typename DstData,
994  decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle),
995  decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
996  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
997  3, // index_t VectorDim,
998  CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
999  false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be false,
1000  // othre wise has scratch
1001  false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be false,
1002  // othre wise has scratch
1003  {c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1004  make_multi_index(0, 0, 0, 0),
1005  c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
1006  make_multi_index(0, 0, 0, 0),
1007  c_element_op};
1008 
1009  constexpr auto mxdlperwave_forward_step =
1010  make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXdl, 0, 0);
1011  constexpr auto nxdlperwave_forward_step =
1012  make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXdl);
1013  constexpr auto nxdlperwave_backward_step =
1014  make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXdl);
1015 
1016  static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) {
1017  constexpr auto mxdlperwave = mxdlperwave_iter;
1018 
1019  static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) {
1020  constexpr bool nxdlperwave_forward_sweep =
1021  (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
1022 
1023  constexpr index_t nxdlperwave_value =
1024  nxdlperwave_forward_sweep
1025  ? nxdlperwave_iter
1026  : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle);
1027 
1028  constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
1029 
1030  // make sure it's safe to do ds_write
1031  block_sync_lds();
1032 
1033  // VGPR to LDS
1034  c_thread_copy_vgpr_to_lds.Run(
1035  c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
1036  make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
1037  c_thread_buf,
1038  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1039  c_block_buf);
1040 
1041  // make sure it's safe to do ds_read
1042  block_sync_lds();
1043 
1044  c_block_copy_lds_to_global.SetSrcSliceOrigin(
1045  c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1046  make_tuple(0, 0, 0, 0));
1047 
1048  // LDS to global
1049  if(is_dp_block)
1050  c_block_copy_lds_to_global.template Run<decltype(c_block_buf),
1051  decltype(c_grid_buf),
1053  c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1054  c_block_buf,
1055  c_grid_desc_mblock_mperblock_nblock_nperblock,
1056  c_grid_buf);
1057  else if(is_sk_block)
1058  {
1059  if constexpr(Block2CTileMap::ReductionStrategy ==
1061  {
1062  // constexpr offset
1063  c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
1064  c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1065  make_tuple(0, 0, 0, 0));
1066 
1067  c_block_copy_lds_to_partial_acc.SetDstSliceOrigin(
1068  c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
1069  make_tuple(mxdlperwave.value, 0, nxdlperwave.value, 0));
1070 
1071  c_block_copy_lds_to_partial_acc
1072  .template Run<decltype(c_block_buf),
1073  decltype(c_partial_acc_buf),
1075  c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1076  c_block_buf,
1077  c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
1078  c_partial_acc_buf);
1079  }
1080  else if constexpr(Block2CTileMap::ReductionStrategy ==
1082  {
1083  c_block_copy_lds_to_global
1084  .template Run<decltype(c_block_buf),
1085  decltype(c_grid_buf),
1087  c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1088  c_block_buf,
1089  c_grid_desc_mblock_mperblock_nblock_nperblock,
1090  c_grid_buf);
1091  }
1092  }
1093 
1094  // move on nxdlperwave dimension
1095  if constexpr(nxdlperwave_forward_sweep &&
1096  (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle))
1097  {
1098  c_block_copy_lds_to_global.MoveDstSliceWindow(
1099  c_grid_desc_mblock_mperblock_nblock_nperblock,
1100  nxdlperwave_forward_step);
1101  }
1102  else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
1103  {
1104  c_block_copy_lds_to_global.MoveDstSliceWindow(
1105  c_grid_desc_mblock_mperblock_nblock_nperblock,
1106  nxdlperwave_backward_step);
1107  }
1108  });
1109 
1110  // move on mxdlperwave dimension
1111  if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle)
1112  {
1113  c_block_copy_lds_to_global.MoveDstSliceWindow(
1114  c_grid_desc_mblock_mperblock_nblock_nperblock,
1115  mxdlperwave_forward_step);
1116  }
1117  });
1118 
1119  if constexpr(Block2CTileMap::ReductionStrategy ==
1121  {
1122  if(is_sk_block)
1123  {
1124  // increase the counter for this tile
1125  workgroup_barrier wg_barrier(p_semaphore);
1126  wg_barrier.inc(tile_idx);
1127  }
1128  }
1129  }
1130 
1131  // exit condition
1132  iter_end -= current_iter_length;
1133  if(iter_end <= iter_start)
1134  break;
1135 
1136  if constexpr(Block2CTileMap::ReductionStrategy == StreamKReductionStrategy::Reduction)
1137  {
1138  block_acc_offset -= MPerBlock * NPerBlock;
1139  }
1140  // make sure next loop LDS is ready for use
1141  block_sync_lds();
1142  }
1143  }
1144 
1145  template <typename Layout>
1146  struct LStr
1147  {
1148  static std::string Get() { return ""; }
1149  };
1150 
1151  template <>
1153  {
1154  static std::string Get() { return "R"; }
1155  };
1156 
1157  template <>
1159  {
1160  static std::string Get() { return "C"; }
1161  };
1162 
1163  static std::string GetTypeString()
1164  {
1165  auto str = std::stringstream();
1166 
1167  // clang-format off
1168  str << "GemmXdlStreamK_"
1169  << std::string(ALayout::name)[0]
1170  << std::string(BLayout::name)[0]
1171  << std::string(CLayout::name)[0]
1172  << "_"
1173  << "B" << BlockSize << "_"
1174  << "Vec" << ABlockTransferSrcScalarPerVector << "x"
1175  << BBlockTransferSrcScalarPerVector << "x"
1176  << CBlockTransferScalarPerVector_NWaveNPerXDL << "_"
1177  << MPerBlock << "x"
1178  << NPerBlock << "x"
1179  << K0PerBlock << "x"
1180  << K1 ;
1181  // clang-format on
1182 
1183  return str.str();
1184  }
1185 };
1186 
1187 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition: device_base.hpp:178
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:268
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
@ Atomic
Definition: block_to_ctile_map.hpp:1012
@ Reduction
Definition: block_to_ctile_map.hpp:1013
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__global__ void kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB *p_a_grid, const typename GridwiseGemm::FloatAB *p_b_grid, typename GridwiseGemm::FloatC *p_c_grid, void *p_workspace, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, typename GridwiseGemm::Block2CTileMap block_mapping)
Definition: gridwise_gemm_xdlops_streamk.hpp:28
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:299
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
__host__ constexpr __device__ auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition: dynamic_buffer.hpp:461
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
unsigned int uint32_t
Definition: stdint.h:126
unsigned char uint8_t
Definition: stdint.h:124
Definition: block_to_ctile_map.hpp:541
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
Definition: gridwise_gemm_xdlops_streamk.hpp:140
index_t StrideB
Definition: gridwise_gemm_xdlops_streamk.hpp:148
index_t StrideC
Definition: gridwise_gemm_xdlops_streamk.hpp:149
FloatC * p_c_grid
Definition: gridwise_gemm_xdlops_streamk.hpp:143
Argument(const FloatAB *p_a_grid_, const FloatAB *p_b_grid_, FloatC *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, uint32_t num_cu, uint32_t occupancy, uint32_t num_sk_blocks_)
Definition: gridwise_gemm_xdlops_streamk.hpp:152
Block2CTileMap block_mapping
Definition: gridwise_gemm_xdlops_streamk.hpp:150
index_t M
Definition: gridwise_gemm_xdlops_streamk.hpp:144
index_t N
Definition: gridwise_gemm_xdlops_streamk.hpp:145
index_t K
Definition: gridwise_gemm_xdlops_streamk.hpp:146
void Print() const
Definition: gridwise_gemm_xdlops_streamk.hpp:177
index_t StrideA
Definition: gridwise_gemm_xdlops_streamk.hpp:147
const FloatAB * p_b_grid
Definition: gridwise_gemm_xdlops_streamk.hpp:142
const FloatAB * p_a_grid
Definition: gridwise_gemm_xdlops_streamk.hpp:141
static std::string Get()
Definition: gridwise_gemm_xdlops_streamk.hpp:1160
static std::string Get()
Definition: gridwise_gemm_xdlops_streamk.hpp:1154
Definition: gridwise_gemm_xdlops_streamk.hpp:1147
static std::string Get()
Definition: gridwise_gemm_xdlops_streamk.hpp:1148
Definition: gridwise_gemm_xdlops_streamk.hpp:115
FloatC_ FloatC
Definition: gridwise_gemm_xdlops_streamk.hpp:137
static constexpr auto I6
Definition: gridwise_gemm_xdlops_streamk.hpp:122
static constexpr auto I5
Definition: gridwise_gemm_xdlops_streamk.hpp:121
static __device__ void Run(const FloatAB *p_a_grid, const FloatAB *p_b_grid, FloatC *p_c_grid, void *p_workspace, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, Block2CTileMap block_mapping, void *__restrict__ p_shared_block)
Definition: gridwise_gemm_xdlops_streamk.hpp:444
FloatAB_ FloatAB
Definition: gridwise_gemm_xdlops_streamk.hpp:136
static constexpr auto N01
Definition: gridwise_gemm_xdlops_streamk.hpp:128
static constexpr auto I3
Definition: gridwise_gemm_xdlops_streamk.hpp:119
__host__ static constexpr __device__ auto GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle()
Definition: gridwise_gemm_xdlops_streamk.hpp:388
__host__ static constexpr __device__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdlops_streamk.hpp:315
__host__ static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdlops_streamk.hpp:281
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_streamk.hpp:289
FloatAcc FloatCShuffle
Definition: gridwise_gemm_xdlops_streamk.hpp:133
static std::string GetTypeString()
Definition: gridwise_gemm_xdlops_streamk.hpp:1163
__host__ static __device__ auto CalculateGridSize(const Argument &karg)
Definition: gridwise_gemm_xdlops_streamk.hpp:185
static constexpr auto I1
Definition: gridwise_gemm_xdlops_streamk.hpp:117
__host__ static __device__ auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA)
Definition: gridwise_gemm_xdlops_streamk.hpp:195
__host__ static constexpr __device__ bool CalculateHasMainK0BlockLoop(index_t K0)
Definition: gridwise_gemm_xdlops_streamk.hpp:353
__host__ static constexpr __device__ auto GetClusterLengthReduction()
Definition: gridwise_gemm_xdlops_streamk.hpp:413
static constexpr auto I4
Definition: gridwise_gemm_xdlops_streamk.hpp:120
__host__ static __device__ auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB)
Definition: gridwise_gemm_xdlops_streamk.hpp:224
Block2CTileMap_ Block2CTileMap
Definition: gridwise_gemm_xdlops_streamk.hpp:135
__host__ static constexpr __device__ auto MakeCBlockClusterAdaptor(const CGridDesc &c_m_n_grid_desc, index_t, index_t, index_t KBatch)
Definition: gridwise_gemm_xdlops_streamk.hpp:380
static constexpr index_t MXdlPerWave
Definition: gridwise_gemm_xdlops_streamk.hpp:311
static constexpr auto I2
Definition: gridwise_gemm_xdlops_streamk.hpp:118
FloatAcc_ FloatAcc
Definition: gridwise_gemm_xdlops_streamk.hpp:132
static constexpr auto K1
Definition: gridwise_gemm_xdlops_streamk.hpp:126
remove_cvref_t< decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1))> CGridDesc_M_N
Definition: gridwise_gemm_xdlops_streamk.hpp:442
static constexpr auto I0
Definition: gridwise_gemm_xdlops_streamk.hpp:116
__host__ static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdlops_streamk.hpp:273
__host__ static constexpr __device__ auto GetPartialAccBlockDescriptor()
Definition: gridwise_gemm_xdlops_streamk.hpp:425
static constexpr auto KPerBlock
Definition: gridwise_gemm_xdlops_streamk.hpp:129
__host__ static constexpr __device__ auto MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_m_n_grid_desc)
Definition: gridwise_gemm_xdlops_streamk.hpp:362
static constexpr auto I7
Definition: gridwise_gemm_xdlops_streamk.hpp:123
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdlops_streamk.hpp:253
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdlops_streamk.hpp:131
__host__ static __device__ auto CalculateK0(index_t KPad)
Definition: gridwise_gemm_xdlops_streamk.hpp:192
static constexpr auto M01
Definition: gridwise_gemm_xdlops_streamk.hpp:127
__host__ static constexpr __device__ auto GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle()
Definition: gridwise_gemm_xdlops_streamk.hpp:401
static constexpr index_t NXdlPerWave
Definition: gridwise_gemm_xdlops_streamk.hpp:312
Definition: gridwise_gemm_pipeline_v3.hpp:11
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
__host__ __device__ void Clear()
Definition: static_buffer.hpp:63
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1r2.hpp:33
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: reduction_operator.hpp:37
Definition: functional2.hpp:33
Definition: tensor_layout.hpp:31
Definition: tensor_layout.hpp:26
Definition: device_base.hpp:197
Definition: unary_element_wise_operation.hpp:334
Definition: workgroup_barrier.hpp:7
__device__ void inc(uint32_t offset)
Definition: workgroup_barrier.hpp:62
__device__ void wait_eq(uint32_t offset, uint32_t value)
Definition: workgroup_barrier.hpp:29