/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp Source File
device_moe_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
19 
20 namespace ck {
21 namespace tensor_operation {
22 namespace device {
23 
24 template <typename ALayout,
25  typename BLayout,
26  typename DsLayout,
27  typename CLayout,
28  typename ADataType,
29  typename BDataType,
30  typename DsDataType,
31  typename CDataType,
32  typename GemmAccDataType,
33  typename CShuffleDataType,
34  typename AElementwiseOperation,
35  typename BElementwiseOperation,
36  typename CElementwiseOperation,
37  GemmSpecialization GemmSpec,
38  index_t BlockSize,
39  index_t MPerBlock,
40  index_t NPerBlock,
41  index_t KPerBlock,
42  index_t AK1,
43  index_t BK1,
44  index_t MPerXDL,
45  index_t NPerXDL,
46  index_t MXdlPerWave,
47  index_t NXdlPerWave,
48  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
49  typename ABlockTransferThreadClusterArrangeOrder,
50  typename ABlockTransferSrcAccessOrder,
51  index_t ABlockTransferSrcVectorDim,
52  index_t ABlockTransferSrcScalarPerVector,
53  index_t ABlockTransferDstScalarPerVector_AK1,
54  bool ABlockLdsExtraM,
55  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
56  typename BBlockTransferThreadClusterArrangeOrder,
57  typename BBlockTransferSrcAccessOrder,
58  index_t BBlockTransferSrcVectorDim,
59  index_t BBlockTransferSrcScalarPerVector,
60  index_t BBlockTransferDstScalarPerVector_BK1,
61  bool BBlockLdsExtraN,
62  index_t CShuffleMXdlPerWavePerShuffle,
63  index_t CShuffleNXdlPerWavePerShuffle,
64  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
65  typename CDEShuffleBlockTransferScalarPerVectors,
68  index_t ActivationOP = 0,
69  bool NSwizzle = false,
70  bool IsInputGemm = true,
71  bool MulRoutedWeight = true,
72  bool PerTokenQuant = true,
73  typename IndexType = index_t,
74  typename ComputeTypeA = CDataType,
75  typename ComputeTypeB = ComputeTypeA,
76  typename LDSTypeA = ComputeTypeA,
77  typename LDSTypeB = ComputeTypeB>
79  BLayout,
80  DsLayout,
81  CLayout,
82  ADataType,
83  BDataType,
84  DsDataType,
85  CDataType,
86  AElementwiseOperation,
87  BElementwiseOperation,
88  CElementwiseOperation>
89 {
90  static constexpr index_t NumDTensor = DsDataType::Size();
91  using GridwiseGemm =
92  GridwiseMoeGemm<ALayout,
93  BLayout,
94  DsLayout,
95  CLayout,
96  ADataType,
97  BDataType,
98  GemmAccDataType,
99  CShuffleDataType,
100  DsDataType,
101  CDataType,
102  AElementwiseOperation,
103  BElementwiseOperation,
104  CElementwiseOperation,
105  GemmSpec,
106  BlockSize,
107  MPerBlock,
108  NPerBlock,
109  KPerBlock,
110  AK1,
111  BK1,
112  MPerXDL,
113  NPerXDL,
114  MXdlPerWave,
115  NXdlPerWave,
116  ABlockTransferThreadClusterLengths_AK0_M_AK1,
117  ABlockTransferThreadClusterArrangeOrder,
118  ABlockTransferSrcAccessOrder,
119  ABlockTransferSrcVectorDim,
120  ABlockTransferSrcScalarPerVector,
121  ABlockTransferDstScalarPerVector_AK1,
122  false,
123  ABlockLdsExtraM,
124  BBlockTransferThreadClusterLengths_BK0_N_BK1,
125  BBlockTransferThreadClusterArrangeOrder,
126  BBlockTransferSrcAccessOrder,
127  BBlockTransferSrcVectorDim,
128  BBlockTransferSrcScalarPerVector,
129  BBlockTransferDstScalarPerVector_BK1,
130  false,
131  BBlockLdsExtraN,
132  CShuffleMXdlPerWavePerShuffle,
133  CShuffleNXdlPerWavePerShuffle,
134  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
135  CDEShuffleBlockTransferScalarPerVectors,
136  BlkGemmPipeSched,
137  BlkGemmPipelineVer,
138  ActivationOP,
139  NSwizzle,
140  IsInputGemm,
141  MulRoutedWeight,
142  PerTokenQuant,
143  IndexType,
144  ComputeTypeA,
145  ComputeTypeB,
146  LDSTypeA,
147  LDSTypeB>;
148 
150 
151  static constexpr index_t APackedSize = []() {
153  return 2;
154  else
155  return 1;
156  }();
157 
158  static constexpr index_t BPackedSize = []() {
160  return 2;
161  else
162  return 1;
163  }();
164 
165  int GetPreShuffleParameters() override { return NPerXDL; }
166 
167  // Invoker
168  struct Invoker : public BaseInvoker
169  {
170  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
171  {
172  if(stream_config.log_level_ > 0)
173  {
174  arg.Print();
175  }
176 
178  {
179  throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
180  }
181 
182  index_t gdx, gdy, gdz;
183  std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
184 
185  float ave_time = 0;
186 
187  index_t k_grain = arg.KBatch * KPerBlock;
188  index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
189 
190  const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
191 
192  const auto RunKernel = [&](const auto& kernel) {
193  if(stream_config.flush_cache)
194  {
195 
196  std::array<std::size_t, NumDTensor> DsSize;
197 
198  Argument arg_ = arg;
199 
200  const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
201  arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
202  const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
203  arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
204 
205  auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
206  sizeof(ADataType) / APackedSize;
207  auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
208  sizeof(BDataType) / BPackedSize;
209 
210  const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
211  arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
212 
213  static_for<0, NumDTensor, 1>{}([&](auto i) {
214  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
215  DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
216  });
218  arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize);
219  rotating_mem.Print();
220 
221  auto run_flush_cache = [&]() {
222  // flush icache
224  // rotating mem
225  rotating_mem.Next();
226  // clear c mem
227  if(arg_.KBatch > 1)
228  hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
229  0,
230  arg_.M * arg_.N * sizeof(CDataType),
231  stream_config.stream_id_));
232  };
233 
234  ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
235  stream_config,
236  run_flush_cache,
237  kernel,
238  dim3(gdx, gdy, gdz),
239  dim3(BlockSize),
240  0,
241  arg_);
242  }
243  else
244  {
245  if(arg.KBatch > 1)
246  hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
247  0,
248  arg.M * arg.N * sizeof(CDataType),
249  stream_config.stream_id_));
250 
251  ave_time = launch_and_time_kernel(
252  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
253  }
254  };
255 
256  constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize /
257  4 * (1 + GridwiseGemm::NWave);
258  constexpr auto estimated_reg_b = NPerBlock * KPerBlock * sizeof(BDataType) / BlockSize /
259  4 * (2) * (IsInputGemm ? 2 : 1);
260  constexpr auto estimated_reg_c = MPerBlock * NPerBlock * sizeof(GemmAccDataType) /
261  BlockSize / 4 * (IsInputGemm ? 2 : 1);
262  constexpr auto estimated_reg_total =
263  estimated_reg_a + estimated_reg_b + estimated_reg_c;
264 
265  constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
266 
267  constexpr auto MemoryDataOp =
269  if(has_main_k_block_loop)
270  {
271  // Tail number always full
272  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
273  {
274  {
276  {
277  const auto kernel = kernel_moe_gemm<GridwiseGemm,
278  true,
279  MemoryDataOp,
280  minimum_occupancy,
282  RunKernel(kernel);
283  }
284  else
285  {
286  const auto kernel = kernel_moe_gemm<GridwiseGemm,
287  true,
288  MemoryDataOp,
289  minimum_occupancy,
291  RunKernel(kernel);
292  }
293  }
294  }
295  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
296  BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
297  {
299  {
300  const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
301  true,
302  MemoryDataOp,
303  minimum_occupancy,
305  RunKernel(kernel);
306  }
307  else
308  {
309  const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
310  true,
311  MemoryDataOp,
312  minimum_occupancy,
314  RunKernel(kernel);
315  }
316  }
317  else
318  {
319  throw std::runtime_error("todo: only v1 & v2 support now");
320  }
321  }
322 #if 1
323  else
324  {
325  // Tail number always 1
326  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
327  {
329  {
330  const auto kernel = kernel_moe_gemm<GridwiseGemm,
331  false,
332  MemoryDataOp,
333  minimum_occupancy,
335  RunKernel(kernel);
336  }
337  else
338  {
339  const auto kernel = kernel_moe_gemm<GridwiseGemm,
340  false,
341  MemoryDataOp,
342  minimum_occupancy,
344  RunKernel(kernel);
345  }
346  }
347  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
348  BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
349  {
351  {
352  const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
353  false,
354  MemoryDataOp,
355  minimum_occupancy,
357  RunKernel(kernel);
358  }
359  else
360  {
361  const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
362  false,
363  MemoryDataOp,
364  minimum_occupancy,
366  RunKernel(kernel);
367  }
368  }
369  else
370  {
371  throw std::runtime_error("todo: only v1 & v2 support now");
372  }
373  }
374 #endif
375 
376  return ave_time;
377  }
378 
379  // polymorphic
380  float Run(const BaseArgument* p_arg,
381  const StreamConfig& stream_config = StreamConfig{}) override
382  {
383  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
384  }
385  };
386 
387  static constexpr bool IsValidCompilationParameter()
388  {
389  // TODO: properly implement this check
390  return true;
391  }
392 
393  static bool IsSupportedArgument(const Argument& arg)
394  {
395  // only impl kbatch 1 now
396  if(arg.KBatch > 1)
397  {
398  return false;
399  }
400  if(!ck::is_xdl_supported())
401  {
402  return false;
403  }
404 
405  if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
406  {
407  return false;
408  }
409 
410  if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
411  GemmSpec == GemmSpecialization::NKPadding ||
412  GemmSpec == GemmSpecialization::MNKPadding ||
413  GemmSpec == GemmSpecialization::KPadding))
414  {
415  return false;
416  }
417  if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
418  {
419  return false;
420  }
421 
422  return GridwiseGemm::CheckValidity(arg);
423  }
424 
425  // polymorphic
426  bool IsSupportedArgument(const BaseArgument* p_arg) override
427  {
428  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
429  }
430 
431  static auto MakeArgument(const void* p_sorted_token_ids,
432  const void* p_sorted_expert_ids,
433  const void* p_max_token_id,
434  const void* p_a,
435  const void* p_b,
436  std::array<const void*, NumDTensor> p_ds,
437  void* p_c,
438  index_t NumTokens,
439  index_t TopK,
440  index_t M,
441  index_t N,
442  index_t K,
443  index_t StrideA,
444  index_t StrideB,
445  std::array<index_t, NumDTensor> StrideDs,
446  index_t StrideC,
447  index_t KBatch,
448  AElementwiseOperation a_element_op,
449  BElementwiseOperation b_element_op,
450  CElementwiseOperation c_element_op)
451  {
452  return Argument{static_cast<const index_t*>(p_sorted_token_ids),
453  static_cast<const index_t*>(p_sorted_expert_ids),
454  static_cast<const index_t*>(p_max_token_id),
455  static_cast<const ADataType*>(p_a),
456  static_cast<const BDataType*>(p_b),
457  p_ds,
458  static_cast<CDataType*>(p_c),
459  NumTokens,
460  TopK,
461  M,
462  N,
463  K,
464  StrideA,
465  StrideB,
466  StrideDs,
467  StrideC,
468  KBatch,
469  a_element_op,
470  b_element_op,
471  c_element_op};
472  }
473 
474  static auto MakeInvoker() { return Invoker{}; }
475 
476  // polymorphic
477  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
478  const void* p_b,
479  std::array<const void*, NumDTensor> p_ds,
480  void* p_c,
481  index_t M,
482  index_t N,
483  index_t K,
484  index_t StrideA,
485  index_t StrideB,
486  std::array<ck::index_t, NumDTensor> StrideDs,
487  index_t StrideC,
488  index_t KBatch,
489  AElementwiseOperation a_element_op,
490  BElementwiseOperation b_element_op,
491  CElementwiseOperation c_element_op) override
492  {
493  return std::make_unique<Argument>(nullptr,
494  nullptr,
495  nullptr,
496  static_cast<const ADataType*>(p_a),
497  static_cast<const BDataType*>(p_b),
498  p_ds,
499  static_cast<CDataType*>(p_c),
500  M, // randoms set, no use
501  0,
502  M,
503  N,
504  K,
505  StrideA,
506  StrideB,
507  StrideDs,
508  StrideC,
509  KBatch,
510  a_element_op,
511  b_element_op,
512  c_element_op);
513  }
514 
515  // polymorphic
516  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
517  {
518  return std::make_unique<Invoker>(Invoker{});
519  }
520 
521  // polymorphic
522  std::string GetTypeString() const override
523  {
524  auto str = std::stringstream();
525 
526  std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
529 
530  std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
532 
533  // clang-format off
534  str << "DeviceMoeGEmm"
535  << "<"
536  << getGemmSpecializationString(GemmSpec) << ", "
537  << std::string(ALayout::name)[0]
538  << std::string(BLayout::name)[0]
539  << std::string(CLayout::name)[0]
540  << ">"
541  << " BlkSize: "
542  << BlockSize << ", "
543  << "BlkTile: "
544  << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
545  << "WaveTile: "
546  << MPerXDL<<"x"<<NPerXDL << ", "
547  << "WaveMap: "
548  << MXdlPerWave<<"x" << NXdlPerWave<<", "
549  << "VmemReadVec: "
550  << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
551  << "BlkGemmPipelineScheduler: "
552  << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
553  << "BlkGemmPipelineVersion: "
554  << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
555  << "BlkGemmPipelinePrefetchStages: "
556  << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
557  // clang-format on
558 
559  return str.str();
560  }
561 };
562 
563 } // namespace device
564 } // namespace tensor_operation
565 } // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
void flush_icache()
Definition: flush_cache.hpp:216
Definition: ck.hpp:267
bool is_xdl_supported()
Definition: device_prop.hpp:68
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__global__ void kernel_moe_gemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:46
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
constexpr bool is_same_v
Definition: type.hpp:283
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:298
__global__ void kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:81
bool is_bf16_atomic_supported()
Definition: device_prop.hpp:85
Definition: stream_config.hpp:10
Definition: gridwise_moe_gemm.hpp:653
Definition: gridwise_moe_gemm.hpp:165
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_gemm.hpp:240
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_gemm.hpp:324
__host__ static constexpr __device__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_gemm.hpp:1124
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_gemm.hpp:416
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_gemm.hpp:1117
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_gemm.hpp:564
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_gemm.hpp:944
static constexpr index_t NWave
Definition: gridwise_moe_gemm.hpp:207
Definition: data_type.hpp:186
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: device_base.hpp:62
Definition: device_moe_gemm.hpp:169
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_moe_gemm.hpp:170
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_moe_gemm.hpp:380
Definition: device_moe_gemm.hpp:89
int GetPreShuffleParameters() override
Definition: device_moe_gemm.hpp:165
typename GridwiseGemm::Argument Argument
Definition: device_moe_gemm.hpp:149
GridwiseMoeGemm< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ActivationOP, NSwizzle, IsInputGemm, MulRoutedWeight, PerTokenQuant, IndexType, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB > GridwiseGemm
Definition: device_moe_gemm.hpp:147
static auto MakeArgument(const void *p_sorted_token_ids, const void *p_sorted_expert_ids, const void *p_max_token_id, const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t NumTokens, index_t TopK, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: device_moe_gemm.hpp:431
static constexpr bool IsValidCompilationParameter()
Definition: device_moe_gemm.hpp:387
static constexpr index_t BPackedSize
Definition: device_moe_gemm.hpp:158
static bool IsSupportedArgument(const Argument &arg)
Definition: device_moe_gemm.hpp:393
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_moe_gemm.hpp:426
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition: device_moe_gemm.hpp:477
static constexpr index_t NumDTensor
Definition: device_moe_gemm.hpp:90
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_moe_gemm.hpp:516
std::string GetTypeString() const override
Definition: device_moe_gemm.hpp:522
static auto MakeInvoker()
Definition: device_moe_gemm.hpp:474
static constexpr index_t APackedSize
Definition: device_moe_gemm.hpp:151
Definition: flush_cache.hpp:20