/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp Source File
device_moe_mx_gemm_bpreshuffle.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
19 
20 namespace ck {
21 namespace tensor_operation {
22 namespace device {
23 
24 template <typename ALayout,
25  typename BLayout,
26  typename DsLayout,
27  typename CLayout,
28  typename ADataType,
29  typename AScaleDataType,
30  typename BDataType,
31  typename BScaleDataType,
32  typename DsDataType,
33  typename CDataType,
34  typename GemmAccDataType,
35  typename CShuffleDataType,
36  typename AElementwiseOperation,
37  typename BElementwiseOperation,
38  typename CElementwiseOperation,
39  GemmSpecialization GemmSpec,
40  index_t ScaleBlockSize,
41  index_t BlockSize,
42  index_t MPerBlock,
43  index_t NPerBlock,
44  index_t KPerBlock,
45  index_t AK1,
46  index_t BK1,
47  index_t MPerXDL,
48  index_t NPerXDL,
49  index_t MXdlPerWave,
50  index_t NXdlPerWave,
51  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
52  typename ABlockTransferThreadClusterArrangeOrder,
53  typename ABlockTransferSrcAccessOrder,
54  index_t ABlockTransferSrcVectorDim,
55  index_t ABlockTransferSrcScalarPerVector,
56  index_t ABlockTransferDstScalarPerVector_AK1,
57  bool ABlockLdsExtraM,
58  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
59  typename BBlockTransferThreadClusterArrangeOrder,
60  typename BBlockTransferSrcAccessOrder,
61  index_t BBlockTransferSrcVectorDim,
62  index_t BBlockTransferSrcScalarPerVector,
63  index_t BBlockTransferDstScalarPerVector_BK1,
64  bool BBlockLdsExtraN,
65  index_t CShuffleMXdlPerWavePerShuffle,
66  index_t CShuffleNXdlPerWavePerShuffle,
67  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
68  typename CDEShuffleBlockTransferScalarPerVectors,
71  index_t ActivationOP = 0,
72  bool NSwizzle = false,
73  bool IsInputGemm = true,
74  bool MulRoutedWeight = true,
75  typename IndexType = index_t,
76  typename ComputeTypeA = ADataType,
77  typename ComputeTypeB = BDataType>
79  BLayout,
80  DsLayout,
81  CLayout,
82  ADataType,
83  AScaleDataType,
84  BDataType,
85  BScaleDataType,
86  DsDataType,
87  CDataType,
88  ScaleBlockSize,
89  AElementwiseOperation,
90  BElementwiseOperation,
91  CElementwiseOperation>
92 {
94  static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
95  static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
96  static constexpr index_t NumDTensor = DsDataType::Size();
97  template <index_t NXdlPerWave_>
99  ALayout,
100  BLayout,
101  DsLayout,
102  CLayout,
103  ADataType,
104  AScaleDataType,
105  BDataType,
106  BScaleDataType,
107  GemmAccDataType,
108  CShuffleDataType,
109  DsDataType,
110  CDataType,
111  AElementwiseOperation,
112  BElementwiseOperation,
113  CElementwiseOperation,
114  GemmSpec,
115  ScaleBlockSize,
116  BlockSize,
117  MPerBlock,
118  NPerBlock,
119  KPerBlock,
120  AK1,
121  BK1,
122  MPerXDL,
123  NPerXDL,
124  MXdlPerWave,
125  NXdlPerWave_,
126  ABlockTransferThreadClusterLengths_AK0_M_AK1,
127  ABlockTransferThreadClusterArrangeOrder,
128  ABlockTransferSrcAccessOrder,
129  ABlockTransferSrcVectorDim,
130  ABlockTransferSrcScalarPerVector,
131  ABlockTransferDstScalarPerVector_AK1,
132  false,
133  ABlockLdsExtraM,
134  BBlockTransferThreadClusterLengths_BK0_N_BK1,
135  BBlockTransferThreadClusterArrangeOrder,
136  BBlockTransferSrcAccessOrder,
137  BBlockTransferSrcVectorDim,
138  BBlockTransferSrcScalarPerVector,
139  BBlockTransferDstScalarPerVector_BK1,
140  false,
141  BBlockLdsExtraN,
142  CShuffleMXdlPerWavePerShuffle,
143  CShuffleNXdlPerWavePerShuffle,
144  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
145  CDEShuffleBlockTransferScalarPerVectors,
146  BlkGemmPipeSched,
147  BlkGemmPipelineVer,
148  ActivationOP,
149  NSwizzle,
150  IsInputGemm,
151  MulRoutedWeight,
152  IndexType,
153  ComputeTypeA,
154  ComputeTypeB>;
157 
159 
160  static constexpr index_t APackedSize = packed_size_v<ADataType>;
161  static constexpr index_t BPackedSize = packed_size_v<BDataType>;
162 
163  int GetPreShuffleParameters() override { return NPerXDL; }
164 
165  // Invoker
166  struct Invoker : public BaseInvoker
167  {
168  template <typename GridwiseGemm>
169  float RunImp(const typename GridwiseGemm::Argument& arg,
170  const StreamConfig& stream_config = StreamConfig{})
171  {
172  if(stream_config.log_level_ > 0)
173  {
174  arg.Print();
175  }
176 
177  if(!GridwiseGemm::CheckValidity(arg))
178  {
179  throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
180  }
181 
182  index_t gdx, gdy, gdz;
183  std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
184 
185  float ave_time = 0;
186 
187  index_t k_grain = arg.KBatch * KPerBlock;
188  index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
189 
190  const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
191 
192  const auto RunKernel = [&](const auto& kernel) {
193  if(stream_config.flush_cache)
194  {
195 
196  std::array<std::size_t, NumDTensor> DsSize;
197 
198  auto arg_ = arg;
199 
200  const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
201  arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
202  const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
203  arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
204 
205  auto size_a_buffer =
206  a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
207  auto size_b_buffer =
208  b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
209 
210  const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
211  arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
212 
213  static_for<0, NumDTensor, 1>{}([&](auto i) {
214  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
215  DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
216  });
217  ck::utility::RotatingMemWrapperMultiD<typename GridwiseGemm::Argument,
218  DsDataType>
219  rotating_mem(arg_,
220  stream_config.rotating_count,
221  size_a_buffer,
222  size_b_buffer,
223  DsSize);
224  rotating_mem.Print();
225 
226  auto run_flush_cache = [&]() {
227  // flush icache
229  // rotating mem
230  rotating_mem.Next();
231  // clear c mem
232  if(arg_.KBatch > 1)
233  hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
234  0,
235  arg_.M * arg_.N * sizeof(CDataType),
236  stream_config.stream_id_));
237  };
238 
239  ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
240  stream_config,
241  run_flush_cache,
242  kernel,
243  dim3(gdx, gdy, gdz),
244  dim3(BlockSize),
245  0,
246  arg_);
247  }
248  else
249  {
250  if(arg.KBatch > 1)
251  hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
252  0,
253  arg.M * arg.N * sizeof(CDataType),
254  stream_config.stream_id_));
255 
256  ave_time = launch_and_time_kernel(
257  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
258  }
259  };
260 
261  // TODO: Check if this is the right algorithm for minimum_occupancy
262  constexpr index_t minimum_occupancy =
263  BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave
264  ? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
265  MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2)
266  ? 2
267  : 1
268  : 2;
269 
270  constexpr auto MemoryDataOp =
272 
273  if(has_main_k_block_loop)
274  {
275  // Tail number always full
276  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
277  {
278  {
279  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
280  {
281  const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
282  true,
283  MemoryDataOp,
284  minimum_occupancy,
286  RunKernel(kernel);
287  }
288  else
289  {
290  const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
291  true,
292  MemoryDataOp,
293  minimum_occupancy,
295  RunKernel(kernel);
296  }
297  }
298  }
299  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
300  {
301  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
302  {
303  const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
304  true,
305  MemoryDataOp,
306  minimum_occupancy,
308  RunKernel(kernel);
309  }
310  else
311  {
312  const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
313  true,
314  MemoryDataOp,
315  minimum_occupancy,
317  RunKernel(kernel);
318  }
319  }
320  else
321  {
322  throw std::runtime_error("todo: only v1 & v3 support now");
323  }
324  }
325  else
326  {
327  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
328  {
329  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
330  {
331  const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
332  false,
333  MemoryDataOp,
334  minimum_occupancy,
336  RunKernel(kernel);
337  }
338  else
339  {
340  const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
341  false,
342  MemoryDataOp,
343  minimum_occupancy,
345  RunKernel(kernel);
346  }
347  }
348  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
349  {
350  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
351  {
352  const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
353  false,
354  MemoryDataOp,
355  minimum_occupancy,
357  RunKernel(kernel);
358  }
359  else
360  {
361  const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
362  false,
363  MemoryDataOp,
364  minimum_occupancy,
366  RunKernel(kernel);
367  }
368  }
369  }
370 
371  return ave_time;
372  }
373 
375 
376  // polymorphic
377  float Run(const BaseArgument* p_arg,
378  const StreamConfig& stream_config = StreamConfig{}) override
379  {
380  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
381  }
382  };
383 
384  static constexpr bool IsValidCompilationParameter()
385  {
386  // TODO: properly implement this check
387  return true;
388  }
389 
390  static bool IsSupportedArgument(const Argument& arg)
391  {
392  // only impl kbatch 1 now
393  if(arg.KBatch > 1)
394  {
395  return false;
396  }
397  if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
398  {
399  return false;
400  }
401  if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
402  {
403  return false;
404  }
405 
406  if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
407  GemmSpec == GemmSpecialization::NKPadding ||
408  GemmSpec == GemmSpecialization::MNKPadding ||
409  GemmSpec == GemmSpecialization::KPadding))
410  {
411  return false;
412  }
413  if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
414  {
415  return false;
416  }
417 
418  if(get_warp_size() == 64)
419  {
420  if constexpr(NXdlPerWave64 > 0)
421  {
422  return GridwiseGemm64::CheckValidity(arg);
423  }
424  }
425  else
426  {
427  if constexpr(NXdlPerWave32 > 0)
428  {
430  reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
431  }
432  }
433  return false;
434  }
435 
436  // polymorphic
437  bool IsSupportedArgument(const BaseArgument* p_arg) override
438  {
439  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
440  }
441 
442  static auto MakeArgument(const void* p_sorted_token_ids,
443  const void* p_sorted_expert_ids,
444  const void* p_max_token_id,
445  const void* p_a,
446  const void* p_a_scale,
447  const void* p_b,
448  const void* p_b_scale,
449  std::array<const void*, NumDTensor> p_ds,
450  void* p_c,
451  index_t NumTokens,
452  index_t TopK,
453  index_t M,
454  index_t N,
455  index_t K,
456  index_t StrideA,
457  index_t StrideScaleA,
458  index_t StrideB,
459  index_t StrideScaleB,
460  std::array<index_t, NumDTensor> StrideDs,
461  index_t StrideC,
462  index_t KBatch,
463  AElementwiseOperation a_element_op,
464  BElementwiseOperation b_element_op,
465  CElementwiseOperation c_element_op)
466  {
467  return Argument{static_cast<const index_t*>(p_sorted_token_ids),
468  static_cast<const index_t*>(p_sorted_expert_ids),
469  static_cast<const index_t*>(p_max_token_id),
470  static_cast<const ADataType*>(p_a),
471  static_cast<const AScaleDataType*>(p_a_scale),
472  static_cast<const BDataType*>(p_b),
473  static_cast<const BScaleDataType*>(p_b_scale),
474  p_ds,
475  static_cast<CDataType*>(p_c),
476  NumTokens,
477  TopK,
478  M,
479  N,
480  K,
481  StrideA,
482  StrideScaleA,
483  StrideB,
484  StrideScaleB,
485  StrideDs,
486  StrideC,
487  KBatch,
488  a_element_op,
489  b_element_op,
490  c_element_op};
491  }
492 
493  static auto MakeInvoker() { return Invoker{}; }
494 
495  // polymorphic
496  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
497  const void* p_a_scale,
498  const void* p_b,
499  const void* p_b_scale,
500  std::array<const void*, NumDTensor> p_ds,
501  void* p_c,
502  index_t M,
503  index_t N,
504  index_t K,
505  index_t StrideA,
506  index_t StrideScaleA,
507  index_t StrideB,
508  index_t StrideScaleB,
509  std::array<ck::index_t, NumDTensor> StrideDs,
510  index_t StrideC,
511  index_t KBatch,
512  AElementwiseOperation a_element_op,
513  BElementwiseOperation b_element_op,
514  CElementwiseOperation c_element_op) override
515  {
516  return std::make_unique<Argument>(nullptr,
517  nullptr,
518  nullptr,
519  static_cast<const ADataType*>(p_a),
520  static_cast<const AScaleDataType*>(p_a_scale),
521  static_cast<const BDataType*>(p_b),
522  static_cast<const BScaleDataType*>(p_b_scale),
523  p_ds,
524  static_cast<CDataType*>(p_c),
525  M, // randoms set, no use
526  0,
527  M,
528  N,
529  K,
530  StrideA,
531  StrideScaleA,
532  StrideB,
533  StrideScaleB,
534  StrideDs,
535  StrideC,
536  KBatch,
537  a_element_op,
538  b_element_op,
539  c_element_op);
540  }
541 
542  // polymorphic
543  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
544  {
545  return std::make_unique<Invoker>(Invoker{});
546  }
547 
548  // polymorphic
549  std::string GetTypeString() const override
550  {
551  auto str = std::stringstream();
552 
553  std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
556 
557  std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
563 
564  // clang-format off
565  str << "DeviceMoeGEmmMx"
566  << "<"
567  << getGemmSpecializationString(GemmSpec) << ", "
568  << std::string(ALayout::name)[0]
569  << std::string(BLayout::name)[0]
570  << std::string(CLayout::name)[0]
571  << ">"
572  << " BlkSize: "
573  << BlockSize << ", "
574  << "BlkTile: "
575  << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
576  << "WaveTile: "
577  << MPerXDL<<"x"<<NPerXDL << ", "
578  << "WaveMap: "
579  << MXdlPerWave<<"x" << NXdlPerWave<<", "
580  << "VmemReadVec: "
581  << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
582  << "BlkGemmPipelineScheduler: "
583  << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
584  << "BlkGemmPipelineVersion: "
585  << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
586  << "BlkGemmPipelinePrefetchStages: "
587  << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
588  // clang-format on
589 
590  return str.str();
591  }
592 };
593 
594 } // namespace device
595 } // namespace tensor_operation
596 } // namespace ck
#define INVOKER_RUN3_IMPL
Definition: device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition: device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
void flush_icache()
Definition: flush_cache.hpp:361
Definition: ck.hpp:268
__global__ void kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm_bns.hpp:48
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__global__ void kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm.hpp:90
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:299
bool is_bf16_atomic_supported()
Definition: device_prop.hpp:108
Definition: stream_config.hpp:10
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:751
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:177
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1071
Definition: functional2.hpp:33
Definition: device_base.hpp:197
Definition: device_base.hpp:208
Definition: device_gemm_multiple_d.hpp:167
Definition: device_moe_mx_gemm_bpreshuffle.hpp:167
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_moe_mx_gemm_bpreshuffle.hpp:169
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_moe_mx_gemm_bpreshuffle.hpp:377
Definition: device_moe_mx_gemm_bpreshuffle.hpp:92
typename GridwiseGemm64::Argument Argument
Definition: device_moe_mx_gemm_bpreshuffle.hpp:158
static constexpr index_t APackedSize
Definition: device_moe_mx_gemm_bpreshuffle.hpp:160
static constexpr index_t NumDTensor
Definition: device_moe_mx_gemm_bpreshuffle.hpp:96
static constexpr bool IsValidCompilationParameter()
Definition: device_moe_mx_gemm_bpreshuffle.hpp:384
static constexpr index_t BPackedSize
Definition: device_moe_mx_gemm_bpreshuffle.hpp:161
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_moe_mx_gemm_bpreshuffle.hpp:543
static bool IsSupportedArgument(const Argument &arg)
Definition: device_moe_mx_gemm_bpreshuffle.hpp:390
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_a_scale, const void *p_b, const void *p_b_scale, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideScaleA, index_t StrideB, index_t StrideScaleB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition: device_moe_mx_gemm_bpreshuffle.hpp:496
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_moe_mx_gemm_bpreshuffle.hpp:437
int GetPreShuffleParameters() override
Definition: device_moe_mx_gemm_bpreshuffle.hpp:163
std::string GetTypeString() const override
Definition: device_moe_mx_gemm_bpreshuffle.hpp:549
static auto MakeInvoker()
Definition: device_moe_mx_gemm_bpreshuffle.hpp:493
static constexpr GET_NXDL_PER_WAVE_IMPL auto NXdlPerWave64
Definition: device_moe_mx_gemm_bpreshuffle.hpp:94
static constexpr auto NXdlPerWave32
Definition: device_moe_mx_gemm_bpreshuffle.hpp:95
static auto MakeArgument(const void *p_sorted_token_ids, const void *p_sorted_expert_ids, const void *p_max_token_id, const void *p_a, const void *p_a_scale, const void *p_b, const void *p_b_scale, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t NumTokens, index_t TopK, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideScaleA, index_t StrideB, index_t StrideScaleB, std::array< index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: device_moe_mx_gemm_bpreshuffle.hpp:442
Definition: flush_cache.hpp:165