10 template <
typename TileLoadThreadGroup, index_t NumGemmKPrefetchStage>
14 template <
typename TileLoadThreadGroup>
28 template <
bool HasMainLoop,
31 typename ABlockTransfer,
33 typename ABlockBuffer,
34 typename ABlockTransferStep,
37 typename BBlockTransfer,
39 typename BBlockBuffer,
40 typename BBlockTransferStep>
42 const ABlockDesc& a_block_desc,
43 ABlockTransfer& a_blockwise_copy,
44 const AGridBuffer& a_grid_buf,
45 ABlockBuffer& a_block_buf,
46 const ABlockTransferStep& a_block_copy_step,
47 const BGridDesc& b_grid_desc,
48 const BBlockDesc& b_block_desc,
49 BBlockTransfer& b_blockwise_copy,
50 const BGridBuffer& b_grid_buf,
51 BBlockBuffer& b_block_buf,
52 const BBlockTransferStep& b_block_copy_step,
56 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
57 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
60 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
61 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
64 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
65 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
67 if constexpr(HasMainLoop)
76 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
77 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
80 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
81 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
87 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
88 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
91 }
while(i < (num_loop - 1));
102 template <
typename TileMathThreadGroup, index_t NumGemmKPrefetchStage>
105 template <
typename TileMathThreadGroup>
116 template <
bool HasMainLoop,
117 typename ABlockBuffer,
118 typename BBlockBuffer,
119 typename BlockwiseGemm,
120 typename CThreadBuffer>
122 BBlockBuffer& b_block_buf,
123 const BlockwiseGemm& block_gemm,
124 CThreadBuffer& c_thread_buf,
128 c_thread_buf.Clear();
131 if constexpr(HasMainLoop)
140 block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
144 }
while(i < (num_loop - 1));
152 block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
int32_t index_t
Definition: ck.hpp:289
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
static __device__ void RunLoadWavePipeline(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, index_t num_loop)
Definition: gridwise_gemm_waveletmodel.hpp:41
__host__ static constexpr __device__ bool IsSupported(index_t)
Definition: gridwise_gemm_waveletmodel.hpp:17
__host__ static constexpr __device__ bool CalculateHasMainLoop(index_t num_loop)
Definition: gridwise_gemm_waveletmodel.hpp:23
Definition: gridwise_gemm_waveletmodel.hpp:11
static __device__ void RunMathWavePipeline(ABlockBuffer &a_block_buf, BBlockBuffer &b_block_buf, const BlockwiseGemm &block_gemm, CThreadBuffer &c_thread_buf, index_t num_loop)
Definition: gridwise_gemm_waveletmodel.hpp:121
__host__ static constexpr __device__ bool IsSupported(index_t)
Definition: gridwise_gemm_waveletmodel.hpp:109
__host__ static constexpr __device__ bool CalculateHasMainLoop(index_t num_loop)
Definition: gridwise_gemm_waveletmodel.hpp:111
Definition: gridwise_gemm_waveletmodel.hpp:103