/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp Source File
blockwise_gemm_pipeline_wmmaops_base.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
12 
13 namespace ck {
14 
15 template <index_t BlockSize,
16  typename ADataType,
17  typename BDataType,
18  typename ComputeTypeA,
19  typename ComputeTypeB,
20  typename AccDataType,
21  typename AWmmaTileDesc,
22  typename BWmmaTileDesc,
23  index_t ABlockTransferSrcScalarPerVector,
24  index_t BBlockTransferSrcScalarPerVector,
25  index_t MPerBlock,
26  index_t NPerBlock,
27  index_t KPerBlock,
28  index_t MPerWmma,
29  index_t NPerWmma,
30  index_t MRepeat,
31  index_t NRepeat,
32  index_t KPack,
33  index_t KInner,
34  bool TransposeC = false>
36 {
37  static constexpr auto I0 = Number<0>{};
38  static constexpr auto I1 = Number<1>{};
39  static constexpr auto I2 = Number<2>{};
40  static constexpr auto I3 = Number<3>{};
41  static constexpr auto I5 = Number<5>{};
42  static constexpr auto I6 = Number<6>{};
43 
45 
46  static constexpr index_t WaveSize = 32;
47 
48  static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
49  static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
50 
51 #if defined(__gfx12__)
52  static constexpr index_t A_KRow = 2;
53  static constexpr index_t B_KRow = 2;
54 #else
55  static constexpr index_t A_KRow = 1;
56  static constexpr index_t B_KRow = 1;
57 #endif
58 
59  static constexpr auto wmma_gemm = WmmaGemm<ComputeTypeA,
60  ComputeTypeB,
61  AccDataType,
62  MPerWmma,
63  NPerWmma,
64  KPack / KInner,
65  TransposeC>{};
66 
67  static constexpr index_t KPerThread = wmma_gemm.wmma_instr.k_per_blk * KInner;
68  static constexpr index_t A_K1 = ck::math::min(AWmmaTileDesc{}.GetLength(I6), KPerThread);
69  static constexpr index_t B_K1 = ck::math::min(BWmmaTileDesc{}.GetLength(I6), KPerThread);
70 
71  static_assert(KPack % (A_K1 * A_KRow) == 0, "wrong!");
72  static_assert(KPack % (B_K1 * B_KRow) == 0, "wrong!");
73  static constexpr index_t KRepeat = KPerBlock / KPack;
74 
75  static constexpr auto WmmaK = Number<wmma_gemm.wmma_instr.k_per_wmma>{};
76 
79  MPerBlock,
80  NPerBlock,
81  KPerBlock,
82  ABlockTransferSrcScalarPerVector,
83  BBlockTransferSrcScalarPerVector,
84  A_K1,
85  B_K1,
86  A_K1,
87  B_K1,
88  MRepeat,
89  NRepeat,
90  MPerWmma,
91  NPerWmma,
92  wmma_gemm.wmma_instr.k_per_wmma>;
93 
95  AccDataType,
96  MRepeat * NRepeat,
97  wmma_gemm.GetRegSizePerWmma(),
98  true>
100 
101  struct Empty
102  {
103  __device__ Empty() {};
104  template <index_t NBuffer>
105  __device__ void GlobalLoad(bool cond)
106  {
107  ignore = NBuffer;
108  ignore = cond;
109  }
110  };
111 
112  template <index_t ScaleSliceSizeMN,
113  index_t ScaleSliceStrideMN,
114  index_t ScaleSliceSizeK,
115  index_t NumberOfBuffers,
116  index_t RegSizePerWmma,
117  typename GridDesc,
118  typename ThreadCopy,
119  typename GridBuffer,
120  typename ThreadStaticBuffer,
121  typename ThreadDesc>
122  struct ABScale
123  {
124  __device__ ABScale(GridDesc scale_grid_desc_,
125  ThreadCopy scale_thread_copy_,
126  GridBuffer scale_grid_buf_)
127  : scale_thread_copy(scale_thread_copy_),
128  scale_grid_desc(scale_grid_desc_),
129  scale_grid_buf(scale_grid_buf_) {};
130 
131  static constexpr index_t num_scale_k_block = ThreadDesc{}.GetLength(Number<1>{});
133 
134  static constexpr index_t num_slice_mn = ScaleSliceSizeMN;
135  static constexpr index_t num_slice_k = ScaleSliceSizeK;
136  static constexpr index_t reg_size_per_wmma = RegSizePerWmma;
137 
138  static constexpr auto scale_thread_desc = ThreadDesc{};
139 
140  static constexpr auto scale_thread_copy_step =
141  make_tuple(make_multi_index(ScaleSliceStrideMN, 0),
142  make_multi_index(-ScaleSliceSizeMN / RegSizePerWmma * ScaleSliceStrideMN, 0),
143  make_multi_index(-ScaleSliceSizeMN / RegSizePerWmma * ScaleSliceStrideMN,
144  ScaleSliceSizeK));
145 
146  template <index_t NBuffer>
147  __device__ void GlobalLoad(bool cond)
148  {
149  static_for<0, ScaleSliceSizeMN / RegSizePerWmma, 1>{}([&](auto m0) {
155 
156  scale_thread_copy.MoveSrcSliceWindow(scale_grid_desc,
158  });
159 
160  if(cond)
161  {
162  scale_thread_copy.MoveSrcSliceWindow(scale_grid_desc,
164  }
165  else
166  {
167  scale_thread_copy.MoveSrcSliceWindow(scale_grid_desc,
169  }
170  }
171 
172  ThreadCopy scale_thread_copy;
173  GridDesc scale_grid_desc;
174  GridBuffer scale_grid_buf;
176  };
177 
178  template <typename AScaleStruct, typename BScaleStruct>
179  struct CScale
180  {
181  __device__ CScale() {}
182 
183  static constexpr auto reg_size_per_wmma =
184  ck::is_same_v<BScaleStruct, Empty> && ck::is_same_v<AScaleStruct, Empty>
185  ? 1
186  : wmma_gemm.GetRegSizePerWmma();
187  static constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
188  Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{},
191  using CScaleThreadDesc = decltype(c_scale_thread_desc);
192  static constexpr auto num_scale_k_block = CScaleThreadDesc{}.GetLength(Number<0>{});
193  static constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{});
194  static constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{});
195  using ThreadStaticBuffer = decltype(make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
196  c_scale_thread_desc.GetElementSpaceSize()));
197 
198  __device__ void Load(AScaleStruct& a_scale_struct, BScaleStruct& b_scale_struct)
199  {
200  using AScaleThreadDesc = decltype(AScaleStruct::scale_thread_desc);
201  using BScaleThreadDesc = decltype(BScaleStruct::scale_thread_desc);
202 
203  static_for<0, num_scale_m_block, 1>{}([&](auto m0) {
204  static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
205  static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
206  constexpr index_t c_offset =
207  CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
208  constexpr index_t a_offset =
209  AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
210  constexpr index_t b_offset =
211  BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
212 
213  c_scale_thread_bufs(I0)(Number<c_offset>{}) =
214  a_scale_struct.scale_thread_bufs(I0)[Number<a_offset>{}] *
215  b_scale_struct.scale_thread_bufs(I0)[Number<b_offset>{}];
216  });
217  });
218  });
219  }
220 
221  __device__ void Clear()
222  {
224  c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
225  .template AsType<AccDataType>()(Number<t>{}) = 0;
226  });
227  }
228 
229  template <index_t k_index, index_t m_index, index_t n_index, typename CThreadBuf>
230  __device__ void UpdateCThreadBuf(CThreadBuf& c_thread_buf)
231  {
233  constexpr index_t c_offset =
234  c_thread_desc_.CalculateOffset(make_tuple(m_index, n_index, t));
235  constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(make_tuple(
236  k_index,
237  (m_index * num_scale_m_block / MRepeat) % num_scale_m_block +
238  (Number<t / (reg_size_per_wmma / AScaleStruct::reg_size_per_wmma)>{}) %
239  AScaleStruct::reg_size_per_wmma,
240  (n_index * num_scale_n_block / NRepeat) % num_scale_n_block));
241  c_thread_buf(Number<c_offset>{}) +=
242  c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
243  .template AsType<AccDataType>()[Number<t>{}] *
244  type_convert<AccDataType>(c_scale_thread_bufs(I0)[Number<cscale_offset>{}]);
245  });
246  }
247 
251  };
252 
253  __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
254 
255  __device__ static auto GetWaveIdx()
256  {
257  const index_t thread_id = ThisThreadBlock::GetThreadId();
258 
259  constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
260  make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
263 
264  return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
265  }
266 
267  __device__ static auto CalculateAThreadOriginDataIndex()
268  {
269  const auto wave_idx = GetWaveIdx();
270 
271  const auto waveId_m = wave_idx[I0];
272 
273  const auto wmma_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
274 
275 #if defined(__gfx12__)
276  const auto wmma_krow = wmma_gemm.GetSubGroupId();
277 #else
278  const auto wmma_krow = 0;
279 #endif
280 
281  return make_tuple(0, 0, 0, waveId_m, wmma_krow, wmma_a_idx, 0);
282  }
283 
284  __device__ static auto CalculateBThreadOriginDataIndex()
285  {
286  const auto wave_idx = GetWaveIdx();
287 
288  const auto waveId_n = wave_idx[I1];
289 
290  const auto wmma_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
291 
292 #if defined(__gfx12__)
293  const auto wmma_krow = wmma_gemm.GetSubGroupId();
294 #else
295  const auto wmma_krow = 0;
296 #endif
297 
298  return make_tuple(0, 0, 0, waveId_n, wmma_krow, wmma_b_idx, 0);
299  }
300 
301  template <index_t m0, index_t n0>
303  {
304  const auto wave_idx = GetWaveIdx();
305 
306  const auto waveId_m = wave_idx[I0];
307  const auto waveId_n = wave_idx[I1];
308 
309  const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk();
310 
311  constexpr auto mrepeat_mwave_mperwmma_to_m_adaptor = make_single_stage_tensor_adaptor(
312  make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWmma))),
315 
316  constexpr auto nrepeat_nwave_nperwmma_to_n_adaptor = make_single_stage_tensor_adaptor(
317  make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWmma))),
320 
321  const index_t c_thread_m = mrepeat_mwave_mperwmma_to_m_adaptor.CalculateBottomIndex(
322  make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
323  const index_t c_thread_n = nrepeat_nwave_nperwmma_to_n_adaptor.CalculateBottomIndex(
324  make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
325 
326  return make_tuple(c_thread_m, c_thread_n);
327  }
328 
329  using Tuple7 = decltype(CalculateAThreadOriginDataIndex());
330 
348  __host__ __device__
349  BlockwiseGemmWmmaops_pipeline_base(Tuple7 a_origin = CalculateAThreadOriginDataIndex(),
350  Tuple7 b_origin = CalculateBThreadOriginDataIndex())
351  : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
352  {
353  static_assert(AWmmaTileDesc::IsKnownAtCompileTime() &&
354  BWmmaTileDesc::IsKnownAtCompileTime(),
355  "wrong! Desc should be known at compile-time");
356 
357  static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
358  "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize");
359 
360  static_assert(MPerBlock % (MPerWmma * MRepeat) == 0 &&
361  NPerBlock % (NPerWmma * NRepeat) == 0,
362  "wrong!");
363  }
364 
365  // transposed WMMA output C' = B' * A'
366  __host__ __device__ static constexpr auto
368  {
369  constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
370  wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
371 
372  constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
373 
375  // |MRepeat |MWave |MSubGroup |NRepeat |NWave
376  // |NThreadPerSubGroup |MAccVgprs
377  make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, NAccVgprs));
378  }
379 
380  static constexpr auto MAccVgprs =
381  wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()[I2];
382 
383  __host__ __device__ static constexpr auto
385  {
386  constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
387  wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
388 
389  constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3];
391  // |MRepeat |MWave |MSubGroup |NRepeat |NWave
392  // |NThreadPerSubGroup |MAccVgprs
393  make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, MAccVgprs),
394  make_tuple(Number<NRepeat>{} * MAccVgprs * AccStride,
395  Number<NRepeat>{} * MAccVgprs * AccStride,
396  Number<NRepeat>{} * MAccVgprs * AccStride,
397  MAccVgprs * AccStride,
398  MAccVgprs * AccStride,
399  MAccVgprs * AccStride,
400  AccStride));
401  }
402 
403  __host__ __device__ static constexpr auto
405  {
406  constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
408  Number<MWaves>{},
410  Number<NRepeat>{},
411  Number<NWaves>{},
412  Number<NPerWmma>{}));
413 
414  return wmma_gemm
415  .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
416  c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
417  }
418 
419  // Describe how data allocated in thread copy src buffer
420  // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
421  static constexpr AWmmaTileDesc a_block_desc_k0_m0_m1_m2_k1;
422  static constexpr BWmmaTileDesc b_block_desc_k0_n0_n1_n2_k1;
423 
424  protected:
425  static constexpr auto a_thread_desc_ =
427  Number<MRepeat>{},
428  Number<KRepeat>{},
429  I1,
430  I1,
431  I1,
432  Number<A_K1>{}),
433  make_tuple(Number<A_K1>{},
434  Number<KPack / A_KRow>{},
435  Number<KPack / A_KRow * MRepeat>{},
436  I0,
437  I0,
438  I0,
439  I1));
440 
441  static constexpr auto b_thread_desc_ =
443  Number<NRepeat>{},
444  Number<KRepeat>{},
445  I1,
446  I1,
447  I1,
448  Number<B_K1>{}),
449  make_tuple(Number<B_K1>{},
450  Number<KPack / B_KRow>{},
451  Number<KPack / B_KRow * NRepeat>{},
452  I0,
453  I0,
454  I0,
455  I1));
456 
457  // C[M, N, NumRegWmma]
458  static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
459  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
460 
461  using AThreadCopy =
463  ComputeTypeA,
464  decltype(a_block_desc_k0_m0_m1_m2_k1),
465  decltype(a_thread_desc_),
466  Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, 1, A_K1>,
468  6,
469  A_K1,
470  A_K1>;
471 
472  using BThreadCopy =
474  ComputeTypeB,
475  decltype(b_block_desc_k0_n0_n1_n2_k1),
476  decltype(b_thread_desc_),
477  Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, 1, B_K1>,
479  6,
480  B_K1,
481  B_K1>;
482 
485 };
486 
487 } // namespace ck
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: ck.hpp:270
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:301
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:123
__device__ void GlobalLoad(bool cond)
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:147
StaticallyIndexedArray< ThreadStaticBuffer, Number< NumberOfBuffers >{}> scale_thread_bufs
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:175
static constexpr index_t num_slice_k
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:135
GridDesc scale_grid_desc
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:173
static constexpr index_t reg_size_per_wmma
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:136
static constexpr index_t num_slice_mn
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:134
GridBuffer scale_grid_buf
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:174
static constexpr auto scale_thread_desc
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:138
static constexpr auto scale_thread_copy_step
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:140
ThreadCopy scale_thread_copy
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:172
__device__ ABScale(GridDesc scale_grid_desc_, ThreadCopy scale_thread_copy_, GridBuffer scale_grid_buf_)
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:124
static constexpr index_t num_scale_k_block
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:131
static constexpr index_t num_scale_krepeat
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:132
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:180
__device__ void Load(AScaleStruct &a_scale_struct, BScaleStruct &b_scale_struct)
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:198
__device__ void Clear()
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:221
__device__ void UpdateCThreadBuf(CThreadBuf &c_thread_buf)
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:230
decltype(c_scale_thread_desc) CScaleThreadDesc
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:191
decltype(make_static_buffer< AddressSpaceEnum::Vgpr, AccDataType >(c_scale_thread_desc.GetElementSpaceSize())) ThreadStaticBuffer
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:196
__device__ CScale()
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:181
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, AccDataType, 1, reg_size_per_wmma, true > c_thread_buf_per_scale
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:250
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:102
__device__ void GlobalLoad(bool cond)
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:105
__device__ Empty()
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:103
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:36
static constexpr auto I2
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:39
__host__ static constexpr __device__ auto GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:384
static constexpr index_t NWaves
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:49
decltype(CalculateAThreadOriginDataIndex()) Tuple7
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:329
static constexpr BWmmaTileDesc b_block_desc_k0_n0_n1_n2_k1
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:422
static constexpr index_t KPerThread
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:67
static constexpr index_t MWaves
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:48
BThreadCopy b_thread_copy_
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:484
static constexpr index_t A_KRow
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:55
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:44
static constexpr auto WmmaK
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:75
static __device__ auto CalculateAThreadOriginDataIndex()
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:267
__host__ static constexpr __device__ auto GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:404
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >)
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:302
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, AccDataType, MRepeat *NRepeat, wmma_gemm.GetRegSizePerWmma(), true > c_thread_buf_
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:99
static constexpr auto I5
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:41
static constexpr auto I3
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:40
static constexpr index_t WaveSize
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:46
__host__ static constexpr __device__ auto GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:367
static __device__ auto GetWaveIdx()
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:255
__host__ __device__ BlockwiseGemmWmmaops_pipeline_base(Tuple7 a_origin=CalculateAThreadOriginDataIndex(), Tuple7 b_origin=CalculateBThreadOriginDataIndex())
Constructor for BlockwiseGemmWmmaops_pipeline_base.
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:349
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:253
static __device__ auto CalculateBThreadOriginDataIndex()
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:284
AThreadCopy a_thread_copy_
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:483
static constexpr index_t B_KRow
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:56
static constexpr index_t B_K1
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:69
static constexpr index_t A_K1
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:68
static constexpr auto I6
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:42
static constexpr index_t KRepeat
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:73
static constexpr AWmmaTileDesc a_block_desc_k0_m0_m1_m2_k1
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:421
static constexpr auto I1
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:38
static constexpr auto I0
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:37
static constexpr auto wmma_gemm
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:59
Definition: blockwise_gemm_pipeline_wmmaops.hpp:26
Definition: sequence.hpp:43
Definition: static_buffer.hpp:75
Definition: wmma_gemm.hpp:675
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33