/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp Source File
gridwise_gemm_pipeline_v1.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
9 
10 namespace ck {
11 
12 template <index_t NumPrefetch, bool AEnableLds, bool BEnableLds>
14 
15 // 1-stage prefetch
16 template <>
17 struct GridwiseGemmPipeline_v1<1, true, true>
18 {
19  static constexpr auto I0 = Number<0>{};
20  static constexpr auto I1 = Number<1>{};
21 
22  __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
23 
24  __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
25  {
26  return num_loop > 1;
27  }
28 
29  template <bool HasMainLoop,
30  typename AGridDesc,
31  typename ABlockDesc,
32  typename ABlockTransfer,
33  typename AGridBuffer,
34  typename ABlockBuffer,
35  typename ABlockTransferStep,
36  typename BGridDesc,
37  typename BBlockDesc,
38  typename BBlockTransfer,
39  typename BGridBuffer,
40  typename BBlockBuffer,
41  typename BBlockTransferStep,
42  typename BlockwiseGemm,
43  typename CThreadBuffer>
44  __device__ static void Run(const AGridDesc& a_grid_desc,
45  const ABlockDesc& a_block_desc,
46  ABlockTransfer& a_blockwise_copy,
47  const AGridBuffer& a_grid_buf,
48  ABlockBuffer& a_block_buf,
49  const ABlockTransferStep& a_block_copy_step,
50  const BGridDesc& b_grid_desc,
51  const BBlockDesc& b_block_desc,
52  BBlockTransfer& b_blockwise_copy,
53  const BGridBuffer& b_grid_buf,
54  BBlockBuffer& b_block_buf,
55  const BBlockTransferStep& b_block_copy_step,
56  const BlockwiseGemm& blockwise_gemm,
57  CThreadBuffer& c_thread_buf,
58  index_t num_loop)
59  {
60  // preload data into LDS
61  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
62  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
63 
64  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
65  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
66 
67  // Initialize C
68  c_thread_buf.Clear();
69 
70  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
71  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
72 
73  // main body
74  if constexpr(HasMainLoop)
75  {
76  index_t i = 0;
77 
78  do
79  {
80  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
81 
83 
84  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
85 
86  blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
87 
89 
90  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
91  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
92 
93  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
94  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
95 
96  ++i;
97  } while(i < (num_loop - 1));
98  }
99 
100  // tail
101  {
102  block_sync_lds();
103 
104  blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
105  }
106  }
107 };
108 
109 // 2-stage prefetch
110 template <>
111 struct GridwiseGemmPipeline_v1<2, true, true>
112 {
113  static constexpr auto I0 = Number<0>{};
114  static constexpr auto I1 = Number<1>{};
115 
116  __host__ __device__ static constexpr bool IsSupported(index_t num_loop)
117  {
118  // TODO: improve applicability
119  return num_loop % 2 == 0;
120  }
121 
122  __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
123  {
124  return (num_loop / 2) > 1;
125  }
126 
127  template <bool HasMainLoop,
128  typename AGridDesc,
129  typename ABlockDesc,
130  typename ABlockTransfer,
131  typename AGridBuffer,
132  typename ABlockBuffer,
133  typename ABlockTransferStep,
134  typename BGridDesc,
135  typename BBlockDesc,
136  typename BBlockTransfer,
137  typename BGridBuffer,
138  typename BBlockBuffer,
139  typename BBlockTransferStep,
140  typename BlockwiseGemm,
141  typename CThreadBuffer>
142  static __device__ void Run(const AGridDesc& a_grid_desc,
143  const ABlockDesc& a_block_desc,
144  ABlockTransfer& a_blockwise_copy,
145  const AGridBuffer& a_grid_buf,
146  ABlockBuffer& a_block_buf,
147  const ABlockTransferStep& a_block_copy_step,
148  const BGridDesc& b_grid_desc,
149  const BBlockDesc& b_block_desc,
150  BBlockTransfer& b_blockwise_copy,
151  const BGridBuffer& b_grid_buf,
152  BBlockBuffer& b_block_buf,
153  const BBlockTransferStep& b_block_copy_step,
154  const BlockwiseGemm& blockwise_gemm,
155  CThreadBuffer& c_thread_buf,
156  index_t num_loop)
157  {
158  // preload data into LDS
159  {
160  // Read 0
161  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
162  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
163 
164  // Move
165  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
166  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
167 
168  // Read 1
169  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
170  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
171  }
172 
173  // Initialize C
174  c_thread_buf.Clear();
175 
176  // main body
177  if constexpr(HasMainLoop)
178  {
179  index_t i = 0;
180 
181  do
182  {
183  // Move
184  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
185  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
186 
187  // Write i
188  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
189  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
190 
191  // Read i+2
192  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
193  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
194 
195  // Sync
196  block_sync_lds();
197 
198  // Gemm i
199  blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
200 
201  // Sync
202  block_sync_lds();
203 
204  // Move
205  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
206  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
207 
208  // Write i+1
209  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1);
210  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1);
211 
212  // Read i+3
213  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
214  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
215 
216  // Sync
217  block_sync_lds();
218 
219  // Gemm i+1
220  blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
221 
222  // Sync
223  block_sync_lds();
224 
225  i += 2;
226  } while(i < (num_loop - 2));
227  }
228 
229  // tail
230  {
231  // Write num_loop - 2
232  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
233  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
234 
235  // Sync
236  block_sync_lds();
237 
238  // Gemm num_loop - 2
239  blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
240 
241  // Sync
242  block_sync_lds();
243 
244  // Write num_loop - 1
245  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1);
246  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1);
247 
248  // Sync
249  block_sync_lds();
250 
251  // Gemm num_loop - 1
252  blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
253  }
254  }
255 };
256 
257 template <>
258 struct GridwiseGemmPipeline_v1<1, false, true>
259 {
260  static constexpr auto I0 = Number<0>{};
261  static constexpr auto I1 = Number<1>{};
262 
263  __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
264 
265  __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
266  {
267  return num_loop > 1;
268  }
269 
270  template <bool HasMainLoop,
271  typename AGridDesc,
272  typename ABlockDesc,
273  typename ABlockTransfer,
274  typename AGridBuffer,
275  typename ABlockBuffer,
276  typename ABlockTransferStep,
277  typename BGridDesc,
278  typename BBlockDesc,
279  typename BBlockTransfer,
280  typename BGridBuffer,
281  typename BBlockBuffer,
282  typename BBlockTransferStep,
283  typename BlockwiseGemm,
284  typename CThreadBuffer>
285  __device__ static void Run(const AGridDesc& a_grid_desc,
286  const ABlockDesc& a_block_desc,
287  ABlockTransfer& a_blockwise_copy,
288  const AGridBuffer& a_grid_buf,
289  ABlockBuffer& a_block_buf,
290  const ABlockTransferStep& a_block_copy_step,
291  const BGridDesc& b_grid_desc,
292  const BBlockDesc& b_block_desc,
293  BBlockTransfer& b_blockwise_copy,
294  const BGridBuffer& b_grid_buf,
295  BBlockBuffer& b_block_buf,
296  const BBlockTransferStep& b_block_copy_step,
297  const BlockwiseGemm& blockwise_gemm,
298  CThreadBuffer& c_thread_buf,
299  index_t num_loop)
300  {
301  constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
302  auto a_block_buf_switch = a_block_buf;
303 
304  // preload data into LDS
305  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
306  a_blockwise_copy.Run(
307  a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf);
308 
309  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
310  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
311 
312  // Initialize C
313  c_thread_buf.Clear();
314 
315  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
316 
317  // main body
318  if constexpr(HasMainLoop)
319  {
320  index_t i = 0;
321 
322  do
323  {
324  a_blockwise_copy.Run(
325  a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf_switch);
326 
327  block_sync_lds();
328 
329  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
330 
331  blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
332 
333  block_sync_lds();
334 
335  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
336  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
337 
338  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
339 
340  a_block_buf = a_block_buf_switch;
341  ++i;
342  } while(i < (num_loop - 1));
343  }
344 
345  // tail
346  {
347  block_sync_lds();
348 
349  blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
350 
351  block_sync_lds();
352  }
353  }
354 };
355 
356 template <>
357 struct GridwiseGemmPipeline_v1<1, true, false>
358 {
359  static constexpr auto I0 = Number<0>{};
360  static constexpr auto I1 = Number<1>{};
361 
362  __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
363 
364  __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
365  {
366  return num_loop > 1;
367  }
368 
369  template <bool HasMainLoop,
370  typename AGridDesc,
371  typename ABlockDesc,
372  typename ABlockTransfer,
373  typename AGridBuffer,
374  typename ABlockBuffer,
375  typename ABlockTransferStep,
376  typename BGridDesc,
377  typename BBlockDesc,
378  typename BBlockTransfer,
379  typename BGridBuffer,
380  typename BBlockBuffer,
381  typename BBlockTransferStep,
382  typename BlockwiseGemm,
383  typename CThreadBuffer>
384  __device__ static void Run(const AGridDesc& a_grid_desc,
385  const ABlockDesc& a_block_desc,
386  ABlockTransfer& a_blockwise_copy,
387  const AGridBuffer& a_grid_buf,
388  ABlockBuffer& a_block_buf,
389  const ABlockTransferStep& a_block_copy_step,
390  const BGridDesc& b_grid_desc,
391  const BBlockDesc& b_block_desc,
392  BBlockTransfer& b_blockwise_copy,
393  const BGridBuffer& b_grid_buf,
394  BBlockBuffer& b_block_buf,
395  const BBlockTransferStep& b_block_copy_step,
396  const BlockwiseGemm& blockwise_gemm,
397  CThreadBuffer& c_thread_buf,
398  index_t num_loop)
399  {
400  constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
401  auto b_block_buf_switch = b_block_buf;
402 
403  // preload data into LDS
404  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
405  b_blockwise_copy.Run(
406  b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf);
407 
408  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
409  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
410 
411  // Initialize C
412  c_thread_buf.Clear();
413 
414  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
415 
416  // main body
417  if constexpr(HasMainLoop)
418  {
419  index_t i = 0;
420 
421  do
422  {
423  b_blockwise_copy.Run(
424  b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf_switch);
425 
426  block_sync_lds();
427 
428  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
429 
430  blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
431 
432  block_sync_lds();
433 
434  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
435  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
436 
437  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
438 
439  b_block_buf = b_block_buf_switch;
440  ++i;
441  } while(i < (num_loop - 1));
442  }
443 
444  // tail
445  {
446  block_sync_lds();
447 
448  blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
449 
450  block_sync_lds();
451  }
452  }
453 };
454 
455 template <>
456 struct GridwiseGemmPipeline_v1<1, false, false>
457 {
458  static constexpr auto I0 = Number<0>{};
459  static constexpr auto I1 = Number<1>{};
460 
461  __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
462 
463  __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
464  {
465  return num_loop > 1;
466  }
467 
468  template <bool HasMainLoop,
469  typename AGridDesc,
470  typename ABlockDesc,
471  typename ABlockTransfer,
472  typename AGridBuffer,
473  typename ABlockBuffer,
474  typename ABlockTransferStep,
475  typename BGridDesc,
476  typename BBlockDesc,
477  typename BBlockTransfer,
478  typename BGridBuffer,
479  typename BBlockBuffer,
480  typename BBlockTransferStep,
481  typename BlockwiseGemm,
482  typename CThreadBuffer>
483  __device__ static void Run(const AGridDesc& a_grid_desc,
484  const ABlockDesc& a_block_desc,
485  ABlockTransfer& a_blockwise_copy,
486  const AGridBuffer& a_grid_buf,
487  ABlockBuffer& a_block_buf,
488  const ABlockTransferStep& a_block_copy_step,
489  const BGridDesc& b_grid_desc,
490  const BBlockDesc& b_block_desc,
491  BBlockTransfer& b_blockwise_copy,
492  const BGridBuffer& b_grid_buf,
493  BBlockBuffer& b_block_buf,
494  const BBlockTransferStep& b_block_copy_step,
495  const BlockwiseGemm& blockwise_gemm,
496  CThreadBuffer& c_thread_buf,
497  index_t num_loop)
498  {
499  constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
500  constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
501  auto b_block_buf_switch = b_block_buf;
502  auto a_block_buf_switch = a_block_buf;
503 
504  // preload data into LDS
505  a_blockwise_copy.Run(
506  a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf);
507  b_blockwise_copy.Run(
508  b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf);
509 
510  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
511  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
512 
513  // Initialize C
514  c_thread_buf.Clear();
515 
516  // main body
517  if constexpr(HasMainLoop)
518  {
519  index_t i = 0;
520 
521  do
522  {
523  a_blockwise_copy.Run(
524  a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf_switch);
525  b_blockwise_copy.Run(
526  b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf_switch);
527 
528  block_sync_lds();
529 
530  blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
531 
532  block_sync_lds();
533 
534  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
535  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
536 
537  a_block_buf = a_block_buf_switch;
538  b_block_buf = b_block_buf_switch;
539  ++i;
540  } while(i < (num_loop - 1));
541  }
542 
543  // tail
544  {
545  block_sync_lds();
546 
547  blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
548 
549  block_sync_lds();
550  }
551  }
552 };
553 
554 template <index_t NumPrefetch, bool AEnableLds, bool BEnableLds>
556 
557 template <>
559 {
560  static constexpr auto I0 = Number<0>{};
561  static constexpr auto I1 = Number<1>{};
562 
563  __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
564 
565  __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
566  {
567  return num_loop > 1;
568  }
569 
570  template <bool HasMainLoop,
571  typename AGridDesc,
572  typename ABlockDesc,
573  typename ABlockTransfer,
574  typename AGridBuffer,
575  typename ABlockBuffer,
576  typename ABlockTransferStep,
577  typename BGridDesc,
578  typename BBlockDesc,
579  typename BBlockTransfer,
580  typename BGridBuffer,
581  typename BBlockBuffer,
582  typename BBlockTransferStep,
583  typename ScaleGridDesc,
584  typename ScaleGridBuffer,
585  typename BlockwiseGemm,
586  typename CThreadBuffer>
587  __device__ static void Run(const AGridDesc& a_grid_desc,
588  const ABlockDesc& a_block_desc,
589  ABlockTransfer& a_blockwise_copy,
590  const AGridBuffer& a_grid_buf,
591  ABlockBuffer& a_block_buf,
592  const ABlockTransferStep& a_block_copy_step,
593  const BGridDesc& b_grid_desc,
594  const BBlockDesc& b_block_desc,
595  BBlockTransfer& b_blockwise_copy,
596  const BGridBuffer& b_grid_buf,
597  BBlockBuffer& b_block_buf,
598  const BBlockTransferStep& b_block_copy_step,
599  const ScaleGridDesc& scale_grid_desc,
600  const ScaleGridBuffer& scale_grid_buf,
601  const BlockwiseGemm& blockwise_gemm,
602  CThreadBuffer& c_thread_buf,
603  index_t num_loop)
604  {
605  // Global Prefetch Stage 1
606  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
607  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
608  // Scale read once
609  b_blockwise_copy.RunScaleRead(scale_grid_desc, scale_grid_buf);
610 
611  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
612  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
613 
614  // Initialize C
615  c_thread_buf.Clear();
616 
617  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
618  // Dequantization fused in blockwise_copy
619  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
620 
621  // main body
622  if constexpr(HasMainLoop)
623  {
624  index_t i = 0;
625 
626  do
627  {
628  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
629 
630  block_sync_lds();
631 
632  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
633 
634  blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
635 
636  block_sync_lds();
637 
638  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
639  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
640 
641  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
642  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
643 
644  ++i;
645  } while(i < (num_loop - 1));
646  }
647 
648  // tail
649  {
650  block_sync_lds();
651 
652  blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
653  }
654  }
655 };
656 
657 template <index_t NumPrefetch>
659 
660 template <>
662 {
663  __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
664 
665  __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
666  {
667  return num_loop > 1;
668  }
669 
670  template <bool HasMainLoop,
671  typename AGridDesc,
672  typename ABlockDesc,
673  typename ABlockTransfer,
674  typename AGridBuffer,
675  typename ABlockBuffer,
676  typename ABlockTransferStep,
677  typename BGridDesc,
678  typename BBlockDesc,
679  typename BBlockTransfer,
680  typename BGridBuffer,
681  typename BBlockBuffer,
682  typename BBlockTransferStep,
683  typename BlockwiseGemm,
684  typename CThreadBuffer>
685  static __device__ void Run(const AGridDesc& a_grid_desc,
686  const ABlockDesc& a_block_desc,
687  ABlockTransfer& a_blockwise_copy,
688  const AGridBuffer& a_grid_buf,
689  ABlockBuffer& a_block_buf,
690  const ABlockTransferStep& a_block_copy_step,
691  const BGridDesc& b_grid_desc,
692  const BBlockDesc& b_block_desc,
693  BBlockTransfer& b_blockwise_copy,
694  const BGridBuffer& b_grid_buf,
695  BBlockBuffer& b_block_buf,
696  const BBlockTransferStep& b_block_copy_step,
697  const BlockwiseGemm& blockwise_gemm,
698  CThreadBuffer& c_thread_buf,
699  index_t num_loop)
700  {
701  // preload data into LDS
702  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
703  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
704 
705  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
706  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
707 
708  // Initialize C
709  c_thread_buf.Clear();
710 
711  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
712  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
713 
714  // main body
715  if constexpr(HasMainLoop)
716  {
717  index_t i = 0;
718 
719  do
720  {
721  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
722 
723  block_sync_lds();
724 
725  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
726 
727  blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
728 
729  // block_sync_lds(); // moved into blockwise_gemm
730 
731  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
732  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
733 
734  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
735  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
736 
737  ++i;
738  } while(i < (num_loop - 1));
739  }
740 
741  // tail
742  {
743  block_sync_lds();
744 
745  blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
746  }
747  }
748 };
749 
750 // Note: 2 stage prefetch not optimized for inter-wave loop scheduler
751 template <>
753 {
754 };
755 
756 // TODO: deprecate as GridwiseGemmPipeline_Selector covers the functionality
757 template <index_t NumPrefetch, LoopScheduler LoopSched>
759 {
760  if constexpr(LoopSched == LoopScheduler::Default)
761  {
763  }
764  else if constexpr(LoopSched == LoopScheduler::Interwave)
765  {
767  }
768 }
769 
770 } // namespace ck
Definition: ck.hpp:267
constexpr auto GridwiseGemmPipeline_v1_Selector()
Definition: gridwise_gemm_pipeline_v1.hpp:758
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
static __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, const BlockwiseGemm &blockwise_gemm, CThreadBuffer &c_thread_buf, index_t num_loop)
Definition: gridwise_gemm_pipeline_v1.hpp:483
__host__ static constexpr __device__ bool CalculateHasMainLoop(index_t num_loop)
Definition: gridwise_gemm_pipeline_v1.hpp:463
__host__ static constexpr __device__ bool IsSupported(index_t)
Definition: gridwise_gemm_pipeline_v1.hpp:461
__host__ static constexpr __device__ bool CalculateHasMainLoop(index_t num_loop)
Definition: gridwise_gemm_pipeline_v1.hpp:265
__host__ static constexpr __device__ bool IsSupported(index_t)
Definition: gridwise_gemm_pipeline_v1.hpp:263
static __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, const BlockwiseGemm &blockwise_gemm, CThreadBuffer &c_thread_buf, index_t num_loop)
Definition: gridwise_gemm_pipeline_v1.hpp:285
__host__ static constexpr __device__ bool IsSupported(index_t)
Definition: gridwise_gemm_pipeline_v1.hpp:362
static __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, const BlockwiseGemm &blockwise_gemm, CThreadBuffer &c_thread_buf, index_t num_loop)
Definition: gridwise_gemm_pipeline_v1.hpp:384
__host__ static constexpr __device__ bool CalculateHasMainLoop(index_t num_loop)
Definition: gridwise_gemm_pipeline_v1.hpp:364
static __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, const BlockwiseGemm &blockwise_gemm, CThreadBuffer &c_thread_buf, index_t num_loop)
Definition: gridwise_gemm_pipeline_v1.hpp:44
__host__ static constexpr __device__ bool IsSupported(index_t)
Definition: gridwise_gemm_pipeline_v1.hpp:22
__host__ static constexpr __device__ bool CalculateHasMainLoop(index_t num_loop)
Definition: gridwise_gemm_pipeline_v1.hpp:24
static __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, const BlockwiseGemm &blockwise_gemm, CThreadBuffer &c_thread_buf, index_t num_loop)
Definition: gridwise_gemm_pipeline_v1.hpp:142
__host__ static constexpr __device__ bool CalculateHasMainLoop(index_t num_loop)
Definition: gridwise_gemm_pipeline_v1.hpp:122
__host__ static constexpr __device__ bool IsSupported(index_t num_loop)
Definition: gridwise_gemm_pipeline_v1.hpp:116
__host__ static constexpr __device__ bool IsSupported(index_t)
Definition: gridwise_gemm_pipeline_v1.hpp:563
__host__ static constexpr __device__ bool CalculateHasMainLoop(index_t num_loop)
Definition: gridwise_gemm_pipeline_v1.hpp:565
static __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, const ScaleGridDesc &scale_grid_desc, const ScaleGridBuffer &scale_grid_buf, const BlockwiseGemm &blockwise_gemm, CThreadBuffer &c_thread_buf, index_t num_loop)
Definition: gridwise_gemm_pipeline_v1.hpp:587
Definition: gridwise_gemm_pipeline_v1.hpp:555
Definition: gridwise_gemm_pipeline_v1.hpp:13
__host__ static constexpr __device__ bool IsSupported(index_t)
Definition: gridwise_gemm_pipeline_v1.hpp:663
__host__ static constexpr __device__ bool CalculateHasMainLoop(index_t num_loop)
Definition: gridwise_gemm_pipeline_v1.hpp:665
static __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, const BlockwiseGemm &blockwise_gemm, CThreadBuffer &c_thread_buf, index_t num_loop)
Definition: gridwise_gemm_pipeline_v1.hpp:685
Definition: gridwise_gemm_pipeline_v1.hpp:658
Definition: integral_constant.hpp:20