/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp Source File
device_moe_mx_gemm_bpreshuffle.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
19 
20 namespace ck {
21 namespace tensor_operation {
22 namespace device {
23 
24 template <typename ALayout,
25  typename BLayout,
26  typename DsLayout,
27  typename CLayout,
28  typename ADataType,
29  typename AScaleDataType,
30  typename BDataType,
31  typename BScaleDataType,
32  typename DsDataType,
33  typename CDataType,
34  typename GemmAccDataType,
35  typename CShuffleDataType,
36  typename AElementwiseOperation,
37  typename BElementwiseOperation,
38  typename CElementwiseOperation,
39  GemmSpecialization GemmSpec,
40  index_t ScaleBlockSize,
41  index_t BlockSize,
42  index_t MPerBlock,
43  index_t NPerBlock,
44  index_t KPerBlock,
45  index_t AK1,
46  index_t BK1,
47  index_t MPerXDL,
48  index_t NPerXDL,
49  index_t MXdlPerWave,
50  index_t NXdlPerWave,
51  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
52  typename ABlockTransferThreadClusterArrangeOrder,
53  typename ABlockTransferSrcAccessOrder,
54  index_t ABlockTransferSrcVectorDim,
55  index_t ABlockTransferSrcScalarPerVector,
56  index_t ABlockTransferDstScalarPerVector_AK1,
57  bool ABlockLdsExtraM,
58  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
59  typename BBlockTransferThreadClusterArrangeOrder,
60  typename BBlockTransferSrcAccessOrder,
61  index_t BBlockTransferSrcVectorDim,
62  index_t BBlockTransferSrcScalarPerVector,
63  index_t BBlockTransferDstScalarPerVector_BK1,
64  bool BBlockLdsExtraN,
65  index_t CShuffleMXdlPerWavePerShuffle,
66  index_t CShuffleNXdlPerWavePerShuffle,
67  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
68  typename CDEShuffleBlockTransferScalarPerVectors,
71  index_t ActivationOP = 0,
72  bool NSwizzle = false,
73  bool IsInputGemm = true,
74  bool MulRoutedWeight = true,
75  typename IndexType = index_t,
76  typename ComputeTypeA = ADataType,
77  typename ComputeTypeB = BDataType>
79  BLayout,
80  DsLayout,
81  CLayout,
82  ADataType,
83  AScaleDataType,
84  BDataType,
85  BScaleDataType,
86  DsDataType,
87  CDataType,
88  ScaleBlockSize,
89  AElementwiseOperation,
90  BElementwiseOperation,
91  CElementwiseOperation>
92 {
93  static constexpr index_t NumDTensor = DsDataType::Size();
95  ALayout,
96  BLayout,
97  DsLayout,
98  CLayout,
99  ADataType,
100  AScaleDataType,
101  BDataType,
102  BScaleDataType,
103  GemmAccDataType,
104  CShuffleDataType,
105  DsDataType,
106  CDataType,
107  AElementwiseOperation,
108  BElementwiseOperation,
109  CElementwiseOperation,
110  GemmSpec,
111  ScaleBlockSize,
112  BlockSize,
113  MPerBlock,
114  NPerBlock,
115  KPerBlock,
116  AK1,
117  BK1,
118  MPerXDL,
119  NPerXDL,
120  MXdlPerWave,
121  NXdlPerWave,
122  ABlockTransferThreadClusterLengths_AK0_M_AK1,
123  ABlockTransferThreadClusterArrangeOrder,
124  ABlockTransferSrcAccessOrder,
125  ABlockTransferSrcVectorDim,
126  ABlockTransferSrcScalarPerVector,
127  ABlockTransferDstScalarPerVector_AK1,
128  false,
129  ABlockLdsExtraM,
130  BBlockTransferThreadClusterLengths_BK0_N_BK1,
131  BBlockTransferThreadClusterArrangeOrder,
132  BBlockTransferSrcAccessOrder,
133  BBlockTransferSrcVectorDim,
134  BBlockTransferSrcScalarPerVector,
135  BBlockTransferDstScalarPerVector_BK1,
136  false,
137  BBlockLdsExtraN,
138  CShuffleMXdlPerWavePerShuffle,
139  CShuffleNXdlPerWavePerShuffle,
140  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
141  CDEShuffleBlockTransferScalarPerVectors,
142  BlkGemmPipeSched,
143  BlkGemmPipelineVer,
144  ActivationOP,
145  NSwizzle,
146  IsInputGemm,
147  MulRoutedWeight,
148  IndexType,
149  ComputeTypeA,
150  ComputeTypeB>;
151 
153 
154  static constexpr index_t APackedSize = packed_size_v<ADataType>;
155  static constexpr index_t BPackedSize = packed_size_v<BDataType>;
156 
157  int GetPreShuffleParameters() override { return NPerXDL; }
158 
159  // Invoker
160  struct Invoker : public BaseInvoker
161  {
162  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
163  {
164  if(stream_config.log_level_ > 0)
165  {
166  arg.Print();
167  }
168 
170  {
171  throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
172  }
173 
174  index_t gdx, gdy, gdz;
175  std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
176 
177  float ave_time = 0;
178 
179  index_t k_grain = arg.KBatch * KPerBlock;
180  index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
181 
182  const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
183 
184  const auto RunKernel = [&](const auto& kernel) {
185  if(stream_config.flush_cache)
186  {
187 
188  std::array<std::size_t, NumDTensor> DsSize;
189 
190  Argument arg_ = arg;
191 
192  const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
193  arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
194  const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
195  arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
196 
197  auto size_a_buffer =
198  a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
199  auto size_b_buffer =
200  b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
201 
202  const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
203  arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
204 
205  static_for<0, NumDTensor, 1>{}([&](auto i) {
206  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
207  DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
208  });
210  arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize);
211  rotating_mem.Print();
212 
213  auto run_flush_cache = [&]() {
214  // flush icache
216  // rotating mem
217  rotating_mem.Next();
218  // clear c mem
219  if(arg_.KBatch > 1)
220  hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
221  0,
222  arg_.M * arg_.N * sizeof(CDataType),
223  stream_config.stream_id_));
224  };
225 
226  ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
227  stream_config,
228  run_flush_cache,
229  kernel,
230  dim3(gdx, gdy, gdz),
231  dim3(BlockSize),
232  0,
233  arg_);
234  }
235  else
236  {
237  if(arg.KBatch > 1)
238  hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
239  0,
240  arg.M * arg.N * sizeof(CDataType),
241  stream_config.stream_id_));
242 
243  ave_time = launch_and_time_kernel(
244  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
245  }
246  };
247 
248  // TODO: Check if this is the right algorithm for minimum_occupancy
249  constexpr index_t minimum_occupancy =
250  BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave
251  ? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
252  MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2)
253  ? 2
254  : 1
255  : 2;
256 
257  constexpr auto MemoryDataOp =
259 
260  if(has_main_k_block_loop)
261  {
262  // Tail number always full
263  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
264  {
265  {
267  {
268  const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
269  true,
270  MemoryDataOp,
271  minimum_occupancy,
273  RunKernel(kernel);
274  }
275  else
276  {
277  const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
278  true,
279  MemoryDataOp,
280  minimum_occupancy,
282  RunKernel(kernel);
283  }
284  }
285  }
286  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
287  {
289  {
290  const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
291  true,
292  MemoryDataOp,
293  minimum_occupancy,
295  RunKernel(kernel);
296  }
297  else
298  {
299  const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
300  true,
301  MemoryDataOp,
302  minimum_occupancy,
304  RunKernel(kernel);
305  }
306  }
307  else
308  {
309  throw std::runtime_error("todo: only v1 & v3 support now");
310  }
311  }
312  else
313  {
314  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
315  {
317  {
318  const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
319  false,
320  MemoryDataOp,
321  minimum_occupancy,
323  RunKernel(kernel);
324  }
325  else
326  {
327  const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
328  false,
329  MemoryDataOp,
330  minimum_occupancy,
332  RunKernel(kernel);
333  }
334  }
335  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
336  {
338  {
339  const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
340  false,
341  MemoryDataOp,
342  minimum_occupancy,
344  RunKernel(kernel);
345  }
346  else
347  {
348  const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
349  false,
350  MemoryDataOp,
351  minimum_occupancy,
353  RunKernel(kernel);
354  }
355  }
356  }
357 
358  return ave_time;
359  }
360 
361  // polymorphic
362  float Run(const BaseArgument* p_arg,
363  const StreamConfig& stream_config = StreamConfig{}) override
364  {
365  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
366  }
367  };
368 
369  static constexpr bool IsValidCompilationParameter()
370  {
371  // TODO: properly implement this check
372  return true;
373  }
374 
375  static bool IsSupportedArgument(const Argument& arg)
376  {
377  // only impl kbatch 1 now
378  if(arg.KBatch > 1)
379  {
380  return false;
381  }
382  if(!ck::is_xdl_supported())
383  {
384  return false;
385  }
386 
387  if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
388  {
389  return false;
390  }
391 
392  if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
393  GemmSpec == GemmSpecialization::NKPadding ||
394  GemmSpec == GemmSpecialization::MNKPadding ||
395  GemmSpec == GemmSpecialization::KPadding))
396  {
397  return false;
398  }
399  if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
400  {
401  return false;
402  }
403 
404  return GridwiseGemm::CheckValidity(arg);
405  }
406 
407  // polymorphic
408  bool IsSupportedArgument(const BaseArgument* p_arg) override
409  {
410  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
411  }
412 
413  static auto MakeArgument(const void* p_sorted_token_ids,
414  const void* p_sorted_expert_ids,
415  const void* p_max_token_id,
416  const void* p_a,
417  const void* p_a_scale,
418  const void* p_b,
419  const void* p_b_scale,
420  std::array<const void*, NumDTensor> p_ds,
421  void* p_c,
422  index_t NumTokens,
423  index_t TopK,
424  index_t M,
425  index_t N,
426  index_t K,
427  index_t StrideA,
428  index_t StrideScaleA,
429  index_t StrideB,
430  index_t StrideScaleB,
431  std::array<index_t, NumDTensor> StrideDs,
432  index_t StrideC,
433  index_t KBatch,
434  AElementwiseOperation a_element_op,
435  BElementwiseOperation b_element_op,
436  CElementwiseOperation c_element_op)
437  {
438  return Argument{static_cast<const index_t*>(p_sorted_token_ids),
439  static_cast<const index_t*>(p_sorted_expert_ids),
440  static_cast<const index_t*>(p_max_token_id),
441  static_cast<const ADataType*>(p_a),
442  static_cast<const AScaleDataType*>(p_a_scale),
443  static_cast<const BDataType*>(p_b),
444  static_cast<const BScaleDataType*>(p_b_scale),
445  p_ds,
446  static_cast<CDataType*>(p_c),
447  NumTokens,
448  TopK,
449  M,
450  N,
451  K,
452  StrideA,
453  StrideScaleA,
454  StrideB,
455  StrideScaleB,
456  StrideDs,
457  StrideC,
458  KBatch,
459  a_element_op,
460  b_element_op,
461  c_element_op};
462  }
463 
464  static auto MakeInvoker() { return Invoker{}; }
465 
466  // polymorphic
467  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
468  const void* p_a_scale,
469  const void* p_b,
470  const void* p_b_scale,
471  std::array<const void*, NumDTensor> p_ds,
472  void* p_c,
473  index_t M,
474  index_t N,
475  index_t K,
476  index_t StrideA,
477  index_t StrideScaleA,
478  index_t StrideB,
479  index_t StrideScaleB,
480  std::array<ck::index_t, NumDTensor> StrideDs,
481  index_t StrideC,
482  index_t KBatch,
483  AElementwiseOperation a_element_op,
484  BElementwiseOperation b_element_op,
485  CElementwiseOperation c_element_op) override
486  {
487  return std::make_unique<Argument>(nullptr,
488  nullptr,
489  nullptr,
490  static_cast<const ADataType*>(p_a),
491  static_cast<const AScaleDataType*>(p_a_scale),
492  static_cast<const BDataType*>(p_b),
493  static_cast<const BScaleDataType*>(p_b_scale),
494  p_ds,
495  static_cast<CDataType*>(p_c),
496  M, // randoms set, no use
497  0,
498  M,
499  N,
500  K,
501  StrideA,
502  StrideScaleA,
503  StrideB,
504  StrideScaleB,
505  StrideDs,
506  StrideC,
507  KBatch,
508  a_element_op,
509  b_element_op,
510  c_element_op);
511  }
512 
513  // polymorphic
514  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
515  {
516  return std::make_unique<Invoker>(Invoker{});
517  }
518 
519  // polymorphic
520  std::string GetTypeString() const override
521  {
522  auto str = std::stringstream();
523 
524  std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
527 
528  std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
534 
535  // clang-format off
536  str << "DeviceMoeGEmmMx"
537  << "<"
538  << getGemmSpecializationString(GemmSpec) << ", "
539  << std::string(ALayout::name)[0]
540  << std::string(BLayout::name)[0]
541  << std::string(CLayout::name)[0]
542  << ">"
543  << " BlkSize: "
544  << BlockSize << ", "
545  << "BlkTile: "
546  << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
547  << "WaveTile: "
548  << MPerXDL<<"x"<<NPerXDL << ", "
549  << "WaveMap: "
550  << MXdlPerWave<<"x" << NXdlPerWave<<", "
551  << "VmemReadVec: "
552  << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
553  << "BlkGemmPipelineScheduler: "
554  << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
555  << "BlkGemmPipelineVersion: "
556  << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
557  << "BlkGemmPipelinePrefetchStages: "
558  << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
559  // clang-format on
560 
561  return str.str();
562  }
563 };
564 
565 } // namespace device
566 } // namespace tensor_operation
567 } // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
void flush_icache()
Definition: flush_cache.hpp:216
Definition: ck.hpp:267
bool is_xdl_supported()
Definition: device_prop.hpp:68
__global__ void kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm_bns.hpp:48
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__global__ void kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm.hpp:87
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:298
bool is_bf16_atomic_supported()
Definition: device_prop.hpp:85
Definition: stream_config.hpp:10
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:745
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:171
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:253
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:649
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1248
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:369
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:477
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1241
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1063
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: device_base.hpp:62
Definition: device_gemm_multiple_d.hpp:167
Definition: device_moe_mx_gemm_bpreshuffle.hpp:161
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_moe_mx_gemm_bpreshuffle.hpp:362
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_moe_mx_gemm_bpreshuffle.hpp:162
Definition: device_moe_mx_gemm_bpreshuffle.hpp:92
static constexpr index_t APackedSize
Definition: device_moe_mx_gemm_bpreshuffle.hpp:154
static constexpr index_t NumDTensor
Definition: device_moe_mx_gemm_bpreshuffle.hpp:93
static constexpr bool IsValidCompilationParameter()
Definition: device_moe_mx_gemm_bpreshuffle.hpp:369
static constexpr index_t BPackedSize
Definition: device_moe_mx_gemm_bpreshuffle.hpp:155
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_moe_mx_gemm_bpreshuffle.hpp:514
static bool IsSupportedArgument(const Argument &arg)
Definition: device_moe_mx_gemm_bpreshuffle.hpp:375
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_a_scale, const void *p_b, const void *p_b_scale, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideScaleA, index_t StrideB, index_t StrideScaleB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition: device_moe_mx_gemm_bpreshuffle.hpp:467
typename GridwiseGemm::Argument Argument
Definition: device_moe_mx_gemm_bpreshuffle.hpp:152
GridwiseMoeGemmMX_BPreshuffle< ALayout, BLayout, DsLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ActivationOP, NSwizzle, IsInputGemm, MulRoutedWeight, IndexType, ComputeTypeA, ComputeTypeB > GridwiseGemm
Definition: device_moe_mx_gemm_bpreshuffle.hpp:150
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_moe_mx_gemm_bpreshuffle.hpp:408
int GetPreShuffleParameters() override
Definition: device_moe_mx_gemm_bpreshuffle.hpp:157
std::string GetTypeString() const override
Definition: device_moe_mx_gemm_bpreshuffle.hpp:520
static auto MakeInvoker()
Definition: device_moe_mx_gemm_bpreshuffle.hpp:464
static auto MakeArgument(const void *p_sorted_token_ids, const void *p_sorted_expert_ids, const void *p_max_token_id, const void *p_a, const void *p_a_scale, const void *p_b, const void *p_b_scale, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t NumTokens, index_t TopK, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideScaleA, index_t StrideB, index_t StrideScaleB, std::array< index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: device_moe_mx_gemm_bpreshuffle.hpp:413
Definition: flush_cache.hpp:20