/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp Source File
blockwise_gemm_pipeline_wmmaops_v3.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 // Compute optimized pipeline
11 // GlobalPrefetchStages: 2
12 // LocalPreFillStages: 1
13 // LocalPreFetchStages: 1
14 // LocalSharedMemoryBuffer: 1
15 
16 template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17  index_t BlockSize,
18  typename ADataType,
19  typename BDataType,
20  typename ComputeTypeA,
21  typename ComputeTypeB,
22  typename AccDataType,
23  typename AWmmaTileDesc,
24  typename BWmmaTileDesc,
25  index_t ABlockTransferSrcScalarPerVector,
26  index_t BBlockTransferSrcScalarPerVector,
27  index_t MPerBlock,
28  index_t NPerBlock,
29  index_t KPerBlock,
30  index_t MPerWmma,
31  index_t NPerWmma,
32  index_t MRepeat,
33  index_t NRepeat,
34  index_t KPack,
35  index_t KInner,
36  bool TransposeC = false,
37  bool BSkipLDS = false>
39 {
40 };
41 
42 template <index_t BlockSize,
43  typename ADataType,
44  typename BDataType,
45  typename ComputeTypeA,
46  typename ComputeTypeB,
47  typename AccDataType,
48  typename AWmmaTileDesc,
49  typename BWmmaTileDesc,
50  index_t ABlockTransferSrcScalarPerVector,
51  index_t BBlockTransferSrcScalarPerVector,
52  index_t MPerBlock,
53  index_t NPerBlock,
54  index_t KPerBlock,
55  index_t MPerWmma,
56  index_t NPerWmma,
57  index_t MRepeat,
58  index_t NRepeat,
59  index_t KPack,
60  index_t KInner,
61  bool TransposeC>
63  BlockSize,
64  ADataType,
65  BDataType,
66  ComputeTypeA,
67  ComputeTypeB,
68  AccDataType,
69  AWmmaTileDesc,
70  BWmmaTileDesc,
71  ABlockTransferSrcScalarPerVector,
72  BBlockTransferSrcScalarPerVector,
73  MPerBlock,
74  NPerBlock,
75  KPerBlock,
76  MPerWmma,
77  NPerWmma,
78  MRepeat,
79  NRepeat,
80  KPack,
81  KInner,
82  TransposeC,
83  false>
85  ADataType,
86  BDataType,
87  ComputeTypeA,
88  ComputeTypeB,
89  AccDataType,
90  AWmmaTileDesc,
91  BWmmaTileDesc,
92  ABlockTransferSrcScalarPerVector,
93  BBlockTransferSrcScalarPerVector,
94  MPerBlock,
95  NPerBlock,
96  KPerBlock,
97  MPerWmma,
98  NPerWmma,
99  MRepeat,
100  NRepeat,
101  KPack,
102  KInner,
103  TransposeC>
104 {
106  ADataType,
107  BDataType,
108  ComputeTypeA,
109  ComputeTypeB,
110  AccDataType,
111  AWmmaTileDesc,
112  BWmmaTileDesc,
113  ABlockTransferSrcScalarPerVector,
114  BBlockTransferSrcScalarPerVector,
115  MPerBlock,
116  NPerBlock,
117  KPerBlock,
118  MPerWmma,
119  NPerWmma,
120  MRepeat,
121  NRepeat,
122  KPack,
123  KInner,
124  TransposeC>;
125  using Base::I0;
126  using Base::I1;
127  using Base::I2;
128  using Base::I3;
129 
130  using Base::A_K1;
131  using Base::A_KRow;
132  using Base::B_K1;
133  using Base::B_KRow;
134  using Base::KRepeat;
135  using Base::WmmaK;
136 
137  using Base::wmma_gemm;
138  using typename Base::HotLoopInstList;
139 
140  using Base::CalculateCThreadOriginDataIndex;
141  using Base::
142  GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
143  using Base::GetCThreadBuffer;
144  using Base::
145  GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
146  using Base::
147  GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs;
148 
149  using Base::a_block_desc_k0_m0_m1_m2_k1;
150  using Base::b_block_desc_k0_n0_n1_n2_k1;
151 
152  using typename Base::Empty;
153 
154  static constexpr index_t PrefetchStages = 2;
155  static constexpr index_t PrefillStages = 1;
156  static constexpr index_t GlobalBufferNum = 1;
157 
158  __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
159  {
160  return num_loop > PrefetchStages;
161  }
162 
163  __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
164  {
165  if(BlockHasHotloop(num_loop))
166  {
167  return TailNumber::Full;
168  }
169  else
170  {
171  if(num_loop == 1)
172  {
173  return TailNumber::Odd;
174  }
175  else
176  {
177  return TailNumber::Even;
178  }
179  }
180  }
181 
182  __device__ static constexpr auto HotLoopScheduler()
183  {
184  // TODO: Calculation of the number of instructions may require changes for WMMA
185  /*
186  // A/B split schedule
187  // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
188  constexpr auto num_ds_read_inst_a =
189  HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
190  ? HotLoopInstList::A_LDS_Read_Inst_Num
191  : HotLoopInstList::A_LDS_Read_Inst_Num / 2;
192  constexpr auto num_ds_read_inst_b =
193  HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
194  ? HotLoopInstList::B_LDS_Read_Inst_Num
195  : HotLoopInstList::B_LDS_Read_Inst_Num / 2;
196 
197  constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
198  constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
199 
200  constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
201  constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
202 
203  constexpr auto num_wmma_inst = HotLoopInstList::C_WMMA_Inst_Num;
204 
205  constexpr auto wmma_cycle = NPerWmma == 16 ? 16 : 32;
206  constexpr auto ds_read_a_issue_cycle =
207  HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
208  constexpr auto ds_read_b_issue_cycle =
209  HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
210  constexpr auto ds_read_a_wmma_rate =
211  (wmma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
212  constexpr auto ds_read_b_wmma_rate =
213  (wmma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
214 
215  constexpr auto num_dsread_a_wmma =
216  (num_ds_read_inst_a + ds_read_a_wmma_rate - 1) / ds_read_a_wmma_rate;
217  constexpr auto num_dsread_b_wmma =
218  (num_ds_read_inst_b + ds_read_b_wmma_rate - 1) / ds_read_b_wmma_rate;
219 
220  // stage 1
221  // Separate this part?
222  // constexpr auto num_wmma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
223  // sizeof(ComputeDataType) / sizeof(BDataType)
224  // ? sizeof(ComputeDataType) / sizeof(ADataType)
225  // : sizeof(ComputeDataType) / sizeof(BDataType);
226  constexpr auto num_wmma_stage1 = num_wmma_inst - (num_dsread_a_wmma + num_dsread_b_wmma);
227  constexpr auto num_wmma_per_issue =
228  num_wmma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
229  constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
230  constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
231 
232  static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
233  ignore = i;
234  static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
235  ignore = idswrite;
236  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
237  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
238  });
239  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
240  __builtin_amdgcn_sched_group_barrier(
241  0x008, num_wmma_per_issue - num_dswrite_per_issue_a, 0); // WMMA
242  });
243  static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
244  ignore = i;
245  static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
246  ignore = idswrite;
247  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
248  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
249  });
250  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
251  __builtin_amdgcn_sched_group_barrier(
252  0x008, num_wmma_per_issue - num_dswrite_per_issue_b, 0); // WMMA
253  });
254 
255  // stage 2
256  static_for<0, num_dsread_a_wmma, 1>{}([&](auto i) {
257  if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_wmma_rate) >=
258  ds_read_a_wmma_rate)
259  {
260  __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_wmma_rate, 0); // DS read
261  }
262  else
263  {
264  __builtin_amdgcn_sched_group_barrier(0x100,
265  num_ds_read_inst_a - (num_dsread_a_wmma - 1) *
266  ds_read_a_wmma_rate,
267  0); // DS read
268  }
269  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
270  });
271 
272  static_for<0, num_dsread_b_wmma, 1>{}([&](auto i) {
273  if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_wmma_rate) >=
274  ds_read_b_wmma_rate)
275  {
276  __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_wmma_rate, 0); // DS read
277  }
278  else
279  {
280  __builtin_amdgcn_sched_group_barrier(0x100,
281  num_ds_read_inst_b - (num_dsread_b_wmma - 1) *
282  ds_read_b_wmma_rate,
283  0); // DS read
284  }
285  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
286  });
287  */
288  }
289 
290  template <typename ABlockBuffer,
291  typename AThreadBuffer,
292  typename BBlockBuffer,
293  typename BThreadBuffer,
294  typename BScaleStruct>
295  __device__ inline void LocalLoad(ABlockBuffer& a_block_buf,
296  AThreadBuffer& a_thread_buf,
297  BBlockBuffer& b_block_buf,
298  BThreadBuffer& b_thread_buf,
299  BScaleStruct& b_scale_struct) const
300  {
301  static_for<0, KRepeat, 1>{}([&](auto k0) {
302  static_for<0, MRepeat, 1>{}([&](auto m0) {
303  a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
304  make_tuple(I0, m0, k0, I0, I0, I0, I0),
305  a_block_buf,
306  a_thread_desc_,
307  make_tuple(I0, m0, k0, I0, I0, I0, I0),
308  a_thread_buf);
309  });
310 
311  if constexpr(ck::is_same_v<BScaleStruct, Empty>)
312  {
313  static_for<0, NRepeat, 1>{}([&](auto n0) {
314  b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
315  make_tuple(I0, n0, k0, I0, I0, I0, I0),
316  b_block_buf,
317  b_thread_desc_,
318  make_tuple(I0, n0, k0, I0, I0, I0, I0),
319  b_thread_buf);
320  });
321  }
322  else
323  {
324  static_for<0, NRepeat, 1>{}([&](auto n0) {
325  b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
326  make_tuple(I0, n0, k0, I0, I0, I0, I0),
327  b_block_buf,
328  b_scale_struct.scale_thread_bufs(
329  I0)[Number<n0 * BScaleStruct::num_scale_k_block +
330  k0 / BScaleStruct::num_scale_krepeat>{}],
331  b_thread_desc_,
332  make_tuple(I0, n0, k0, I0, I0, I0, I0),
333  b_thread_buf);
334  });
335  }
336  });
337  }
338 
339  template <bool HasMainLoop,
340  TailNumber TailNum,
341  typename AGridDesc,
342  typename ABlockDesc,
343  typename ABlockTransfer,
344  typename AGridBuffer,
345  typename ABlockBuffer,
346  typename ABlockTransferStep,
347  typename BGridDesc,
348  typename BBlockDesc,
349  typename BBlockTransfer,
350  typename BGridBuffer,
351  typename BBlockBuffer,
352  typename BBlockTransferStep,
353  typename CThreadBuffer,
354  typename AScaleStruct,
355  typename BScaleStruct,
356  typename enable_if<ck::is_same_v<AScaleStruct, Empty>, bool>::type = false>
357  __device__ void Run(const AGridDesc& a_grid_desc,
358  const ABlockDesc& a_block_desc,
359  ABlockTransfer& a_blockwise_copy,
360  const AGridBuffer& a_grid_buf,
361  ABlockBuffer& a_block_buf,
362  const ABlockTransferStep& a_block_copy_step,
363  const BGridDesc& b_grid_desc,
364  const BBlockDesc& b_block_desc,
365  BBlockTransfer& b_blockwise_copy,
366  const BGridBuffer& b_grid_buf,
367  BBlockBuffer& b_block_buf,
368  const BBlockTransferStep& b_block_copy_step,
369  CThreadBuffer& c_thread_buf,
370  AScaleStruct&,
371  BScaleStruct& b_scale_struct,
372  index_t num_loop,
373  index_t num_loop_per_scale) const
374  {
375  __builtin_amdgcn_sched_barrier(0);
376 
377  constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
378 
379  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
380  a_thread_desc_.GetElementSpaceSize());
381  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
382  b_thread_desc_.GetElementSpaceSize());
383 
384  // Global prefetch 1
385  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
386  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
387 
388  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
389  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
390 
391  // Scales global load
392  b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
393 
394  // Local prefill 1
395  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
396  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
397 
398  // Global prefetch 2, perform when at least 2 loops exist.
399  if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full)
400  {
401  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
402  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
403 
404  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
405  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
406  }
407 
408  // Initialize C
409  c_thread_buf.Clear();
410 
411  // Local prefetch 1
412  block_sync_lds();
413 
414  LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
415 
416  __builtin_amdgcn_sched_barrier(0);
417 
418  // Main body, perform when at least 3 loops exist.
419  if constexpr(HasMainLoop)
420  {
421  index_t i = 0;
422  do
423  {
424  block_sync_lds();
425 
426  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
427  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
428 
429  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
430  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
431 
432  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
433  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
434 
435  b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
436 
437  static_for<0, KRepeat, 1>{}([&](auto k0) {
438  static_for<0, MRepeat, 1>{}([&](auto m0) {
439  static_for<0, NRepeat, 1>{}([&](auto n0) {
440  static_for<0, KInner, 1>{}([&](auto k_inner) {
441  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
442  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
443 
444  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
445  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
446  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
447  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
449  m0,
450  k0,
451  I0,
452  I0,
453  I0,
454  Number<kk % A_K1>{}))>{}];
455  });
456  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
457  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
458  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
459  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
461  n0,
462  k0,
463  I0,
464  I0,
465  I0,
466  Number<kk % B_K1>{}))>{}];
467  });
468 
469  using wmma_input_type_a =
470  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
471  using wmma_input_type_b =
472  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
473 
474  constexpr index_t c_offset =
475  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
476 
477  wmma_gemm.Run(
478  a_thread_vec.template AsType<wmma_input_type_a>(),
479  b_thread_vec.template AsType<wmma_input_type_b>(),
480  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
481  });
482  });
483  });
484  });
485 
486  block_sync_lds();
487 
488  LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
489 
490  HotLoopScheduler();
491  __builtin_amdgcn_sched_barrier(0);
492 
493  i += 1;
494  } while(i < (num_loop - 2));
495  }
496 
497  // Pre-tail, perform when at least 2 loops exist.
498  if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full)
499  {
500  block_sync_lds();
501 
502  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
503  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
504 
505  // No RunRead or MoveSrcSliceWindow here, already finished them all!
506 
507  b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
508 
509  static_for<0, KRepeat, 1>{}([&](auto k0) {
510  static_for<0, MRepeat, 1>{}([&](auto m0) {
511  static_for<0, NRepeat, 1>{}([&](auto n0) {
512  static_for<0, KInner, 1>{}([&](auto k_inner) {
513  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
514  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
515 
516  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
517  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
518  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
519  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
521  m0,
522  k0,
523  I0,
524  I0,
525  I0,
526  Number<kk % A_K1>{}))>{}];
527  });
528  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
529  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
530  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
531  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
533  n0,
534  k0,
535  I0,
536  I0,
537  I0,
538  Number<kk % B_K1>{}))>{}];
539  });
540 
541  using wmma_input_type_a =
542  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
543  using wmma_input_type_b =
544  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
545 
546  constexpr index_t c_offset =
547  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
548 
549  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
550  b_thread_vec.template AsType<wmma_input_type_b>(),
551  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
552  });
553  });
554  });
555  });
556 
557  block_sync_lds();
558 
559  LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
560 
561  HotLoopScheduler();
562  __builtin_amdgcn_sched_barrier(0);
563  }
564 
565  // Tail, always perform.
566  {
567  static_for<0, KRepeat, 1>{}([&](auto k0) {
568  static_for<0, MRepeat, 1>{}([&](auto m0) {
569  static_for<0, NRepeat, 1>{}([&](auto n0) {
570  static_for<0, KInner, 1>{}([&](auto k_inner) {
571  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
572  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
573 
574  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
575  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
576  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
577  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
579  m0,
580  k0,
581  I0,
582  I0,
583  I0,
584  Number<kk % A_K1>{}))>{}];
585  });
586  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
587  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
588  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
589  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
591  n0,
592  k0,
593  I0,
594  I0,
595  I0,
596  Number<kk % B_K1>{}))>{}];
597  });
598 
599  using wmma_input_type_a =
600  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
601  using wmma_input_type_b =
602  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
603 
604  constexpr index_t c_offset =
605  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
606 
607  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
608  b_thread_vec.template AsType<wmma_input_type_b>(),
609  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
610  });
611  });
612  });
613  });
614  // Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle
615  // latency
616  // __builtin_amdgcn_sched_barrier(0);
617  }
618  }
619 
620  template <bool HasMainLoop,
621  TailNumber TailNum,
622  typename AGridDesc,
623  typename ABlockDesc,
624  typename ABlockTransfer,
625  typename AGridBuffer,
626  typename ABlockBuffer,
627  typename ABlockTransferStep,
628  typename BGridDesc,
629  typename BBlockDesc,
630  typename BBlockTransfer,
631  typename BGridBuffer,
632  typename BBlockBuffer,
633  typename BBlockTransferStep,
634  typename CThreadBuffer,
635  typename AScaleStruct,
636  typename BScaleStruct,
638  !ck::is_same_v<BScaleStruct, Empty>,
639  bool>::type = false>
640  __device__ void Run(const AGridDesc& a_grid_desc,
641  const ABlockDesc& a_block_desc,
642  ABlockTransfer& a_blockwise_copy,
643  const AGridBuffer& a_grid_buf,
644  ABlockBuffer& a_block_buf,
645  const ABlockTransferStep& a_block_copy_step,
646  const BGridDesc& b_grid_desc,
647  const BBlockDesc& b_block_desc,
648  BBlockTransfer& b_blockwise_copy,
649  const BGridBuffer& b_grid_buf,
650  BBlockBuffer& b_block_buf,
651  const BBlockTransferStep& b_block_copy_step,
652  CThreadBuffer& c_thread_buf,
653  AScaleStruct& a_scale_struct,
654  BScaleStruct& b_scale_struct,
655  index_t num_loop,
656  index_t num_loop_per_scale) const
657  {
658  __builtin_amdgcn_sched_barrier(0);
659 
660  constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
661  static constexpr auto NumScaleKBlock =
662  Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{};
663 
664  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
665  a_thread_desc_.GetElementSpaceSize());
666  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
667  b_thread_desc_.GetElementSpaceSize());
668 
669  using CScaleStruct = typename Base::template CScale<AScaleStruct, BScaleStruct>;
670  auto c_scale_struct = CScaleStruct{};
671 
672  // Global prefetch 1
673  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
674  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
675 
676  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
677  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
678 
679  // Scales global load
680  a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
681  b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
682 
683  c_scale_struct.Load(a_scale_struct, b_scale_struct);
684 
685  // Local prefill 1
686  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
687  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
688 
689  // Global prefetch 2, perform when at least 2 loops exist.
690  if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full)
691  {
692  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
693  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
694 
695  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
696  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
697  }
698 
699  // Initialize C
700  c_thread_buf.Clear();
701 
702  // Local prefetch 1
703  block_sync_lds();
704 
705  auto local_load_func = [&]() {
706  static_for<0, KRepeat, 1>{}([&](auto k0) {
707  static_for<0, MRepeat, 1>{}([&](auto m0) {
708  a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
709  make_tuple(I0, m0, k0, I0, I0, I0, I0),
710  a_block_buf,
711  a_thread_desc_,
712  make_tuple(I0, m0, k0, I0, I0, I0, I0),
713  a_thread_buf);
714  });
715  static_for<0, NRepeat, 1>{}([&](auto n0) {
716  b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
717  make_tuple(I0, n0, k0, I0, I0, I0, I0),
718  b_block_buf,
719  b_thread_desc_,
720  make_tuple(I0, n0, k0, I0, I0, I0, I0),
721  b_thread_buf);
722  });
723  });
724  };
725 
726  local_load_func();
727 
728  __builtin_amdgcn_sched_barrier(0);
729 
730  // Main body, perform when at least 3 loops exist.
731  if constexpr(HasMainLoop)
732  {
733  index_t i = 0;
734  do
735  {
736  block_sync_lds();
737 
738  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
739  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
740 
741  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
742  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
743 
744  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
745  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
746 
747  a_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
748  b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
749 
750  static_for<0, MRepeat, 1>{}([&](auto m0) {
751  static_for<0, NRepeat, 1>{}([&](auto n0) {
752  static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
753  c_scale_struct.Clear();
754  static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
755  static_for<0, KInner, 1>{}([&](auto k_inner) {
756  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
757  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
758 
759  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
760  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
761  constexpr index_t k_index =
762  kscale0 * (KRepeat / NumScaleKBlock) + k0;
763  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
764  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
766  m0,
767  k_index,
768  I0,
769  I0,
770  I0,
771  Number<kk % A_K1>{}))>{}];
772  });
773  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
774  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
775  constexpr index_t k_index =
776  kscale0 * (KRepeat / NumScaleKBlock) + k0;
777  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
778  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
780  n0,
781  k_index,
782  I0,
783  I0,
784  I0,
785  Number<kk % B_K1>{}))>{}];
786  });
787 
788  using wmma_input_type_a =
789  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
790  using wmma_input_type_b =
791  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
792 
793  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
794  b_thread_vec.template AsType<wmma_input_type_b>(),
795  c_scale_struct.c_thread_buf_per_scale
796  .GetVectorTypeReference(Number<0>{}));
797  });
798  });
799  c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
800  });
801  });
802  });
803 
804  c_scale_struct.Load(a_scale_struct, b_scale_struct);
805  block_sync_lds();
806 
807  local_load_func();
808 
809  HotLoopScheduler();
810  __builtin_amdgcn_sched_barrier(0);
811 
812  i += 1;
813  } while(i < (num_loop - 2));
814  }
815 
816  // Pre-tail, perform when at least 2 loops exist.
817  if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full)
818  {
819  block_sync_lds();
820 
821  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
822  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
823 
824  // No RunRead or MoveSrcSliceWindow here, already finished them all!
825  a_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
826  b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
827 
828  static_for<0, MRepeat, 1>{}([&](auto m0) {
829  static_for<0, NRepeat, 1>{}([&](auto n0) {
830  static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
831  c_scale_struct.Clear();
832  static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
833  static_for<0, KInner, 1>{}([&](auto k_inner) {
834  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
835  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
836 
837  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
838  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
839  constexpr index_t k_index =
840  kscale0 * (KRepeat / NumScaleKBlock) + k0;
841  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
842  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
844  m0,
845  k_index,
846  I0,
847  I0,
848  I0,
849  Number<kk % A_K1>{}))>{}];
850  });
851  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
852  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
853  constexpr index_t k_index =
854  kscale0 * (KRepeat / NumScaleKBlock) + k0;
855  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
856  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
858  n0,
859  k_index,
860  I0,
861  I0,
862  I0,
863  Number<kk % B_K1>{}))>{}];
864  });
865 
866  using wmma_input_type_a =
867  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
868  using wmma_input_type_b =
869  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
870 
871  wmma_gemm.Run(
872  a_thread_vec.template AsType<wmma_input_type_a>(),
873  b_thread_vec.template AsType<wmma_input_type_b>(),
874  c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
875  Number<0>{}));
876  });
877  });
878  c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
879  });
880  });
881  });
882 
883  c_scale_struct.Load(a_scale_struct, b_scale_struct);
884  block_sync_lds();
885 
886  local_load_func();
887 
888  HotLoopScheduler();
889  __builtin_amdgcn_sched_barrier(0);
890  }
891 
892  // Tail, always perform.
893  {
894  static_for<0, MRepeat, 1>{}([&](auto m0) {
895  static_for<0, NRepeat, 1>{}([&](auto n0) {
896  static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
897  c_scale_struct.Clear();
898  static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
899  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
900  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
901  static_for<0, KInner, 1>{}([&](auto k_inner) {
902  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
903  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
904  constexpr index_t k_index =
905  kscale0 * (KRepeat / NumScaleKBlock) + k0;
906  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
907  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
909  m0,
910  k_index,
911  I0,
912  I0,
913  I0,
914  Number<kk % A_K1>{}))>{}];
915  });
916  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
917  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
918  constexpr index_t k_index =
919  kscale0 * (KRepeat / NumScaleKBlock) + k0;
920  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
921  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
923  n0,
924  k_index,
925  I0,
926  I0,
927  I0,
928  Number<kk % B_K1>{}))>{}];
929  });
930 
931  using wmma_input_type_a =
932  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
933  using wmma_input_type_b =
934  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
935 
936  wmma_gemm.Run(
937  a_thread_vec.template AsType<wmma_input_type_a>(),
938  b_thread_vec.template AsType<wmma_input_type_b>(),
939  c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
940  Number<0>{}));
941  });
942  });
943  c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
944  });
945  });
946  });
947  // Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle
948  // latency
949  // __builtin_amdgcn_sched_barrier(0);
950  }
951  }
952 
953  protected:
954  using Base::a_thread_copy_;
955  using Base::a_thread_desc_;
956  using Base::b_thread_copy_;
957  using Base::b_thread_desc_;
958  using Base::c_thread_desc_;
959 };
960 
961 } // namespace ck
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:270
TailNumber
Tail number enumeration for pipeline buffering.
Definition: scheduler_enum.hpp:49
@ Even
Even number of iterations.
@ Odd
Odd number of iterations.
@ Full
Full tail iterations.
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
BlockGemmPipelineScheduler
Block GEMM pipeline scheduler enumeration.
Definition: scheduler_enum.hpp:33
@ Intrawave
Schedule within a single wavefront.
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:301
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:36
Definition: blockwise_gemm_pipeline_wmmaops.hpp:26
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, AScaleStruct &, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:357
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, AScaleStruct &a_scale_struct, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:640
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:39
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:11