/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp Source File
gridwise_gemm_dpp.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
18 
19 namespace ck {
20 
21 template <typename GridwiseGemm, bool HasMainKBlockLoop>
22 __global__ void
23 #if CK_USE_LAUNCH_BOUNDS
25 #endif
26 #if CK_USE_WAVES_PER_EU
27  __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU)))
28 #endif
29  kernel_gemm_dpp(const typename GridwiseGemm::Argument karg)
30 {
31 #if(defined(__gfx103__) || defined(__gfx11__))
32  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
33 
34  const auto a_grid_desc_ak0_m_ak1 = amd_wave_read_first_lane(
35  GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(karg.M, karg.K, karg.AK0, karg.StrideA));
36  const auto b_grid_desc_bk0_n_bk1 = amd_wave_read_first_lane(
37  GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(karg.K, karg.N, karg.BK0, karg.StrideB));
38  const auto c_grid_desc_m_n = amd_wave_read_first_lane(
39  GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC));
40 
41  GridwiseGemm::template Run<HasMainKBlockLoop>(karg.p_a_grid,
42  karg.p_b_grid,
43  karg.p_c_grid,
44  p_shared,
45  a_grid_desc_ak0_m_ak1,
46  b_grid_desc_bk0_n_bk1,
47  c_grid_desc_m_n);
48 #else
49  ignore = karg;
50 #endif
51 }
52 
53 template <index_t BlockSize,
54  typename ABDataType,
55  typename AccDataType,
56  typename CDataType,
57  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
58  typename ALayout,
59  typename BLayout,
60  typename CLayout,
61  typename AElementwiseOperation,
62  typename BElementwiseOperation,
63  typename CElementwiseOperation,
65  index_t MPerBlock,
66  index_t NPerBlock,
67  index_t KPerBlock,
68  index_t MPerDpp,
69  index_t NPerDpp,
70  index_t AK1Value,
71  index_t BK1Value,
72  index_t MDppPerWave,
73  index_t NDppPerWave,
74  typename ABlockTransferThreadClusterLengths_K0_M_K1,
75  typename ABlockTransferThreadClusterArrangeOrder,
76  typename ABlockTransferSrcAccessOrder,
77  index_t ABlockTransferSrcVectorDim,
78  index_t ABlockTransferSrcScalarPerVector,
79  index_t ABlockTransferDstScalarPerVector_K1,
80  bool AThreadTransferSrcResetCoordinateAfterRun,
81  bool ABlockLdsExtraM,
82  typename BBlockTransferThreadClusterLengths_K0_N_K1,
83  typename BBlockTransferThreadClusterArrangeOrder,
84  typename BBlockTransferSrcAccessOrder,
85  index_t BBlockTransferSrcVectorDim,
86  index_t BBlockTransferSrcScalarPerVector,
87  index_t BBlockTransferDstScalarPerVector_K1,
88  bool BThreadTransferSrcResetCoordinateAfterRun,
89  bool BBlockLdsExtraN,
90  typename CThreadTransferSrcDstAccessOrder,
91  index_t CThreadTransferSrcDstVectorDim,
92  index_t CThreadTransferDstScalarPerVector,
93  index_t NumGemmKPrefetchStage = 1,
96 {
97  static constexpr auto I0 = Number<0>{};
98  static constexpr auto I1 = Number<1>{};
99  static constexpr auto I2 = Number<2>{};
100  static constexpr auto I3 = Number<3>{};
101  static constexpr auto I4 = Number<4>{};
102  static constexpr auto I5 = Number<5>{};
103 
104  static constexpr auto AK1 = Number<AK1Value>{};
105  static constexpr auto BK1 = Number<BK1Value>{};
106  static constexpr auto AK0PerBlock = Number<KPerBlock / AK1Value>{};
107  static constexpr auto BK0PerBlock = Number<KPerBlock / BK1Value>{};
108 
109  static constexpr auto max_lds_align = math::lcm(AK1, BK1);
110 
112  // return block_id to C matrix tile idx (m0, n0) mapping
114 
115  __host__ static auto CalculateGridSize(index_t M, index_t N)
116  {
117  return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
118  }
119 
120  __host__ static auto CalculateMPadded(index_t M)
121  {
122  return math::integer_divide_ceil(M, MPerBlock) * MPerBlock;
123  }
124 
125  __host__ static auto CalculateNPadded(index_t N)
126  {
127  return math::integer_divide_ceil(N, NPerBlock) * NPerBlock;
128  }
129 
130  __host__ static auto CalculateAK0(index_t K) { return math::integer_divide_floor(K, AK1Value); }
131  __host__ static auto CalculateBK0(index_t K) { return math::integer_divide_floor(K, BK1Value); }
132 
133  // Argument
134  struct Problem
135  {
136  __host__ Problem(index_t M_,
137  index_t N_,
138  index_t K_,
139  index_t StrideA_,
140  index_t StrideB_,
141  index_t StrideC_)
142  : M{M_},
143  N{N_},
144  K{K_},
145  StrideA{StrideA_},
146  StrideB{StrideB_},
147  StrideC{StrideC_},
150  AK0{CalculateAK0(K)},
151  BK0{CalculateBK0(K)}
152  {
153  }
154 
155  __host__ void Print() const
156  {
157  std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
158  << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
159  << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
160  << "AK0:" << AK0 << ", " << "BK0:" << BK0 << "}" << std::endl;
161  }
162 
173  };
174 
175  // Argument
177  {
178  __host__ Argument(const ABDataType* p_a_grid_,
179  const ABDataType* p_b_grid_,
180  CDataType* p_c_grid_,
181  index_t M_,
182  index_t N_,
183  index_t K_,
184  index_t StrideA_,
185  index_t StrideB_,
186  index_t StrideC_)
187  : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
188  p_a_grid{p_a_grid_},
189  p_b_grid{p_b_grid_},
190  p_c_grid{p_c_grid_}
191  {
192  }
193 
194  const ABDataType* p_a_grid;
195  const ABDataType* p_b_grid;
196  CDataType* p_c_grid;
197  };
198 
200  decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
201 
202  __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
203  {
204  // A matrix in LDS memory, dst of blockwise copy
205  constexpr auto a_block_desc_ak0_m_ak1 = [&]() {
206  if constexpr(ABlockLdsExtraM)
207  {
211  }
212  else
213  {
216  }
217  }();
218 
219  return a_block_desc_ak0_m_ak1;
220  }
221 
222  __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
223  {
224  // B matrix in LDS memory, dst of blockwise copy
225  constexpr auto b_block_desc_bk0_n_bk1 = [&]() {
226  if constexpr(BBlockLdsExtraN)
227  {
231  }
232  else
233  {
236  }
237  }();
238 
239  return b_block_desc_bk0_n_bk1;
240  }
241 
242  __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
243  {
244  // LDS allocation for A and B: be careful of alignment
245  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
246  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
247 
248  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
249  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
250  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
251  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
252 
253  return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(ABDataType);
254  }
255 
256  __host__ static constexpr bool CheckValidity(const Problem& problem)
257  {
258  static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value,
259  "Wrong! AK1 must be known at the time of compilation.");
260  static_assert(is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value,
261  "Wrong! BK1 must be known at the time of compilation.");
262 
263  static_assert(
264  MPerBlock % (MPerDpp * MDppPerWave) == 0,
265  "Invalid tuning parameters! MPerBlock must be divisible by MPerDpp * MDppPerWave.");
266  static_assert(
267  NPerBlock % (NPerDpp * NDppPerWave) == 0,
268  "Invalid tuning parameters! NPerBlock must be divisible by NPerDpp * NDppPerWave.");
269 
270  static_assert(
271  KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0,
272  "Invalid tuning parameters! KPerBlock must be divisible by both AK1 and BK1.");
273 
274  static_assert(AK1Value % ABlockTransferDstScalarPerVector_K1 == 0,
275  "Invalid tuning parameters! AK1Value must be divisible by "
276  "ABlockTransferDstScalarPerVector_K1");
277 
278  static_assert(BK1Value % BBlockTransferDstScalarPerVector_K1 == 0,
279  "Invalid tuning parameters! BK1Value must be divisible by "
280  "BBlockTransferDstScalarPerVector_K1");
281 
282  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
286  {
287  if(!(problem.M % MPerBlock == 0))
288  {
289  return false;
290  }
291  }
292 
293  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
297  {
298  if(!(problem.N % NPerBlock == 0))
299  {
300  return false;
301  }
302  }
303 
305  {
306  if(problem.K % ABlockTransferSrcScalarPerVector != 0)
307  {
308  return false;
309  }
310  }
311  else
312  {
313  if(problem.M % ABlockTransferSrcScalarPerVector != 0)
314  {
315  return false;
316  }
317  }
318 
320  {
321  if(problem.N % BBlockTransferSrcScalarPerVector != 0)
322  {
323  return false;
324  }
325  }
326  else
327  {
328  if(problem.K % BBlockTransferSrcScalarPerVector != 0)
329  {
330  return false;
331  }
332  }
333 
334  if(problem.K % KPerBlock != 0)
335  {
336  return false;
337  }
338 
339  // check gridwise gemm pipeline
340  const auto num_k_loop = problem.K / KPerBlock;
341  if(!GridwiseGemmPipe::IsSupported(num_k_loop))
342  {
343  return false;
344  }
345 
346  return true;
347  }
348 
349  __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
350  {
351  const auto num_loop = K / KPerBlock;
352 
353  return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
354  }
355 
356  template <typename CGridDesc>
357  __host__ __device__ static constexpr auto
358  MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(const CGridDesc& c_grid_desc_m_n)
359  {
360  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
361  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
362 
363  constexpr index_t KPack = math::max(
365 
366  using BlockwiseGemm =
368  ABDataType,
369  AccDataType,
370  decltype(a_block_desc_ak0_m_ak1),
371  decltype(b_block_desc_bk0_n_bk1),
372  MPerDpp,
373  NPerDpp,
374  MDppPerWave,
375  NDppPerWave,
376  KPack>;
377 
378  return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(c_grid_desc_m_n);
379  }
380 
381  static constexpr auto matrix_padder =
383  MPerBlock, NPerBlock, KPerBlock};
384 
385  __device__ static auto
387  {
388  const auto a_grid_desc_mraw_kraw = [&]() {
390  {
391  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
392  }
394  {
395  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
396  }
397  }();
398 
399  const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
401  a_grid_desc_m_k,
406  }
407 
408  __device__ static auto
410  {
411  const auto b_grid_desc_nraw_kraw = [&]() {
413  {
414  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
415  }
417  {
418  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
419  }
420  }();
421 
422  const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
424  b_grid_desc_n_k,
426  make_unmerge_transform(make_tuple(BK0, BK1Value))),
429  }
430 
431  __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
432  {
433  const auto c_grid_desc_mraw_nraw = [&]() {
435  {
436  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
437  }
439  {
440  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
441  }
442  }();
443 
444  return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
445  }
446 
447  template <bool HasMainKBlockLoop,
448  typename AGridDesc_AK0_M_AK1,
449  typename BGridDesc_BK0_N_BK1,
450  typename CGridDesc_M_N>
451  __device__ static void Run(const ABDataType* __restrict__ p_a_grid,
452  const ABDataType* __restrict__ p_b_grid,
453  CDataType* __restrict__ p_c_grid,
454  void* __restrict__ p_shared,
455  const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
456  const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
457  const CGridDesc_M_N& c_grid_desc_m_n)
458  {
459  const auto c_grid_desc_m0_n0_m1_n1_m2_n2 =
460  MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(c_grid_desc_m_n);
461 
462  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
463  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
464  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
465  p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
466  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
467  p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_n2.GetElementSpaceSize());
468 
469  const AElementwiseOperation a_element_op{};
470  const BElementwiseOperation b_element_op{};
471  const CElementwiseOperation c_element_op{};
472 
473  const auto block_2_ctile_map =
474  Block2CTileMap{c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)};
475 
476  // divide block work by [M, N]
477  const auto block_work_idx =
478  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
479 
480  if(!block_2_ctile_map.ValidCTileIndex(
481  block_work_idx,
482  make_tuple(c_grid_desc_m0_n0_m1_n1_m2_n2.GetLength(I0),
483  c_grid_desc_m0_n0_m1_n1_m2_n2.GetLength(I1))))
484  {
485  return;
486  }
487 
488  // HACK: this force m/n_block_data_idx_on_grid into SGPR
489  const index_t m_block_data_idx_on_grid =
490  __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
491  const index_t n_block_data_idx_on_grid =
492  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
493 
494  // A matrix in LDS memory, dst of blockwise copy
495  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
496  // B matrix in LDS memory, dst of blockwise copy
497  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
498 
499  auto a_blockwise_copy =
501  AElementwiseOperation,
505  ABlockTransferThreadClusterLengths_K0_M_K1,
506  ABlockTransferThreadClusterArrangeOrder,
507  ABDataType,
508  ABDataType,
509  decltype(a_grid_desc_ak0_m_ak1),
510  decltype(a_block_desc_ak0_m_ak1),
511  ABlockTransferSrcAccessOrder,
513  ABlockTransferSrcVectorDim,
514  2,
515  ABlockTransferSrcScalarPerVector,
516  ABlockTransferDstScalarPerVector_K1,
517  1,
518  1,
519  AThreadTransferSrcResetCoordinateAfterRun,
520  true,
521  NumGemmKPrefetchStage>(
522  a_grid_desc_ak0_m_ak1,
523  make_multi_index(0, m_block_data_idx_on_grid, 0),
524  a_element_op,
525  a_block_desc_ak0_m_ak1,
526  make_multi_index(0, 0, 0),
528 
529  auto b_blockwise_copy =
531  BElementwiseOperation,
535  BBlockTransferThreadClusterLengths_K0_N_K1,
536  BBlockTransferThreadClusterArrangeOrder,
537  ABDataType,
538  ABDataType,
539  decltype(b_grid_desc_bk0_n_bk1),
540  decltype(b_block_desc_bk0_n_bk1),
541  BBlockTransferSrcAccessOrder,
543  BBlockTransferSrcVectorDim,
544  2,
545  BBlockTransferSrcScalarPerVector,
546  BBlockTransferDstScalarPerVector_K1,
547  1,
548  1,
549  BThreadTransferSrcResetCoordinateAfterRun,
550  true,
551  NumGemmKPrefetchStage>(
552  b_grid_desc_bk0_n_bk1,
553  make_multi_index(0, n_block_data_idx_on_grid, 0),
554  b_element_op,
555  b_block_desc_bk0_n_bk1,
556  make_multi_index(0, 0, 0),
558 
559  // GEMM definition
560  // c_mtx += transpose(a_mtx) * b_mtx
561  // a_mtx[AK0PerBlock, MPerBlock] is in LDS
562  // b_mtx[BK0PerBlock, NPerBlock] is in LDS
563  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
564  // register
565  constexpr index_t KPack = math::max(
567  auto blockwise_gemm =
569  ABDataType,
570  AccDataType,
571  decltype(a_block_desc_ak0_m_ak1),
572  decltype(b_block_desc_bk0_n_bk1),
573  MPerDpp,
574  NPerDpp,
575  MDppPerWave,
576  NDppPerWave,
577  KPack>();
578 
579  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
580 
581  // LDS allocation for A and B: be careful of alignment
582  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
583  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
584 
585  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
586  static_cast<ABDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
587 
588  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
589  static_cast<ABDataType*>(p_shared) + a_block_space_size_aligned,
590  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
591 
592  constexpr auto a_block_slice_copy_step = make_multi_index(AK0PerBlock, 0, 0);
593  constexpr auto b_block_slice_copy_step = make_multi_index(BK0PerBlock, 0, 0);
594 
595  // gridwise GEMM pipeline
596  const auto AK0 = a_grid_desc_ak0_m_ak1.GetLength(I0);
597  // (AK0 / AK0PerBlock) is always equal to (BK0 / BK0PerBlock)
598  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(AK0 / AK0PerBlock);
599 
600  GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
601  a_block_desc_ak0_m_ak1,
602  a_blockwise_copy,
603  a_grid_buf,
604  a_block_buf,
605  a_block_slice_copy_step,
606  b_grid_desc_bk0_n_bk1,
607  b_block_desc_bk0_n_bk1,
608  b_blockwise_copy,
609  b_grid_buf,
610  b_block_buf,
611  b_block_slice_copy_step,
612  blockwise_gemm,
613  c_thread_buf,
614  num_k_block_main_loop);
615 
616  // output: register to global memory
617  {
618  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2 =
619  blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2();
620 
621  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
622  blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2();
623 
624  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
625  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
626  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
627  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
628  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I4);
629  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I5);
630 
631  constexpr auto MPerThread = c_thread_desc_m0_n0_m1_n1_m2_n2.GetLength(I4);
632  constexpr auto NPerThread = c_thread_desc_m0_n0_m1_n1_m2_n2.GetLength(I5);
633 
634  // calculate origin of thread output tensor on global memory
635  // blockwise GEMM c matrix starting index
636  const auto c_thread_mtx_on_block =
637  blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0);
638 
639  const index_t m_thread_data_on_grid =
640  m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
641 
642  const index_t n_thread_data_on_grid =
643  n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
644 
645  const auto m_thread_data_on_grid_to_m0_m1_m2_adaptor = make_single_stage_tensor_adaptor(
649 
650  const auto m_thread_data_on_grid_idx =
651  m_thread_data_on_grid_to_m0_m1_m2_adaptor.CalculateBottomIndex(
652  make_multi_index(m_thread_data_on_grid));
653 
654  const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
658 
659  const auto n_thread_data_on_grid_idx =
660  n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
661  make_multi_index(n_thread_data_on_grid));
662 
663  auto c_thread_copy =
665  CDataType,
666  decltype(c_thread_desc_m0_n0_m1_n1_m2_n2),
667  decltype(c_grid_desc_m0_n0_m1_n1_m2_n2),
668  CElementwiseOperation,
670  CThreadTransferSrcDstAccessOrder,
671  CThreadTransferSrcDstVectorDim,
672  CThreadTransferDstScalarPerVector,
673  CGlobalMemoryDataOperation,
674  1,
675  true>{
676  c_grid_desc_m0_n0_m1_n1_m2_n2,
677  make_multi_index(m_thread_data_on_grid_idx[I0],
678  n_thread_data_on_grid_idx[I0],
679  m_thread_data_on_grid_idx[I1],
680  n_thread_data_on_grid_idx[I1],
681  m_thread_data_on_grid_idx[I2],
682  n_thread_data_on_grid_idx[I2]),
683  c_element_op};
684 
685  c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_n2,
686  make_tuple(I0, I0, I0, I0, I0, I0),
687  c_thread_buf,
688  c_grid_desc_m0_n0_m1_n1_m2_n2,
689  c_grid_buf);
690  }
691  }
692 };
693 
694 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:30
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ auto integer_divide_floor(X x, Y y)
Definition: math.hpp:66
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:276
__global__ void kernel_gemm_dpp(const typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_dpp.hpp:29
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition: amd_wave_read_first_lane.hpp:100
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:18
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:295
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Definition: blockwise_gemm_dpp.hpp:33
Definition: dpp_gemm.hpp:322
Definition: gridwise_gemm_dpp.hpp:177
const ABDataType * p_a_grid
Definition: gridwise_gemm_dpp.hpp:194
const ABDataType * p_b_grid
Definition: gridwise_gemm_dpp.hpp:195
CDataType * p_c_grid
Definition: gridwise_gemm_dpp.hpp:196
__host__ Argument(const ABDataType *p_a_grid_, const ABDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_dpp.hpp:178
Definition: gridwise_gemm_dpp.hpp:135
index_t NPadded
Definition: gridwise_gemm_dpp.hpp:170
index_t BK0
Definition: gridwise_gemm_dpp.hpp:172
index_t StrideB
Definition: gridwise_gemm_dpp.hpp:167
index_t N
Definition: gridwise_gemm_dpp.hpp:164
index_t K
Definition: gridwise_gemm_dpp.hpp:165
index_t StrideC
Definition: gridwise_gemm_dpp.hpp:168
index_t M
Definition: gridwise_gemm_dpp.hpp:163
index_t AK0
Definition: gridwise_gemm_dpp.hpp:171
index_t MPadded
Definition: gridwise_gemm_dpp.hpp:169
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_dpp.hpp:136
__host__ void Print() const
Definition: gridwise_gemm_dpp.hpp:155
index_t StrideA
Definition: gridwise_gemm_dpp.hpp:166
Definition: gridwise_gemm_dpp.hpp:96
static __host__ auto CalculateAK0(index_t K)
Definition: gridwise_gemm_dpp.hpp:130
static __device__ void Run(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, CDataType *__restrict__ p_c_grid, void *__restrict__ p_shared, const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dpp.hpp:451
static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
Definition: gridwise_gemm_dpp.hpp:431
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(const CGridDesc &c_grid_desc_m_n)
Definition: gridwise_gemm_dpp.hpp:358
static constexpr auto BK0PerBlock
Definition: gridwise_gemm_dpp.hpp:107
static __host__ auto CalculateBK0(index_t K)
Definition: gridwise_gemm_dpp.hpp:131
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_dpp.hpp:349
static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t N, index_t BK0, index_t StrideB)
Definition: gridwise_gemm_dpp.hpp:409
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_dpp.hpp:111
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_dpp.hpp:115
static constexpr auto I4
Definition: gridwise_gemm_dpp.hpp:101
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_dpp.hpp:256
static constexpr auto matrix_padder
Definition: gridwise_gemm_dpp.hpp:381
static constexpr auto I5
Definition: gridwise_gemm_dpp.hpp:102
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_dpp.hpp:120
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage >())> GridwiseGemmPipe
Definition: gridwise_gemm_dpp.hpp:200
static constexpr auto AK0PerBlock
Definition: gridwise_gemm_dpp.hpp:106
static constexpr auto I3
Definition: gridwise_gemm_dpp.hpp:100
__host__ static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_dpp.hpp:222
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_dpp.hpp:125
static constexpr auto BK1
Definition: gridwise_gemm_dpp.hpp:105
static constexpr auto I2
Definition: gridwise_gemm_dpp.hpp:99
static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t K, index_t AK0, index_t StrideA)
Definition: gridwise_gemm_dpp.hpp:386
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_dpp.hpp:242
static constexpr auto I1
Definition: gridwise_gemm_dpp.hpp:98
static constexpr auto I0
Definition: gridwise_gemm_dpp.hpp:97
static constexpr auto AK1
Definition: gridwise_gemm_dpp.hpp:104
static constexpr auto max_lds_align
Definition: gridwise_gemm_dpp.hpp:109
__host__ static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_dpp.hpp:202
Definition: sequence.hpp:43
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: is_known_at_compile_time.hpp:14
Definition: type.hpp:177
Definition: device_base.hpp:51
Definition: matrix_padder.hpp:180
Definition: unary_element_wise_operation.hpp:334