/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp Source File
device_moe_mx_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
19 
20 namespace ck {
21 namespace tensor_operation {
22 namespace device {
23 
24 template <typename ALayout,
25  typename BLayout,
26  typename DsLayout,
27  typename CLayout,
28  typename ADataType,
29  typename AScaleDataType,
30  typename BDataType,
31  typename BScaleDataType,
32  typename DsDataType,
33  typename CDataType,
34  typename GemmAccDataType,
35  typename CShuffleDataType,
36  typename AElementwiseOperation,
37  typename BElementwiseOperation,
38  typename CElementwiseOperation,
39  GemmSpecialization GemmSpec,
40  index_t ScaleBlockSize,
41  index_t BlockSize,
42  index_t MPerBlock,
43  index_t NPerBlock,
44  index_t KPerBlock,
45  index_t AK1,
46  index_t BK1,
47  index_t MPerXDL,
48  index_t NPerXDL,
49  index_t MXdlPerWave,
50  index_t NXdlPerWave,
51  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
52  typename ABlockTransferThreadClusterArrangeOrder,
53  typename ABlockTransferSrcAccessOrder,
54  index_t ABlockTransferSrcVectorDim,
55  index_t ABlockTransferSrcScalarPerVector,
56  index_t ABlockTransferDstScalarPerVector_AK1,
57  bool ABlockLdsExtraM,
58  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
59  typename BBlockTransferThreadClusterArrangeOrder,
60  typename BBlockTransferSrcAccessOrder,
61  index_t BBlockTransferSrcVectorDim,
62  index_t BBlockTransferSrcScalarPerVector,
63  index_t BBlockTransferDstScalarPerVector_BK1,
64  bool BBlockLdsExtraN,
65  index_t CShuffleMXdlPerWavePerShuffle,
66  index_t CShuffleNXdlPerWavePerShuffle,
67  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
68  typename CDEShuffleBlockTransferScalarPerVectors,
71  index_t ActivationOP = 0,
72  bool NSwizzle = false,
73  bool IsInputGemm = true,
74  bool MulRoutedWeight = true,
75  typename IndexType = index_t,
76  typename ComputeTypeA = ADataType,
77  typename ComputeTypeB = BDataType>
79  BLayout,
80  DsLayout,
81  CLayout,
82  ADataType,
83  AScaleDataType,
84  BDataType,
85  BScaleDataType,
86  DsDataType,
87  CDataType,
88  ScaleBlockSize,
89  AElementwiseOperation,
90  BElementwiseOperation,
91  CElementwiseOperation>
92 {
94  static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
95  static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
96  static constexpr index_t NumDTensor = DsDataType::Size();
97  template <index_t NXdlPerWave_>
99  GridwiseMoeGemmMX<ALayout,
100  BLayout,
101  DsLayout,
102  CLayout,
103  ADataType,
104  AScaleDataType,
105  BDataType,
106  BScaleDataType,
107  GemmAccDataType,
108  CShuffleDataType,
109  DsDataType,
110  CDataType,
111  AElementwiseOperation,
112  BElementwiseOperation,
113  CElementwiseOperation,
114  GemmSpec,
115  ScaleBlockSize,
116  BlockSize,
117  MPerBlock,
118  NPerBlock,
119  KPerBlock,
120  AK1,
121  BK1,
122  MPerXDL,
123  NPerXDL,
124  MXdlPerWave,
125  NXdlPerWave_,
126  ABlockTransferThreadClusterLengths_AK0_M_AK1,
127  ABlockTransferThreadClusterArrangeOrder,
128  ABlockTransferSrcAccessOrder,
129  ABlockTransferSrcVectorDim,
130  ABlockTransferSrcScalarPerVector,
131  ABlockTransferDstScalarPerVector_AK1,
132  false,
133  ABlockLdsExtraM,
134  BBlockTransferThreadClusterLengths_BK0_N_BK1,
135  BBlockTransferThreadClusterArrangeOrder,
136  BBlockTransferSrcAccessOrder,
137  BBlockTransferSrcVectorDim,
138  BBlockTransferSrcScalarPerVector,
139  BBlockTransferDstScalarPerVector_BK1,
140  false,
141  BBlockLdsExtraN,
142  CShuffleMXdlPerWavePerShuffle,
143  CShuffleNXdlPerWavePerShuffle,
144  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
145  CDEShuffleBlockTransferScalarPerVectors,
146  BlkGemmPipeSched,
147  BlkGemmPipelineVer,
148  ActivationOP,
149  NSwizzle,
150  IsInputGemm,
151  MulRoutedWeight,
152  IndexType,
153  ComputeTypeA,
154  ComputeTypeB>;
157 
159  static constexpr index_t APackedSize = packed_size_v<ADataType>;
160  static constexpr index_t BPackedSize = packed_size_v<BDataType>;
161 
162  int GetPreShuffleParameters() override { return NPerXDL; }
163 
164  // Invoker
165  struct Invoker : public BaseInvoker
166  {
167  template <typename GridwiseGemm>
168  float RunImp(const typename GridwiseGemm::Argument& arg,
169  const StreamConfig& stream_config = StreamConfig{})
170  {
171  if(stream_config.log_level_ > 0)
172  {
173  arg.Print();
174  }
175 
176  if(!GridwiseGemm::CheckValidity(arg))
177  {
178  throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
179  }
180 
181  index_t gdx, gdy, gdz;
182  std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
183 
184  float ave_time = 0;
185 
186  index_t k_grain = arg.KBatch * KPerBlock;
187  index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
188 
189  const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
190 
191  const auto RunKernel = [&](const auto& kernel) {
192  if(stream_config.flush_cache)
193  {
194 
195  std::array<std::size_t, NumDTensor> DsSize;
196 
197  auto arg_ = arg;
198 
199  const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
200  arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
201  const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
202  arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
203 
204  auto size_a_buffer =
205  a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
206  auto size_b_buffer =
207  b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
208 
209  const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
210  arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
211 
212  static_for<0, NumDTensor, 1>{}([&](auto i) {
213  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
214  DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
215  });
216  ck::utility::RotatingMemWrapperMultiD<typename GridwiseGemm::Argument,
217  DsDataType>
218  rotating_mem(arg_,
219  stream_config.rotating_count,
220  size_a_buffer,
221  size_b_buffer,
222  DsSize);
223  rotating_mem.Print();
224 
225  auto run_flush_cache = [&]() {
226  // flush icache
228  // rotating mem
229  rotating_mem.Next();
230  // clear c mem
231  if(arg_.KBatch > 1)
232  hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
233  0,
234  arg_.M * arg_.N * sizeof(CDataType),
235  stream_config.stream_id_));
236  };
237 
238  ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
239  stream_config,
240  run_flush_cache,
241  kernel,
242  dim3(gdx, gdy, gdz),
243  dim3(BlockSize),
244  0,
245  arg_);
246  }
247  else
248  {
249  if(arg.KBatch > 1)
250  hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
251  0,
252  arg.M * arg.N * sizeof(CDataType),
253  stream_config.stream_id_));
254 
255  ave_time = launch_and_time_kernel(
256  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
257  }
258  };
259 
260  // TODO: Check if this is the right algorithm for minimum_occupancy
261  constexpr index_t minimum_occupancy =
262  BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave
263  ? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
264  MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2)
265  ? 2
266  : 1
267  : 2;
268 
269  constexpr auto MemoryDataOp =
271 
272  if(has_main_k_block_loop)
273  {
274  // Tail number always full
275  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
276  {
277  const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
278  true,
279  MemoryDataOp,
280  minimum_occupancy,
282  RunKernel(kernel);
283  }
284  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
285  {
286  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
287  {
288  const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
289  true,
290  MemoryDataOp,
291  minimum_occupancy,
293  RunKernel(kernel);
294  }
295  else
296  {
297  const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
298  true,
299  MemoryDataOp,
300  minimum_occupancy,
302  RunKernel(kernel);
303  }
304  }
305  else
306  {
307  throw std::runtime_error("todo: only v1 & v3 support now");
308  }
309  }
310  else
311  {
312  // Tail number always full
313  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
314  {
315  const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
316  false,
317  MemoryDataOp,
318  minimum_occupancy,
320  RunKernel(kernel);
321  }
322  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
323  {
324  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
325  {
326  const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
327  false,
328  MemoryDataOp,
329  minimum_occupancy,
331  RunKernel(kernel);
332  }
333  else
334  {
335  const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
336  false,
337  MemoryDataOp,
338  minimum_occupancy,
340  RunKernel(kernel);
341  }
342  }
343  }
344 
345  return ave_time;
346  }
347 
349 
350  // polymorphic
351  float Run(const BaseArgument* p_arg,
352  const StreamConfig& stream_config = StreamConfig{}) override
353  {
354  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
355  }
356  };
357 
358  static constexpr bool IsValidCompilationParameter()
359  {
360  // TODO: properly implement this check
361  return true;
362  }
363 
364  static bool IsSupportedArgument(const Argument& arg)
365  {
366  // only impl kbatch 1 now
367  if(arg.KBatch > 1)
368  {
369  return false;
370  }
371  if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
372  {
373  return false;
374  }
375  if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
376  {
377  return false;
378  }
379 
380  if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
381  GemmSpec == GemmSpecialization::NKPadding ||
382  GemmSpec == GemmSpecialization::MNKPadding ||
383  GemmSpec == GemmSpecialization::KPadding))
384  {
385  return false;
386  }
387  if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
388  {
389  return false;
390  }
391 
392  if(get_warp_size() == 64)
393  {
394  if constexpr(NXdlPerWave64 > 0)
395  {
396  return GridwiseGemm64::CheckValidity(arg);
397  }
398  }
399  else
400  {
401  if constexpr(NXdlPerWave32 > 0)
402  {
404  reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
405  }
406  }
407  return false;
408  }
409 
410  // polymorphic
411  bool IsSupportedArgument(const BaseArgument* p_arg) override
412  {
413  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
414  }
415 
416  static auto MakeArgument(const void* p_sorted_token_ids,
417  const void* p_sorted_expert_ids,
418  const void* p_max_token_id,
419  const void* p_a,
420  const void* p_a_scale,
421  const void* p_b,
422  const void* p_b_scale,
423  std::array<const void*, NumDTensor> p_ds,
424  void* p_c,
425  index_t NumTokens,
426  index_t TopK,
427  index_t M,
428  index_t N,
429  index_t K,
430  index_t StrideA,
431  index_t StrideScaleA,
432  index_t StrideB,
433  index_t StrideScaleB,
434  std::array<index_t, NumDTensor> StrideDs,
435  index_t StrideC,
436  index_t KBatch,
437  AElementwiseOperation a_element_op,
438  BElementwiseOperation b_element_op,
439  CElementwiseOperation c_element_op)
440  {
441  return Argument{static_cast<const index_t*>(p_sorted_token_ids),
442  static_cast<const index_t*>(p_sorted_expert_ids),
443  static_cast<const index_t*>(p_max_token_id),
444  static_cast<const ADataType*>(p_a),
445  static_cast<const AScaleDataType*>(p_a_scale),
446  static_cast<const BDataType*>(p_b),
447  static_cast<const BScaleDataType*>(p_b_scale),
448  p_ds,
449  static_cast<CDataType*>(p_c),
450  NumTokens,
451  TopK,
452  M,
453  N,
454  K,
455  StrideA,
456  StrideScaleA,
457  StrideB,
458  StrideScaleB,
459  StrideDs,
460  StrideC,
461  KBatch,
462  a_element_op,
463  b_element_op,
464  c_element_op};
465  }
466 
467  static auto MakeInvoker() { return Invoker{}; }
468 
469  // polymorphic
470  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
471  const void* p_a_scale,
472  const void* p_b,
473  const void* p_b_scale,
474  std::array<const void*, NumDTensor> p_ds,
475  void* p_c,
476  index_t M,
477  index_t N,
478  index_t K,
479  index_t StrideA,
480  index_t StrideScaleA,
481  index_t StrideB,
482  index_t StrideScaleB,
483  std::array<ck::index_t, NumDTensor> StrideDs,
484  index_t StrideC,
485  index_t KBatch,
486  AElementwiseOperation a_element_op,
487  BElementwiseOperation b_element_op,
488  CElementwiseOperation c_element_op) override
489  {
490  return std::make_unique<Argument>(nullptr,
491  nullptr,
492  nullptr,
493  static_cast<const ADataType*>(p_a),
494  static_cast<const AScaleDataType*>(p_a_scale),
495  static_cast<const BDataType*>(p_b),
496  static_cast<const BScaleDataType*>(p_b_scale),
497  p_ds,
498  static_cast<CDataType*>(p_c),
499  M, // randoms set, no use
500  0,
501  M,
502  N,
503  K,
504  StrideA,
505  StrideScaleA,
506  StrideB,
507  StrideScaleB,
508  StrideDs,
509  StrideC,
510  KBatch,
511  a_element_op,
512  b_element_op,
513  c_element_op);
514  }
515 
516  // polymorphic
517  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
518  {
519  return std::make_unique<Invoker>(Invoker{});
520  }
521 
522  // polymorphic
523  std::string GetTypeString() const override
524  {
525  auto str = std::stringstream();
526 
527  std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
530 
531  std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
537 
538  // clang-format off
539  str << "DeviceMoeGEmmMx"
540  << "<"
541  << getGemmSpecializationString(GemmSpec) << ", "
542  << std::string(ALayout::name)[0]
543  << std::string(BLayout::name)[0]
544  << std::string(CLayout::name)[0]
545  << ">"
546  << " BlkSize: "
547  << BlockSize << ", "
548  << "BlkTile: "
549  << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
550  << "WaveTile: "
551  << MPerXDL<<"x"<<NPerXDL << ", "
552  << "WaveMap: "
553  << MXdlPerWave<<"x" << NXdlPerWave<<", "
554  << "VmemReadVec: "
555  << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
556  << "BlkGemmPipelineScheduler: "
557  << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
558  << "BlkGemmPipelineVersion: "
559  << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
560  << "BlkGemmPipelinePrefetchStages: "
561  << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
562  // clang-format on
563 
564  return str.str();
565  }
566 };
567 
568 } // namespace device
569 } // namespace tensor_operation
570 } // namespace ck
#define INVOKER_RUN3_IMPL
Definition: device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition: device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
void flush_icache()
Definition: flush_cache.hpp:361
Definition: ck.hpp:268
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__global__ void kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm.hpp:90
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:299
bool is_bf16_atomic_supported()
Definition: device_prop.hpp:108
Definition: stream_config.hpp:10
Definition: gridwise_moe_mx_gemm.hpp:721
Definition: gridwise_moe_mx_gemm.hpp:179
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_mx_gemm.hpp:1179
Definition: functional2.hpp:33
Definition: device_base.hpp:197
Definition: device_base.hpp:208
Definition: device_gemm_multiple_d.hpp:167
Definition: device_moe_mx_gemm.hpp:166
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_moe_mx_gemm.hpp:351
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_moe_mx_gemm.hpp:168
Definition: device_moe_mx_gemm.hpp:92
static constexpr GET_NXDL_PER_WAVE_IMPL auto NXdlPerWave64
Definition: device_moe_mx_gemm.hpp:94
typename GridwiseGemm64::Argument Argument
Definition: device_moe_mx_gemm.hpp:158
std::string GetTypeString() const override
Definition: device_moe_mx_gemm.hpp:523
static constexpr bool IsValidCompilationParameter()
Definition: device_moe_mx_gemm.hpp:358
static constexpr index_t APackedSize
Definition: device_moe_mx_gemm.hpp:159
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_moe_mx_gemm.hpp:517
static bool IsSupportedArgument(const Argument &arg)
Definition: device_moe_mx_gemm.hpp:364
static constexpr index_t NumDTensor
Definition: device_moe_mx_gemm.hpp:96
static auto MakeArgument(const void *p_sorted_token_ids, const void *p_sorted_expert_ids, const void *p_max_token_id, const void *p_a, const void *p_a_scale, const void *p_b, const void *p_b_scale, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t NumTokens, index_t TopK, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideScaleA, index_t StrideB, index_t StrideScaleB, std::array< index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: device_moe_mx_gemm.hpp:416
int GetPreShuffleParameters() override
Definition: device_moe_mx_gemm.hpp:162
static constexpr index_t BPackedSize
Definition: device_moe_mx_gemm.hpp:160
static constexpr auto NXdlPerWave32
Definition: device_moe_mx_gemm.hpp:95
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_a_scale, const void *p_b, const void *p_b_scale, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideScaleA, index_t StrideB, index_t StrideScaleB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition: device_moe_mx_gemm.hpp:470
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_moe_mx_gemm.hpp:411
static auto MakeInvoker()
Definition: device_moe_mx_gemm.hpp:467
Definition: flush_cache.hpp:165