/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp Source File
blockwise_gemm_wmma.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
11 
12 #define CK_MNK_LOOP
13 
14 namespace ck {
15 
16 #ifdef __gfx12__
17 template <index_t BlockSize,
18  typename FloatA,
19  typename FloatB,
20  typename FloatAcc,
21  typename ABlockDesc,
22  typename BBlockDesc,
23  index_t MPerBlock,
24  index_t NPerBlock,
25  index_t KPerBlock,
26  index_t MPerWMMA,
27  index_t NPerWMMA,
28  index_t MRepeat,
29  index_t NRepeat,
30  index_t KPack,
31  bool AEnableLds = true,
32  bool BEnableLds = true,
33  bool TransposeC = false>
34 /* Option: Read from LDS, big buffer hold all threads required data
35  * Source
36  * A: K0PerBlock x MPerBlock x K1
37  * B: K0PerBlock x NPerBlock x K1
38  * Destination
39  * C, non-transpose
40  * thread level: MRepeat x NRepeat x MAccVgprs
41  * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
42  * KPACK == WMMA_K = 16
43  *
44  * Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS)
45  * Source:
46  * A(if skip LDS): MRepeat x KPack
47  * B(if skip LDS): NRepeat x KPack
48  * Destination
49  * C, non-transpose
50  * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
51  */
52 struct BlockwiseGemmWMMA
53 {
54  static constexpr auto I0 = Number<0>{};
55  static constexpr auto I1 = Number<1>{};
56  static constexpr auto I2 = Number<2>{};
57  static constexpr auto I3 = Number<3>{};
58  static constexpr auto I4 = Number<4>{};
59  static constexpr auto I5 = Number<5>{};
60  static constexpr auto WmmaK = Number<16>{};
61 
62  using ThisThreadBlock = ThisThreadBlock<BlockSize>;
63 
64  // Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
65  static constexpr index_t WaveSize = 32;
66 
67  // When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
68  // When not use LDS, each Row read half of whole data from source buffer, exchange the data via
69  // permutation
70  static constexpr index_t A_KRow = 2;
71  static constexpr index_t B_KRow = 2;
72 
73  static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
74  static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
75 
76  static constexpr auto wmma_gemm =
77  WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{};
78 
79  static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
80  static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
81 
82  StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
83  FloatAcc,
84  MRepeat * NRepeat,
85  wmma_gemm.GetRegSizePerWmma(),
86  true>
88 
89  __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
90 
91  __device__ static auto GetWaveIdx()
92  {
93  const index_t thread_id = ThisThreadBlock::GetThreadId();
94 
95  constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
97  make_tuple(Sequence<0, 1, 2>{}),
98  make_tuple(Sequence<0>{}));
99 
100  return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
101  }
102 
103  // Default, Block buffer in LDS, thread level offset enabled
104  __device__ static auto CalculateAThreadOriginDataIndex()
105  {
106  if constexpr(AEnableLds)
107  {
108  const auto wave_idx = GetWaveIdx();
109  const auto waveId_m = wave_idx[I0];
110  const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
111 
112  // |KRepeat |MRepeat|MWave |KRow |MLane |KPack
113  return make_tuple(0, 0, waveId_m, wmma_gemm.GetSubGroupId(), WMMA_a_idx, 0);
114  }
115  else
116  {
117  return make_tuple(0, 0, 0, 0, 0, 0);
118  }
119  }
120 
121  __device__ static auto CalculateBThreadOriginDataIndex()
122  {
123  if constexpr(BEnableLds)
124  {
125  const auto wave_idx = GetWaveIdx();
126  const auto waveId_n = wave_idx[I1];
127  const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
128 
129  // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
130  return make_tuple(0, 0, waveId_n, wmma_gemm.GetSubGroupId(), WMMA_b_idx, 0);
131  }
132  else
133  {
134  return make_tuple(0, 0, 0, 0, 0, 0);
135  }
136  }
137 
138  template <index_t m0, index_t n0>
139  __device__ static auto CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>)
140  {
141  const auto wave_idx = GetWaveIdx();
142 
143  const auto waveId_m = wave_idx[I0];
144  const auto waveId_n = wave_idx[I1];
145 
146  const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk();
147 
148  constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor(
149  make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))),
150  make_tuple(Sequence<0>{}),
151  make_tuple(Sequence<0, 1, 2>{}));
152 
153  constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor(
154  make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))),
155  make_tuple(Sequence<0>{}),
156  make_tuple(Sequence<0, 1, 2>{}));
157 
158  const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex(
159  make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
160  const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex(
161  make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
162 
163  return make_tuple(c_thread_m, c_thread_n);
164  }
165 
166  template <index_t m0, index_t n0>
167  __device__ static auto CalculateCThreadOriginDataIndex7D(Number<m0>, Number<n0>)
168  {
169  const auto wave_idx = GetWaveIdx();
170 
171  const auto waveId_m = wave_idx[I0];
172  const auto waveId_n = wave_idx[I1];
173 
174  const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D();
175 
176  return make_tuple(
177  Number<m0>{}, waveId_m, blk_idx[I0], Number<n0>{}, waveId_n, blk_idx[I1], blk_idx[I2]);
178  }
179 
180  using Tuple6 = decltype(CalculateAThreadOriginDataIndex());
181  __host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(),
183  : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
184  {
185  static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
186  "wrong! Desc should be known at compile-time");
187 
188  static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
189  "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
190 
191  static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 &&
192  NPerBlock % (NPerWMMA * NRepeat) == 0,
193  "wrong!");
194  }
195 
196  // transposed WMMA output C' = B' * A'
197  __host__ __device__ static constexpr auto
199  {
200  constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
201  wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
202 
203  constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
204 
206  // |MRepeat |MWave |MSubGroup |NRepeat |NWave
207  // |NThreadPerSubGroup |MAccVgprs
208  make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, NAccVgprs));
209  }
210 
211  // Thread level, register decriptor. Vector-write
212  __host__ __device__ static constexpr auto
214  {
215  constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
216  wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
217 
218  constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
219  constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3];
221  // |MRepeat |MWave |MSubGroup |NRepeat |NWave
222  // |NThreadPerSubGroup |MAccVgprs
223  make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, MAccVgprs),
224  make_tuple(Number<NRepeat>{} * MAccVgprs * AccStride,
225  Number<NRepeat>{} * MAccVgprs * AccStride,
226  Number<NRepeat>{} * MAccVgprs * AccStride,
227  MAccVgprs * AccStride,
228  MAccVgprs * AccStride,
229  MAccVgprs * AccStride,
230  AccStride));
231  }
232 
233  template <typename CGridDesc_M_N>
234  __host__ __device__ static constexpr auto
236  const CGridDesc_M_N& c_grid_desc_m_n)
237  {
238  const auto M = c_grid_desc_m_n.GetLength(I0);
239  const auto N = c_grid_desc_m_n.GetLength(I1);
240 
241  const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma =
243  c_grid_desc_m_n,
244  make_tuple(
245  make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)),
246  make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))),
247  make_tuple(Sequence<0>{}, Sequence<1>{}),
248  make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
249 
250  return wmma_gemm
251  .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
252  c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
253  }
254 
255  // transposed WMMA output C' = B' * A'
256  __host__ __device__ static constexpr auto
258  {
259  constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
261  Number<MWaves>{},
262  Number<MPerWMMA>{},
263  Number<NRepeat>{},
264  Number<NWaves>{},
265  Number<NPerWMMA>{}));
266 
267  return wmma_gemm
268  .MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(
269  c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
270  }
271 
272  // Provide dimension size
273  __host__ __device__ static constexpr auto
275  {
276  constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
278  Number<MWaves>{},
279  Number<MPerWMMA>{},
280  Number<NRepeat>{},
281  Number<NWaves>{},
282  Number<NPerWMMA>{}));
283 
284  return wmma_gemm
285  .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
286  c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
287  }
288 
289  // Describe how data allocated in thread copy src buffer
290  // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
291  static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1;
292  static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1;
293 
294  template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
295  __device__ void Run(const ABlockBuffer& a_block_buf,
296  const BBlockBuffer& b_block_buf,
297  CThreadBuffer& c_thread_buf) const
298  {
299  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
300  a_thread_desc_.GetElementSpaceSize());
301  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
302  b_thread_desc_.GetElementSpaceSize());
303 
304  static_assert(KPack % (A_K1 * A_KRow) == 0, "");
305  static_assert(KPack % (B_K1 * B_KRow) == 0, "");
306 
307  // basic intrinsic to determine loopover direction
308  if constexpr(MRepeat < NRepeat)
309  {
310  static_for<0, KPerBlock / KPack, 1>{}(
311  [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
312  static_for<0, MRepeat, 1>{}([&](auto m0) {
313  // read A
314  a_thread_copy_.Run(
316  make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
317  a_block_buf,
319  make_tuple(I0, m0, I0, I0, I0, I0),
320  a_thread_buf);
321 
322  static_for<0, NRepeat, 1>{}([&](auto n0) {
323  // read B
324  b_thread_copy_.Run(
326  make_tuple(Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
327  b_block_buf,
329  make_tuple(I0, n0, I0, I0, I0, I0),
330  b_thread_buf);
331 
332  vector_type<FloatA, KPack / A_KRow> a_thread_vec;
333  vector_type<FloatB, KPack / B_KRow> b_thread_vec;
334 
335  static_for<0, KPack / A_KRow, 1>{}([&](auto i) {
336  a_thread_vec.template AsType<FloatA>()(i) =
337  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
338  make_tuple(i / A_K1, m0, 0, 0, 0, i % A_K1))>{}];
339  });
340 
341  static_for<0, KPack / B_KRow, 1>{}([&](auto i) {
342  b_thread_vec.template AsType<FloatB>()(i) =
343  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
344  make_tuple(i / B_K1, n0, 0, 0, 0, i % B_K1))>{}];
345  });
346 
347  using wmma_input_type_a =
348  typename vector_type<FloatA, WmmaK / A_KRow>::type;
349  using wmma_input_type_b =
350  typename vector_type<FloatB, WmmaK / B_KRow>::type;
351 
352  constexpr index_t c_offset =
353  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
354 
355  wmma_gemm.template Run<>(
356  a_thread_vec.template AsType<wmma_input_type_a>(),
357  b_thread_vec.template AsType<wmma_input_type_b>(),
358  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
359  });
360  });
361  });
362  }
363  else
364  {
365  static_for<0, NRepeat, 1>{}([&](auto n0) {
366  static_for<0, MRepeat, 1>{}([&](auto m0) {
367  static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of
368  // k=0,kpack*1, ..
369  // read B
370  b_thread_copy_.Run(
372  make_tuple(Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
373  b_block_buf,
375  make_tuple(I0, n0, I0, I0, I0, I0),
376  b_thread_buf);
377  // read A
378  a_thread_copy_.Run(
380  make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
381  a_block_buf,
383  make_tuple(I0, m0, I0, I0, I0, I0),
384  a_thread_buf);
385 
386  vector_type<FloatA, KPack / A_KRow> a_thread_vec;
387  vector_type<FloatB, KPack / B_KRow> b_thread_vec;
388 
389  static_for<0, KPack / A_KRow, 1>{}([&](auto i) {
390  a_thread_vec.template AsType<FloatA>()(i) =
391  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
392  make_tuple(i / A_K1, m0, 0, 0, 0, i % A_K1))>{}];
393  });
394 
395  static_for<0, KPack / B_KRow, 1>{}([&](auto i) {
396  b_thread_vec.template AsType<FloatB>()(i) =
397  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
398  make_tuple(i / B_K1, n0, 0, 0, 0, i % B_K1))>{}];
399  });
400 
401  using wmma_input_type_a =
402  typename vector_type<FloatA, WmmaK / A_KRow>::type;
403  using wmma_input_type_b =
404  typename vector_type<FloatB, WmmaK / B_KRow>::type;
405 
406  constexpr index_t c_offset =
407  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
408 
409  wmma_gemm.template Run<>(
410  a_thread_vec.template AsType<wmma_input_type_a>(),
411  b_thread_vec.template AsType<wmma_input_type_b>(),
412  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
413  });
414  });
415  });
416  }
417  }
418 
419  protected:
420  static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
421  make_tuple(Number<KPack / A_K1 / A_KRow>{}, Number<MRepeat>{}, I1, I1, I1, Number<A_K1>{}),
422  make_tuple(Number<A_K1>{},
423  Number<KPack / A_KRow>{},
424  Number<A_K1>{},
425  Number<A_K1>{},
426  Number<A_K1>{},
427  Number<1>{}));
428 
429  static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
430  make_tuple(Number<KPack / B_K1 / B_KRow>{}, Number<NRepeat>{}, I1, I1, I1, Number<B_K1>{}),
431  make_tuple(Number<B_K1>{},
432  Number<KPack / B_KRow>{},
433  Number<B_K1>{},
434  Number<B_K1>{},
435  Number<B_K1>{},
436  Number<1>{}));
437 
438  // C[M, N, NumRegWMMA]
439  static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
440  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
441 
442  template <bool EnableLds>
443  struct AThreadCopySelector;
444 
445  template <>
446  struct AThreadCopySelector<true>
447  {
448  using type =
449  ThreadwiseTensorSliceTransfer_v4<FloatA,
450  FloatA,
451  decltype(a_block_desc_k0_m0_m1_m2_k1),
452  decltype(a_thread_desc_),
453  Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
454  Sequence<0, 1, 2, 3, 4, 5>,
455  5,
456  A_K1,
457  A_K1>;
458  };
459 
460  template <>
461  struct AThreadCopySelector<false>
462  {
463  using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow<
464  FloatA,
465  FloatA,
466  decltype(a_block_desc_k0_m0_m1_m2_k1),
467  decltype(a_thread_desc_),
468  tensor_operation::element_wise::PassThrough,
469  Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
470  Sequence<0, 1, 2, 3, 4, 5>,
471  5,
472  A_K1,
473  false>;
474  };
475 
476  template <bool EnableLds>
477  struct BThreadCopySelector;
478 
479  template <>
480  struct BThreadCopySelector<true>
481  {
482  using type =
483  ThreadwiseTensorSliceTransfer_v4<FloatB,
484  FloatB,
485  decltype(b_block_desc_k0_n0_n1_n2_k1),
486  decltype(b_thread_desc_),
487  Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
488  Sequence<0, 1, 2, 3, 4, 5>,
489  5,
490  B_K1,
491  B_K1>;
492  };
493 
494  template <>
495  struct BThreadCopySelector<false>
496  {
497  using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow<
498  FloatB,
499  FloatB,
500  decltype(b_block_desc_k0_n0_n1_n2_k1),
501  decltype(b_thread_desc_),
502  tensor_operation::element_wise::PassThrough,
503  Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
504  Sequence<0, 1, 2, 3, 4, 5>,
505  5,
506  B_K1,
507  false>;
508  };
509 
510  typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
511  typename BThreadCopySelector<BEnableLds>::type b_thread_copy_;
512 };
513 #else
514 template <index_t BlockSize,
515  typename FloatA,
516  typename FloatB,
517  typename FloatAcc,
518  typename ABlockDesc,
519  typename BBlockDesc,
520  index_t MPerBlock,
521  index_t NPerBlock,
522  index_t KPerBlock,
523  index_t MPerWMMA,
524  index_t NPerWMMA,
525  index_t MRepeat,
526  index_t NRepeat,
527  index_t KPack,
528  bool AEnableLds = true,
529  bool BEnableLds = true,
530  bool TransposeC = false>
531 /* Option: Read from LDS, big buffer hold all threads required data
532  * Source
533  * A: K0PerBlock x MPerBlock x K1
534  * B: K0PerBlock x NPerBlock x K1
535  * Destination
536  * C, non-transpose
537  * thread level: MRepeat x NRepeat x MAccVgprs
538  * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
539  * KPACK == WMMA_K = 16
540  *
541  * Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS)
542  * Source:
543  * A(if skip LDS): MRepeat x KPack
544  * B(if skip LDS): NRepeat x KPack
545  * Destination
546  * C, non-transpose
547  * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
548  */
550 {
551  static constexpr auto I0 = Number<0>{};
552  static constexpr auto I1 = Number<1>{};
553  static constexpr auto I2 = Number<2>{};
554  static constexpr auto I3 = Number<3>{};
555  static constexpr auto I4 = Number<4>{};
556  static constexpr auto I5 = Number<5>{};
557  static constexpr auto WmmaK = Number<16>{};
558 
560 
561  // Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
562  static constexpr index_t WaveSize = 32;
563 
564  // When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
565  // When not use LDS, each Row read half of whole data from source buffer, exchange the data via
566  // permutation
567  static constexpr index_t A_KRow = AEnableLds ? 1 : 2;
568  static constexpr index_t B_KRow = BEnableLds ? 1 : 2;
569  static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
570  static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
571 
572  static constexpr auto wmma_gemm =
574 
575  static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
576  static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
577 
579  FloatAcc,
580  MRepeat * NRepeat,
581  wmma_gemm.GetRegSizePerWmma(),
582  true>
584 
585  __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
586 
587  __device__ static auto GetWaveIdx()
588  {
589  const index_t thread_id = ThisThreadBlock::GetThreadId();
590 
591  constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
595 
596  return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
597  }
598 
599  // Default, Block buffer in LDS, thread level offset enabled
600  __device__ static auto CalculateAThreadOriginDataIndex()
601  {
602  if constexpr(AEnableLds)
603  {
604  const auto wave_idx = GetWaveIdx();
605  const auto waveId_m = wave_idx[I0];
606  const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
607 
608  // |KRepeat |MRepeat|MWave |KRow |MLane |KPack
609  return make_tuple(0, 0, waveId_m, 0, WMMA_a_idx, 0);
610  }
611  else
612  {
613  return make_tuple(0, 0, 0, 0, 0, 0);
614  }
615  }
616 
617  __device__ static auto CalculateBThreadOriginDataIndex()
618  {
619  if constexpr(BEnableLds)
620  {
621  const auto wave_idx = GetWaveIdx();
622  const auto waveId_n = wave_idx[I1];
623  const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
624 
625  // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
626  return make_tuple(0, 0, waveId_n, 0, WMMA_b_idx, 0);
627  }
628  else
629  {
630  return make_tuple(0, 0, 0, 0, 0, 0);
631  }
632  }
633 
634  template <index_t m0, index_t n0>
636  {
637  const auto wave_idx = GetWaveIdx();
638 
639  const auto waveId_m = wave_idx[I0];
640  const auto waveId_n = wave_idx[I1];
641 
642  const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk();
643 
644  constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor(
645  make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))),
648 
649  constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor(
650  make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))),
653 
654  const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex(
655  make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
656  const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex(
657  make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
658 
659  return make_tuple(c_thread_m, c_thread_n);
660  }
661 
662  template <index_t m0, index_t n0>
664  {
665  const auto wave_idx = GetWaveIdx();
666 
667  const auto waveId_m = wave_idx[I0];
668  const auto waveId_n = wave_idx[I1];
669 
670  const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D();
671 
672  return make_tuple(
673  Number<m0>{}, waveId_m, blk_idx[I0], Number<n0>{}, waveId_n, blk_idx[I1], blk_idx[I2]);
674  }
675 
677  __host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(),
679  : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
680  {
681  static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
682  "wrong! Desc should be known at compile-time");
683 
684  static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
685  "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
686 
687  static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 &&
688  NPerBlock % (NPerWMMA * NRepeat) == 0,
689  "wrong!");
690  }
691 
692  // transposed WMMA output C' = B' * A'
693  __host__ __device__ static constexpr auto
695  {
696  constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
697  wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
698 
699  constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
700 
702  // |MRepeat |MWave |MSubGroup |NRepeat |NWave
703  // |NThreadPerSubGroup |MAccVgprs
704  make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, NAccVgprs));
705  }
706 
707  // Thread level, register decriptor. Vector-write
708  __host__ __device__ static constexpr auto
710  {
711  constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
712  wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
713 
714  constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
715  constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3];
717  // |MRepeat |MWave |MSubGroup |NRepeat |NWave
718  // |NThreadPerSubGroup |MAccVgprs
719  make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, MAccVgprs),
720  make_tuple(Number<NRepeat>{} * MAccVgprs * AccStride,
721  Number<NRepeat>{} * MAccVgprs * AccStride,
722  Number<NRepeat>{} * MAccVgprs * AccStride,
723  MAccVgprs * AccStride,
724  MAccVgprs * AccStride,
725  MAccVgprs * AccStride,
726  AccStride));
727  }
728 
729  template <typename CGridDesc_M_N>
730  __host__ __device__ static constexpr auto
732  const CGridDesc_M_N& c_grid_desc_m_n)
733  {
734  const auto M = c_grid_desc_m_n.GetLength(I0);
735  const auto N = c_grid_desc_m_n.GetLength(I1);
736 
737  const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma =
739  c_grid_desc_m_n,
740  make_tuple(
741  make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)),
742  make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))),
745 
746  return wmma_gemm
747  .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
748  c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
749  }
750 
751  // transposed WMMA output C' = B' * A'
752  __host__ __device__ static constexpr auto
754  {
755  constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
757  Number<MWaves>{},
759  Number<NRepeat>{},
760  Number<NWaves>{},
761  Number<NPerWMMA>{}));
762 
763  return wmma_gemm
764  .MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(
765  c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
766  }
767 
768  // Provide dimension size
769  __host__ __device__ static constexpr auto
771  {
772  constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
774  Number<MWaves>{},
776  Number<NRepeat>{},
777  Number<NWaves>{},
778  Number<NPerWMMA>{}));
779 
780  return wmma_gemm
781  .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
782  c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
783  }
784 
785  // Describe how data allocated in thread copy src buffer
786  // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
787  static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1;
788  static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1;
789 
790  template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
791  __device__ void Run(const ABlockBuffer& a_block_buf,
792  const BBlockBuffer& b_block_buf,
793  CThreadBuffer& c_thread_buf) const
794  {
795  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
796  a_thread_desc_.GetElementSpaceSize());
797  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
798  b_thread_desc_.GetElementSpaceSize());
799 
800  // basic intrinsic to determine loopover direction
801  if constexpr(MRepeat < NRepeat)
802  {
803  static_for<0, KPerBlock / KPack, 1>{}(
804  [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
805  static_for<0, MRepeat, 1>{}([&](auto m0) {
806  // read A
807  a_thread_copy_.Run(
810  a_block_buf,
812  make_tuple(I0, m0, I0, I0, I0, I0),
813  a_thread_buf);
814 
815  static_for<0, NRepeat, 1>{}([&](auto n0) {
816  // read B
817  b_thread_copy_.Run(
820  b_block_buf,
822  make_tuple(I0, n0, I0, I0, I0, I0),
823  b_thread_buf);
824 
825  vector_type<FloatA, KPack> a_thread_vec;
826  vector_type<FloatB, KPack> b_thread_vec;
827 
828  static_for<0, KPack, 1>{}([&](auto i) {
829  a_thread_vec.template AsType<FloatA>()(i) =
830  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
831  make_tuple(i / A_K1 / A_KRow,
832  m0,
833  0,
834  (i / A_K1) % A_KRow,
835  0,
836  i % A_K1))>{}];
837  b_thread_vec.template AsType<FloatB>()(i) =
838  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
839  make_tuple(i / B_K1 / B_KRow,
840  n0,
841  0,
842  (i / B_K1) % B_KRow,
843  0,
844  i % B_K1))>{}];
845  });
846 
847  using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
848  using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
849 
850  constexpr index_t c_offset =
851  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
852 
853  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
854  b_thread_vec.template AsType<wmma_input_type_b>(),
855  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
856  });
857  });
858  });
859  }
860  else
861  {
862  static_for<0, NRepeat, 1>{}([&](auto n0) {
863  static_for<0, MRepeat, 1>{}([&](auto m0) {
864  static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of
865  // k=0,kpack*1, ..
866  // read B
867  b_thread_copy_.Run(
870  b_block_buf,
872  make_tuple(I0, n0, I0, I0, I0, I0),
873  b_thread_buf);
874  // read A
875  a_thread_copy_.Run(
878  a_block_buf,
880  make_tuple(I0, m0, I0, I0, I0, I0),
881  a_thread_buf);
882 
883  vector_type<FloatA, KPack> a_thread_vec;
884  vector_type<FloatB, KPack> b_thread_vec;
885 
886  static_for<0, KPack, 1>{}([&](auto i) {
887  b_thread_vec.template AsType<FloatB>()(i) =
888  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
889  make_tuple(i / B_K1 / B_KRow,
890  n0,
891  0,
892  (i / B_K1) % B_KRow,
893  0,
894  i % B_K1))>{}];
895  a_thread_vec.template AsType<FloatA>()(i) =
896  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
897  make_tuple(i / A_K1 / A_KRow,
898  m0,
899  0,
900  (i / A_K1) % A_KRow,
901  0,
902  i % A_K1))>{}];
903  });
904 
905  using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
906  using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
907 
908  constexpr index_t c_offset =
909  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
910 
911  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
912  b_thread_vec.template AsType<wmma_input_type_b>(),
913  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
914  });
915  });
916  });
917  }
918  }
919 
920  protected:
921  static constexpr auto a_thread_desc_ =
923  Number<MRepeat>{},
924  I1,
925  Number<A_KRow>{},
926  I1,
927  Number<A_K1>{}),
928  make_tuple(Number<A_K1 * A_KRow>{},
929  Number<KPack>{},
930  Number<A_K1 * A_KRow>{},
931  Number<A_K1>{},
932  Number<A_K1>{},
933  Number<1>{}));
934 
935  static constexpr auto b_thread_desc_ =
937  Number<NRepeat>{},
938  I1,
939  Number<B_KRow>{},
940  I1,
941  Number<B_K1>{}),
942  make_tuple(Number<B_K1 * B_KRow>{},
943  Number<KPack>{},
944  Number<B_K1 * B_KRow>{},
945  Number<B_K1>{},
946  Number<B_K1>{},
947  Number<1>{}));
948 
949  // C[M, N, NumRegWMMA]
951  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
952 
953  template <bool EnableLds>
955 
956  template <>
957  struct AThreadCopySelector<true>
958  {
959  using type =
961  FloatA,
962  decltype(a_block_desc_k0_m0_m1_m2_k1),
963  decltype(a_thread_desc_),
964  Sequence<KPack / A_K1 / A_KRow, 1, 1, A_KRow, 1, A_K1>,
966  5,
967  A_K1,
968  A_K1>;
969  };
970 
971  template <>
972  struct AThreadCopySelector<false>
973  {
975  FloatA,
976  FloatA,
977  decltype(a_block_desc_k0_m0_m1_m2_k1),
978  decltype(a_thread_desc_),
980  Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
982  5,
983  A_K1,
984  0x76543210,
985  0xfedcba98,
986  TransposeC ? false : true>;
987  };
988 
989  template <bool EnableLds>
991 
992  template <>
993  struct BThreadCopySelector<true>
994  {
995  using type =
997  FloatB,
998  decltype(b_block_desc_k0_n0_n1_n2_k1),
999  decltype(b_thread_desc_),
1000  Sequence<KPack / B_K1 / B_KRow, 1, 1, B_KRow, 1, B_K1>,
1002  5,
1003  B_K1,
1004  B_K1>;
1005  };
1006 
1007  template <>
1008  struct BThreadCopySelector<false>
1009  {
1011  FloatB,
1012  FloatB,
1013  decltype(b_block_desc_k0_n0_n1_n2_k1),
1014  decltype(b_thread_desc_),
1016  Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
1018  5,
1019  B_K1,
1020  0x76543210,
1021  0xfedcba98,
1022  TransposeC ? true : false>;
1023  };
1024 
1027 };
1028 #endif
1029 
1030 } // namespace ck
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: blockwise_gemm_wmma.hpp:954
Definition: blockwise_gemm_wmma.hpp:990
Definition: blockwise_gemm_wmma.hpp:550
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: blockwise_gemm_wmma.hpp:731
__host__ static constexpr __device__ auto GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
Definition: blockwise_gemm_wmma.hpp:709
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, FloatAcc, MRepeat *NRepeat, wmma_gemm.GetRegSizePerWmma(), true > c_thread_buf_
Definition: blockwise_gemm_wmma.hpp:583
static constexpr index_t NWaves
Definition: blockwise_gemm_wmma.hpp:576
static constexpr index_t A_KRow
Definition: blockwise_gemm_wmma.hpp:567
static constexpr auto b_thread_desc_
Definition: blockwise_gemm_wmma.hpp:935
static constexpr index_t B_K1
Definition: blockwise_gemm_wmma.hpp:570
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: blockwise_gemm_wmma.hpp:559
__device__ void Run(const ABlockBuffer &a_block_buf, const BBlockBuffer &b_block_buf, CThreadBuffer &c_thread_buf) const
Definition: blockwise_gemm_wmma.hpp:791
static constexpr index_t A_K1
Definition: blockwise_gemm_wmma.hpp:569
static constexpr auto I0
Definition: blockwise_gemm_wmma.hpp:551
static constexpr auto I5
Definition: blockwise_gemm_wmma.hpp:556
static constexpr index_t B_KRow
Definition: blockwise_gemm_wmma.hpp:568
__host__ static constexpr __device__ auto GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
Definition: blockwise_gemm_wmma.hpp:770
BThreadCopySelector< BEnableLds >::type b_thread_copy_
Definition: blockwise_gemm_wmma.hpp:1026
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_wmma.hpp:585
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >)
Definition: blockwise_gemm_wmma.hpp:635
static constexpr auto I1
Definition: blockwise_gemm_wmma.hpp:552
static constexpr auto I3
Definition: blockwise_gemm_wmma.hpp:554
static constexpr index_t MWaves
Definition: blockwise_gemm_wmma.hpp:575
decltype(CalculateAThreadOriginDataIndex()) Tuple6
Definition: blockwise_gemm_wmma.hpp:676
static constexpr auto a_thread_desc_
Definition: blockwise_gemm_wmma.hpp:921
static constexpr index_t WaveSize
Definition: blockwise_gemm_wmma.hpp:562
static __device__ auto CalculateAThreadOriginDataIndex()
Definition: blockwise_gemm_wmma.hpp:600
__host__ static constexpr __device__ auto GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
Definition: blockwise_gemm_wmma.hpp:753
static constexpr auto c_thread_desc_
Definition: blockwise_gemm_wmma.hpp:950
static constexpr auto WmmaK
Definition: blockwise_gemm_wmma.hpp:557
static constexpr auto I4
Definition: blockwise_gemm_wmma.hpp:555
static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1
Definition: blockwise_gemm_wmma.hpp:787
static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1
Definition: blockwise_gemm_wmma.hpp:788
static __device__ auto GetWaveIdx()
Definition: blockwise_gemm_wmma.hpp:587
__host__ static constexpr __device__ auto GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
Definition: blockwise_gemm_wmma.hpp:694
static constexpr auto wmma_gemm
Definition: blockwise_gemm_wmma.hpp:572
static __device__ auto CalculateBThreadOriginDataIndex()
Definition: blockwise_gemm_wmma.hpp:617
__host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin=CalculateAThreadOriginDataIndex(), Tuple6 b_origin=CalculateBThreadOriginDataIndex())
Definition: blockwise_gemm_wmma.hpp:677
static constexpr auto I2
Definition: blockwise_gemm_wmma.hpp:553
AThreadCopySelector< AEnableLds >::type a_thread_copy_
Definition: blockwise_gemm_wmma.hpp:1025
static __device__ auto CalculateCThreadOriginDataIndex7D(Number< m0 >, Number< n0 >)
Definition: blockwise_gemm_wmma.hpp:663
Definition: sequence.hpp:43
Definition: static_buffer.hpp:75
static __device__ index_t GetThreadId()
Definition: thread_group.hpp:19
static constexpr __device__ index_t GetNumOfThread()
Definition: thread_group.hpp:15
Definition: threadwise_tensor_slice_transfer.hpp:1877
Definition: threadwise_tensor_slice_transfer.hpp:1260
Definition: wmma_gemm.hpp:663
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:334
Definition: dtype_vector.hpp:10