/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp Source File
gridwise_gemm_dl_multiple_d.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
18 
19 namespace ck {
20 
21 template <index_t BlockSize,
22  typename FloatAB,
23  typename FloatAcc,
24  typename DsDataType,
25  typename FloatC,
26  typename AElementwiseOperation,
27  typename BElementwiseOperation,
28  typename CDEElementwiseOperation,
29  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
30  typename AGridDesc_K0_M_K1,
31  typename BGridDesc_K0_N_K1,
32  typename CGridDesc_M_N,
33  index_t MPerBlock,
34  index_t NPerBlock,
35  index_t K0PerBlock,
36  index_t K1Value,
37  index_t M1PerThreadM111,
38  index_t N1PerThreadN111,
39  index_t KPerThread,
40  typename M11N11ThreadClusterM110Xs,
41  typename M11N11ThreadClusterN110Xs,
42  typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
43  typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
44  typename ABlockTransferThreadClusterArrangeOrder,
45  typename ABlockTransferSrcAccessOrder,
46  typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
47  typename ABlockTransferSrcVectorTensorContiguousDimOrder,
48  typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
49  typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
50  typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
51  typename BBlockTransferThreadClusterArrangeOrder,
52  typename BBlockTransferSrcAccessOrder,
53  typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
54  typename BBlockTransferSrcVectorTensorContiguousDimOrder,
55  typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
56  typename CThreadTransferSrcDstAccessOrder,
57  index_t CThreadTransferSrcDstVectorDim,
58  index_t CThreadTransferDstScalarPerVector>
60 {
61  static constexpr index_t NumDTensor = DsDataType::Size();
62 
63  static constexpr auto I0 = Number<0>{};
64  static constexpr auto I1 = Number<1>{};
65  static constexpr auto I2 = Number<2>{};
66  static constexpr auto I3 = Number<3>{};
67 
68  // K1 should be Number<...>
69  static constexpr auto K1 = Number<K1Value>{};
70 
71  // ck::Tuple<const D0DataType*, const D1DataType*, ...>
72  static constexpr auto MakeDsGridPointer()
73  {
74  return generate_tuple(
75  [&](auto i) {
76  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
77 
78  return static_cast<const DDataType*>(nullptr);
79  },
81  }
82 
83  __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
84  {
85  // TODO: change this. I think it needs multi-dimensional alignment
86  constexpr auto max_lds_align = K1;
87 
88  // TODO: check alignment
89  // A matrix in LDS memory, dst of blockwise copy
90  constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned(
91  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
92 
93  // TODO: check alignment
94  // B matrix in LDS memory, dst of blockwise copy
95  constexpr auto b_block_desc_k_n = make_naive_tensor_descriptor_aligned(
96  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
97 
98  // TODO: check alignment
99  // LDS allocation for A and B: be careful of alignment
100  constexpr auto a_block_aligned_space_size =
101  math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align);
102 
103  constexpr auto b_block_aligned_space_size =
104  math::integer_least_multiple(b_block_desc_k_n.GetElementSpaceSize(), max_lds_align);
105 
106  return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
107  }
108 
109  __host__ __device__ static constexpr bool
110  CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
111  const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
112  const CGridDesc_M_N& c_grid_desc_m_n)
113  {
114  constexpr long_index_t TwoGB = (long_index_t{1} << 31);
115 
116  if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB &&
117  b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB &&
118  c_grid_desc_m_n.GetElementSpaceSize() * sizeof(FloatC) <= TwoGB))
119  {
120  return false;
121  }
122 
123  const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
124  const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
125  const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
126 
127  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
128 
129  return (M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
130  K0 == b_grid_desc_k0_n_k1.GetLength(I0) &&
131  K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
132  K1 == b_grid_desc_k0_n_k1.GetLength(I2)) &&
133  (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0);
134  }
135 
136  __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
137  {
138  const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
139 
140  return grid_size;
141  }
142 
143  __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0)
144  {
145  const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1;
146 
147  return has_main_k_block_loop;
148  }
149 
150  __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
151  {
152  const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0;
153 
154  return has_double_tail_k_block_loop;
155  }
156 
157  __host__ __device__ static constexpr auto
158  MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1)
159  {
160  const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
161  const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
162 
163  const auto M1 = Number<MPerBlock>{};
164  const auto M0 = M / M1;
165 
166  const auto a_grid_desc_k0_m0_m1_k1 =
167  transform_tensor_descriptor(a_grid_desc_k0_m_k1,
173 
174  return a_grid_desc_k0_m0_m1_k1;
175  }
176 
177  __host__ __device__ static constexpr auto
178  MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1)
179  {
180  const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0);
181  const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
182 
183  const auto N1 = Number<NPerBlock>{};
184  const auto N0 = N / N1;
185 
186  const auto b_grid_desc_k0_n0_n1_k1 =
187  transform_tensor_descriptor(b_grid_desc_k0_n_k1,
193 
194  return b_grid_desc_k0_n0_n1_k1;
195  }
196 
197  // E desc for destination in blockwise copy
198  template <typename CGridDesc_M_N_>
199  __host__ __device__ static constexpr auto
200  MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N_& c_grid_desc_m_n)
201  {
202  const auto M = c_grid_desc_m_n.GetLength(I0);
203  const auto N = c_grid_desc_m_n.GetLength(I1);
204 
205  constexpr auto M1 = Number<MPerBlock>{};
206  constexpr auto N1 = Number<NPerBlock>{};
207 
208  const auto M0 = M / M1;
209  const auto N0 = N / N1;
210 
211  constexpr auto M11 =
212  Number<container_reduce(M11N11ThreadClusterM110Xs{}, math::multiplies{}, I1) *
213  M1PerThreadM111>{};
214  constexpr auto N11 =
215  Number<container_reduce(M11N11ThreadClusterN110Xs{}, math::multiplies{}, I1) *
216  N1PerThreadN111>{};
217 
218  constexpr auto M10 = M1 / M11;
219  constexpr auto N10 = N1 / N11;
220 
221  const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor(
222  c_grid_desc_m_n,
224  make_unmerge_transform(make_tuple(N0, N10, N11))),
227 
228  return c_grid_desc_m0_m10_m11_n0_n10_n11;
229  }
230 
231  // Ds desc for source in blockwise copy
232  template <typename DsGridDesc_M_N>
233  __host__ __device__ static constexpr auto
234  MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(const DsGridDesc_M_N& ds_grid_desc_m_n)
235  {
236  return generate_tuple(
237  [&](auto i) { return MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(ds_grid_desc_m_n[i]); },
239  }
240  // return block_id to C matrix tile idx (m0, n0) mapping
241  __host__ __device__ static constexpr auto
242  MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
243  {
245  c_grid_desc_m_n);
246  }
247 
248  using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
249  using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
251  decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
252 
253  using DsGridPointer = decltype(MakeDsGridPointer());
254 
255  template <typename DsGridDesc_M0_M10_M11_N0_N10_N11,
256  bool HasMainKBlockLoop,
257  bool HasDoubleTailKBlockLoop,
258  typename Block2CTileMap>
259  __device__ static void
260  Run(const FloatAB* __restrict__ p_a_grid,
261  const FloatAB* __restrict__ p_b_grid,
262  DsGridPointer p_ds_grid,
263  FloatC* __restrict__ p_c_grid,
264  void* __restrict__ p_shared_block,
265  const AElementwiseOperation&,
266  const BElementwiseOperation&,
267  const CDEElementwiseOperation& cde_element_op,
268  const AGridDesc_K0_M0_M1_K1& a_grid_desc_k0_m0_m1_k1,
269  const BGridDesc_K0_N0_N1_K1& b_grid_desc_k0_n0_n1_k1,
270  const DsGridDesc_M0_M10_M11_N0_N10_N11& ds_grid_desc_m0_m10_m11_n0_n10_n11,
271  const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11,
272  const Block2CTileMap& block_2_ctile_map,
275  {
276  const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
277  p_a_grid, a_grid_desc_k0_m0_m1_k1.GetElementSpaceSize());
278  const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
279  p_b_grid, b_grid_desc_k0_n0_n1_k1.GetElementSpaceSize());
280  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
281  p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
282 
283  // divide block work by [M, N]
284  const auto c_m0_n0_block_cluster_idx =
285  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
286 
287  // HACK: this force index data into SGPR
288  const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
289  const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
290 
291  if(!block_2_ctile_map.ValidCTileIndex(
292  make_tuple(im0, in0),
293  make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0),
294  c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3))))
295  {
296  return;
297  }
298 
299  // TODO: change this. I think it needs multi-dimensional alignment
300  constexpr auto max_lds_align = K1;
301 
302  // TODO: check alignment
303  // A matrix in LDS memory, dst of blockwise copy
304  // be careful of LDS alignment
305  constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
306  make_tuple(Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
307 
308  // TODO: check alignment
309  // B matrix in LDS memory, dst of blockwise copy
310  // be careful of LDS alignment
311  constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned(
312  make_tuple(Number<K0PerBlock>{}, I1, Number<NPerBlock>{}, K1), max_lds_align);
313 
314  // TODO: check alignment
315  // A matrix in LDS memory, for blockwise GEMM
316  constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
317  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
318 
319  // TODO: check alignment
320  // B matrix in LDS memory, for blockwise GEMM
321  constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
322  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
323 
324  static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() ==
325  a_k0_m_k1_block_desc.GetElementSpaceSize() &&
326  b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() ==
327  b_k0_n_k1_block_desc.GetElementSpaceSize() &&
328  "wrong!");
329 
330  // A matrix blockwise copy
331  auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
332  BlockSize,
334  Sequence<K0PerBlock, 1, MPerBlock, K1.value>,
335  ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
336  ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
337  ABlockTransferThreadClusterArrangeOrder,
338  FloatAB,
339  FloatAB,
340  remove_reference_t<decltype(a_grid_desc_k0_m0_m1_k1)>,
341  decltype(a_block_desc_k0_m0_m1_k1),
342  ABlockTransferSrcAccessOrder,
344  ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths
345  ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, // DstVectorTensorLengths
346  ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
347  Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
348  false,
349  true>(a_grid_desc_k0_m0_m1_k1,
350  make_multi_index(0, im0, 0, 0),
351  a_block_desc_k0_m0_m1_k1,
352  make_multi_index(0, 0, 0, 0));
353 
354  // B matrix blockwise copy
355  auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
356  BlockSize,
358  Sequence<K0PerBlock, 1, NPerBlock, K1.value>,
359  BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
360  BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
361  BBlockTransferThreadClusterArrangeOrder,
362  FloatAB,
363  FloatAB,
364  remove_reference_t<decltype(b_grid_desc_k0_n0_n1_k1)>,
365  decltype(b_block_desc_k0_n0_n1_k1),
366  BBlockTransferSrcAccessOrder,
368  BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths
369  BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, // DstVectorTensorLengths
370  BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
371  Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
372  false,
373  true>(b_grid_desc_k0_n0_n1_k1,
374  make_multi_index(0, in0, 0, 0),
375  b_block_desc_k0_n0_n1_k1,
376  make_multi_index(0, 0, 0, 0));
377 
378  // GEMM definition
379  // c_mtx += transpose(a_mtx) * b_mtx
380  // a_mtx[K0PerBlock, MPerBlock] is in LDS
381  // b_mtx[KPerBlocl, NPerBlock] is in LDS
382  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
383  // register
384  const auto blockwise_gemm =
386  BlockSize,
387  FloatAB,
388  FloatAB,
389  FloatAcc,
390  decltype(a_k0_m_k1_block_desc),
391  decltype(b_k0_n_k1_block_desc),
392  M1PerThreadM111,
393  N1PerThreadN111,
394  KPerThread,
395  M11N11ThreadClusterM110Xs,
396  M11N11ThreadClusterN110Xs,
397  M1PerThreadM111,
398  N1PerThreadN111>{};
399 
400  constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
401  decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
402 
403  constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed(
404  sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
405 
406  // LDS allocation for A and B: be careful of alignment
407  constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
408  a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
409 
410  constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
411  b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align);
412 
413  FloatAB* p_a_block_double = static_cast<FloatAB*>(p_shared_block);
414  FloatAB* p_b_block_double =
415  static_cast<FloatAB*>(p_shared_block) + 2 * a_block_aligned_space_size;
416 
417  // register allocation for output
418  auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
419  c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize());
420 
421  // Initialize C
422  c_thread_buf.Clear();
423 
424  constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0);
425  constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0);
426 
427  auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
428  p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
429  auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
430  p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
431 
432  auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
433  p_a_block_double + a_block_aligned_space_size,
434  a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
435  auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
436  p_b_block_double + b_block_aligned_space_size,
437  b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
438 
439  // LDS double buffer: preload data into LDS
440  {
441  a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
442  b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
443 
444  a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
445  b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
446  }
447 
448  if constexpr(HasMainKBlockLoop)
449  {
450  const auto K0 = a_grid_desc_k0_m0_m1_k1.GetLength(I0);
451 
452  index_t k_block_data_begin = 0;
453 
454  // LDS double buffer: main body
455  // use Do-While loop instead of For loop to simplify control flow
456  do
457  {
458  // even iteration
459  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
460  a_block_slice_copy_step);
461  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
462  b_block_slice_copy_step);
463 
464  // LDS doubel buffer: load next data from device mem
465  a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
466  b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
467 
468  block_sync_lds();
469 
470  // LDS double buffer: GEMM on current data
471  blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11,
472  a_block_even_buf,
473  b_block_even_buf,
474  c_thread_buf);
475 
476  // LDS double buffer: store next data to LDS
477  a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
478  b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
479 
480  // odd iteration
481  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
482  a_block_slice_copy_step);
483  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
484  b_block_slice_copy_step);
485 
486  // LDS doubel buffer: load next data from device mem
487  a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
488  b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
489 
490  block_sync_lds();
491 
492  // LDS double buffer: GEMM on current data
493  blockwise_gemm.Run(
494  c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
495 
496  // LDS double buffer: store next data to LDS
497  a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
498  b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
499 
500  k_block_data_begin += 2 * K0PerBlock;
501  } while(k_block_data_begin < K0 - 2 * K0PerBlock);
502  }
503 
504  // LDS double buffer: tail
505  if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
506  {
507  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_block_slice_copy_step);
508  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step);
509 
510  block_sync_lds();
511 
512  // LDS double buffer: load last data from device mem
513  a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
514  b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
515 
516  // LDS double buffer: GEMM on 2nd-last data
517  blockwise_gemm.Run(
518  c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
519 
520  // LDS double buffer: store last data to LDS
521  a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
522  b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
523 
524  block_sync_lds();
525 
526  // LDS double buffer: GEMM on last data
527  blockwise_gemm.Run(
528  c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
529  }
530  else // if has 1 iteration left
531  {
532  __syncthreads();
533 
534  // LDS double buffer: GEMM on last data
535  blockwise_gemm.Run(
536  c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
537  }
538 
539  // output: register to global memory
540  {
541  constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 =
543  make_tuple(I1,
544  Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
546  I1,
549 
550  const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
551  blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
553 
554  const auto ds_grid_buf = generate_tuple(
555  [&](auto i) {
556  return make_dynamic_buffer<AddressSpaceEnum::Global>(
557  p_ds_grid[i], ds_grid_desc_m0_m10_m11_n0_n10_n11[i].GetElementSpaceSize());
558  },
560 
561  auto ds_thread_buf = generate_tuple(
562  [&](auto i) {
563  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
564 
566  DDataType,
567  c_m10_m11_n10_n11_thread_tensor_lengths[I3],
568  true>{};
569  },
571 
572  auto ds_threadwise_copy = generate_tuple(
573  [&](auto i) {
574  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
575 
577  DDataType,
578  DDataType,
579  decltype(ds_grid_desc_m0_m10_m11_n0_n10_n11[i]),
580  decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
581  Sequence<I1,
582  I1,
583  I1,
584  I1,
585  I1,
587  CThreadTransferSrcDstAccessOrder,
588  CThreadTransferSrcDstVectorDim,
589  CThreadTransferDstScalarPerVector,
590  1,
591  false>(ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
592  make_multi_index(im0,
593  c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
594  c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
595  in0,
596  c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
597  c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]));
598  },
600 
604  // load d matrix data
605  static_for<0, NumDTensor, 1>{}([&](auto i) {
606  ds_threadwise_copy(i).Run(ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
607  ds_grid_buf[i],
608  c_thread_desc_m0_m10_m11_n0_n10_n11,
609  make_tuple(I0, I0, I0, I0, I0, I0),
610  ds_thread_buf(i));
611  });
612  // cal element op
614  [&](auto i) {
615  // get reference to src data
616  const auto src_data_refs = generate_tie(
617  // return type should be lvalue
618  [&](auto iSrc) -> const auto& {
619  return ds_thread_buf[iSrc][i];
620  },
622 
623  // get reference to dst data
624  constexpr index_t c_offset =
625  c_thread_desc_m0_m10_m11_n0_n10_n11.CalculateOffset(
626  make_tuple(0, m10, m11, 0, n10, i));
627  auto dst_data_refs = generate_tie(
628  // return type should be lvalue
629  [&](auto) -> auto& { return c_thread_buf(Number<c_offset>{}); },
630  Number<2>{});
631 
632  unpack2(cde_element_op, dst_data_refs, src_data_refs);
633  });
634 
635  static_for<0, NumDTensor, 1>{}([&](auto i) {
636  ds_threadwise_copy(i).MoveSrcSliceWindow(
637  ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
638  make_multi_index(0, 0, 0, 0, 1, 0));
639  });
640  });
641  static_for<0, NumDTensor, 1>{}([&](auto i) {
642  ds_threadwise_copy(i).MoveSrcSliceWindow(
643  ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
645  0, 0, 1, 0, -c_m10_m11_n10_n11_thread_tensor_lengths[I2], 0));
646  });
647  });
648  static_for<0, NumDTensor, 1>{}([&](auto i) {
649  ds_threadwise_copy(i).MoveSrcSliceWindow(
650  ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
652  0, 1, -c_m10_m11_n10_n11_thread_tensor_lengths[I1], 0, 0, 0));
653  });
654  });
655 
657  FloatAcc,
658  FloatC,
659  decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
660  decltype(c_grid_desc_m0_m10_m11_n0_n10_n11),
662  Sequence<1,
663  c_m10_m11_n10_n11_thread_tensor_lengths[I0],
664  c_m10_m11_n10_n11_thread_tensor_lengths[I1],
665  1,
666  c_m10_m11_n10_n11_thread_tensor_lengths[I2],
667  c_m10_m11_n10_n11_thread_tensor_lengths[I3]>,
668  CThreadTransferSrcDstAccessOrder,
669  CThreadTransferSrcDstVectorDim,
670  CThreadTransferDstScalarPerVector,
671  CGlobalMemoryDataOperation,
672  1,
673  true>{c_grid_desc_m0_m10_m11_n0_n10_n11,
674  make_multi_index(im0,
675  c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
676  c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
677  in0,
678  c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
679  c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]),
681  .Run(c_thread_desc_m0_m10_m11_n0_n10_n11,
682  make_tuple(I0, I0, I0, I0, I0, I0),
683  c_thread_buf,
684  c_grid_desc_m0_m10_m11_n0_n10_n11,
685  c_grid_buf);
686  }
687  }
688 };
689 
690 } // namespace ck
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
typename remove_reference< T >::type remove_reference_t
Definition: type.hpp:292
__host__ constexpr __device__ auto unpack2(F &&f, X &&x, Y &&y)
Definition: functional4.hpp:55
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:276
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
int64_t long_index_t
Definition: ck.hpp:299
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
__host__ constexpr __device__ auto sequence_to_tuple_of_number(Sequence< Is... >)
Definition: container_helper.hpp:380
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
__host__ constexpr __device__ auto container_reduce(const Container &x, Reduce reduce, Init init, Number< IBegin >=Number< 0 >{}, Number< IEnd >=Number< Container::Size()>{}, Number< IStep >=Number< 1 >{})
Definition: container_helper.hpp:111
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: block_to_ctile_map.hpp:616
Definition: blockwise_tensor_slice_transfer_v5r1.hpp:37
Definition: gridwise_gemm_dl_multiple_d.hpp:60
__host__ static constexpr __device__ auto MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1)
Definition: gridwise_gemm_dl_multiple_d.hpp:178
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K0)
Definition: gridwise_gemm_dl_multiple_d.hpp:143
__host__ static constexpr __device__ auto MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(const DsGridDesc_M_N &ds_grid_desc_m_n)
Definition: gridwise_gemm_dl_multiple_d.hpp:234
__host__ static constexpr __device__ auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_multiple_d.hpp:242
static constexpr auto I2
Definition: gridwise_gemm_dl_multiple_d.hpp:65
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_gemm_dl_multiple_d.hpp:253
static constexpr auto I1
Definition: gridwise_gemm_dl_multiple_d.hpp:64
__host__ static constexpr __device__ auto MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1)
Definition: gridwise_gemm_dl_multiple_d.hpp:158
static constexpr auto I3
Definition: gridwise_gemm_dl_multiple_d.hpp:66
static constexpr index_t NumDTensor
Definition: gridwise_gemm_dl_multiple_d.hpp:61
static constexpr auto I0
Definition: gridwise_gemm_dl_multiple_d.hpp:63
decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})) AGridDesc_K0_M0_M1_K1
Definition: gridwise_gemm_dl_multiple_d.hpp:248
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_dl_multiple_d.hpp:136
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N_ &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_multiple_d.hpp:200
__host__ static constexpr __device__ bool CalculateHasDoubleTailKBlockLoop(index_t K0)
Definition: gridwise_gemm_dl_multiple_d.hpp:150
static constexpr auto K1
Definition: gridwise_gemm_dl_multiple_d.hpp:69
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_multiple_d.hpp:110
decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})) CGridDesc_M0_M10_M11_N0_N10_N11
Definition: gridwise_gemm_dl_multiple_d.hpp:251
static __device__ void Run(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, DsGridPointer p_ds_grid, FloatC *__restrict__ p_c_grid, void *__restrict__ p_shared_block, const AElementwiseOperation &, const BElementwiseOperation &, const CDEElementwiseOperation &cde_element_op, const AGridDesc_K0_M0_M1_K1 &a_grid_desc_k0_m0_m1_k1, const BGridDesc_K0_N0_N1_K1 &b_grid_desc_k0_n0_n1_k1, const DsGridDesc_M0_M10_M11_N0_N10_N11 &ds_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11 &c_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap &block_2_ctile_map, integral_constant< bool, HasMainKBlockLoop >, integral_constant< bool, HasDoubleTailKBlockLoop >)
Definition: gridwise_gemm_dl_multiple_d.hpp:260
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_dl_multiple_d.hpp:83
static constexpr auto MakeDsGridPointer()
Definition: gridwise_gemm_dl_multiple_d.hpp:72
decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})) BGridDesc_K0_N0_N1_K1
Definition: gridwise_gemm_dl_multiple_d.hpp:249
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: integral_constant.hpp:20
Definition: math.hpp:34
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:334