/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp Source File
device_moe_mx_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
19 
20 namespace ck {
21 namespace tensor_operation {
22 namespace device {
23 
24 template <typename ALayout,
25  typename BLayout,
26  typename DsLayout,
27  typename CLayout,
28  typename ADataType,
29  typename AScaleDataType,
30  typename BDataType,
31  typename BScaleDataType,
32  typename DsDataType,
33  typename CDataType,
34  typename GemmAccDataType,
35  typename CShuffleDataType,
36  typename AElementwiseOperation,
37  typename BElementwiseOperation,
38  typename CElementwiseOperation,
39  GemmSpecialization GemmSpec,
40  index_t ScaleBlockSize,
41  index_t BlockSize,
42  index_t MPerBlock,
43  index_t NPerBlock,
44  index_t KPerBlock,
45  index_t AK1,
46  index_t BK1,
47  index_t MPerXDL,
48  index_t NPerXDL,
49  index_t MXdlPerWave,
50  index_t NXdlPerWave,
51  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
52  typename ABlockTransferThreadClusterArrangeOrder,
53  typename ABlockTransferSrcAccessOrder,
54  index_t ABlockTransferSrcVectorDim,
55  index_t ABlockTransferSrcScalarPerVector,
56  index_t ABlockTransferDstScalarPerVector_AK1,
57  bool ABlockLdsExtraM,
58  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
59  typename BBlockTransferThreadClusterArrangeOrder,
60  typename BBlockTransferSrcAccessOrder,
61  index_t BBlockTransferSrcVectorDim,
62  index_t BBlockTransferSrcScalarPerVector,
63  index_t BBlockTransferDstScalarPerVector_BK1,
64  bool BBlockLdsExtraN,
65  index_t CShuffleMXdlPerWavePerShuffle,
66  index_t CShuffleNXdlPerWavePerShuffle,
67  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
68  typename CDEShuffleBlockTransferScalarPerVectors,
71  index_t ActivationOP = 0,
72  bool NSwizzle = false,
73  bool IsInputGemm = true,
74  bool MulRoutedWeight = true,
75  typename IndexType = index_t,
76  typename ComputeTypeA = ADataType,
77  typename ComputeTypeB = BDataType>
79  BLayout,
80  DsLayout,
81  CLayout,
82  ADataType,
83  AScaleDataType,
84  BDataType,
85  BScaleDataType,
86  DsDataType,
87  CDataType,
88  ScaleBlockSize,
89  AElementwiseOperation,
90  BElementwiseOperation,
91  CElementwiseOperation>
92 {
93  static constexpr index_t NumDTensor = DsDataType::Size();
94  using GridwiseGemm =
95  GridwiseMoeGemmMX<ALayout,
96  BLayout,
97  DsLayout,
98  CLayout,
99  ADataType,
100  AScaleDataType,
101  BDataType,
102  BScaleDataType,
103  GemmAccDataType,
104  CShuffleDataType,
105  DsDataType,
106  CDataType,
107  AElementwiseOperation,
108  BElementwiseOperation,
109  CElementwiseOperation,
110  GemmSpec,
111  ScaleBlockSize,
112  BlockSize,
113  MPerBlock,
114  NPerBlock,
115  KPerBlock,
116  AK1,
117  BK1,
118  MPerXDL,
119  NPerXDL,
120  MXdlPerWave,
121  NXdlPerWave,
122  ABlockTransferThreadClusterLengths_AK0_M_AK1,
123  ABlockTransferThreadClusterArrangeOrder,
124  ABlockTransferSrcAccessOrder,
125  ABlockTransferSrcVectorDim,
126  ABlockTransferSrcScalarPerVector,
127  ABlockTransferDstScalarPerVector_AK1,
128  false,
129  ABlockLdsExtraM,
130  BBlockTransferThreadClusterLengths_BK0_N_BK1,
131  BBlockTransferThreadClusterArrangeOrder,
132  BBlockTransferSrcAccessOrder,
133  BBlockTransferSrcVectorDim,
134  BBlockTransferSrcScalarPerVector,
135  BBlockTransferDstScalarPerVector_BK1,
136  false,
137  BBlockLdsExtraN,
138  CShuffleMXdlPerWavePerShuffle,
139  CShuffleNXdlPerWavePerShuffle,
140  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
141  CDEShuffleBlockTransferScalarPerVectors,
142  BlkGemmPipeSched,
143  BlkGemmPipelineVer,
144  ActivationOP,
145  NSwizzle,
146  IsInputGemm,
147  MulRoutedWeight,
148  IndexType,
149  ComputeTypeA,
150  ComputeTypeB>;
151 
153 
154  static constexpr index_t APackedSize = packed_size_v<ADataType>;
155  static constexpr index_t BPackedSize = packed_size_v<BDataType>;
156 
157  int GetPreShuffleParameters() override { return NPerXDL; }
158 
159  // Invoker
160  struct Invoker : public BaseInvoker
161  {
162  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
163  {
164  if(stream_config.log_level_ > 0)
165  {
166  arg.Print();
167  }
168 
170  {
171  throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
172  }
173 
174  index_t gdx, gdy, gdz;
175  std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
176 
177  float ave_time = 0;
178 
179  index_t k_grain = arg.KBatch * KPerBlock;
180  index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
181 
182  const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
183 
184  const auto RunKernel = [&](const auto& kernel) {
185  if(stream_config.flush_cache)
186  {
187 
188  std::array<std::size_t, NumDTensor> DsSize;
189 
190  Argument arg_ = arg;
191 
192  const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
193  arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
194  const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
195  arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
196 
197  auto size_a_buffer =
198  a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
199  auto size_b_buffer =
200  b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
201 
202  const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
203  arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
204 
205  static_for<0, NumDTensor, 1>{}([&](auto i) {
206  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
207  DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
208  });
210  arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize);
211  rotating_mem.Print();
212 
213  auto run_flush_cache = [&]() {
214  // flush icache
216  // rotating mem
217  rotating_mem.Next();
218  // clear c mem
219  if(arg_.KBatch > 1)
220  hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
221  0,
222  arg_.M * arg_.N * sizeof(CDataType),
223  stream_config.stream_id_));
224  };
225 
226  ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
227  stream_config,
228  run_flush_cache,
229  kernel,
230  dim3(gdx, gdy, gdz),
231  dim3(BlockSize),
232  0,
233  arg_);
234  }
235  else
236  {
237  if(arg.KBatch > 1)
238  hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
239  0,
240  arg.M * arg.N * sizeof(CDataType),
241  stream_config.stream_id_));
242 
243  ave_time = launch_and_time_kernel(
244  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
245  }
246  };
247 
248  // TODO: Check if this is the right algorithm for minimum_occupancy
249  constexpr index_t minimum_occupancy =
250  BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave
251  ? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
252  MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2)
253  ? 2
254  : 1
255  : 2;
256 
257  constexpr auto MemoryDataOp =
259 
260  if(has_main_k_block_loop)
261  {
262  // Tail number always full
263  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
264  {
265  const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
266  true,
267  MemoryDataOp,
268  minimum_occupancy,
270  RunKernel(kernel);
271  }
272  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
273  {
275  {
276  const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
277  true,
278  MemoryDataOp,
279  minimum_occupancy,
281  RunKernel(kernel);
282  }
283  else
284  {
285  const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
286  true,
287  MemoryDataOp,
288  minimum_occupancy,
290  RunKernel(kernel);
291  }
292  }
293  else
294  {
295  throw std::runtime_error("todo: only v1 & v3 support now");
296  }
297  }
298  else
299  {
300  // Tail number always full
301  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
302  {
303  const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
304  false,
305  MemoryDataOp,
306  minimum_occupancy,
308  RunKernel(kernel);
309  }
310  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
311  {
313  {
314  const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
315  false,
316  MemoryDataOp,
317  minimum_occupancy,
319  RunKernel(kernel);
320  }
321  else
322  {
323  const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
324  false,
325  MemoryDataOp,
326  minimum_occupancy,
328  RunKernel(kernel);
329  }
330  }
331  }
332 
333  return ave_time;
334  }
335 
336  // polymorphic
337  float Run(const BaseArgument* p_arg,
338  const StreamConfig& stream_config = StreamConfig{}) override
339  {
340  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
341  }
342  };
343 
344  static constexpr bool IsValidCompilationParameter()
345  {
346  // TODO: properly implement this check
347  return true;
348  }
349 
350  static bool IsSupportedArgument(const Argument& arg)
351  {
352  // only impl kbatch 1 now
353  if(arg.KBatch > 1)
354  {
355  return false;
356  }
357  if(!ck::is_xdl_supported())
358  {
359  return false;
360  }
361 
362  if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
363  {
364  return false;
365  }
366 
367  if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
368  GemmSpec == GemmSpecialization::NKPadding ||
369  GemmSpec == GemmSpecialization::MNKPadding ||
370  GemmSpec == GemmSpecialization::KPadding))
371  {
372  return false;
373  }
374  if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
375  {
376  return false;
377  }
378 
379  return GridwiseGemm::CheckValidity(arg);
380  }
381 
382  // polymorphic
383  bool IsSupportedArgument(const BaseArgument* p_arg) override
384  {
385  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
386  }
387 
388  static auto MakeArgument(const void* p_sorted_token_ids,
389  const void* p_sorted_expert_ids,
390  const void* p_max_token_id,
391  const void* p_a,
392  const void* p_a_scale,
393  const void* p_b,
394  const void* p_b_scale,
395  std::array<const void*, NumDTensor> p_ds,
396  void* p_c,
397  index_t NumTokens,
398  index_t TopK,
399  index_t M,
400  index_t N,
401  index_t K,
402  index_t StrideA,
403  index_t StrideScaleA,
404  index_t StrideB,
405  index_t StrideScaleB,
406  std::array<index_t, NumDTensor> StrideDs,
407  index_t StrideC,
408  index_t KBatch,
409  AElementwiseOperation a_element_op,
410  BElementwiseOperation b_element_op,
411  CElementwiseOperation c_element_op)
412  {
413  return Argument{static_cast<const index_t*>(p_sorted_token_ids),
414  static_cast<const index_t*>(p_sorted_expert_ids),
415  static_cast<const index_t*>(p_max_token_id),
416  static_cast<const ADataType*>(p_a),
417  static_cast<const AScaleDataType*>(p_a_scale),
418  static_cast<const BDataType*>(p_b),
419  static_cast<const BScaleDataType*>(p_b_scale),
420  p_ds,
421  static_cast<CDataType*>(p_c),
422  NumTokens,
423  TopK,
424  M,
425  N,
426  K,
427  StrideA,
428  StrideScaleA,
429  StrideB,
430  StrideScaleB,
431  StrideDs,
432  StrideC,
433  KBatch,
434  a_element_op,
435  b_element_op,
436  c_element_op};
437  }
438 
439  static auto MakeInvoker() { return Invoker{}; }
440 
441  // polymorphic
442  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
443  const void* p_a_scale,
444  const void* p_b,
445  const void* p_b_scale,
446  std::array<const void*, NumDTensor> p_ds,
447  void* p_c,
448  index_t M,
449  index_t N,
450  index_t K,
451  index_t StrideA,
452  index_t StrideScaleA,
453  index_t StrideB,
454  index_t StrideScaleB,
455  std::array<ck::index_t, NumDTensor> StrideDs,
456  index_t StrideC,
457  index_t KBatch,
458  AElementwiseOperation a_element_op,
459  BElementwiseOperation b_element_op,
460  CElementwiseOperation c_element_op) override
461  {
462  return std::make_unique<Argument>(nullptr,
463  nullptr,
464  nullptr,
465  static_cast<const ADataType*>(p_a),
466  static_cast<const AScaleDataType*>(p_a_scale),
467  static_cast<const BDataType*>(p_b),
468  static_cast<const BScaleDataType*>(p_b_scale),
469  p_ds,
470  static_cast<CDataType*>(p_c),
471  M, // randoms set, no use
472  0,
473  M,
474  N,
475  K,
476  StrideA,
477  StrideScaleA,
478  StrideB,
479  StrideScaleB,
480  StrideDs,
481  StrideC,
482  KBatch,
483  a_element_op,
484  b_element_op,
485  c_element_op);
486  }
487 
488  // polymorphic
489  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
490  {
491  return std::make_unique<Invoker>(Invoker{});
492  }
493 
494  // polymorphic
495  std::string GetTypeString() const override
496  {
497  auto str = std::stringstream();
498 
499  std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
502 
503  std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
509 
510  // clang-format off
511  str << "DeviceMoeGEmmMx"
512  << "<"
513  << getGemmSpecializationString(GemmSpec) << ", "
514  << std::string(ALayout::name)[0]
515  << std::string(BLayout::name)[0]
516  << std::string(CLayout::name)[0]
517  << ">"
518  << " BlkSize: "
519  << BlockSize << ", "
520  << "BlkTile: "
521  << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
522  << "WaveTile: "
523  << MPerXDL<<"x"<<NPerXDL << ", "
524  << "WaveMap: "
525  << MXdlPerWave<<"x" << NXdlPerWave<<", "
526  << "VmemReadVec: "
527  << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
528  << "BlkGemmPipelineScheduler: "
529  << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
530  << "BlkGemmPipelineVersion: "
531  << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
532  << "BlkGemmPipelinePrefetchStages: "
533  << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
534  // clang-format on
535 
536  return str.str();
537  }
538 };
539 
540 } // namespace device
541 } // namespace tensor_operation
542 } // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
void flush_icache()
Definition: flush_cache.hpp:216
Definition: ck.hpp:267
bool is_xdl_supported()
Definition: device_prop.hpp:68
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__global__ void kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm.hpp:87
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:298
bool is_bf16_atomic_supported()
Definition: device_prop.hpp:85
Definition: stream_config.hpp:10
Definition: gridwise_moe_mx_gemm.hpp:715
Definition: gridwise_moe_mx_gemm.hpp:173
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_mx_gemm.hpp:242
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_mx_gemm.hpp:1171
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_mx_gemm.hpp:619
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_mx_gemm.hpp:331
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_mx_gemm.hpp:1356
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_mx_gemm.hpp:446
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_mx_gemm.hpp:1349
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: device_base.hpp:62
Definition: device_gemm_multiple_d.hpp:167
Definition: device_moe_mx_gemm.hpp:161
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_moe_mx_gemm.hpp:162
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_moe_mx_gemm.hpp:337
Definition: device_moe_mx_gemm.hpp:92
typename GridwiseGemm::Argument Argument
Definition: device_moe_mx_gemm.hpp:152
std::string GetTypeString() const override
Definition: device_moe_mx_gemm.hpp:495
static constexpr bool IsValidCompilationParameter()
Definition: device_moe_mx_gemm.hpp:344
static constexpr index_t APackedSize
Definition: device_moe_mx_gemm.hpp:154
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_moe_mx_gemm.hpp:489
static bool IsSupportedArgument(const Argument &arg)
Definition: device_moe_mx_gemm.hpp:350
static constexpr index_t NumDTensor
Definition: device_moe_mx_gemm.hpp:93
static auto MakeArgument(const void *p_sorted_token_ids, const void *p_sorted_expert_ids, const void *p_max_token_id, const void *p_a, const void *p_a_scale, const void *p_b, const void *p_b_scale, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t NumTokens, index_t TopK, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideScaleA, index_t StrideB, index_t StrideScaleB, std::array< index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: device_moe_mx_gemm.hpp:388
GridwiseMoeGemmMX< ALayout, BLayout, DsLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ActivationOP, NSwizzle, IsInputGemm, MulRoutedWeight, IndexType, ComputeTypeA, ComputeTypeB > GridwiseGemm
Definition: device_moe_mx_gemm.hpp:150
int GetPreShuffleParameters() override
Definition: device_moe_mx_gemm.hpp:157
static constexpr index_t BPackedSize
Definition: device_moe_mx_gemm.hpp:155
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_a_scale, const void *p_b, const void *p_b_scale, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideScaleA, index_t StrideB, index_t StrideScaleB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition: device_moe_mx_gemm.hpp:442
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_moe_mx_gemm.hpp:383
static auto MakeInvoker()
Definition: device_moe_mx_gemm.hpp:439
Definition: flush_cache.hpp:20