/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp Source File
device_moe_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
19 
20 namespace ck {
21 namespace tensor_operation {
22 namespace device {
23 
24 template <typename ALayout,
25  typename BLayout,
26  typename DsLayout,
27  typename CLayout,
28  typename ADataType,
29  typename BDataType,
30  typename DsDataType,
31  typename CDataType,
32  typename GemmAccDataType,
33  typename CShuffleDataType,
34  typename AElementwiseOperation,
35  typename BElementwiseOperation,
36  typename CElementwiseOperation,
37  GemmSpecialization GemmSpec,
38  index_t BlockSize,
39  index_t MPerBlock,
40  index_t NPerBlock,
41  index_t KPerBlock,
42  index_t AK1,
43  index_t BK1,
44  index_t MPerXDL,
45  index_t NPerXDL,
46  index_t MXdlPerWave,
47  index_t NXdlPerWave,
48  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
49  typename ABlockTransferThreadClusterArrangeOrder,
50  typename ABlockTransferSrcAccessOrder,
51  index_t ABlockTransferSrcVectorDim,
52  index_t ABlockTransferSrcScalarPerVector,
53  index_t ABlockTransferDstScalarPerVector_AK1,
54  bool ABlockLdsExtraM,
55  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
56  typename BBlockTransferThreadClusterArrangeOrder,
57  typename BBlockTransferSrcAccessOrder,
58  index_t BBlockTransferSrcVectorDim,
59  index_t BBlockTransferSrcScalarPerVector,
60  index_t BBlockTransferDstScalarPerVector_BK1,
61  bool BBlockLdsExtraN,
62  index_t CShuffleMXdlPerWavePerShuffle,
63  index_t CShuffleNXdlPerWavePerShuffle,
64  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
65  typename CDEShuffleBlockTransferScalarPerVectors,
68  index_t ActivationOP = 0,
69  bool NSwizzle = false,
70  bool IsInputGemm = true,
71  bool MulRoutedWeight = true,
72  bool PerTokenQuant = true,
73  typename IndexType = index_t,
74  typename ComputeTypeA = CDataType,
75  typename ComputeTypeB = ComputeTypeA,
76  typename LDSTypeA = ComputeTypeA,
77  typename LDSTypeB = ComputeTypeB>
79  BLayout,
80  DsLayout,
81  CLayout,
82  ADataType,
83  BDataType,
84  DsDataType,
85  CDataType,
86  AElementwiseOperation,
87  BElementwiseOperation,
88  CElementwiseOperation>
89 {
91  static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
92  static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
93  static constexpr index_t NumDTensor = DsDataType::Size();
94  template <index_t NXdlPerWave_>
96  GridwiseMoeGemm<ALayout,
97  BLayout,
98  DsLayout,
99  CLayout,
100  ADataType,
101  BDataType,
102  GemmAccDataType,
103  CShuffleDataType,
104  DsDataType,
105  CDataType,
106  AElementwiseOperation,
107  BElementwiseOperation,
108  CElementwiseOperation,
109  GemmSpec,
110  BlockSize,
111  MPerBlock,
112  NPerBlock,
113  KPerBlock,
114  AK1,
115  BK1,
116  MPerXDL,
117  NPerXDL,
118  MXdlPerWave,
119  NXdlPerWave_,
120  ABlockTransferThreadClusterLengths_AK0_M_AK1,
121  ABlockTransferThreadClusterArrangeOrder,
122  ABlockTransferSrcAccessOrder,
123  ABlockTransferSrcVectorDim,
124  ABlockTransferSrcScalarPerVector,
125  ABlockTransferDstScalarPerVector_AK1,
126  false,
127  ABlockLdsExtraM,
128  BBlockTransferThreadClusterLengths_BK0_N_BK1,
129  BBlockTransferThreadClusterArrangeOrder,
130  BBlockTransferSrcAccessOrder,
131  BBlockTransferSrcVectorDim,
132  BBlockTransferSrcScalarPerVector,
133  BBlockTransferDstScalarPerVector_BK1,
134  false,
135  BBlockLdsExtraN,
136  CShuffleMXdlPerWavePerShuffle,
137  math::min(CShuffleNXdlPerWavePerShuffle, NXdlPerWave_),
138  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
139  CDEShuffleBlockTransferScalarPerVectors,
140  BlkGemmPipeSched,
141  BlkGemmPipelineVer,
142  ActivationOP,
143  NSwizzle,
144  IsInputGemm,
145  MulRoutedWeight,
146  PerTokenQuant,
147  IndexType,
148  ComputeTypeA,
149  ComputeTypeB,
150  LDSTypeA,
151  LDSTypeB>;
154 
156 
157  static constexpr index_t APackedSize = []() {
159  return 2;
160  else
161  return 1;
162  }();
163 
164  static constexpr index_t BPackedSize = []() {
166  return 2;
167  else
168  return 1;
169  }();
170 
171  int GetPreShuffleParameters() override { return NPerXDL; }
172 
173  // Invoker
174  struct Invoker : public BaseInvoker
175  {
176  template <typename GridwiseGemm>
177  float RunImp(const typename GridwiseGemm::Argument& arg,
178  const StreamConfig& stream_config = StreamConfig{})
179  {
180  if(stream_config.log_level_ > 0)
181  {
182  arg.Print();
183  }
184 
185  if(!GridwiseGemm::CheckValidity(arg))
186  {
187  throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
188  }
189 
190  index_t gdx, gdy, gdz;
191  std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
192 
193  float ave_time = 0;
194 
195  index_t k_grain = arg.KBatch * KPerBlock;
196  index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
197 
198  const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
199 
200  const auto RunKernel = [&](const auto& kernel) {
201  if(stream_config.flush_cache)
202  {
203 
204  std::array<std::size_t, NumDTensor> DsSize;
205 
206  auto arg_ = arg;
207 
208  const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
209  arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
210  const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
211  arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
212 
213  auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
214  sizeof(ADataType) / APackedSize;
215  auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
216  sizeof(BDataType) / BPackedSize;
217 
218  const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
219  arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
220 
221  static_for<0, NumDTensor, 1>{}([&](auto i) {
222  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
223  DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
224  });
225  ck::utility::RotatingMemWrapperMultiD<typename GridwiseGemm::Argument,
226  DsDataType>
227  rotating_mem(arg_,
228  stream_config.rotating_count,
229  size_a_buffer,
230  size_b_buffer,
231  DsSize);
232  rotating_mem.Print();
233 
234  auto run_flush_cache = [&]() {
235  // flush icache
237  // rotating mem
238  rotating_mem.Next();
239  // clear c mem
240  if(arg_.KBatch > 1)
241  hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
242  0,
243  arg_.M * arg_.N * sizeof(CDataType),
244  stream_config.stream_id_));
245  };
246 
247  ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
248  stream_config,
249  run_flush_cache,
250  kernel,
251  dim3(gdx, gdy, gdz),
252  dim3(BlockSize),
253  0,
254  arg_);
255  }
256  else
257  {
258  if(arg.KBatch > 1)
259  hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
260  0,
261  arg.M * arg.N * sizeof(CDataType),
262  stream_config.stream_id_));
263 
264  ave_time = launch_and_time_kernel(
265  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
266  }
267  };
268 
269  constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize /
270  4 * (1 + GridwiseGemm::NWave);
271  constexpr auto estimated_reg_b = NPerBlock * KPerBlock * sizeof(BDataType) / BlockSize /
272  4 * (2) * (IsInputGemm ? 2 : 1);
273  constexpr auto estimated_reg_c = MPerBlock * NPerBlock * sizeof(GemmAccDataType) /
274  BlockSize / 4 * (IsInputGemm ? 2 : 1);
275  constexpr auto estimated_reg_total =
276  estimated_reg_a + estimated_reg_b + estimated_reg_c;
277 
278  constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
279 
280  constexpr auto MemoryDataOp =
282  if(has_main_k_block_loop)
283  {
284  // Tail number always full
285  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
286  {
287  {
288  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
289  {
290  const auto kernel = kernel_moe_gemm<GridwiseGemm,
291  true,
292  MemoryDataOp,
293  minimum_occupancy,
295  RunKernel(kernel);
296  }
297  else
298  {
299  const auto kernel = kernel_moe_gemm<GridwiseGemm,
300  true,
301  MemoryDataOp,
302  minimum_occupancy,
304  RunKernel(kernel);
305  }
306  }
307  }
308  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
309  BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
310  {
311  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
312  {
313  const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
314  true,
315  MemoryDataOp,
316  minimum_occupancy,
318  RunKernel(kernel);
319  }
320  else
321  {
322  const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
323  true,
324  MemoryDataOp,
325  minimum_occupancy,
327  RunKernel(kernel);
328  }
329  }
330  else
331  {
332  throw std::runtime_error("todo: only v1 & v2 support now");
333  }
334  }
335 #if 1
336  else
337  {
338  // Tail number always 1
339  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
340  {
341  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
342  {
343  const auto kernel = kernel_moe_gemm<GridwiseGemm,
344  false,
345  MemoryDataOp,
346  minimum_occupancy,
348  RunKernel(kernel);
349  }
350  else
351  {
352  const auto kernel = kernel_moe_gemm<GridwiseGemm,
353  false,
354  MemoryDataOp,
355  minimum_occupancy,
357  RunKernel(kernel);
358  }
359  }
360  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
361  BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
362  {
363  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
364  {
365  const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
366  false,
367  MemoryDataOp,
368  minimum_occupancy,
370  RunKernel(kernel);
371  }
372  else
373  {
374  const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
375  false,
376  MemoryDataOp,
377  minimum_occupancy,
379  RunKernel(kernel);
380  }
381  }
382  else
383  {
384  throw std::runtime_error("todo: only v1 & v2 support now");
385  }
386  }
387 #endif
388 
389  return ave_time;
390  }
391 
393 
394  // polymorphic
395  float Run(const BaseArgument* p_arg,
396  const StreamConfig& stream_config = StreamConfig{}) override
397  {
398  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
399  }
400  };
401 
402  static constexpr bool IsValidCompilationParameter()
403  {
404  // TODO: properly implement this check
405  return true;
406  }
407 
408  static bool IsSupportedArgument(const Argument& arg)
409  {
410  // only impl kbatch 1 now
411  if(arg.KBatch > 1)
412  {
413  return false;
414  }
415  if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
416  {
417  return false;
418  }
419  if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
420  {
421  return false;
422  }
423 
424  if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
425  GemmSpec == GemmSpecialization::NKPadding ||
426  GemmSpec == GemmSpecialization::MNKPadding ||
427  GemmSpec == GemmSpecialization::KPadding))
428  {
429  return false;
430  }
431  if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
432  {
433  return false;
434  }
435  if(get_warp_size() == 64)
436  {
437  if constexpr(NXdlPerWave64 > 0)
438  {
439  return GridwiseGemm64::CheckValidity(arg);
440  }
441  }
442  else
443  {
444  if constexpr(NXdlPerWave32 > 0)
445  {
447  reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
448  }
449  }
450  return false;
451  }
452 
453  // polymorphic
454  bool IsSupportedArgument(const BaseArgument* p_arg) override
455  {
456  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
457  }
458 
459  static auto MakeArgument(const void* p_sorted_token_ids,
460  const void* p_sorted_expert_ids,
461  const void* p_max_token_id,
462  const void* p_a,
463  const void* p_b,
464  std::array<const void*, NumDTensor> p_ds,
465  void* p_c,
466  index_t NumTokens,
467  index_t TopK,
468  index_t M,
469  index_t N,
470  index_t K,
471  index_t StrideA,
472  index_t StrideB,
473  std::array<index_t, NumDTensor> StrideDs,
474  index_t StrideC,
475  index_t KBatch,
476  AElementwiseOperation a_element_op,
477  BElementwiseOperation b_element_op,
478  CElementwiseOperation c_element_op)
479  {
480  return Argument{static_cast<const index_t*>(p_sorted_token_ids),
481  static_cast<const index_t*>(p_sorted_expert_ids),
482  static_cast<const index_t*>(p_max_token_id),
483  static_cast<const ADataType*>(p_a),
484  static_cast<const BDataType*>(p_b),
485  p_ds,
486  static_cast<CDataType*>(p_c),
487  NumTokens,
488  TopK,
489  M,
490  N,
491  K,
492  StrideA,
493  StrideB,
494  StrideDs,
495  StrideC,
496  KBatch,
497  a_element_op,
498  b_element_op,
499  c_element_op};
500  }
501 
502  static auto MakeInvoker() { return Invoker{}; }
503 
504  // polymorphic
505  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
506  const void* p_b,
507  std::array<const void*, NumDTensor> p_ds,
508  void* p_c,
509  index_t M,
510  index_t N,
511  index_t K,
512  index_t StrideA,
513  index_t StrideB,
514  std::array<ck::index_t, NumDTensor> StrideDs,
515  index_t StrideC,
516  index_t KBatch,
517  AElementwiseOperation a_element_op,
518  BElementwiseOperation b_element_op,
519  CElementwiseOperation c_element_op) override
520  {
521  return std::make_unique<Argument>(nullptr,
522  nullptr,
523  nullptr,
524  static_cast<const ADataType*>(p_a),
525  static_cast<const BDataType*>(p_b),
526  p_ds,
527  static_cast<CDataType*>(p_c),
528  M, // randoms set, no use
529  0,
530  M,
531  N,
532  K,
533  StrideA,
534  StrideB,
535  StrideDs,
536  StrideC,
537  KBatch,
538  a_element_op,
539  b_element_op,
540  c_element_op);
541  }
542 
543  // polymorphic
544  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
545  {
546  return std::make_unique<Invoker>(Invoker{});
547  }
548 
549  // polymorphic
550  std::string GetTypeString() const override
551  {
552  auto str = std::stringstream();
553 
554  std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
557 
558  std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
560 
561  // clang-format off
562  str << "DeviceMoeGEmm"
563  << "<"
564  << getGemmSpecializationString(GemmSpec) << ", "
565  << std::string(ALayout::name)[0]
566  << std::string(BLayout::name)[0]
567  << std::string(CLayout::name)[0]
568  << ">"
569  << " BlkSize: "
570  << BlockSize << ", "
571  << "BlkTile: "
572  << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
573  << "WaveTile: "
574  << MPerXDL<<"x"<<NPerXDL << ", "
575  << "WaveMap: "
576  << MXdlPerWave<<"x" << NXdlPerWave<<", "
577  << "VmemReadVec: "
578  << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
579  << "BlkGemmPipelineScheduler: "
580  << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
581  << "BlkGemmPipelineVersion: "
582  << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
583  << "BlkGemmPipelinePrefetchStages: "
584  << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
585  // clang-format on
586 
587  return str.str();
588  }
589 };
590 
591 } // namespace device
592 } // namespace tensor_operation
593 } // namespace ck
#define INVOKER_RUN3_IMPL
Definition: device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition: device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
void flush_icache()
Definition: flush_cache.hpp:361
Definition: ck.hpp:268
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__global__ void kernel_moe_gemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:46
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
constexpr bool is_same_v
Definition: type.hpp:283
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:299
__global__ void kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:84
bool is_bf16_atomic_supported()
Definition: device_prop.hpp:108
Definition: stream_config.hpp:10
Definition: gridwise_moe_gemm.hpp:659
Definition: gridwise_moe_gemm.hpp:171
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_gemm.hpp:952
Definition: data_type.hpp:187
Definition: functional2.hpp:33
Definition: device_base.hpp:197
Definition: device_base.hpp:208
Definition: device_moe_gemm.hpp:175
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_moe_gemm.hpp:395
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_moe_gemm.hpp:177
Definition: device_moe_gemm.hpp:89
int GetPreShuffleParameters() override
Definition: device_moe_gemm.hpp:171
static auto MakeArgument(const void *p_sorted_token_ids, const void *p_sorted_expert_ids, const void *p_max_token_id, const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t NumTokens, index_t TopK, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: device_moe_gemm.hpp:459
static constexpr bool IsValidCompilationParameter()
Definition: device_moe_gemm.hpp:402
static constexpr index_t BPackedSize
Definition: device_moe_gemm.hpp:164
static bool IsSupportedArgument(const Argument &arg)
Definition: device_moe_gemm.hpp:408
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_moe_gemm.hpp:454
static constexpr auto NXdlPerWave32
Definition: device_moe_gemm.hpp:92
typename GridwiseGemm64::Argument Argument
Definition: device_moe_gemm.hpp:155
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition: device_moe_gemm.hpp:505
static constexpr index_t NumDTensor
Definition: device_moe_gemm.hpp:93
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_moe_gemm.hpp:544
static constexpr GET_NXDL_PER_WAVE_IMPL auto NXdlPerWave64
Definition: device_moe_gemm.hpp:91
std::string GetTypeString() const override
Definition: device_moe_gemm.hpp:550
static auto MakeInvoker()
Definition: device_moe_gemm.hpp:502
static constexpr index_t APackedSize
Definition: device_moe_gemm.hpp:157
Definition: flush_cache.hpp:165