/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_default_policy.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_default_policy.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_default_policy.hpp Source File
smoothquant_pipeline_default_policy.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
13 {
14  template <typename Problem>
16  {
17  using S = typename Problem::BlockShape;
18 
21  sequence<>,
28  }
29 
30  template <typename Problem>
32  {
33  using S = typename Problem::BlockShape;
34 
42  sequence<0, 3>>{});
43  }
44 
45  template <typename Problem>
46  CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d()
47  {
48  using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
49  typename Problem::ComputeDataType,
50  typename Problem::BlockShape>;
51  return BlockReduce2d<P_>{};
52  }
53 
54  template <typename Problem>
55  CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync()
56  {
57  using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
58  typename Problem::ComputeDataType,
59  typename Problem::BlockShape>;
60  return BlockReduce2dSync<P_>{};
61  }
62 
63  template <typename Problem>
65  {
66  using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
67  typename Problem::ComputeDataType,
68  typename Problem::BlockShape>;
70  }
71 
72  template <typename Problem>
74  {
75  if constexpr(Problem::kNeedCrossWarpSync)
76  {
77  using P_ = BlockReduce2dProblem<typename Problem::XDataType,
78  typename Problem::ComputeDataType,
79  typename Problem::BlockShape>;
80 
81  using block_reduce2d = BlockReduce2d<P_>;
82  using x_block_tile =
83  decltype(make_static_distributed_tensor<typename Problem::XDataType>(
84  MakeXBlockTileDistribution<Problem>()));
85  using y_block_tile = decltype(block_reduce2d::template MakeYBlockTile<x_block_tile>());
86 
87  return GetBlockReduce2dCrossWarpSync<Problem>().template GetSmemSize<y_block_tile>();
88  }
89  else
90  {
91  return 1; // zero size arrays are an extension
92  }
93  }
94 };
95 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
Definition: block_reduce2d.hpp:200
Definition: block_reduce2d.hpp:45
Definition: block_reduce2d_problem.hpp:12
Definition: block_reduce2d.hpp:135
Definition: smoothquant_pipeline_default_policy.hpp:13
static constexpr CK_TILE_HOST_DEVICE auto GetBlockReduce2d()
Definition: smoothquant_pipeline_default_policy.hpp:46
static constexpr CK_TILE_HOST_DEVICE auto GetBlockReduce2dCrossWarpSync()
Definition: smoothquant_pipeline_default_policy.hpp:64
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: smoothquant_pipeline_default_policy.hpp:73
static constexpr CK_TILE_DEVICE auto MakeXBlockTileDistribution()
Definition: smoothquant_pipeline_default_policy.hpp:15
static constexpr CK_TILE_DEVICE auto MakeSmoothScaleBlockTileDistribution()
Definition: smoothquant_pipeline_default_policy.hpp:31
static constexpr CK_TILE_HOST_DEVICE auto GetBlockReduce2dSync()
Definition: smoothquant_pipeline_default_policy.hpp:55
Definition: sequence.hpp:49
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192