/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp Source File
device_moe_mx_gemm_bns.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
19 
20 namespace ck {
21 namespace tensor_operation {
22 namespace device {
23 
24 template <typename ALayout,
25  typename BLayout,
26  typename DsLayout,
27  typename CLayout,
28  typename ADataType,
29  typename AScaleDataType,
30  typename BDataType,
31  typename BScaleDataType,
32  typename DsDataType,
33  typename CDataType,
34  typename GemmAccDataType,
35  typename CShuffleDataType,
36  typename AElementwiseOperation,
37  typename BElementwiseOperation,
38  typename CElementwiseOperation,
39  GemmSpecialization GemmSpec,
40  index_t ScaleBlockSize,
41  index_t BlockSize,
42  index_t MPerBlock,
43  index_t NPerBlock,
44  index_t KPerBlock,
45  index_t AK1,
46  index_t BK1,
47  index_t MPerXDL,
48  index_t NPerXDL,
49  index_t MXdlPerWave,
50  index_t NXdlPerWave,
51  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
52  typename ABlockTransferThreadClusterArrangeOrder,
53  typename ABlockTransferSrcAccessOrder,
54  index_t ABlockTransferSrcVectorDim,
55  index_t ABlockTransferSrcScalarPerVector,
56  index_t ABlockTransferDstScalarPerVector_AK1,
57  bool ABlockLdsExtraM,
58  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
59  typename BBlockTransferThreadClusterArrangeOrder,
60  typename BBlockTransferSrcAccessOrder,
61  index_t BBlockTransferSrcVectorDim,
62  index_t BBlockTransferSrcScalarPerVector,
63  index_t BBlockTransferDstScalarPerVector_BK1,
64  bool BBlockLdsExtraN,
65  index_t CShuffleMXdlPerWavePerShuffle,
66  index_t CShuffleNXdlPerWavePerShuffle,
67  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
68  typename CDEShuffleBlockTransferScalarPerVectors,
71  index_t ActivationOP = 0,
72  bool NSwizzle = false,
73  bool IsInputGemm = true,
74  bool MulRoutedWeight = true,
75  typename IndexType = index_t,
76  typename ComputeTypeA = ADataType,
77  typename ComputeTypeB = BDataType>
79  BLayout,
80  DsLayout,
81  CLayout,
82  ADataType,
83  AScaleDataType,
84  BDataType,
85  BScaleDataType,
86  DsDataType,
87  CDataType,
88  ScaleBlockSize,
89  AElementwiseOperation,
90  BElementwiseOperation,
91  CElementwiseOperation>
92 {
93  static constexpr index_t NumDTensor = DsDataType::Size();
94  using GridwiseGemm =
95  GridwiseMoeGemmMXBNS<ALayout,
96  BLayout,
97  DsLayout,
98  CLayout,
99  ADataType,
100  AScaleDataType,
101  BDataType,
102  BScaleDataType,
103  GemmAccDataType,
104  CShuffleDataType,
105  DsDataType,
106  CDataType,
107  AElementwiseOperation,
108  BElementwiseOperation,
109  CElementwiseOperation,
110  GemmSpec,
111  ScaleBlockSize,
112  BlockSize,
113  MPerBlock,
114  NPerBlock,
115  KPerBlock,
116  AK1,
117  BK1,
118  MPerXDL,
119  NPerXDL,
120  MXdlPerWave,
121  NXdlPerWave,
122  ABlockTransferThreadClusterLengths_AK0_M_AK1,
123  ABlockTransferThreadClusterArrangeOrder,
124  ABlockTransferSrcAccessOrder,
125  ABlockTransferSrcVectorDim,
126  ABlockTransferSrcScalarPerVector,
127  ABlockTransferDstScalarPerVector_AK1,
128  false,
129  ABlockLdsExtraM,
130  BBlockTransferThreadClusterLengths_BK0_N_BK1,
131  BBlockTransferThreadClusterArrangeOrder,
132  BBlockTransferSrcAccessOrder,
133  BBlockTransferSrcVectorDim,
134  BBlockTransferSrcScalarPerVector,
135  BBlockTransferDstScalarPerVector_BK1,
136  false,
137  BBlockLdsExtraN,
138  CShuffleMXdlPerWavePerShuffle,
139  CShuffleNXdlPerWavePerShuffle,
140  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
141  CDEShuffleBlockTransferScalarPerVectors,
142  BlkGemmPipeSched,
143  BlkGemmPipelineVer,
144  ActivationOP,
145  NSwizzle,
146  IsInputGemm,
147  MulRoutedWeight,
148  IndexType,
149  ComputeTypeA,
150  ComputeTypeB>;
151 
153 
154  static constexpr index_t APackedSize = packed_size_v<ADataType>;
155  static constexpr index_t BPackedSize = packed_size_v<BDataType>;
156 
157  int GetPreShuffleParameters() override { return NPerXDL; }
158 
159  // Invoker
160  struct Invoker : public BaseInvoker
161  {
162  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
163  {
164  if(stream_config.log_level_ > 0)
165  {
166  arg.Print();
167  }
168 
170  {
171  throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
172  }
173 
174  index_t gdx, gdy, gdz;
175  std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
176 
177  float ave_time = 0;
178 
179  index_t k_grain = arg.KBatch * KPerBlock;
180  index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
181 
182  const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
183 
184  const auto RunKernel = [&](const auto& kernel) {
185  if(stream_config.flush_cache)
186  {
187 
188  std::array<std::size_t, NumDTensor> DsSize;
189 
190  Argument arg_ = arg;
191 
192  const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
193  arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
194  const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
195  arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
196 
197  auto size_a_buffer =
198  a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
199  auto size_b_buffer =
200  b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
201 
202  const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
203  arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
204 
205  static_for<0, NumDTensor, 1>{}([&](auto i) {
206  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
207  DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
208  });
210  arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize);
211  rotating_mem.Print();
212 
213  auto run_flush_cache = [&]() {
214  // flush icache
216  // rotating mem
217  rotating_mem.Next();
218  // clear c mem
219  if(arg_.KBatch > 1)
220  hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
221  0,
222  arg_.M * arg_.N * sizeof(CDataType),
223  stream_config.stream_id_));
224  };
225 
226  ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
227  stream_config,
228  run_flush_cache,
229  kernel,
230  dim3(gdx, gdy, gdz),
231  dim3(BlockSize),
232  0,
233  arg_);
234  }
235  else
236  {
237  if(arg.KBatch > 1)
238  hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
239  0,
240  arg.M * arg.N * sizeof(CDataType),
241  stream_config.stream_id_));
242 
243  ave_time = launch_and_time_kernel(
244  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
245  }
246  };
247 
248  // TODO: Check if this is the right algorithm for minimum_occupancy
249  constexpr index_t minimum_occupancy =
250  BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave
251  ? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
252  MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2)
253  ? 2
254  : 1
255  : 2;
256 
257  constexpr auto MemoryDataOp =
259  if(has_main_k_block_loop)
260  {
261  // Tail number always full
262  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
263  {
264  const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
265  true,
266  MemoryDataOp,
267  minimum_occupancy,
269  RunKernel(kernel);
270  }
271  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
272  {
274  {
275  const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
276  true,
277  MemoryDataOp,
278  minimum_occupancy,
280  RunKernel(kernel);
281  }
282  else
283  {
284  const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
285  true,
286  MemoryDataOp,
287  minimum_occupancy,
289  RunKernel(kernel);
290  }
291  }
292  else
293  {
294  throw std::runtime_error("todo: only v1 & v3 support now");
295  }
296  }
297  else
298  {
299  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
300  {
301  const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
302  false,
303  MemoryDataOp,
304  minimum_occupancy,
306  RunKernel(kernel);
307  }
308  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
309  {
311  {
312  const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
313  false,
314  MemoryDataOp,
315  minimum_occupancy,
317  RunKernel(kernel);
318  }
319  else
320  {
321  const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
322  false,
323  MemoryDataOp,
324  minimum_occupancy,
326  RunKernel(kernel);
327  }
328  }
329  }
330 
331  return ave_time;
332  }
333 
334  // polymorphic
335  float Run(const BaseArgument* p_arg,
336  const StreamConfig& stream_config = StreamConfig{}) override
337  {
338  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
339  }
340  };
341 
342  static constexpr bool IsValidCompilationParameter()
343  {
344  // TODO: properly implement this check
345  return true;
346  }
347 
348  static bool IsSupportedArgument(const Argument& arg)
349  {
350  // only impl kbatch 1 now
351  if(arg.KBatch > 1)
352  {
353  return false;
354  }
355  if(!ck::is_xdl_supported())
356  {
357  return false;
358  }
359 
360  if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
361  {
362  return false;
363  }
364 
365  if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
366  GemmSpec == GemmSpecialization::NKPadding ||
367  GemmSpec == GemmSpecialization::MNKPadding ||
368  GemmSpec == GemmSpecialization::KPadding))
369  {
370  return false;
371  }
372  if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
373  {
374  return false;
375  }
376 
377  return GridwiseGemm::CheckValidity(arg);
378  }
379 
380  // polymorphic
381  bool IsSupportedArgument(const BaseArgument* p_arg) override
382  {
383  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
384  }
385 
386  static auto MakeArgument(const void* p_sorted_token_ids,
387  const void* p_sorted_expert_ids,
388  const void* p_max_token_id,
389  const void* p_a,
390  const void* p_a_scale,
391  const void* p_b,
392  const void* p_b_scale,
393  std::array<const void*, NumDTensor> p_ds,
394  void* p_c,
395  index_t NumTokens,
396  index_t TopK,
397  index_t M,
398  index_t N,
399  index_t K,
400  index_t StrideA,
401  index_t StrideScaleA,
402  index_t StrideB,
403  index_t StrideScaleB,
404  std::array<index_t, NumDTensor> StrideDs,
405  index_t StrideC,
406  index_t KBatch,
407  AElementwiseOperation a_element_op,
408  BElementwiseOperation b_element_op,
409  CElementwiseOperation c_element_op)
410  {
411  return Argument{static_cast<const index_t*>(p_sorted_token_ids),
412  static_cast<const index_t*>(p_sorted_expert_ids),
413  static_cast<const index_t*>(p_max_token_id),
414  static_cast<const ADataType*>(p_a),
415  static_cast<const AScaleDataType*>(p_a_scale),
416  static_cast<const BDataType*>(p_b),
417  static_cast<const BScaleDataType*>(p_b_scale),
418  p_ds,
419  static_cast<CDataType*>(p_c),
420  NumTokens,
421  TopK,
422  M,
423  N,
424  K,
425  StrideA,
426  StrideScaleA,
427  StrideB,
428  StrideScaleB,
429  StrideDs,
430  StrideC,
431  KBatch,
432  a_element_op,
433  b_element_op,
434  c_element_op};
435  }
436 
437  static auto MakeInvoker() { return Invoker{}; }
438 
439  // polymorphic
440  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
441  const void* p_a_scale,
442  const void* p_b,
443  const void* p_b_scale,
444  std::array<const void*, NumDTensor> p_ds,
445  void* p_c,
446  index_t M,
447  index_t N,
448  index_t K,
449  index_t StrideA,
450  index_t StrideScaleA,
451  index_t StrideB,
452  index_t StrideScaleB,
453  std::array<ck::index_t, NumDTensor> StrideDs,
454  index_t StrideC,
455  index_t KBatch,
456  AElementwiseOperation a_element_op,
457  BElementwiseOperation b_element_op,
458  CElementwiseOperation c_element_op) override
459  {
460  return std::make_unique<Argument>(nullptr,
461  nullptr,
462  nullptr,
463  static_cast<const ADataType*>(p_a),
464  static_cast<const AScaleDataType*>(p_a_scale),
465  static_cast<const BDataType*>(p_b),
466  static_cast<const BScaleDataType*>(p_b_scale),
467  p_ds,
468  static_cast<CDataType*>(p_c),
469  M, // randoms set, no use
470  0,
471  M,
472  N,
473  K,
474  StrideA,
475  StrideScaleA,
476  StrideB,
477  StrideScaleB,
478  StrideDs,
479  StrideC,
480  KBatch,
481  a_element_op,
482  b_element_op,
483  c_element_op);
484  }
485 
486  // polymorphic
487  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
488  {
489  return std::make_unique<Invoker>(Invoker{});
490  }
491 
492  // polymorphic
493  std::string GetTypeString() const override
494  {
495  auto str = std::stringstream();
496 
497  std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
500 
501  std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
507 
508  // clang-format off
509  str << "DeviceMoeGEmmMx"
510  << "<"
511  << getGemmSpecializationString(GemmSpec) << ", "
512  << std::string(ALayout::name)[0]
513  << std::string(BLayout::name)[0]
514  << std::string(CLayout::name)[0]
515  << ">"
516  << " BlkSize: "
517  << BlockSize << ", "
518  << "BlkTile: "
519  << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
520  << "WaveTile: "
521  << MPerXDL<<"x"<<NPerXDL << ", "
522  << "WaveMap: "
523  << MXdlPerWave<<"x" << NXdlPerWave<<", "
524  << "VmemReadVec: "
525  << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
526  << "BlkGemmPipelineScheduler: "
527  << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
528  << "BlkGemmPipelineVersion: "
529  << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
530  << "BlkGemmPipelinePrefetchStages: "
531  << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
532  // clang-format on
533 
534  return str.str();
535  }
536 };
537 
538 } // namespace device
539 } // namespace tensor_operation
540 } // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
void flush_icache()
Definition: flush_cache.hpp:216
Definition: ck.hpp:267
bool is_xdl_supported()
Definition: device_prop.hpp:68
__global__ void kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm_bns.hpp:48
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:298
bool is_bf16_atomic_supported()
Definition: device_prop.hpp:85
Definition: stream_config.hpp:10
Definition: gridwise_moe_mx_gemm_bns.hpp:648
Definition: gridwise_moe_mx_gemm_bns.hpp:173
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_mx_gemm_bns.hpp:1281
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_mx_gemm_bns.hpp:397
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_mx_gemm_bns.hpp:1288
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_mx_gemm_bns.hpp:315
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_mx_gemm_bns.hpp:1103
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_mx_gemm_bns.hpp:552
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_mx_gemm_bns.hpp:234
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: device_base.hpp:62
Definition: device_gemm_multiple_d.hpp:167
Definition: device_moe_mx_gemm_bns.hpp:161
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_moe_mx_gemm_bns.hpp:162
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_moe_mx_gemm_bns.hpp:335
Definition: device_moe_mx_gemm_bns.hpp:92
static constexpr index_t BPackedSize
Definition: device_moe_mx_gemm_bns.hpp:155
std::string GetTypeString() const override
Definition: device_moe_mx_gemm_bns.hpp:493
static auto MakeInvoker()
Definition: device_moe_mx_gemm_bns.hpp:437
static constexpr index_t APackedSize
Definition: device_moe_mx_gemm_bns.hpp:154
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_a_scale, const void *p_b, const void *p_b_scale, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideScaleA, index_t StrideB, index_t StrideScaleB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition: device_moe_mx_gemm_bns.hpp:440
static auto MakeArgument(const void *p_sorted_token_ids, const void *p_sorted_expert_ids, const void *p_max_token_id, const void *p_a, const void *p_a_scale, const void *p_b, const void *p_b_scale, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t NumTokens, index_t TopK, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideScaleA, index_t StrideB, index_t StrideScaleB, std::array< index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: device_moe_mx_gemm_bns.hpp:386
GridwiseMoeGemmMXBNS< ALayout, BLayout, DsLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ActivationOP, NSwizzle, IsInputGemm, MulRoutedWeight, IndexType, ComputeTypeA, ComputeTypeB > GridwiseGemm
Definition: device_moe_mx_gemm_bns.hpp:150
static bool IsSupportedArgument(const Argument &arg)
Definition: device_moe_mx_gemm_bns.hpp:348
int GetPreShuffleParameters() override
Definition: device_moe_mx_gemm_bns.hpp:157
typename GridwiseGemm::Argument Argument
Definition: device_moe_mx_gemm_bns.hpp:152
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_moe_mx_gemm_bns.hpp:381
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_moe_mx_gemm_bns.hpp:487
static constexpr index_t NumDTensor
Definition: device_moe_mx_gemm_bns.hpp:93
static constexpr bool IsValidCompilationParameter()
Definition: device_moe_mx_gemm_bns.hpp:342
Definition: flush_cache.hpp:20