/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp Source File
gridwise_gemm_pipeline_v3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
11 {
12  __host__ __device__ static constexpr bool IsSupported(index_t)
13  {
14  // TODO: improve applicability
15  return true;
16  }
17 
18  template <typename AGridDesc,
19  typename ABlockDesc,
20  typename ABlockTransfer,
21  typename AGridBuffer,
22  typename ABlockBuffer,
23  typename ABlockTransferStep,
24  typename BGridDesc,
25  typename BBlockDesc,
26  typename BBlockTransfer,
27  typename BGridBuffer,
28  typename BBlockBuffer,
29  typename BBlockTransferStep,
30  typename BlockwiseGemm,
31  typename CThreadBuffer>
32  __device__ static void Run(const AGridDesc& a_grid_desc,
33  const ABlockDesc& a_block_desc,
34  ABlockTransfer& a_blockwise_copy,
35  const AGridBuffer& a_grid_buf,
36  ABlockBuffer& a_block_buf,
37  const ABlockTransferStep& a_block_copy_step,
38  const BGridDesc& b_grid_desc,
39  const BBlockDesc& b_block_desc,
40  BBlockTransfer& b_blockwise_copy,
41  const BGridBuffer& b_grid_buf,
42  BBlockBuffer& b_block_buf,
43  const BBlockTransferStep& b_block_copy_step,
44  const BlockwiseGemm& blockwise_gemm,
45  CThreadBuffer& c_thread_buf,
46  index_t num_loop)
47  {
48  // global read 0
49  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
50  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
51 
52  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
53  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
54 
55  // Initialize C
56  c_thread_buf.Clear();
57 
58  // LDS write 0
59  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
60  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
61 
62  num_loop--;
63 
64  while(num_loop > 0)
65  {
66  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
68  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
69 
70  blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
71 
73 
74  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
75  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
76  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
77  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
78 
79  num_loop--;
80  }
81  // tail
82  {
84  blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
85  }
86  }
87 };
88 
89 } // namespace ck
Definition: ck.hpp:267
int32_t index_t
Definition: ck.hpp:298
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: gridwise_gemm_pipeline_v3.hpp:11
static __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, const BlockwiseGemm &blockwise_gemm, CThreadBuffer &c_thread_buf, index_t num_loop)
Definition: gridwise_gemm_pipeline_v3.hpp:32
__host__ static constexpr __device__ bool IsSupported(index_t)
Definition: gridwise_gemm_pipeline_v3.hpp:12