/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp Source File
device_moe_gemm_blockscale.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 #include <hip/hip_runtime.h>
9 
20 
21 namespace ck {
22 namespace tensor_operation {
23 namespace device {
24 
25 template <typename ALayout,
26  typename BLayout,
27  typename DsLayout,
28  typename CLayout,
29  typename ADataType,
30  typename AScaleDataType,
31  typename BDataType,
32  typename BScaleDataType,
33  typename DsDataType,
34  typename CDataType,
35  typename GemmAccDataType,
36  typename CShuffleDataType,
37  typename AElementwiseOperation,
38  typename BElementwiseOperation,
39  typename CElementwiseOperation,
40  GemmSpecialization GemmSpec,
41  index_t BlockSize,
42  index_t ScaleBlockM,
43  index_t ScaleBlockN,
44  index_t ScaleBlockK,
45  index_t MPerBlock,
46  index_t NPerBlock,
47  index_t KPerBlock,
48  index_t AK1,
49  index_t BK1,
50  index_t MPerXDL,
51  index_t NPerXDL,
52  index_t MXdlPerWave,
53  index_t NXdlPerWave,
54  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
55  typename ABlockTransferThreadClusterArrangeOrder,
56  typename ABlockTransferSrcAccessOrder,
57  index_t ABlockTransferSrcVectorDim,
58  index_t ABlockTransferSrcScalarPerVector,
59  index_t ABlockTransferDstScalarPerVector_AK1,
60  bool ABlockLdsExtraM,
61  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
62  typename BBlockTransferThreadClusterArrangeOrder,
63  typename BBlockTransferSrcAccessOrder,
64  index_t BBlockTransferSrcVectorDim,
65  index_t BBlockTransferSrcScalarPerVector,
66  index_t BBlockTransferDstScalarPerVector_BK1,
67  bool BBlockLdsExtraN,
68  index_t CShuffleMXdlPerWavePerShuffle,
69  index_t CShuffleNXdlPerWavePerShuffle,
70  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
71  typename CDEShuffleBlockTransferScalarPerVectors,
74  index_t ActivationOP = 0,
75  bool NSwizzle = false,
76  bool IsInputGemm = true,
77  bool IsSplitK = false,
78  bool MulRoutedWeight = false,
79  typename IndexType = index_t,
80  typename ComputeTypeA = CDataType,
81  typename ComputeTypeB = ComputeTypeA,
82  typename LDSTypeA = ComputeTypeA,
83  typename LDSTypeB = ComputeTypeB>
86  BLayout,
87  DsLayout,
88  CLayout,
89  ADataType,
90  AScaleDataType,
91  BDataType,
92  BScaleDataType,
93  DsDataType,
94  CDataType,
95  ScaleBlockM,
96  ScaleBlockN,
97  ScaleBlockK,
98  AElementwiseOperation,
99  BElementwiseOperation,
100  CElementwiseOperation>
101 {
103  static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
104  static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
105  static constexpr index_t NumDTensor = DsDataType::Size();
106  template <index_t NXdlPerWave_>
108  ALayout,
109  BLayout,
110  DsLayout,
111  CLayout,
112  ADataType,
113  BDataType,
114  GemmAccDataType,
115  CShuffleDataType,
116  DsDataType,
117  CDataType,
118  AElementwiseOperation,
119  BElementwiseOperation,
120  CElementwiseOperation,
121  GemmSpec,
122  BlockSize,
123  ScaleBlockM,
124  ScaleBlockN,
125  ScaleBlockK,
126  MPerBlock,
127  NPerBlock,
128  KPerBlock,
129  AK1,
130  BK1,
131  MPerXDL,
132  NPerXDL,
133  MXdlPerWave,
134  NXdlPerWave_,
135  ABlockTransferThreadClusterLengths_AK0_M_AK1,
136  ABlockTransferThreadClusterArrangeOrder,
137  ABlockTransferSrcAccessOrder,
138  ABlockTransferSrcVectorDim,
139  ABlockTransferSrcScalarPerVector,
140  ABlockTransferDstScalarPerVector_AK1,
141  false,
142  ABlockLdsExtraM,
143  BBlockTransferThreadClusterLengths_BK0_N_BK1,
144  BBlockTransferThreadClusterArrangeOrder,
145  BBlockTransferSrcAccessOrder,
146  BBlockTransferSrcVectorDim,
147  BBlockTransferSrcScalarPerVector,
148  BBlockTransferDstScalarPerVector_BK1,
149  false,
150  BBlockLdsExtraN,
151  CShuffleMXdlPerWavePerShuffle,
152  math::min(CShuffleNXdlPerWavePerShuffle, NXdlPerWave_),
153  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
154  CDEShuffleBlockTransferScalarPerVectors,
155  BlkGemmPipeSched,
156  BlkGemmPipelineVer,
157  ActivationOP,
158  NSwizzle,
159  IsInputGemm,
160  IsSplitK,
161  MulRoutedWeight,
162  IndexType,
163  ComputeTypeA,
164  ComputeTypeB,
165  LDSTypeA,
166  LDSTypeB>;
169 
171 
172  static constexpr index_t APackedSize = []() {
174  return 2;
175  else
176  return 1;
177  }();
178 
179  static constexpr index_t BPackedSize = []() {
181  return 2;
182  else
183  return 1;
184  }();
185 
186  int GetPreShuffleParameters() override { return NPerXDL; }
187 
188  // Invoker
189  struct Invoker : public BaseInvoker
190  {
191  template <typename GridwiseGemm>
192  float RunImp(const typename GridwiseGemm::Argument& arg,
193  const StreamConfig& stream_config = StreamConfig{})
194  {
195  if(stream_config.log_level_ > 0)
196  {
197  arg.Print();
198  }
199 
200  if(!GridwiseGemm::CheckValidity(arg))
201  {
202  throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
203  }
204 
205  index_t gdx, gdy, gdz;
206  std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(
207  arg.M, arg.N * (IsInputGemm && IsSplitK ? 2 : 1), arg.K, arg.KBatch);
208 
209  float ave_time = 0;
210 
211  index_t K_split = arg.KBatch == 1 ? arg.K : arg.KBatch * KPerBlock;
212 
213  const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
214  const auto RunKernel = [&](const auto& kernel) {
215  if(stream_config.flush_cache)
216  {
217 
218  std::array<std::size_t, NumDTensor> DsSize;
219 
220  auto arg_ = arg;
221 
222  const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
223  arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
224  const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
225  arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
226 
227  auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
228  sizeof(ADataType) / APackedSize;
229  auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
230  sizeof(BDataType) / BPackedSize;
231 
232  const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
233  arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
234 
235  static_for<0, NumDTensor, 1>{}([&](auto i) {
236  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
237  DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
238  });
239  ck::utility::RotatingMemWrapperMultiD<typename GridwiseGemm::Argument,
240  DsDataType>
241  rotating_mem(arg_,
242  stream_config.rotating_count,
243  size_a_buffer,
244  size_b_buffer,
245  DsSize);
246  rotating_mem.Print();
247 
248  auto run_flush_cache = [&]() {
249  // flush icache
251  // rotating mem
252  rotating_mem.Next();
253  // clear c mem
254  // if(arg_.KBatch > 1)
255  // hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
256  // 0,
257  // arg_.M * arg_.N * sizeof(CDataType)
258  // * (IsInputGemm && IsSplitK ? 2 : 1),
259  // stream_config.stream_id_));
260  };
261 
262  ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
263  stream_config,
264  run_flush_cache,
265  kernel,
266  dim3(gdx, gdy, gdz),
267  dim3(BlockSize),
268  0,
269  arg_);
270  }
271  else
272  {
273  // if(arg.KBatch > 1)
274  // hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
275  // 0,
276  // arg.M * arg.N * sizeof(CDataType) *
277  // (IsInputGemm && IsSplitK ? 2 : 1),
278  // stream_config.stream_id_));
279 
280  ave_time = launch_and_time_kernel(
281  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
282  }
283  };
284 
285  constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize /
286  4 * (1 + GridwiseGemm::NWave);
287  constexpr auto estimated_reg_b = NPerBlock * KPerBlock * sizeof(BDataType) / BlockSize /
288  4 * (2) * (IsInputGemm ? 2 : 1);
289  constexpr auto estimated_reg_c = MPerBlock * NPerBlock * sizeof(GemmAccDataType) /
290  BlockSize / 4 * (IsInputGemm ? 2 : 1);
291  constexpr auto estimated_reg_total =
292  estimated_reg_a + estimated_reg_b + estimated_reg_c;
293 
294  constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
295 
296  constexpr auto MemoryDataOp = (IsInputGemm && !IsSplitK)
299 
300  if(has_main_k_block_loop)
301  {
302  // Tail number always full
303  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
304  {
305  {
306  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
307  {
308  const auto kernel = kernel_moe_gemm<GridwiseGemm,
309  true,
310  MemoryDataOp,
311  minimum_occupancy,
313  RunKernel(kernel);
314  }
315  else
316  {
317  const auto kernel = kernel_moe_gemm<GridwiseGemm,
318  true,
319  MemoryDataOp,
320  minimum_occupancy,
322  RunKernel(kernel);
323  }
324  }
325  }
326  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
327  BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
328  {
329  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
330  {
331  const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
332  true,
333  MemoryDataOp,
334  minimum_occupancy,
336  RunKernel(kernel);
337  }
338  else
339  {
340  const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
341  true,
342  MemoryDataOp,
343  minimum_occupancy,
345  RunKernel(kernel);
346  }
347  }
348  else
349  {
350  throw std::runtime_error("todo: only v1 & v2 support now");
351  }
352  }
353 #if 1
354  else
355  {
356  // Tail number always 1
357  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
358  {
359  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
360  {
361  const auto kernel = kernel_moe_gemm<GridwiseGemm,
362  false,
363  MemoryDataOp,
364  minimum_occupancy,
366  RunKernel(kernel);
367  }
368  else
369  {
370  const auto kernel = kernel_moe_gemm<GridwiseGemm,
371  false,
372  MemoryDataOp,
373  minimum_occupancy,
375  RunKernel(kernel);
376  }
377  }
378  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
379  BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
380  {
381  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
382  {
383  const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
384  false,
385  MemoryDataOp,
386  minimum_occupancy,
388  RunKernel(kernel);
389  }
390  else
391  {
392  const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
393  false,
394  MemoryDataOp,
395  minimum_occupancy,
397  RunKernel(kernel);
398  }
399  }
400  }
401 #endif
402 
403  return ave_time;
404  }
405 
407 
408  // polymorphic
409  float Run(const BaseArgument* p_arg,
410  const StreamConfig& stream_config = StreamConfig{}) override
411  {
412  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
413  }
414  };
415 
416  static constexpr bool IsValidCompilationParameter()
417  {
418  // TODO: properly implement this check
419  return true;
420  }
421 
422  static bool IsSupportedArgument(const Argument& arg)
423  {
424  // only impl kbatch 1 for fp32
425  if(arg.KBatch > 1 && !std::is_same_v<CDataType, float>)
426  {
427  return false;
428  }
429  if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
430  {
431  return false;
432  }
433  if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
434  {
435  return false;
436  }
437 
438  if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
439  GemmSpec == GemmSpecialization::NKPadding ||
440  GemmSpec == GemmSpecialization::MNKPadding ||
441  GemmSpec == GemmSpecialization::KPadding))
442  {
443  return false;
444  }
445  if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
446  {
447  return false;
448  }
449  if(arg.KBatch > 1 && arg.K % (KPerBlock * arg.KBatch) != 0)
450  {
451  // Not support Kpadding with KBatch > 1
452  return false;
453  }
454 
455  if(get_warp_size() == 64)
456  {
457  if constexpr(NXdlPerWave64 > 0)
458  {
459  return GridwiseGemm64::CheckValidity(arg);
460  }
461  }
462  else
463  {
464  if constexpr(NXdlPerWave32 > 0)
465  {
467  reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
468  }
469  }
470  return false;
471  }
472 
473  // polymorphic
474  bool IsSupportedArgument(const BaseArgument* p_arg) override
475  {
476  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
477  }
478 
479  static auto MakeArgument(const void* p_sorted_token_ids,
480  const void* p_sorted_expert_ids,
481  const void* p_max_token_id,
482  const void* p_a,
483  const void* p_b,
484  std::array<const void*, NumDTensor> p_ds,
485  void* p_c,
486  index_t NumTokens,
487  index_t TopK,
488  index_t M,
489  index_t N,
490  index_t K,
491  index_t StrideA,
492  index_t StrideB,
493  std::array<index_t, NumDTensor> StrideDs,
494  index_t StrideC,
495  const void* p_a_scale,
496  const void* p_b_scale,
497  index_t KBatch,
498  AElementwiseOperation a_element_op,
499  BElementwiseOperation b_element_op,
500  CElementwiseOperation c_element_op)
501  {
502  return Argument{static_cast<const index_t*>(p_sorted_token_ids),
503  static_cast<const index_t*>(p_sorted_expert_ids),
504  static_cast<const index_t*>(p_max_token_id),
505  static_cast<const ADataType*>(p_a),
506  static_cast<const BDataType*>(p_b),
507  p_ds,
508  static_cast<CDataType*>(p_c),
509  NumTokens,
510  TopK,
511  M,
512  N,
513  K,
514  StrideA,
515  StrideB,
516  StrideDs,
517  StrideC,
518  static_cast<const AScaleDataType*>(p_a_scale),
519  static_cast<const BScaleDataType*>(p_b_scale),
520  KBatch,
521  a_element_op,
522  b_element_op,
523  c_element_op};
524  }
525 
526  static auto MakeInvoker() { return Invoker{}; }
527 
528  // polymorphic
529  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
530  const void* p_b,
531  std::array<const void*, NumDTensor> p_ds,
532  void* p_c,
533  index_t M,
534  index_t N,
535  index_t K,
536  index_t StrideA,
537  index_t StrideB,
538  std::array<ck::index_t, NumDTensor> StrideDs,
539  index_t StrideC,
540  const void* p_a_scale,
541  const void* p_b_scale,
542  // index_t KBatch,
543  AElementwiseOperation a_element_op,
544  BElementwiseOperation b_element_op,
545  CElementwiseOperation c_element_op) override
546  {
547  return std::make_unique<Argument>(nullptr,
548  nullptr,
549  nullptr,
550  static_cast<const ADataType*>(p_a),
551  static_cast<const BDataType*>(p_b),
552  p_ds,
553  static_cast<CDataType*>(p_c),
554  M, // randoms set, no use
555  0,
556  M,
557  N,
558  K,
559  StrideA,
560  StrideB,
561  StrideDs,
562  StrideC,
563  static_cast<const AScaleDataType*>(p_a_scale),
564  static_cast<const BScaleDataType*>(p_b_scale),
565  1, // KBatch,
566  a_element_op,
567  b_element_op,
568  c_element_op);
569  }
570 
571  // polymorphic
572  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
573  {
574  return std::make_unique<Invoker>(Invoker{});
575  }
576 
577  // polymorphic
578  std::string GetTypeString() const override
579  {
580  auto str = std::stringstream();
581 
582  std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
585 
586  std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
590 
591  // clang-format off
592  str << "DeviceMoeGEmm"
593  << "<"
594  << getGemmSpecializationString(GemmSpec) << ", "
595  << std::string(ALayout::name)[0]
596  << std::string(BLayout::name)[0]
597  << std::string(CLayout::name)[0]
598  << ">"
599  << " BlkSize: "
600  << BlockSize << ", "
601  << "BlkTile: "
602  << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
603  << "WaveTile: "
604  << MPerXDL<<"x"<<NPerXDL << ", "
605  << "WaveMap: "
606  << MXdlPerWave<<"x" << NXdlPerWave<<", "
607  << "VmemReadVec: "
608  << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
609  << "BlkGemmPipelineScheduler: "
610  << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
611  << "BlkGemmPipelineVersion: "
612  << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
613  << "BlkGemmPipelinePrefetchStages: "
614  << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
615  // clang-format on
616 
617  return str.str();
618  }
619 };
620 
621 } // namespace device
622 } // namespace tensor_operation
623 } // namespace ck
#define INVOKER_RUN3_IMPL
Definition: device_base.hpp:187
#define GET_NXDL_PER_WAVE_IMPL
Definition: device_base.hpp:87
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
void flush_icache()
Definition: flush_cache.hpp:383
Definition: ck.hpp:270
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
BlockGemmPipelineVersion
Block GEMM pipeline version enumeration.
Definition: scheduler_enum.hpp:17
@ v2
Memory-optimized pipeline.
@ v3
Compute-optimized pipeline.
__global__ void kernel_moe_gemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:46
@ Even
Even number of iterations.
@ Odd
Odd number of iterations.
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
constexpr bool is_same_v
Definition: type.hpp:283
BlockGemmPipelineScheduler
Block GEMM pipeline scheduler enumeration.
Definition: scheduler_enum.hpp:33
@ Intrawave
Schedule within a single wavefront.
@ Interwave
Schedule across multiple wavefronts.
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:301
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:16
__global__ void kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:84
bool is_bf16_atomic_supported()
Definition: device_prop.hpp:113
Definition: stream_config.hpp:10
Definition: gridwise_moe_gemm_blockscale.hpp:673
Definition: gridwise_moe_gemm_blockscale.hpp:178
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_gemm_blockscale.hpp:981
Definition: data_type.hpp:187
Definition: functional2.hpp:33
Definition: device_base.hpp:270
Definition: device_base.hpp:281
Definition: device_gemm_multiple_d_ab_scale.hpp:82
Definition: device_moe_gemm_blockscale.hpp:190
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_moe_gemm_blockscale.hpp:192
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_moe_gemm_blockscale.hpp:409
Definition: device_moe_gemm_blockscale.hpp:101
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_moe_gemm_blockscale.hpp:572
static constexpr index_t BPackedSize
Definition: device_moe_gemm_blockscale.hpp:179
static constexpr bool IsValidCompilationParameter()
Definition: device_moe_gemm_blockscale.hpp:416
static auto MakeArgument(const void *p_sorted_token_ids, const void *p_sorted_expert_ids, const void *p_max_token_id, const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t NumTokens, index_t TopK, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideC, const void *p_a_scale, const void *p_b_scale, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: device_moe_gemm_blockscale.hpp:479
static constexpr index_t APackedSize
Definition: device_moe_gemm_blockscale.hpp:172
static constexpr index_t NumDTensor
Definition: device_moe_gemm_blockscale.hpp:105
typename GridwiseGemm64::Argument Argument
Definition: device_moe_gemm_blockscale.hpp:170
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_moe_gemm_blockscale.hpp:474
int GetPreShuffleParameters() override
Definition: device_moe_gemm_blockscale.hpp:186
static auto MakeInvoker()
Definition: device_moe_gemm_blockscale.hpp:526
static constexpr GET_NXDL_PER_WAVE_IMPL auto NXdlPerWave64
Definition: device_moe_gemm_blockscale.hpp:103
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, const void *p_a_scale, const void *p_b_scale, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition: device_moe_gemm_blockscale.hpp:529
static constexpr auto NXdlPerWave32
Definition: device_moe_gemm_blockscale.hpp:104
static bool IsSupportedArgument(const Argument &arg)
Definition: device_moe_gemm_blockscale.hpp:422
std::string GetTypeString() const override
Definition: device_moe_gemm_blockscale.hpp:578
Definition: flush_cache.hpp:174