/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp Source File
gridwise_gemm_xdlops_v2r4.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
16 
17 namespace ck {
18 
19 template <typename GridwiseGemm,
20  typename FloatAB,
21  typename FloatC,
22  typename ABK0MK1GridDesc,
23  typename BBK0NK1GridDesc,
24  typename CM0N0M1N1M2M3M4N2GridDesc,
25  typename AElementwiseOperation,
26  typename BElementwiseOperation,
27  typename CElementwiseOperation,
28  typename CBlockClusterAdaptor,
29  bool HasMainKBlockLoop>
30 __global__ void
31 #if CK_USE_LAUNCH_BOUNDS
33 #endif
34  kernel_gemm_xdlops_v2r4(const FloatAB* __restrict__ p_a_grid,
35  const FloatAB* __restrict__ p_b_grid,
36  FloatC* __restrict__ p_c_grid,
37  const ABK0MK1GridDesc a_b_k0_m_k1_grid_desc,
38  const BBK0NK1GridDesc b_b_k0_n_k1_grid_desc,
39  const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
40  const AElementwiseOperation a_element_op,
41  const BElementwiseOperation b_element_op,
42  const CElementwiseOperation c_element_op,
43  const CBlockClusterAdaptor c_block_cluster_adaptor)
44 {
45 #if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__))
46  constexpr index_t shared_block_size =
47  GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
48 
49  __shared__ FloatAB p_shared_block[shared_block_size];
50 
51  GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
52  p_b_grid,
53  p_c_grid,
54  p_shared_block,
55  a_b_k0_m_k1_grid_desc,
56  b_b_k0_n_k1_grid_desc,
57  c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
58  a_element_op,
59  b_element_op,
60  c_element_op,
61  c_block_cluster_adaptor);
62 #else
63  ignore = p_a_grid;
64  ignore = p_b_grid;
65  ignore = p_c_grid;
66  ignore = a_b_k0_m_k1_grid_desc;
67  ignore = b_b_k0_n_k1_grid_desc;
68  ignore = c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc;
69  ignore = a_element_op;
70  ignore = b_element_op;
71  ignore = c_element_op;
72  ignore = c_block_cluster_adaptor;
73 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
74 }
75 
76 template <index_t BlockSize,
77  typename FloatAB,
78  typename FloatAcc,
79  typename FloatC,
80  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
81  typename ABK0MK1GridDesc,
82  typename BBK0NK1GridDesc,
83  typename CMNGridDesc,
84  typename AElementwiseOperation,
85  typename BElementwiseOperation,
86  typename CElementwiseOperation,
87  index_t MPerBlock,
88  index_t NPerBlock,
89  index_t K0PerBlock,
90  index_t MPerXDL,
91  index_t NPerXDL,
92  index_t K1Value,
93  index_t MRepeat,
94  index_t NRepeat,
95  typename ABlockTransferThreadClusterLengths_K0_M_K1,
96  typename ABlockTransferThreadClusterArrangeOrder,
97  typename ABlockTransferSrcAccessOrder,
98  index_t ABlockTransferSrcVectorDim,
99  index_t ABlockTransferSrcScalarPerVector,
100  index_t ABlockTransferDstScalarPerVector_K1,
101  bool AThreadTransferSrcResetCoordinateAfterRun,
102  bool ABlockLdsExtraM,
103  typename BBlockTransferThreadClusterLengths_K0_N_K1,
104  typename BBlockTransferThreadClusterArrangeOrder,
105  typename BBlockTransferSrcAccessOrder,
106  index_t BBlockTransferSrcVectorDim,
107  index_t BBlockTransferSrcScalarPerVector,
108  index_t BBlockTransferDstScalarPerVector_K1,
109  bool BThreadTransferSrcResetCoordinateAfterRun,
110  bool BBlockLdsExtraN,
111  typename CThreadTransferSrcDstAccessOrder,
112  index_t CThreadTransferSrcDstVectorDim,
113  index_t CThreadTransferDstScalarPerVector>
115 {
116  static constexpr auto I0 = Number<0>{};
117  static constexpr auto I1 = Number<1>{};
118  static constexpr auto I2 = Number<2>{};
119  static constexpr auto I3 = Number<3>{};
120  static constexpr auto I4 = Number<4>{};
121  static constexpr auto I5 = Number<5>{};
122  static constexpr auto I6 = Number<6>{};
123  static constexpr auto I7 = Number<7>{};
124 
125  // K1 should be Number<...>
126  static constexpr auto K1 = Number<K1Value>{};
127 
129 
130  __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
131  {
132  constexpr auto max_lds_align = K1;
133 
134  // A matrix in LDS memory, dst of blockwise copy
135  constexpr auto a_k0_m_k1_block_desc = [&]() {
136  if constexpr(ABlockLdsExtraM)
137  {
141  }
142  else
143  {
145  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
146  }
147  }();
148 
149  // B matrix in LDS memory, dst of blockwise copy
150  constexpr auto b_k0_n_k1_block_desc = [&]() {
151  if constexpr(BBlockLdsExtraN)
152  {
156  }
157  else
158  {
160  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
161  }
162  }();
163 
164  // LDS allocation for A and B: be careful of alignment
165  constexpr auto a_block_space_size =
166  math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
167 
168  constexpr auto b_block_space_size =
169  math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
170 
171  return (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
172  }
173 
174  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
175  template <typename Block2CTileMap>
176  __host__ __device__ static constexpr bool
177  CheckValidity(const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc,
178  const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc,
179  const CMNGridDesc& c_m_n_grid_desc,
180  const Block2CTileMap& block_2_ctile_map)
181  {
182  static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
183  "wrong! K1 need to be known at compile-time");
184 
185  static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) &&
186  (NPerBlock % (NRepeat * NPerXDL)) == 0,
187  "Invalid tuning param!");
188 
189  const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2);
190  const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2);
191  const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
192  const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0);
193 
194  if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
195  K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) &&
196  K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) &&
197  K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) &&
198  KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0)))
199  return false;
200 
201  if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
202  return false;
203 
204  if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc))
205  {
206  return false;
207  }
208 
209  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
210  return true;
211  }
212 
213  __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
214  {
215  const bool has_main_k0_block_loop = K0 > K0PerBlock;
216 
217  return has_main_k0_block_loop;
218  }
219 
220  __host__ __device__ static constexpr auto
221  MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
222  {
223  constexpr auto max_lds_align = K1;
224 
225  // A matrix in LDS memory, dst of blockwise copy
226  constexpr auto a_k0_m_k1_block_desc = [&]() {
227  if constexpr(ABlockLdsExtraM)
228  {
232  }
233  else
234  {
236  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
237  }
238  }();
239 
240  // B matrix in LDS memory, dst of blockwise copy
241  constexpr auto b_k0_n_k1_block_desc = [&]() {
242  if constexpr(BBlockLdsExtraN)
243  {
247  }
248  else
249  {
251  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
252  }
253  }();
254 
255  using BlockwiseGemm =
257  FloatAB,
258  FloatAcc,
259  decltype(a_k0_m_k1_block_desc),
260  decltype(b_k0_n_k1_block_desc),
261  MPerXDL,
262  NPerXDL,
263  MRepeat,
264  NRepeat,
265  K1>;
266 
267  return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_m_n_grid_desc);
268  }
269 
270  // return block_id to C matrix tile idx (m0, n0) mapping
271  __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
272  const CMNGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch)
273  {
275  c_m_n_grid_desc, 8, KBatch);
276  }
277 
279  using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1));
280 
281  template <bool HasMainKBlockLoop>
282  __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
283  const FloatAB* __restrict__ p_b_grid,
284  FloatC* __restrict__ p_c_grid,
285  FloatAB* __restrict__ p_shared_block,
286  const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc,
287  const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc,
288  const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
289  const AElementwiseOperation& a_element_op,
290  const BElementwiseOperation& b_element_op,
291  const CElementwiseOperation& c_element_op,
292  const CBlockClusterAdaptor& c_block_cluster_adaptor)
293  {
294  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
295  p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize());
296  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
297  p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize());
298  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
299  p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize());
300 
301  const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
302 
303  // divide block work by [M, N]
304  const auto block_work_idx =
305  c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
306 
307  if(!c_block_cluster_adaptor.ValidCTileIndex(
308  make_tuple(block_work_idx[I1], block_work_idx[I2]),
309  make_tuple(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetLength(I0),
310  c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetLength(I1))))
311  {
312  return;
313  }
314 
315  const index_t k_batch_id = block_work_idx[I0];
316 
317  // HACK: this force m/n_block_data_idx_on_grid into SGPR
318  const index_t m_block_data_idx_on_grid =
319  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
320 
321  const index_t n_block_data_idx_on_grid =
322  __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
323 
324  // lds max alignment
325  constexpr auto max_lds_align = K1;
326 
327  // A matrix in LDS memory, dst of blockwise copy
328  constexpr auto a_k0_m_k1_block_desc = [&]() {
329  if constexpr(ABlockLdsExtraM)
330  {
334  }
335  else
336  {
338  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
339  }
340  }();
341 
342  constexpr auto a_b_k0_m_k1_block_desc = [&]() {
343  if constexpr(ABlockLdsExtraM)
344  {
349  K1,
350  I1));
351  }
352  else
353  {
356  max_lds_align);
357  }
358  }();
359  // B matrix in LDS memory, dst of blockwise copy
360  constexpr auto b_k0_n_k1_block_desc = [&]() {
361  if constexpr(BBlockLdsExtraN)
362  {
366  }
367  else
368  {
370  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
371  }
372  }();
373 
374  constexpr auto b_b_k0_n_k1_block_desc = [&]() {
375  if constexpr(BBlockLdsExtraN)
376  {
381  K1,
382  I1));
383  }
384  else
385  {
388  max_lds_align);
389  }
390  }();
391  // A matrix blockwise copy
392  auto a_blockwise_copy =
394  AElementwiseOperation,
398  ABlockTransferThreadClusterLengths_K0_M_K1,
399  ABlockTransferThreadClusterArrangeOrder,
400  FloatAB,
401  FloatAB,
402  decltype(a_b_k0_m_k1_grid_desc),
403  decltype(a_b_k0_m_k1_block_desc),
404  ABlockTransferSrcAccessOrder,
406  ABlockTransferSrcVectorDim,
407  3,
408  ABlockTransferSrcScalarPerVector,
409  ABlockTransferDstScalarPerVector_K1,
410  1,
411  1,
412  AThreadTransferSrcResetCoordinateAfterRun,
413  true>(
414  a_b_k0_m_k1_grid_desc,
415  make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0),
416  a_element_op,
417  a_b_k0_m_k1_block_desc,
418  make_multi_index(0, 0, 0, 0),
420 
421  // B matrix blockwise copy
422  auto b_blockwise_copy =
424  BElementwiseOperation,
428  BBlockTransferThreadClusterLengths_K0_N_K1,
429  BBlockTransferThreadClusterArrangeOrder,
430  FloatAB,
431  FloatAB,
432  decltype(b_b_k0_n_k1_grid_desc),
433  decltype(b_b_k0_n_k1_block_desc),
434  BBlockTransferSrcAccessOrder,
436  BBlockTransferSrcVectorDim,
437  3,
438  BBlockTransferSrcScalarPerVector,
439  BBlockTransferDstScalarPerVector_K1,
440  1,
441  1,
442  BThreadTransferSrcResetCoordinateAfterRun,
443  true>(
444  b_b_k0_n_k1_grid_desc,
445  make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0),
446  b_element_op,
447  b_b_k0_n_k1_block_desc,
448  make_multi_index(0, 0, 0, 0),
450 
451  // GEMM definition
452  // c_mtx += transpose(a_mtx) * b_mtx
453  // a_mtx[K0PerBlock, MPerBlock] is in LDS
454  // b_mtx[K0PerBlock, NPerBlock] is in LDS
455  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
456  // register
457  // sanity check
458 
459  auto blockwise_gemm =
461  FloatAB,
462  FloatAcc,
463  decltype(a_k0_m_k1_block_desc),
464  decltype(b_k0_n_k1_block_desc),
465  MPerXDL,
466  NPerXDL,
467  MRepeat,
468  NRepeat,
469  K1>{};
470 
471  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
472 
473  // LDS allocation for A and B: be careful of alignment
474  constexpr auto a_block_space_size =
475  math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
476 
477  FloatAB* p_a_block = p_shared_block;
478  FloatAB* p_b_block = p_shared_block + a_block_space_size;
479 
480  constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
481  constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
482 
483  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
484  p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
485  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
486  p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
487 
488  // preload data into LDS
489  {
490  a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
491  b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
492 
493  a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
494  b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
495  }
496 
497  // Initialize C
498  c_thread_buf.Clear();
499 
500  // main body
501  if constexpr(HasMainKBlockLoop)
502  {
503  index_t k0_block_data_begin = 0;
504 
505  do
506  {
507  a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_block_slice_copy_step);
508  b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, b_block_slice_copy_step);
509 
510  a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
511 
512  block_sync_lds();
513 
514  b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
515 
516  blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
517 
518  block_sync_lds();
519 
520  a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
521  b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
522 
523  k0_block_data_begin += K0PerBlock;
524  } while(k0_block_data_begin < (K0 - K0PerBlock));
525  }
526 
527  // tail
528  {
529  block_sync_lds();
530 
531  blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
532  }
533 
534  // output: register to global memory
535  {
536  constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
537  blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
538 
539  constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
540  constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
541  constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
542  constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
543  constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
544  constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
545  constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
546  constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
547 
548  constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
550  Number<M0>{}, Number<N0>{}, I1, I1, Number<M2>{}, I1, Number<M4>{}, I1));
551 
552  // calculate origin of thread output tensor on global memory
553  // blockwise GEMM c matrix starting index
554  const auto c_thread_mtx_on_block =
555  blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
556 
557  const index_t m_thread_data_on_grid =
558  m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
559 
560  const index_t n_thread_data_on_grid =
561  n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
562 
563  const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
565  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
568 
569  const auto m_thread_data_on_grid_idx =
570  m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
571  make_multi_index(m_thread_data_on_grid));
572 
573  const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
577 
578  const auto n_thread_data_on_grid_idx =
579  n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
580  make_multi_index(n_thread_data_on_grid));
581 
582  auto c_thread_copy =
584  FloatC,
585  decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
586  decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc),
587  CElementwiseOperation,
589  CThreadTransferSrcDstAccessOrder,
590  CThreadTransferSrcDstVectorDim,
591  CThreadTransferDstScalarPerVector,
592  CGlobalMemoryDataOperation,
593  1,
594  true>{
595 
596  c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
597  make_multi_index(m_thread_data_on_grid_idx[I0],
598  n_thread_data_on_grid_idx[I0],
599  m_thread_data_on_grid_idx[I1],
600  n_thread_data_on_grid_idx[I1],
601  m_thread_data_on_grid_idx[I2],
602  m_thread_data_on_grid_idx[I3],
603  m_thread_data_on_grid_idx[I4],
604  n_thread_data_on_grid_idx[I2]),
605  c_element_op};
606 
607  c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
608  make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
609  c_thread_buf,
610  c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
611  c_grid_buf);
612  }
613  }
614 };
615 
616 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:30
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:276
__global__ void kernel_gemm_xdlops_v2r4(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const ABK0MK1GridDesc a_b_k0_m_k1_grid_desc, const BBK0NK1GridDesc b_b_k0_n_k1_grid_desc, const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const CBlockClusterAdaptor c_block_cluster_adaptor)
Definition: gridwise_gemm_xdlops_v2r4.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:295
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Definition: block_to_ctile_map.hpp:540
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
Definition: gridwise_gemm_xdlops_v2r4.hpp:115
static constexpr auto I3
Definition: gridwise_gemm_xdlops_v2r4.hpp:119
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdlops_v2r4.hpp:128
static constexpr auto I6
Definition: gridwise_gemm_xdlops_v2r4.hpp:122
__host__ static constexpr __device__ auto MakeCBlockClusterAdaptor(const CMNGridDesc &c_m_n_grid_desc, index_t, index_t, index_t KBatch)
Definition: gridwise_gemm_xdlops_v2r4.hpp:271
static constexpr auto K1
Definition: gridwise_gemm_xdlops_v2r4.hpp:126
static constexpr auto I5
Definition: gridwise_gemm_xdlops_v2r4.hpp:121
static constexpr auto I7
Definition: gridwise_gemm_xdlops_v2r4.hpp:123
static __device__ void Run(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, FloatAB *__restrict__ p_shared_block, const ABK0MK1GridDesc &a_b_k0_m_k1_grid_desc, const BBK0NK1GridDesc &b_b_k0_n_k1_grid_desc, const CM0N0M1N1M2M3M4N2GridDesc &c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op, const CBlockClusterAdaptor &c_block_cluster_adaptor)
Definition: gridwise_gemm_xdlops_v2r4.hpp:282
__host__ static constexpr __device__ auto MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc &c_m_n_grid_desc)
Definition: gridwise_gemm_xdlops_v2r4.hpp:221
static constexpr auto I4
Definition: gridwise_gemm_xdlops_v2r4.hpp:120
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_v2r4.hpp:130
decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)) CBlockClusterAdaptor
Definition: gridwise_gemm_xdlops_v2r4.hpp:279
static constexpr auto I1
Definition: gridwise_gemm_xdlops_v2r4.hpp:117
static constexpr auto I0
Definition: gridwise_gemm_xdlops_v2r4.hpp:116
__host__ static constexpr __device__ bool CalculateHasMainK0BlockLoop(index_t K0)
Definition: gridwise_gemm_xdlops_v2r4.hpp:213
static constexpr auto I2
Definition: gridwise_gemm_xdlops_v2r4.hpp:118
__host__ static constexpr __device__ bool CheckValidity(const ABK0MK1GridDesc &a_b_k0_m_k1_grid_desc, const BBK0NK1GridDesc &b_b_k0_n_k1_grid_desc, const CMNGridDesc &c_m_n_grid_desc, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_xdlops_v2r4.hpp:177
decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{})) CM0N0M1N1M2M3M4N2GridDesc
Definition: gridwise_gemm_xdlops_v2r4.hpp:278
Definition: sequence.hpp:43
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf, Number< ThreadScratchId > thread_scratch_id)
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:143
__device__ void RunRead(const SrcDesc &src_desc, const SrcBuffer &src_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:119
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &step)
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:153
__device__ void RunWrite(const DstDesc &dst_desc, DstBuffer &dst_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:131
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: is_known_at_compile_time.hpp:14
Definition: unary_element_wise_operation.hpp:334