/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp Source File
gridwise_gemm_xdlops_v2r4.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
16 
17 namespace ck {
18 
19 template <typename GridwiseGemm,
20  typename FloatAB,
21  typename FloatC,
22  typename ABK0MK1GridDesc,
23  typename BBK0NK1GridDesc,
24  typename CM0N0M1N1M2M3M4N2GridDesc,
25  typename AElementwiseOperation,
26  typename BElementwiseOperation,
27  typename CElementwiseOperation,
28  typename CBlockClusterAdaptor,
29  bool HasMainKBlockLoop>
30 __global__ void
31 #if CK_USE_LAUNCH_BOUNDS
33 #endif
34  kernel_gemm_xdlops_v2r4(const FloatAB* __restrict__ p_a_grid,
35  const FloatAB* __restrict__ p_b_grid,
36  FloatC* __restrict__ p_c_grid,
37  const ABK0MK1GridDesc a_b_k0_m_k1_grid_desc,
38  const BBK0NK1GridDesc b_b_k0_n_k1_grid_desc,
39  const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
40  const AElementwiseOperation a_element_op,
41  const BElementwiseOperation b_element_op,
42  const CElementwiseOperation c_element_op,
43  const CBlockClusterAdaptor c_block_cluster_adaptor)
44 {
45 #ifdefined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
46  defined(__gfx12__)
47  if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
48  {
49  constexpr index_t shared_block_size =
50  GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
51 
52  __shared__ FloatAB p_shared_block[shared_block_size];
53 
54  GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
55  p_b_grid,
56  p_c_grid,
57  p_shared_block,
58  a_b_k0_m_k1_grid_desc,
59  b_b_k0_n_k1_grid_desc,
60  c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
61  a_element_op,
62  b_element_op,
63  c_element_op,
64  c_block_cluster_adaptor);
65  }
66 #else
67  ignore = p_a_grid;
68  ignore = p_b_grid;
69  ignore = p_c_grid;
70  ignore = a_b_k0_m_k1_grid_desc;
71  ignore = b_b_k0_n_k1_grid_desc;
72  ignore = c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc;
73  ignore = a_element_op;
74  ignore = b_element_op;
75  ignore = c_element_op;
76  ignore = c_block_cluster_adaptor;
77 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
78 }
79 
80 template <index_t BlockSize,
81  typename FloatAB,
82  typename FloatAcc,
83  typename FloatC,
84  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
85  typename ABK0MK1GridDesc,
86  typename BBK0NK1GridDesc,
87  typename CMNGridDesc,
88  typename AElementwiseOperation,
89  typename BElementwiseOperation,
90  typename CElementwiseOperation,
91  index_t MPerBlock,
92  index_t NPerBlock,
93  index_t K0PerBlock,
94  index_t MPerXDL,
95  index_t NPerXDL,
96  index_t K1Value,
97  index_t MRepeat,
98  index_t NRepeat,
99  typename ABlockTransferThreadClusterLengths_K0_M_K1,
100  typename ABlockTransferThreadClusterArrangeOrder,
101  typename ABlockTransferSrcAccessOrder,
102  index_t ABlockTransferSrcVectorDim,
103  index_t ABlockTransferSrcScalarPerVector,
104  index_t ABlockTransferDstScalarPerVector_K1,
105  bool AThreadTransferSrcResetCoordinateAfterRun,
106  bool ABlockLdsExtraM,
107  typename BBlockTransferThreadClusterLengths_K0_N_K1,
108  typename BBlockTransferThreadClusterArrangeOrder,
109  typename BBlockTransferSrcAccessOrder,
110  index_t BBlockTransferSrcVectorDim,
111  index_t BBlockTransferSrcScalarPerVector,
112  index_t BBlockTransferDstScalarPerVector_K1,
113  bool BThreadTransferSrcResetCoordinateAfterRun,
114  bool BBlockLdsExtraN,
115  typename CThreadTransferSrcDstAccessOrder,
116  index_t CThreadTransferSrcDstVectorDim,
117  index_t CThreadTransferDstScalarPerVector>
119 {
120  static constexpr auto I0 = Number<0>{};
121  static constexpr auto I1 = Number<1>{};
122  static constexpr auto I2 = Number<2>{};
123  static constexpr auto I3 = Number<3>{};
124  static constexpr auto I4 = Number<4>{};
125  static constexpr auto I5 = Number<5>{};
126  static constexpr auto I6 = Number<6>{};
127  static constexpr auto I7 = Number<7>{};
128 
129  // K1 should be Number<...>
130  static constexpr auto K1 = Number<K1Value>{};
131 
133 
134  __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
135  {
136  constexpr auto max_lds_align = K1;
137 
138  // A matrix in LDS memory, dst of blockwise copy
139  constexpr auto a_k0_m_k1_block_desc = [&]() {
140  if constexpr(ABlockLdsExtraM)
141  {
145  }
146  else
147  {
149  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
150  }
151  }();
152 
153  // B matrix in LDS memory, dst of blockwise copy
154  constexpr auto b_k0_n_k1_block_desc = [&]() {
155  if constexpr(BBlockLdsExtraN)
156  {
160  }
161  else
162  {
164  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
165  }
166  }();
167 
168  // LDS allocation for A and B: be careful of alignment
169  constexpr auto a_block_space_size =
170  math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
171 
172  constexpr auto b_block_space_size =
173  math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
174 
175  return (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
176  }
177 
178  template <
179  InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set>
180  __device__ static bool constexpr IsValidCompilationParameter()
181  {
182  return ck::tensor_operation::device::IsValidGemmCompilationParameter<
183  BlockSize,
184  MPerBlock,
185  NPerBlock,
186  MPerXdl,
187  NPerXdl,
188  MXdlPerWave,
189  NXdlPerWave,
190  FloatC,
191  CGlobalMemoryDataOperation>();
192  }
193 
194  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
195  template <typename Block2CTileMap>
196  __host__ __device__ static constexpr bool
197  CheckValidity(const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc,
198  const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc,
199  const CMNGridDesc& c_m_n_grid_desc,
200  const Block2CTileMap& block_2_ctile_map)
201  {
202  static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
203  "wrong! K1 need to be known at compile-time");
204 
205  static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) &&
206  (NPerBlock % (NRepeat * NPerXDL)) == 0,
207  "Invalid tuning param!");
208 
209  const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2);
210  const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2);
211  const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
212  const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0);
213 
214  if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
215  K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) &&
216  K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) &&
217  K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) &&
218  KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0)))
219  return false;
220 
221  if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
222  return false;
223 
224  if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc))
225  {
226  return false;
227  }
228 
229  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
230  return true;
231  }
232 
233  __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
234  {
235  const bool has_main_k0_block_loop = K0 > K0PerBlock;
236 
237  return has_main_k0_block_loop;
238  }
239 
240  __host__ __device__ static constexpr auto
241  MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
242  {
243  constexpr auto max_lds_align = K1;
244 
245  // A matrix in LDS memory, dst of blockwise copy
246  constexpr auto a_k0_m_k1_block_desc = [&]() {
247  if constexpr(ABlockLdsExtraM)
248  {
252  }
253  else
254  {
256  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
257  }
258  }();
259 
260  // B matrix in LDS memory, dst of blockwise copy
261  constexpr auto b_k0_n_k1_block_desc = [&]() {
262  if constexpr(BBlockLdsExtraN)
263  {
267  }
268  else
269  {
271  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
272  }
273  }();
274 
275  using BlockwiseGemm =
277  FloatAB,
278  FloatAcc,
279  decltype(a_k0_m_k1_block_desc),
280  decltype(b_k0_n_k1_block_desc),
281  MPerXDL,
282  NPerXDL,
283  MRepeat,
284  NRepeat,
285  K1>;
286 
287  return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_m_n_grid_desc);
288  }
289 
290  // return block_id to C matrix tile idx (m0, n0) mapping
291  __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
292  const CMNGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch)
293  {
295  c_m_n_grid_desc, 8, KBatch);
296  }
297 
299  using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1));
300 
301  template <bool HasMainKBlockLoop>
302  __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
303  const FloatAB* __restrict__ p_b_grid,
304  FloatC* __restrict__ p_c_grid,
305  FloatAB* __restrict__ p_shared_block,
306  const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc,
307  const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc,
308  const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
309  const AElementwiseOperation& a_element_op,
310  const BElementwiseOperation& b_element_op,
311  const CElementwiseOperation& c_element_op,
312  const CBlockClusterAdaptor& c_block_cluster_adaptor)
313  {
314  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
315  p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize());
316  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
317  p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize());
318  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
319  p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize());
320 
321  const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
322 
323  // divide block work by [M, N]
324  const auto block_work_idx =
325  c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
326 
327  if(!c_block_cluster_adaptor.ValidCTileIndex(
328  make_tuple(block_work_idx[I1], block_work_idx[I2]),
329  make_tuple(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetLength(I0),
330  c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetLength(I1))))
331  {
332  return;
333  }
334 
335  const index_t k_batch_id = block_work_idx[I0];
336 
337  // HACK: this force m/n_block_data_idx_on_grid into SGPR
338  const index_t m_block_data_idx_on_grid =
339  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
340 
341  const index_t n_block_data_idx_on_grid =
342  __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
343 
344  // lds max alignment
345  constexpr auto max_lds_align = K1;
346 
347  // A matrix in LDS memory, dst of blockwise copy
348  constexpr auto a_k0_m_k1_block_desc = [&]() {
349  if constexpr(ABlockLdsExtraM)
350  {
354  }
355  else
356  {
358  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
359  }
360  }();
361 
362  constexpr auto a_b_k0_m_k1_block_desc = [&]() {
363  if constexpr(ABlockLdsExtraM)
364  {
369  K1,
370  I1));
371  }
372  else
373  {
376  max_lds_align);
377  }
378  }();
379  // B matrix in LDS memory, dst of blockwise copy
380  constexpr auto b_k0_n_k1_block_desc = [&]() {
381  if constexpr(BBlockLdsExtraN)
382  {
386  }
387  else
388  {
390  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
391  }
392  }();
393 
394  constexpr auto b_b_k0_n_k1_block_desc = [&]() {
395  if constexpr(BBlockLdsExtraN)
396  {
401  K1,
402  I1));
403  }
404  else
405  {
408  max_lds_align);
409  }
410  }();
411  // A matrix blockwise copy
412  auto a_blockwise_copy =
414  AElementwiseOperation,
418  ABlockTransferThreadClusterLengths_K0_M_K1,
419  ABlockTransferThreadClusterArrangeOrder,
420  FloatAB,
421  FloatAB,
422  decltype(a_b_k0_m_k1_grid_desc),
423  decltype(a_b_k0_m_k1_block_desc),
424  ABlockTransferSrcAccessOrder,
426  ABlockTransferSrcVectorDim,
427  3,
428  ABlockTransferSrcScalarPerVector,
429  ABlockTransferDstScalarPerVector_K1,
430  1,
431  1,
432  AThreadTransferSrcResetCoordinateAfterRun,
433  true>(
434  a_b_k0_m_k1_grid_desc,
435  make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0),
436  a_element_op,
437  a_b_k0_m_k1_block_desc,
438  make_multi_index(0, 0, 0, 0),
440 
441  // B matrix blockwise copy
442  auto b_blockwise_copy =
444  BElementwiseOperation,
448  BBlockTransferThreadClusterLengths_K0_N_K1,
449  BBlockTransferThreadClusterArrangeOrder,
450  FloatAB,
451  FloatAB,
452  decltype(b_b_k0_n_k1_grid_desc),
453  decltype(b_b_k0_n_k1_block_desc),
454  BBlockTransferSrcAccessOrder,
456  BBlockTransferSrcVectorDim,
457  3,
458  BBlockTransferSrcScalarPerVector,
459  BBlockTransferDstScalarPerVector_K1,
460  1,
461  1,
462  BThreadTransferSrcResetCoordinateAfterRun,
463  true>(
464  b_b_k0_n_k1_grid_desc,
465  make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0),
466  b_element_op,
467  b_b_k0_n_k1_block_desc,
468  make_multi_index(0, 0, 0, 0),
470 
471  // GEMM definition
472  // c_mtx += transpose(a_mtx) * b_mtx
473  // a_mtx[K0PerBlock, MPerBlock] is in LDS
474  // b_mtx[K0PerBlock, NPerBlock] is in LDS
475  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
476  // register
477  // sanity check
478 
479  auto blockwise_gemm =
481  FloatAB,
482  FloatAcc,
483  decltype(a_k0_m_k1_block_desc),
484  decltype(b_k0_n_k1_block_desc),
485  MPerXDL,
486  NPerXDL,
487  MRepeat,
488  NRepeat,
489  K1>{};
490 
491  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
492 
493  // LDS allocation for A and B: be careful of alignment
494  constexpr auto a_block_space_size =
495  math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
496 
497  FloatAB* p_a_block = p_shared_block;
498  FloatAB* p_b_block = p_shared_block + a_block_space_size;
499 
500  constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
501  constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
502 
503  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
504  p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
505  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
506  p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
507 
508  // preload data into LDS
509  {
510  a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
511  b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
512 
513  a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
514  b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
515  }
516 
517  // Initialize C
518  c_thread_buf.Clear();
519 
520  // main body
521  if constexpr(HasMainKBlockLoop)
522  {
523  index_t k0_block_data_begin = 0;
524 
525  do
526  {
527  a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_block_slice_copy_step);
528  b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, b_block_slice_copy_step);
529 
530  a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
531 
532  block_sync_lds();
533 
534  b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
535 
536  blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
537 
538  block_sync_lds();
539 
540  a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
541  b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
542 
543  k0_block_data_begin += K0PerBlock;
544  } while(k0_block_data_begin < (K0 - K0PerBlock));
545  }
546 
547  // tail
548  {
549  block_sync_lds();
550 
551  blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
552  }
553 
554  // output: register to global memory
555  {
556  constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
557  blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
558 
559  constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
560  constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
561  constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
562  constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
563  constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
564  constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
565  constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
566  constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
567 
568  constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
570  Number<M0>{}, Number<N0>{}, I1, I1, Number<M2>{}, I1, Number<M4>{}, I1));
571 
572  // calculate origin of thread output tensor on global memory
573  // blockwise GEMM c matrix starting index
574  const auto c_thread_mtx_on_block =
575  blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
576 
577  const index_t m_thread_data_on_grid =
578  m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
579 
580  const index_t n_thread_data_on_grid =
581  n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
582 
583  const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
585  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
588 
589  const auto m_thread_data_on_grid_idx =
590  m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
591  make_multi_index(m_thread_data_on_grid));
592 
593  const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
597 
598  const auto n_thread_data_on_grid_idx =
599  n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
600  make_multi_index(n_thread_data_on_grid));
601 
602  auto c_thread_copy =
604  FloatC,
605  decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
606  decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc),
607  CElementwiseOperation,
609  CThreadTransferSrcDstAccessOrder,
610  CThreadTransferSrcDstVectorDim,
611  CThreadTransferDstScalarPerVector,
612  CGlobalMemoryDataOperation,
613  1,
614  true>{
615 
616  c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
617  make_multi_index(m_thread_data_on_grid_idx[I0],
618  n_thread_data_on_grid_idx[I0],
619  m_thread_data_on_grid_idx[I1],
620  n_thread_data_on_grid_idx[I1],
621  m_thread_data_on_grid_idx[I2],
622  m_thread_data_on_grid_idx[I3],
623  m_thread_data_on_grid_idx[I4],
624  n_thread_data_on_grid_idx[I2]),
625  c_element_op};
626 
627  c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
628  make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
629  c_thread_buf,
630  c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
631  c_grid_buf);
632  }
633  }
634 };
635 
636 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
Definition: ck.hpp:268
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:277
__global__ void kernel_gemm_xdlops_v2r4(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const ABK0MK1GridDesc a_b_k0_m_k1_grid_desc, const BBK0NK1GridDesc b_b_k0_n_k1_grid_desc, const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const CBlockClusterAdaptor c_block_cluster_adaptor)
Definition: gridwise_gemm_xdlops_v2r4.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:299
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:295
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Definition: block_to_ctile_map.hpp:541
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
Definition: gridwise_gemm_xdlops_v2r4.hpp:119
static constexpr auto I3
Definition: gridwise_gemm_xdlops_v2r4.hpp:123
static __device__ constexpr bool IsValidCompilationParameter()
Definition: gridwise_gemm_xdlops_v2r4.hpp:180
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdlops_v2r4.hpp:132
static constexpr auto I6
Definition: gridwise_gemm_xdlops_v2r4.hpp:126
__host__ static constexpr __device__ auto MakeCBlockClusterAdaptor(const CMNGridDesc &c_m_n_grid_desc, index_t, index_t, index_t KBatch)
Definition: gridwise_gemm_xdlops_v2r4.hpp:291
static constexpr auto K1
Definition: gridwise_gemm_xdlops_v2r4.hpp:130
static constexpr auto I5
Definition: gridwise_gemm_xdlops_v2r4.hpp:125
static constexpr auto I7
Definition: gridwise_gemm_xdlops_v2r4.hpp:127
static __device__ void Run(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, FloatAB *__restrict__ p_shared_block, const ABK0MK1GridDesc &a_b_k0_m_k1_grid_desc, const BBK0NK1GridDesc &b_b_k0_n_k1_grid_desc, const CM0N0M1N1M2M3M4N2GridDesc &c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op, const CBlockClusterAdaptor &c_block_cluster_adaptor)
Definition: gridwise_gemm_xdlops_v2r4.hpp:302
__host__ static constexpr __device__ auto MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc &c_m_n_grid_desc)
Definition: gridwise_gemm_xdlops_v2r4.hpp:241
static constexpr auto I4
Definition: gridwise_gemm_xdlops_v2r4.hpp:124
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_v2r4.hpp:134
decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)) CBlockClusterAdaptor
Definition: gridwise_gemm_xdlops_v2r4.hpp:299
static constexpr auto I1
Definition: gridwise_gemm_xdlops_v2r4.hpp:121
static constexpr auto I0
Definition: gridwise_gemm_xdlops_v2r4.hpp:120
__host__ static constexpr __device__ bool CalculateHasMainK0BlockLoop(index_t K0)
Definition: gridwise_gemm_xdlops_v2r4.hpp:233
static constexpr auto I2
Definition: gridwise_gemm_xdlops_v2r4.hpp:122
__host__ static constexpr __device__ bool CheckValidity(const ABK0MK1GridDesc &a_b_k0_m_k1_grid_desc, const BBK0NK1GridDesc &b_b_k0_n_k1_grid_desc, const CMNGridDesc &c_m_n_grid_desc, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_xdlops_v2r4.hpp:197
decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{})) CM0N0M1N1M2M3M4N2GridDesc
Definition: gridwise_gemm_xdlops_v2r4.hpp:298
Definition: sequence.hpp:43
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf, Number< ThreadScratchId > thread_scratch_id)
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:143
__device__ void RunRead(const SrcDesc &src_desc, const SrcBuffer &src_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:119
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &step)
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:153
__device__ void RunWrite(const DstDesc &dst_desc, DstBuffer &dst_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:131
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: is_known_at_compile_time.hpp:14
Definition: unary_element_wise_operation.hpp:334