/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp Source File
device_moe_gemm_blockscale.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 #include <hip/hip_runtime.h>
9 
20 
21 namespace ck {
22 namespace tensor_operation {
23 namespace device {
24 
25 template <typename ALayout,
26  typename BLayout,
27  typename DsLayout,
28  typename CLayout,
29  typename ADataType,
30  typename AScaleDataType,
31  typename BDataType,
32  typename BScaleDataType,
33  typename DsDataType,
34  typename CDataType,
35  typename GemmAccDataType,
36  typename CShuffleDataType,
37  typename AElementwiseOperation,
38  typename BElementwiseOperation,
39  typename CElementwiseOperation,
40  GemmSpecialization GemmSpec,
41  index_t BlockSize,
42  index_t ScaleBlockM,
43  index_t ScaleBlockN,
44  index_t ScaleBlockK,
45  index_t MPerBlock,
46  index_t NPerBlock,
47  index_t KPerBlock,
48  index_t AK1,
49  index_t BK1,
50  index_t MPerXDL,
51  index_t NPerXDL,
52  index_t MXdlPerWave,
53  index_t NXdlPerWave,
54  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
55  typename ABlockTransferThreadClusterArrangeOrder,
56  typename ABlockTransferSrcAccessOrder,
57  index_t ABlockTransferSrcVectorDim,
58  index_t ABlockTransferSrcScalarPerVector,
59  index_t ABlockTransferDstScalarPerVector_AK1,
60  bool ABlockLdsExtraM,
61  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
62  typename BBlockTransferThreadClusterArrangeOrder,
63  typename BBlockTransferSrcAccessOrder,
64  index_t BBlockTransferSrcVectorDim,
65  index_t BBlockTransferSrcScalarPerVector,
66  index_t BBlockTransferDstScalarPerVector_BK1,
67  bool BBlockLdsExtraN,
68  index_t CShuffleMXdlPerWavePerShuffle,
69  index_t CShuffleNXdlPerWavePerShuffle,
70  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
71  typename CDEShuffleBlockTransferScalarPerVectors,
74  index_t ActivationOP = 0,
75  bool NSwizzle = false,
76  bool IsInputGemm = true,
77  bool MulRoutedWeight = false,
78  typename IndexType = index_t,
79  typename ComputeTypeA = CDataType,
80  typename ComputeTypeB = ComputeTypeA,
81  typename LDSTypeA = ComputeTypeA,
82  typename LDSTypeB = ComputeTypeB>
85  BLayout,
86  DsLayout,
87  CLayout,
88  ADataType,
89  AScaleDataType,
90  BDataType,
91  BScaleDataType,
92  DsDataType,
93  CDataType,
94  ScaleBlockM,
95  ScaleBlockN,
96  ScaleBlockK,
97  AElementwiseOperation,
98  BElementwiseOperation,
99  CElementwiseOperation>
100 {
102  static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
103  static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
104  static constexpr index_t NumDTensor = DsDataType::Size();
105  template <index_t NXdlPerWave_>
107  ALayout,
108  BLayout,
109  DsLayout,
110  CLayout,
111  ADataType,
112  BDataType,
113  GemmAccDataType,
114  CShuffleDataType,
115  DsDataType,
116  CDataType,
117  AElementwiseOperation,
118  BElementwiseOperation,
119  CElementwiseOperation,
120  GemmSpec,
121  BlockSize,
122  ScaleBlockM,
123  ScaleBlockN,
124  ScaleBlockK,
125  MPerBlock,
126  NPerBlock,
127  KPerBlock,
128  AK1,
129  BK1,
130  MPerXDL,
131  NPerXDL,
132  MXdlPerWave,
133  NXdlPerWave_,
134  ABlockTransferThreadClusterLengths_AK0_M_AK1,
135  ABlockTransferThreadClusterArrangeOrder,
136  ABlockTransferSrcAccessOrder,
137  ABlockTransferSrcVectorDim,
138  ABlockTransferSrcScalarPerVector,
139  ABlockTransferDstScalarPerVector_AK1,
140  false,
141  ABlockLdsExtraM,
142  BBlockTransferThreadClusterLengths_BK0_N_BK1,
143  BBlockTransferThreadClusterArrangeOrder,
144  BBlockTransferSrcAccessOrder,
145  BBlockTransferSrcVectorDim,
146  BBlockTransferSrcScalarPerVector,
147  BBlockTransferDstScalarPerVector_BK1,
148  false,
149  BBlockLdsExtraN,
150  CShuffleMXdlPerWavePerShuffle,
151  math::min(CShuffleNXdlPerWavePerShuffle, NXdlPerWave_),
152  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
153  CDEShuffleBlockTransferScalarPerVectors,
154  BlkGemmPipeSched,
155  BlkGemmPipelineVer,
156  ActivationOP,
157  NSwizzle,
158  IsInputGemm,
159  MulRoutedWeight,
160  IndexType,
161  ComputeTypeA,
162  ComputeTypeB,
163  LDSTypeA,
164  LDSTypeB>;
167 
169 
170  static constexpr index_t APackedSize = []() {
172  return 2;
173  else
174  return 1;
175  }();
176 
177  static constexpr index_t BPackedSize = []() {
179  return 2;
180  else
181  return 1;
182  }();
183 
184  int GetPreShuffleParameters() override { return NPerXDL; }
185 
186  // Invoker
187  struct Invoker : public BaseInvoker
188  {
189  template <typename GridwiseGemm>
190  float RunImp(const typename GridwiseGemm::Argument& arg,
191  const StreamConfig& stream_config = StreamConfig{})
192  {
193  if(stream_config.log_level_ > 0)
194  {
195  arg.Print();
196  }
197 
198  if(!GridwiseGemm::CheckValidity(arg))
199  {
200  throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
201  }
202 
203  index_t gdx, gdy, gdz;
204  std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
205 
206  float ave_time = 0;
207 
208  index_t k_grain = arg.KBatch * KPerBlock;
209  index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
210 
211  const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
212  const auto RunKernel = [&](const auto& kernel) {
213  if(stream_config.flush_cache)
214  {
215 
216  std::array<std::size_t, NumDTensor> DsSize;
217 
218  auto arg_ = arg;
219 
220  const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
221  arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
222  const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
223  arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
224 
225  auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
226  sizeof(ADataType) / APackedSize;
227  auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
228  sizeof(BDataType) / BPackedSize;
229 
230  const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
231  arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
232 
233  static_for<0, NumDTensor, 1>{}([&](auto i) {
234  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
235  DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
236  });
237  ck::utility::RotatingMemWrapperMultiD<typename GridwiseGemm::Argument,
238  DsDataType>
239  rotating_mem(arg_,
240  stream_config.rotating_count,
241  size_a_buffer,
242  size_b_buffer,
243  DsSize);
244  rotating_mem.Print();
245 
246  auto run_flush_cache = [&]() {
247  // flush icache
249  // rotating mem
250  rotating_mem.Next();
251  // clear c mem
252  if(arg_.KBatch > 1)
253  hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
254  0,
255  arg_.M * arg_.N * sizeof(CDataType),
256  stream_config.stream_id_));
257  };
258 
259  ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
260  stream_config,
261  run_flush_cache,
262  kernel,
263  dim3(gdx, gdy, gdz),
264  dim3(BlockSize),
265  0,
266  arg_);
267  }
268  else
269  {
270  if(arg.KBatch > 1)
271  hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
272  0,
273  arg.M * arg.N * sizeof(CDataType),
274  stream_config.stream_id_));
275 
276  ave_time = launch_and_time_kernel(
277  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
278  }
279  };
280 
281  constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize /
282  4 * (1 + GridwiseGemm::NWave);
283  constexpr auto estimated_reg_b = NPerBlock * KPerBlock * sizeof(BDataType) / BlockSize /
284  4 * (2) * (IsInputGemm ? 2 : 1);
285  constexpr auto estimated_reg_c = MPerBlock * NPerBlock * sizeof(GemmAccDataType) /
286  BlockSize / 4 * (IsInputGemm ? 2 : 1);
287  constexpr auto estimated_reg_total =
288  estimated_reg_a + estimated_reg_b + estimated_reg_c;
289 
290  constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
291 
292  constexpr auto MemoryDataOp =
294 
295  if(has_main_k_block_loop)
296  {
297  // Tail number always full
298  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
299  {
300  {
301  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
302  {
303  const auto kernel = kernel_moe_gemm<GridwiseGemm,
304  true,
305  MemoryDataOp,
306  minimum_occupancy,
308  RunKernel(kernel);
309  }
310  else
311  {
312  const auto kernel = kernel_moe_gemm<GridwiseGemm,
313  true,
314  MemoryDataOp,
315  minimum_occupancy,
317  RunKernel(kernel);
318  }
319  }
320  }
321  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
322  BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
323  {
324  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
325  {
326  const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
327  true,
328  MemoryDataOp,
329  minimum_occupancy,
331  RunKernel(kernel);
332  }
333  else
334  {
335  const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
336  true,
337  MemoryDataOp,
338  minimum_occupancy,
340  RunKernel(kernel);
341  }
342  }
343  else
344  {
345  throw std::runtime_error("todo: only v1 & v2 support now");
346  }
347  }
348 #if 1
349  else
350  {
351  // Tail number always 1
352  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
353  {
354  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
355  {
356  const auto kernel = kernel_moe_gemm<GridwiseGemm,
357  false,
358  MemoryDataOp,
359  minimum_occupancy,
361  RunKernel(kernel);
362  }
363  else
364  {
365  const auto kernel = kernel_moe_gemm<GridwiseGemm,
366  false,
367  MemoryDataOp,
368  minimum_occupancy,
370  RunKernel(kernel);
371  }
372  }
373  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
374  BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
375  {
376  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
377  {
378  const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
379  false,
380  MemoryDataOp,
381  minimum_occupancy,
383  RunKernel(kernel);
384  }
385  else
386  {
387  const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
388  false,
389  MemoryDataOp,
390  minimum_occupancy,
392  RunKernel(kernel);
393  }
394  }
395  }
396 #endif
397 
398  return ave_time;
399  }
400 
402 
403  // polymorphic
404  float Run(const BaseArgument* p_arg,
405  const StreamConfig& stream_config = StreamConfig{}) override
406  {
407  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
408  }
409  };
410 
411  static constexpr bool IsValidCompilationParameter()
412  {
413  // TODO: properly implement this check
414  return true;
415  }
416 
417  static bool IsSupportedArgument(const Argument& arg)
418  {
419  // only impl kbatch 1 now
420  if(arg.KBatch > 1)
421  {
422  return false;
423  }
424  if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
425  {
426  return false;
427  }
428  if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
429  {
430  return false;
431  }
432 
433  if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
434  GemmSpec == GemmSpecialization::NKPadding ||
435  GemmSpec == GemmSpecialization::MNKPadding ||
436  GemmSpec == GemmSpecialization::KPadding))
437  {
438  return false;
439  }
440  if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
441  {
442  return false;
443  }
444 
445  if(get_warp_size() == 64)
446  {
447  if constexpr(NXdlPerWave64 > 0)
448  {
449  return GridwiseGemm64::CheckValidity(arg);
450  }
451  }
452  else
453  {
454  if constexpr(NXdlPerWave32 > 0)
455  {
457  reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
458  }
459  }
460  return false;
461  }
462 
463  // polymorphic
464  bool IsSupportedArgument(const BaseArgument* p_arg) override
465  {
466  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
467  }
468 
469  static auto MakeArgument(const void* p_sorted_token_ids,
470  const void* p_sorted_expert_ids,
471  const void* p_max_token_id,
472  const void* p_a,
473  const void* p_b,
474  std::array<const void*, NumDTensor> p_ds,
475  void* p_c,
476  index_t NumTokens,
477  index_t TopK,
478  index_t M,
479  index_t N,
480  index_t K,
481  index_t StrideA,
482  index_t StrideB,
483  std::array<index_t, NumDTensor> StrideDs,
484  index_t StrideC,
485  const void* p_a_scale,
486  const void* p_b_scale,
487  index_t KBatch,
488  AElementwiseOperation a_element_op,
489  BElementwiseOperation b_element_op,
490  CElementwiseOperation c_element_op)
491  {
492  return Argument{static_cast<const index_t*>(p_sorted_token_ids),
493  static_cast<const index_t*>(p_sorted_expert_ids),
494  static_cast<const index_t*>(p_max_token_id),
495  static_cast<const ADataType*>(p_a),
496  static_cast<const BDataType*>(p_b),
497  p_ds,
498  static_cast<CDataType*>(p_c),
499  NumTokens,
500  TopK,
501  M,
502  N,
503  K,
504  StrideA,
505  StrideB,
506  StrideDs,
507  StrideC,
508  static_cast<const AScaleDataType*>(p_a_scale),
509  static_cast<const BScaleDataType*>(p_b_scale),
510  KBatch,
511  a_element_op,
512  b_element_op,
513  c_element_op};
514  }
515 
516  static auto MakeInvoker() { return Invoker{}; }
517 
518  // polymorphic
519  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
520  const void* p_b,
521  std::array<const void*, NumDTensor> p_ds,
522  void* p_c,
523  index_t M,
524  index_t N,
525  index_t K,
526  index_t StrideA,
527  index_t StrideB,
528  std::array<ck::index_t, NumDTensor> StrideDs,
529  index_t StrideC,
530  const void* p_a_scale,
531  const void* p_b_scale,
532  // index_t KBatch,
533  AElementwiseOperation a_element_op,
534  BElementwiseOperation b_element_op,
535  CElementwiseOperation c_element_op) override
536  {
537  return std::make_unique<Argument>(nullptr,
538  nullptr,
539  nullptr,
540  static_cast<const ADataType*>(p_a),
541  static_cast<const BDataType*>(p_b),
542  p_ds,
543  static_cast<CDataType*>(p_c),
544  M, // randoms set, no use
545  0,
546  M,
547  N,
548  K,
549  StrideA,
550  StrideB,
551  StrideDs,
552  StrideC,
553  static_cast<const AScaleDataType*>(p_a_scale),
554  static_cast<const BScaleDataType*>(p_b_scale),
555  1, // KBatch,
556  a_element_op,
557  b_element_op,
558  c_element_op);
559  }
560 
561  // polymorphic
562  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
563  {
564  return std::make_unique<Invoker>(Invoker{});
565  }
566 
567  // polymorphic
568  std::string GetTypeString() const override
569  {
570  auto str = std::stringstream();
571 
572  std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
575 
576  std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
580 
581  // clang-format off
582  str << "DeviceMoeGEmm"
583  << "<"
584  << getGemmSpecializationString(GemmSpec) << ", "
585  << std::string(ALayout::name)[0]
586  << std::string(BLayout::name)[0]
587  << std::string(CLayout::name)[0]
588  << ">"
589  << " BlkSize: "
590  << BlockSize << ", "
591  << "BlkTile: "
592  << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
593  << "WaveTile: "
594  << MPerXDL<<"x"<<NPerXDL << ", "
595  << "WaveMap: "
596  << MXdlPerWave<<"x" << NXdlPerWave<<", "
597  << "VmemReadVec: "
598  << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
599  << "BlkGemmPipelineScheduler: "
600  << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
601  << "BlkGemmPipelineVersion: "
602  << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
603  << "BlkGemmPipelinePrefetchStages: "
604  << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
605  // clang-format on
606 
607  return str.str();
608  }
609 };
610 
611 } // namespace device
612 } // namespace tensor_operation
613 } // namespace ck
#define INVOKER_RUN3_IMPL
Definition: device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition: device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
void flush_icache()
Definition: flush_cache.hpp:361
Definition: ck.hpp:268
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__global__ void kernel_moe_gemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:46
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
constexpr bool is_same_v
Definition: type.hpp:283
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:299
__global__ void kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:84
bool is_bf16_atomic_supported()
Definition: device_prop.hpp:108
Definition: stream_config.hpp:10
Definition: gridwise_moe_gemm_blockscale.hpp:666
Definition: gridwise_moe_gemm_blockscale.hpp:177
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_gemm_blockscale.hpp:968
Definition: data_type.hpp:187
Definition: functional2.hpp:33
Definition: device_base.hpp:197
Definition: device_base.hpp:208
Definition: device_gemm_multiple_d_ab_scale.hpp:80
Definition: device_moe_gemm_blockscale.hpp:188
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_moe_gemm_blockscale.hpp:190
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_moe_gemm_blockscale.hpp:404
Definition: device_moe_gemm_blockscale.hpp:100
static constexpr index_t BPackedSize
Definition: device_moe_gemm_blockscale.hpp:177
static constexpr auto NXdlPerWave32
Definition: device_moe_gemm_blockscale.hpp:103
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_moe_gemm_blockscale.hpp:562
static bool IsSupportedArgument(const Argument &arg)
Definition: device_moe_gemm_blockscale.hpp:417
static constexpr index_t APackedSize
Definition: device_moe_gemm_blockscale.hpp:170
int GetPreShuffleParameters() override
Definition: device_moe_gemm_blockscale.hpp:184
std::string GetTypeString() const override
Definition: device_moe_gemm_blockscale.hpp:568
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, const void *p_a_scale, const void *p_b_scale, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition: device_moe_gemm_blockscale.hpp:519
static constexpr GET_NXDL_PER_WAVE_IMPL auto NXdlPerWave64
Definition: device_moe_gemm_blockscale.hpp:102
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_moe_gemm_blockscale.hpp:464
typename GridwiseGemm64::Argument Argument
Definition: device_moe_gemm_blockscale.hpp:168
static auto MakeArgument(const void *p_sorted_token_ids, const void *p_sorted_expert_ids, const void *p_max_token_id, const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t NumTokens, index_t TopK, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideC, const void *p_a_scale, const void *p_b_scale, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: device_moe_gemm_blockscale.hpp:469
static constexpr bool IsValidCompilationParameter()
Definition: device_moe_gemm_blockscale.hpp:411
static constexpr index_t NumDTensor
Definition: device_moe_gemm_blockscale.hpp:104
static auto MakeInvoker()
Definition: device_moe_gemm_blockscale.hpp:516
Definition: flush_cache.hpp:165