/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp Source File
device_moe_mx_gemm_bns.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
19 
20 namespace ck {
21 namespace tensor_operation {
22 namespace device {
23 
24 template <typename ALayout,
25  typename BLayout,
26  typename DsLayout,
27  typename CLayout,
28  typename ADataType,
29  typename AScaleDataType,
30  typename BDataType,
31  typename BScaleDataType,
32  typename DsDataType,
33  typename CDataType,
34  typename GemmAccDataType,
35  typename CShuffleDataType,
36  typename AElementwiseOperation,
37  typename BElementwiseOperation,
38  typename CElementwiseOperation,
39  GemmSpecialization GemmSpec,
40  index_t ScaleBlockSize,
41  index_t BlockSize,
42  index_t MPerBlock,
43  index_t NPerBlock,
44  index_t KPerBlock,
45  index_t AK1,
46  index_t BK1,
47  index_t MPerXDL,
48  index_t NPerXDL,
49  index_t MXdlPerWave,
50  index_t NXdlPerWave,
51  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
52  typename ABlockTransferThreadClusterArrangeOrder,
53  typename ABlockTransferSrcAccessOrder,
54  index_t ABlockTransferSrcVectorDim,
55  index_t ABlockTransferSrcScalarPerVector,
56  index_t ABlockTransferDstScalarPerVector_AK1,
57  bool ABlockLdsExtraM,
58  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
59  typename BBlockTransferThreadClusterArrangeOrder,
60  typename BBlockTransferSrcAccessOrder,
61  index_t BBlockTransferSrcVectorDim,
62  index_t BBlockTransferSrcScalarPerVector,
63  index_t BBlockTransferDstScalarPerVector_BK1,
64  bool BBlockLdsExtraN,
65  index_t CShuffleMXdlPerWavePerShuffle,
66  index_t CShuffleNXdlPerWavePerShuffle,
67  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
68  typename CDEShuffleBlockTransferScalarPerVectors,
71  index_t ActivationOP = 0,
72  bool NSwizzle = false,
73  bool IsInputGemm = true,
74  bool MulRoutedWeight = true,
75  typename IndexType = index_t,
76  typename ComputeTypeA = ADataType,
77  typename ComputeTypeB = BDataType>
79  BLayout,
80  DsLayout,
81  CLayout,
82  ADataType,
83  AScaleDataType,
84  BDataType,
85  BScaleDataType,
86  DsDataType,
87  CDataType,
88  ScaleBlockSize,
89  AElementwiseOperation,
90  BElementwiseOperation,
91  CElementwiseOperation>
92 {
94  static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
95  static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
96  static constexpr index_t NumDTensor = DsDataType::Size();
97  template <index_t NXdlPerWave_>
99  GridwiseMoeGemmMXBNS<ALayout,
100  BLayout,
101  DsLayout,
102  CLayout,
103  ADataType,
104  AScaleDataType,
105  BDataType,
106  BScaleDataType,
107  GemmAccDataType,
108  CShuffleDataType,
109  DsDataType,
110  CDataType,
111  AElementwiseOperation,
112  BElementwiseOperation,
113  CElementwiseOperation,
114  GemmSpec,
115  ScaleBlockSize,
116  BlockSize,
117  MPerBlock,
118  NPerBlock,
119  KPerBlock,
120  AK1,
121  BK1,
122  MPerXDL,
123  NPerXDL,
124  MXdlPerWave,
125  NXdlPerWave_,
126  ABlockTransferThreadClusterLengths_AK0_M_AK1,
127  ABlockTransferThreadClusterArrangeOrder,
128  ABlockTransferSrcAccessOrder,
129  ABlockTransferSrcVectorDim,
130  ABlockTransferSrcScalarPerVector,
131  ABlockTransferDstScalarPerVector_AK1,
132  false,
133  ABlockLdsExtraM,
134  BBlockTransferThreadClusterLengths_BK0_N_BK1,
135  BBlockTransferThreadClusterArrangeOrder,
136  BBlockTransferSrcAccessOrder,
137  BBlockTransferSrcVectorDim,
138  BBlockTransferSrcScalarPerVector,
139  BBlockTransferDstScalarPerVector_BK1,
140  false,
141  BBlockLdsExtraN,
142  CShuffleMXdlPerWavePerShuffle,
143  CShuffleNXdlPerWavePerShuffle,
144  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
145  CDEShuffleBlockTransferScalarPerVectors,
146  BlkGemmPipeSched,
147  BlkGemmPipelineVer,
148  ActivationOP,
149  NSwizzle,
150  IsInputGemm,
151  MulRoutedWeight,
152  IndexType,
153  ComputeTypeA,
154  ComputeTypeB>;
157 
159 
160  static constexpr index_t APackedSize = packed_size_v<ADataType>;
161  static constexpr index_t BPackedSize = packed_size_v<BDataType>;
162 
163  int GetPreShuffleParameters() override { return NPerXDL; }
164 
165  // Invoker
166  struct Invoker : public BaseInvoker
167  {
168  template <typename GridwiseGemm>
169  float RunImp(const typename GridwiseGemm::Argument& arg,
170  const StreamConfig& stream_config = StreamConfig{})
171  {
172  if(stream_config.log_level_ > 0)
173  {
174  arg.Print();
175  }
176 
177  if(!GridwiseGemm::CheckValidity(arg))
178  {
179  throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
180  }
181 
182  index_t gdx, gdy, gdz;
183  std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
184 
185  float ave_time = 0;
186 
187  index_t k_grain = arg.KBatch * KPerBlock;
188  index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
189 
190  const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
191 
192  const auto RunKernel = [&](const auto& kernel) {
193  if(stream_config.flush_cache)
194  {
195 
196  std::array<std::size_t, NumDTensor> DsSize;
197 
198  auto arg_ = arg;
199 
200  const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
201  arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
202  const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
203  arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
204 
205  auto size_a_buffer =
206  a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
207  auto size_b_buffer =
208  b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
209 
210  const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
211  arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
212 
213  static_for<0, NumDTensor, 1>{}([&](auto i) {
214  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
215  DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
216  });
217  ck::utility::RotatingMemWrapperMultiD<typename GridwiseGemm::Argument,
218  DsDataType>
219  rotating_mem(arg_,
220  stream_config.rotating_count,
221  size_a_buffer,
222  size_b_buffer,
223  DsSize);
224  rotating_mem.Print();
225 
226  auto run_flush_cache = [&]() {
227  // flush icache
229  // rotating mem
230  rotating_mem.Next();
231  // clear c mem
232  if(arg_.KBatch > 1)
233  hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
234  0,
235  arg_.M * arg_.N * sizeof(CDataType),
236  stream_config.stream_id_));
237  };
238 
239  ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
240  stream_config,
241  run_flush_cache,
242  kernel,
243  dim3(gdx, gdy, gdz),
244  dim3(BlockSize),
245  0,
246  arg_);
247  }
248  else
249  {
250  if(arg.KBatch > 1)
251  hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
252  0,
253  arg.M * arg.N * sizeof(CDataType),
254  stream_config.stream_id_));
255 
256  ave_time = launch_and_time_kernel(
257  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
258  }
259  };
260 
261  // TODO: Check if this is the right algorithm for minimum_occupancy
262  constexpr index_t minimum_occupancy =
263  BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave
264  ? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
265  MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2)
266  ? 2
267  : 1
268  : 2;
269 
270  constexpr auto MemoryDataOp =
272  if(has_main_k_block_loop)
273  {
274  // Tail number always full
275  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
276  {
277  const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
278  true,
279  MemoryDataOp,
280  minimum_occupancy,
282  RunKernel(kernel);
283  }
284  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
285  {
286  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
287  {
288  const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
289  true,
290  MemoryDataOp,
291  minimum_occupancy,
293  RunKernel(kernel);
294  }
295  else
296  {
297  const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
298  true,
299  MemoryDataOp,
300  minimum_occupancy,
302  RunKernel(kernel);
303  }
304  }
305  else
306  {
307  throw std::runtime_error("todo: only v1 & v3 support now");
308  }
309  }
310  else
311  {
312  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
313  {
314  const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
315  false,
316  MemoryDataOp,
317  minimum_occupancy,
319  RunKernel(kernel);
320  }
321  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
322  {
323  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
324  {
325  const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
326  false,
327  MemoryDataOp,
328  minimum_occupancy,
330  RunKernel(kernel);
331  }
332  else
333  {
334  const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
335  false,
336  MemoryDataOp,
337  minimum_occupancy,
339  RunKernel(kernel);
340  }
341  }
342  }
343 
344  return ave_time;
345  }
346 
348 
349  // polymorphic
350  float Run(const BaseArgument* p_arg,
351  const StreamConfig& stream_config = StreamConfig{}) override
352  {
353  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
354  }
355  };
356 
357  static constexpr bool IsValidCompilationParameter()
358  {
359  // TODO: properly implement this check
360  return true;
361  }
362 
363  static bool IsSupportedArgument(const Argument& arg)
364  {
365  // only impl kbatch 1 now
366  if(arg.KBatch > 1)
367  {
368  return false;
369  }
370  if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
371  {
372  return false;
373  }
374  if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
375  {
376  return false;
377  }
378 
379  if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
380  GemmSpec == GemmSpecialization::NKPadding ||
381  GemmSpec == GemmSpecialization::MNKPadding ||
382  GemmSpec == GemmSpecialization::KPadding))
383  {
384  return false;
385  }
386  if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
387  {
388  return false;
389  }
390 
391  if(get_warp_size() == 64)
392  {
393  if constexpr(NXdlPerWave64 > 0)
394  {
395  return GridwiseGemm64::CheckValidity(arg);
396  }
397  }
398  else
399  {
400  if constexpr(NXdlPerWave32 > 0)
401  {
403  reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
404  }
405  }
406  return false;
407  }
408 
409  // polymorphic
410  bool IsSupportedArgument(const BaseArgument* p_arg) override
411  {
412  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
413  }
414 
415  static auto MakeArgument(const void* p_sorted_token_ids,
416  const void* p_sorted_expert_ids,
417  const void* p_max_token_id,
418  const void* p_a,
419  const void* p_a_scale,
420  const void* p_b,
421  const void* p_b_scale,
422  std::array<const void*, NumDTensor> p_ds,
423  void* p_c,
424  index_t NumTokens,
425  index_t TopK,
426  index_t M,
427  index_t N,
428  index_t K,
429  index_t StrideA,
430  index_t StrideScaleA,
431  index_t StrideB,
432  index_t StrideScaleB,
433  std::array<index_t, NumDTensor> StrideDs,
434  index_t StrideC,
435  index_t KBatch,
436  AElementwiseOperation a_element_op,
437  BElementwiseOperation b_element_op,
438  CElementwiseOperation c_element_op)
439  {
440  return Argument{static_cast<const index_t*>(p_sorted_token_ids),
441  static_cast<const index_t*>(p_sorted_expert_ids),
442  static_cast<const index_t*>(p_max_token_id),
443  static_cast<const ADataType*>(p_a),
444  static_cast<const AScaleDataType*>(p_a_scale),
445  static_cast<const BDataType*>(p_b),
446  static_cast<const BScaleDataType*>(p_b_scale),
447  p_ds,
448  static_cast<CDataType*>(p_c),
449  NumTokens,
450  TopK,
451  M,
452  N,
453  K,
454  StrideA,
455  StrideScaleA,
456  StrideB,
457  StrideScaleB,
458  StrideDs,
459  StrideC,
460  KBatch,
461  a_element_op,
462  b_element_op,
463  c_element_op};
464  }
465 
466  static auto MakeInvoker() { return Invoker{}; }
467 
468  // polymorphic
469  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
470  const void* p_a_scale,
471  const void* p_b,
472  const void* p_b_scale,
473  std::array<const void*, NumDTensor> p_ds,
474  void* p_c,
475  index_t M,
476  index_t N,
477  index_t K,
478  index_t StrideA,
479  index_t StrideScaleA,
480  index_t StrideB,
481  index_t StrideScaleB,
482  std::array<ck::index_t, NumDTensor> StrideDs,
483  index_t StrideC,
484  index_t KBatch,
485  AElementwiseOperation a_element_op,
486  BElementwiseOperation b_element_op,
487  CElementwiseOperation c_element_op) override
488  {
489  return std::make_unique<Argument>(nullptr,
490  nullptr,
491  nullptr,
492  static_cast<const ADataType*>(p_a),
493  static_cast<const AScaleDataType*>(p_a_scale),
494  static_cast<const BDataType*>(p_b),
495  static_cast<const BScaleDataType*>(p_b_scale),
496  p_ds,
497  static_cast<CDataType*>(p_c),
498  M, // randoms set, no use
499  0,
500  M,
501  N,
502  K,
503  StrideA,
504  StrideScaleA,
505  StrideB,
506  StrideScaleB,
507  StrideDs,
508  StrideC,
509  KBatch,
510  a_element_op,
511  b_element_op,
512  c_element_op);
513  }
514 
515  // polymorphic
516  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
517  {
518  return std::make_unique<Invoker>(Invoker{});
519  }
520 
521  // polymorphic
522  std::string GetTypeString() const override
523  {
524  auto str = std::stringstream();
525 
526  std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
529 
530  std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
536 
537  // clang-format off
538  str << "DeviceMoeGEmmMx"
539  << "<"
540  << getGemmSpecializationString(GemmSpec) << ", "
541  << std::string(ALayout::name)[0]
542  << std::string(BLayout::name)[0]
543  << std::string(CLayout::name)[0]
544  << ">"
545  << " BlkSize: "
546  << BlockSize << ", "
547  << "BlkTile: "
548  << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
549  << "WaveTile: "
550  << MPerXDL<<"x"<<NPerXDL << ", "
551  << "WaveMap: "
552  << MXdlPerWave<<"x" << NXdlPerWave<<", "
553  << "VmemReadVec: "
554  << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
555  << "BlkGemmPipelineScheduler: "
556  << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
557  << "BlkGemmPipelineVersion: "
558  << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
559  << "BlkGemmPipelinePrefetchStages: "
560  << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
561  // clang-format on
562 
563  return str.str();
564  }
565 };
566 
567 } // namespace device
568 } // namespace tensor_operation
569 } // namespace ck
#define INVOKER_RUN3_IMPL
Definition: device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition: device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
void flush_icache()
Definition: flush_cache.hpp:361
Definition: ck.hpp:268
__global__ void kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm_bns.hpp:48
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:299
bool is_bf16_atomic_supported()
Definition: device_prop.hpp:108
Definition: stream_config.hpp:10
Definition: gridwise_moe_mx_gemm_bns.hpp:654
Definition: gridwise_moe_mx_gemm_bns.hpp:179
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_mx_gemm_bns.hpp:1111
Definition: functional2.hpp:33
Definition: device_base.hpp:197
Definition: device_base.hpp:208
Definition: device_gemm_multiple_d.hpp:167
Definition: device_moe_mx_gemm_bns.hpp:167
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_moe_mx_gemm_bns.hpp:350
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_moe_mx_gemm_bns.hpp:169
Definition: device_moe_mx_gemm_bns.hpp:92
static constexpr index_t BPackedSize
Definition: device_moe_mx_gemm_bns.hpp:161
std::string GetTypeString() const override
Definition: device_moe_mx_gemm_bns.hpp:522
static auto MakeInvoker()
Definition: device_moe_mx_gemm_bns.hpp:466
static constexpr index_t APackedSize
Definition: device_moe_mx_gemm_bns.hpp:160
typename GridwiseGemm64::Argument Argument
Definition: device_moe_mx_gemm_bns.hpp:158
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_a_scale, const void *p_b, const void *p_b_scale, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideScaleA, index_t StrideB, index_t StrideScaleB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition: device_moe_mx_gemm_bns.hpp:469
static auto MakeArgument(const void *p_sorted_token_ids, const void *p_sorted_expert_ids, const void *p_max_token_id, const void *p_a, const void *p_a_scale, const void *p_b, const void *p_b_scale, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t NumTokens, index_t TopK, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideScaleA, index_t StrideB, index_t StrideScaleB, std::array< index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: device_moe_mx_gemm_bns.hpp:415
static constexpr auto NXdlPerWave32
Definition: device_moe_mx_gemm_bns.hpp:95
static bool IsSupportedArgument(const Argument &arg)
Definition: device_moe_mx_gemm_bns.hpp:363
int GetPreShuffleParameters() override
Definition: device_moe_mx_gemm_bns.hpp:163
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_moe_mx_gemm_bns.hpp:410
static constexpr GET_NXDL_PER_WAVE_IMPL auto NXdlPerWave64
Definition: device_moe_mx_gemm_bns.hpp:94
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_moe_mx_gemm_bns.hpp:516
static constexpr index_t NumDTensor
Definition: device_moe_mx_gemm_bns.hpp:96
static constexpr bool IsValidCompilationParameter()
Definition: device_moe_mx_gemm_bns.hpp:357
Definition: flush_cache.hpp:165