/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp Source File
device_moe_gemm_blockscale.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 #include <hip/hip_runtime.h>
9 
20 
21 namespace ck {
22 namespace tensor_operation {
23 namespace device {
24 
25 template <typename ALayout,
26  typename BLayout,
27  typename DsLayout,
28  typename CLayout,
29  typename ADataType,
30  typename AScaleDataType,
31  typename BDataType,
32  typename BScaleDataType,
33  typename DsDataType,
34  typename CDataType,
35  typename GemmAccDataType,
36  typename CShuffleDataType,
37  typename AElementwiseOperation,
38  typename BElementwiseOperation,
39  typename CElementwiseOperation,
40  GemmSpecialization GemmSpec,
41  index_t BlockSize,
42  index_t ScaleBlockM,
43  index_t ScaleBlockN,
44  index_t ScaleBlockK,
45  index_t MPerBlock,
46  index_t NPerBlock,
47  index_t KPerBlock,
48  index_t AK1,
49  index_t BK1,
50  index_t MPerXDL,
51  index_t NPerXDL,
52  index_t MXdlPerWave,
53  index_t NXdlPerWave,
54  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
55  typename ABlockTransferThreadClusterArrangeOrder,
56  typename ABlockTransferSrcAccessOrder,
57  index_t ABlockTransferSrcVectorDim,
58  index_t ABlockTransferSrcScalarPerVector,
59  index_t ABlockTransferDstScalarPerVector_AK1,
60  bool ABlockLdsExtraM,
61  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
62  typename BBlockTransferThreadClusterArrangeOrder,
63  typename BBlockTransferSrcAccessOrder,
64  index_t BBlockTransferSrcVectorDim,
65  index_t BBlockTransferSrcScalarPerVector,
66  index_t BBlockTransferDstScalarPerVector_BK1,
67  bool BBlockLdsExtraN,
68  index_t CShuffleMXdlPerWavePerShuffle,
69  index_t CShuffleNXdlPerWavePerShuffle,
70  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
71  typename CDEShuffleBlockTransferScalarPerVectors,
74  index_t ActivationOP = 0,
75  bool NSwizzle = false,
76  bool IsInputGemm = true,
77  bool MulRoutedWeight = false,
78  typename IndexType = index_t,
79  typename ComputeTypeA = CDataType,
80  typename ComputeTypeB = ComputeTypeA,
81  typename LDSTypeA = ComputeTypeA,
82  typename LDSTypeB = ComputeTypeB>
85  BLayout,
86  DsLayout,
87  CLayout,
88  ADataType,
89  AScaleDataType,
90  BDataType,
91  BScaleDataType,
92  DsDataType,
93  CDataType,
94  ScaleBlockM,
95  ScaleBlockN,
96  ScaleBlockK,
97  AElementwiseOperation,
98  BElementwiseOperation,
99  CElementwiseOperation>
100 {
101  static constexpr index_t NumDTensor = DsDataType::Size();
103  ALayout,
104  BLayout,
105  DsLayout,
106  CLayout,
107  ADataType,
108  BDataType,
109  GemmAccDataType,
110  CShuffleDataType,
111  DsDataType,
112  CDataType,
113  AElementwiseOperation,
114  BElementwiseOperation,
115  CElementwiseOperation,
116  GemmSpec,
117  BlockSize,
118  ScaleBlockM,
119  ScaleBlockN,
120  ScaleBlockK,
121  MPerBlock,
122  NPerBlock,
123  KPerBlock,
124  AK1,
125  BK1,
126  MPerXDL,
127  NPerXDL,
128  MXdlPerWave,
129  NXdlPerWave,
130  ABlockTransferThreadClusterLengths_AK0_M_AK1,
131  ABlockTransferThreadClusterArrangeOrder,
132  ABlockTransferSrcAccessOrder,
133  ABlockTransferSrcVectorDim,
134  ABlockTransferSrcScalarPerVector,
135  ABlockTransferDstScalarPerVector_AK1,
136  false,
137  ABlockLdsExtraM,
138  BBlockTransferThreadClusterLengths_BK0_N_BK1,
139  BBlockTransferThreadClusterArrangeOrder,
140  BBlockTransferSrcAccessOrder,
141  BBlockTransferSrcVectorDim,
142  BBlockTransferSrcScalarPerVector,
143  BBlockTransferDstScalarPerVector_BK1,
144  false,
145  BBlockLdsExtraN,
146  CShuffleMXdlPerWavePerShuffle,
147  CShuffleNXdlPerWavePerShuffle,
148  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
149  CDEShuffleBlockTransferScalarPerVectors,
150  BlkGemmPipeSched,
151  BlkGemmPipelineVer,
152  ActivationOP,
153  NSwizzle,
154  IsInputGemm,
155  MulRoutedWeight,
156  IndexType,
157  ComputeTypeA,
158  ComputeTypeB,
159  LDSTypeA,
160  LDSTypeB>;
161 
163 
164  static constexpr index_t APackedSize = []() {
166  return 2;
167  else
168  return 1;
169  }();
170 
171  static constexpr index_t BPackedSize = []() {
173  return 2;
174  else
175  return 1;
176  }();
177 
178  int GetPreShuffleParameters() override { return NPerXDL; }
179 
180  // Invoker
181  struct Invoker : public BaseInvoker
182  {
183  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
184  {
185  if(stream_config.log_level_ > 0)
186  {
187  arg.Print();
188  }
189 
191  {
192  throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
193  }
194 
195  index_t gdx, gdy, gdz;
196  std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
197 
198  float ave_time = 0;
199 
200  index_t k_grain = arg.KBatch * KPerBlock;
201  index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
202 
203  const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
204  const auto RunKernel = [&](const auto& kernel) {
205  if(stream_config.flush_cache)
206  {
207 
208  std::array<std::size_t, NumDTensor> DsSize;
209 
210  Argument arg_ = arg;
211 
212  const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
213  arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
214  const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
215  arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
216 
217  auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
218  sizeof(ADataType) / APackedSize;
219  auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
220  sizeof(BDataType) / BPackedSize;
221 
222  const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
223  arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
224 
225  static_for<0, NumDTensor, 1>{}([&](auto i) {
226  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
227  DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
228  });
230  arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize);
231  rotating_mem.Print();
232 
233  auto run_flush_cache = [&]() {
234  // flush icache
236  // rotating mem
237  rotating_mem.Next();
238  // clear c mem
239  if(arg_.KBatch > 1)
240  hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
241  0,
242  arg_.M * arg_.N * sizeof(CDataType),
243  stream_config.stream_id_));
244  };
245 
246  ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
247  stream_config,
248  run_flush_cache,
249  kernel,
250  dim3(gdx, gdy, gdz),
251  dim3(BlockSize),
252  0,
253  arg_);
254  }
255  else
256  {
257  if(arg.KBatch > 1)
258  hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
259  0,
260  arg.M * arg.N * sizeof(CDataType),
261  stream_config.stream_id_));
262 
263  ave_time = launch_and_time_kernel(
264  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
265  }
266  };
267 
268  constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize /
269  4 * (1 + GridwiseGemm::NWave);
270  constexpr auto estimated_reg_b = NPerBlock * KPerBlock * sizeof(BDataType) / BlockSize /
271  4 * (2) * (IsInputGemm ? 2 : 1);
272  constexpr auto estimated_reg_c = MPerBlock * NPerBlock * sizeof(GemmAccDataType) /
273  BlockSize / 4 * (IsInputGemm ? 2 : 1);
274  constexpr auto estimated_reg_total =
275  estimated_reg_a + estimated_reg_b + estimated_reg_c;
276 
277  constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
278 
279  constexpr auto MemoryDataOp =
281 
282  if(has_main_k_block_loop)
283  {
284  // Tail number always full
285  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
286  {
287  {
289  {
290  const auto kernel = kernel_moe_gemm<GridwiseGemm,
291  true,
292  MemoryDataOp,
293  minimum_occupancy,
295  RunKernel(kernel);
296  }
297  else
298  {
299  const auto kernel = kernel_moe_gemm<GridwiseGemm,
300  true,
301  MemoryDataOp,
302  minimum_occupancy,
304  RunKernel(kernel);
305  }
306  }
307  }
308  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
309  BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
310  {
312  {
313  const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
314  true,
315  MemoryDataOp,
316  minimum_occupancy,
318  RunKernel(kernel);
319  }
320  else
321  {
322  const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
323  true,
324  MemoryDataOp,
325  minimum_occupancy,
327  RunKernel(kernel);
328  }
329  }
330  else
331  {
332  throw std::runtime_error("todo: only v1 & v2 support now");
333  }
334  }
335 #if 1
336  else
337  {
338  // Tail number always 1
339  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
340  {
342  {
343  const auto kernel = kernel_moe_gemm<GridwiseGemm,
344  false,
345  MemoryDataOp,
346  minimum_occupancy,
348  RunKernel(kernel);
349  }
350  else
351  {
352  const auto kernel = kernel_moe_gemm<GridwiseGemm,
353  false,
354  MemoryDataOp,
355  minimum_occupancy,
357  RunKernel(kernel);
358  }
359  }
360  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
361  BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
362  {
364  {
365  const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
366  false,
367  MemoryDataOp,
368  minimum_occupancy,
370  RunKernel(kernel);
371  }
372  else
373  {
374  const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
375  false,
376  MemoryDataOp,
377  minimum_occupancy,
379  RunKernel(kernel);
380  }
381  }
382  }
383 #endif
384 
385  return ave_time;
386  }
387 
388  // polymorphic
389  float Run(const BaseArgument* p_arg,
390  const StreamConfig& stream_config = StreamConfig{}) override
391  {
392  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
393  }
394  };
395 
396  static constexpr bool IsValidCompilationParameter()
397  {
398  // TODO: properly implement this check
399  return true;
400  }
401 
402  static bool IsSupportedArgument(const Argument& arg)
403  {
404  // only impl kbatch 1 now
405  if(arg.KBatch > 1)
406  {
407  return false;
408  }
409  if(!ck::is_xdl_supported())
410  {
411  return false;
412  }
413 
414  if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
415  {
416  return false;
417  }
418 
419  if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
420  GemmSpec == GemmSpecialization::NKPadding ||
421  GemmSpec == GemmSpecialization::MNKPadding ||
422  GemmSpec == GemmSpecialization::KPadding))
423  {
424  return false;
425  }
426  if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
427  {
428  return false;
429  }
430 
431  return GridwiseGemm::CheckValidity(arg);
432  }
433 
434  // polymorphic
435  bool IsSupportedArgument(const BaseArgument* p_arg) override
436  {
437  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
438  }
439 
440  static auto MakeArgument(const void* p_sorted_token_ids,
441  const void* p_sorted_expert_ids,
442  const void* p_max_token_id,
443  const void* p_a,
444  const void* p_b,
445  std::array<const void*, NumDTensor> p_ds,
446  void* p_c,
447  index_t NumTokens,
448  index_t TopK,
449  index_t M,
450  index_t N,
451  index_t K,
452  index_t StrideA,
453  index_t StrideB,
454  std::array<index_t, NumDTensor> StrideDs,
455  index_t StrideC,
456  const void* p_a_scale,
457  const void* p_b_scale,
458  index_t KBatch,
459  AElementwiseOperation a_element_op,
460  BElementwiseOperation b_element_op,
461  CElementwiseOperation c_element_op)
462  {
463  return Argument{static_cast<const index_t*>(p_sorted_token_ids),
464  static_cast<const index_t*>(p_sorted_expert_ids),
465  static_cast<const index_t*>(p_max_token_id),
466  static_cast<const ADataType*>(p_a),
467  static_cast<const BDataType*>(p_b),
468  p_ds,
469  static_cast<CDataType*>(p_c),
470  NumTokens,
471  TopK,
472  M,
473  N,
474  K,
475  StrideA,
476  StrideB,
477  StrideDs,
478  StrideC,
479  static_cast<const AScaleDataType*>(p_a_scale),
480  static_cast<const BScaleDataType*>(p_b_scale),
481  KBatch,
482  a_element_op,
483  b_element_op,
484  c_element_op};
485  }
486 
487  static auto MakeInvoker() { return Invoker{}; }
488 
489  // polymorphic
490  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
491  const void* p_b,
492  std::array<const void*, NumDTensor> p_ds,
493  void* p_c,
494  index_t M,
495  index_t N,
496  index_t K,
497  index_t StrideA,
498  index_t StrideB,
499  std::array<ck::index_t, NumDTensor> StrideDs,
500  index_t StrideC,
501  const void* p_a_scale,
502  const void* p_b_scale,
503  // index_t KBatch,
504  AElementwiseOperation a_element_op,
505  BElementwiseOperation b_element_op,
506  CElementwiseOperation c_element_op) override
507  {
508  return std::make_unique<Argument>(nullptr,
509  nullptr,
510  nullptr,
511  static_cast<const ADataType*>(p_a),
512  static_cast<const BDataType*>(p_b),
513  p_ds,
514  static_cast<CDataType*>(p_c),
515  M, // randoms set, no use
516  0,
517  M,
518  N,
519  K,
520  StrideA,
521  StrideB,
522  StrideDs,
523  StrideC,
524  static_cast<const AScaleDataType*>(p_a_scale),
525  static_cast<const BScaleDataType*>(p_b_scale),
526  1, // KBatch,
527  a_element_op,
528  b_element_op,
529  c_element_op);
530  }
531 
532  // polymorphic
533  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
534  {
535  return std::make_unique<Invoker>(Invoker{});
536  }
537 
538  // polymorphic
539  std::string GetTypeString() const override
540  {
541  auto str = std::stringstream();
542 
543  std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
546 
547  std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
551 
552  // clang-format off
553  str << "DeviceMoeGEmm"
554  << "<"
555  << getGemmSpecializationString(GemmSpec) << ", "
556  << std::string(ALayout::name)[0]
557  << std::string(BLayout::name)[0]
558  << std::string(CLayout::name)[0]
559  << ">"
560  << " BlkSize: "
561  << BlockSize << ", "
562  << "BlkTile: "
563  << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
564  << "WaveTile: "
565  << MPerXDL<<"x"<<NPerXDL << ", "
566  << "WaveMap: "
567  << MXdlPerWave<<"x" << NXdlPerWave<<", "
568  << "VmemReadVec: "
569  << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
570  << "BlkGemmPipelineScheduler: "
571  << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
572  << "BlkGemmPipelineVersion: "
573  << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
574  << "BlkGemmPipelinePrefetchStages: "
575  << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
576  // clang-format on
577 
578  return str.str();
579  }
580 };
581 
582 } // namespace device
583 } // namespace tensor_operation
584 } // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
void flush_icache()
Definition: flush_cache.hpp:216
Definition: ck.hpp:267
bool is_xdl_supported()
Definition: device_prop.hpp:68
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__global__ void kernel_moe_gemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:46
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
constexpr bool is_same_v
Definition: type.hpp:283
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:298
__global__ void kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:81
bool is_bf16_atomic_supported()
Definition: device_prop.hpp:85
Definition: stream_config.hpp:10
Definition: gridwise_moe_gemm_blockscale.hpp:660
Definition: gridwise_moe_gemm_blockscale.hpp:171
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:1133
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_gemm_blockscale.hpp:960
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_gemm_blockscale.hpp:329
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_gemm_blockscale.hpp:569
__host__ static constexpr __device__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:1140
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:246
static constexpr index_t NWave
Definition: gridwise_moe_gemm_blockscale.hpp:213
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_gemm_blockscale.hpp:421
Definition: data_type.hpp:186
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: device_base.hpp:62
Definition: device_gemm_multiple_d_ab_scale.hpp:80
Definition: device_moe_gemm_blockscale.hpp:182
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_moe_gemm_blockscale.hpp:389
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_moe_gemm_blockscale.hpp:183
Definition: device_moe_gemm_blockscale.hpp:100
static constexpr index_t BPackedSize
Definition: device_moe_gemm_blockscale.hpp:171
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_moe_gemm_blockscale.hpp:533
static bool IsSupportedArgument(const Argument &arg)
Definition: device_moe_gemm_blockscale.hpp:402
typename GridwiseGemm::Argument Argument
Definition: device_moe_gemm_blockscale.hpp:162
static constexpr index_t APackedSize
Definition: device_moe_gemm_blockscale.hpp:164
int GetPreShuffleParameters() override
Definition: device_moe_gemm_blockscale.hpp:178
std::string GetTypeString() const override
Definition: device_moe_gemm_blockscale.hpp:539
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, const void *p_a_scale, const void *p_b_scale, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition: device_moe_gemm_blockscale.hpp:490
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_moe_gemm_blockscale.hpp:435
GridwiseMoeGemmBlockScale< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockM, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ActivationOP, NSwizzle, IsInputGemm, MulRoutedWeight, IndexType, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB > GridwiseGemm
Definition: device_moe_gemm_blockscale.hpp:160
static auto MakeArgument(const void *p_sorted_token_ids, const void *p_sorted_expert_ids, const void *p_max_token_id, const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t NumTokens, index_t TopK, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideC, const void *p_a_scale, const void *p_b_scale, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: device_moe_gemm_blockscale.hpp:440
static constexpr bool IsValidCompilationParameter()
Definition: device_moe_gemm_blockscale.hpp:396
static constexpr index_t NumDTensor
Definition: device_moe_gemm_blockscale.hpp:101
static auto MakeInvoker()
Definition: device_moe_gemm_blockscale.hpp:487
Definition: flush_cache.hpp:20