/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp Source File
grouped_flatmm_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
12 
13 namespace ck_tile {
14 
15 template <class ScaleM = FlatmmScalePointer<-1>,
16  class ScaleN = FlatmmScalePointer<-1>,
17  index_t NumDTensor = 0>
19 {
22  index_t* M_,
23  index_t* N_,
24  index_t* K_,
25  const void** a_ptr_,
26  index_t* stride_A_,
27  const void** b_shuffle_ptr_,
28  index_t* stride_B_,
29  const std::array<const void*, NumDTensor>& ds_ptr_,
30  const std::array<index_t, NumDTensor>& stride_Ds_,
31  void** c_ptr_,
32  index_t* stride_C_,
33  index_t k_batch_,
34  ScaleM* scale_m_ = nullptr,
35  ScaleN* scale_n_ = nullptr)
36  : group_count(group_count_),
37  M(M_),
38  N(N_),
39  K(K_),
40  a_ptr(a_ptr_),
41  stride_A(stride_A_),
42  b_shuffle_ptr(b_shuffle_ptr_),
43  stride_B(stride_B_),
44  ds_ptr(ds_ptr_),
45  stride_Ds(stride_Ds_),
46  c_ptr(c_ptr_),
47  stride_C(stride_C_),
48  k_batch(k_batch_),
49  scale_m(scale_m_),
50  scale_n(scale_n_)
51  {
52  }
53 
58  const void** a_ptr;
60  const void** b_shuffle_ptr;
62  const std::array<const void*, NumDTensor> ds_ptr;
63  const std::array<index_t, NumDTensor> stride_Ds;
64  union
65  {
66  void** e_ptr;
67  void** c_ptr;
68  };
71  ScaleM* scale_m = nullptr;
72  ScaleN* scale_n = nullptr;
73 };
74 
75 template <class ScaleM = FlatmmScalePointer<-1>,
76  class ScaleN = FlatmmScalePointer<-1>,
77  index_t NumDTensor = 0>
79 {
82  index_t M_,
83  index_t N_,
84  index_t K_,
85  const void* a_ptr_,
86  index_t stride_A_,
87  const void* b_shuffle_ptr_,
88  index_t stride_B_,
89  const std::array<const void*, NumDTensor>& ds_ptr_,
90  const std::array<index_t, NumDTensor>& stride_Ds_,
91  void* c_ptr_,
92  index_t stride_C_,
93  index_t k_batch_,
94  ScaleM scale_m_ = nullptr,
95  ScaleN scale_n_ = nullptr)
96  : group_count(1),
97  M_indices(M_indices_),
98  M(M_),
99  N(N_),
100  K(K_),
101  a_ptr(a_ptr_),
102  stride_A(stride_A_),
103  b_shuffle_ptr(b_shuffle_ptr_),
104  stride_B(stride_B_),
105  ds_ptr(ds_ptr_),
106  stride_Ds(stride_Ds_),
107  c_ptr(c_ptr_),
108  stride_C(stride_C_),
109  k_batch(k_batch_),
110  scale_m(scale_m_),
111  scale_n(scale_n_)
112  {
113  }
119  const void* a_ptr;
121  const void* b_shuffle_ptr;
123  const std::array<const void*, NumDTensor> ds_ptr;
124  const std::array<index_t, NumDTensor> stride_Ds;
125  union
126  {
127  void* e_ptr;
128  void* c_ptr;
129  };
132  ScaleM scale_m = nullptr;
133  ScaleN scale_n = nullptr;
134 };
135 
136 template <class ScaleM = FlatmmScalePointer<-1>,
137  class ScaleN = FlatmmScalePointer<-1>,
138  index_t NumDTensor = 0>
140 {
143  index_t group_count_,
144  index_t Max_M_,
145  index_t N_,
146  index_t K_,
147  const void* a_ptr_,
148  index_t stride_A_,
149  const void* b_shuffle_ptr_,
150  index_t stride_B_,
151  const std::array<const void*, NumDTensor>& ds_ptr_,
152  const std::array<index_t, NumDTensor>& stride_Ds_,
153  void* c_ptr_,
154  index_t stride_C_,
155  index_t k_batch_,
156  ScaleM scale_m_ = nullptr,
157  ScaleN scale_n_ = nullptr)
158  : M_indices(M_indices_),
159  group_count(group_count_),
160  M(Max_M_),
161  N(N_),
162  K(K_),
163  a_ptr(a_ptr_),
164  stride_A(stride_A_),
165  b_shuffle_ptr(b_shuffle_ptr_),
166  stride_B(stride_B_),
167  ds_ptr(ds_ptr_),
168  stride_Ds(stride_Ds_),
169  c_ptr(c_ptr_),
170  stride_C(stride_C_),
171  k_batch(k_batch_),
172  scale_m(scale_m_),
173  scale_n(scale_n_)
174  {
175  }
176 
182  const void* a_ptr;
184  const void* b_shuffle_ptr;
186  const std::array<const void*, NumDTensor> ds_ptr;
187  const std::array<index_t, NumDTensor> stride_Ds;
188  union
189  {
190  void* e_ptr;
191  void* c_ptr;
192  };
195  ScaleM scale_m = nullptr;
196  ScaleN scale_n = nullptr;
197 };
198 
199 template <typename TilePartitioner_, typename FlatmmPipeline_, typename EpiloguePipeline_>
200 struct GroupedFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>
201 {
204 
207 
209 
212  // Below type is actually accumulation data type - the output of block GEMM.
216 
217  static constexpr index_t NumDTensor = DsDataType::size();
218  static constexpr index_t kBlockSize = FlatmmPipeline_::BlockSize;
219 
220  static constexpr auto I0 = number<0>();
221  static constexpr auto I1 = number<1>();
222  static constexpr auto I2 = number<2>();
223  static constexpr auto I3 = number<3>();
224 
225  static_assert(DsLayout::size() == DsDataType::size(),
226  "The size of DsLayout and DsDataType should be the same");
227 
228  CK_TILE_HOST static const std::string GetName()
229  {
230  return concat(
231  '_', "grouped_flatmm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
232  }
233 
234  template <class ScaleM = FlatmmScalePointer<-1>,
235  class ScaleN = FlatmmScalePointer<-1>,
236  index_t NumDTensor = 0>
237  CK_TILE_HOST_DEVICE static auto
239  {
240  hipDeviceProp_t prop;
241  int deviceId = 0; // default device
242 
243  constexpr int block_size = UnderlyingGemmKernel::BlockSize().x;
244  int dync_smem_size = 0;
245  int maxActiveBlocksPerCU;
246 
247  [[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
248 
249  e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
250  &maxActiveBlocksPerCU,
251  reinterpret_cast<void*>(
253  block_size,
254  dync_smem_size);
255 
256  const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
257 
258  // std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
259  // << ", persistent_block_size: " << persistent_block_size << std::endl;
260 
261  assert(kernelArgs.k_batch == 1);
262  return dim3(persistent_block_size, 1, kernelArgs.k_batch);
263  }
264 
265  template <class ScaleM = FlatmmScalePointer<-1>,
266  class ScaleN = FlatmmScalePointer<-1>,
267  index_t NumDTensor = 0>
268  CK_TILE_HOST_DEVICE static auto
270  kernelArgs)
271  {
272  hipDeviceProp_t prop;
273  int deviceId = 0; // default device
274 
275  constexpr int block_size = UnderlyingGemmKernel::BlockSize().x;
276  int dync_smem_size = 0;
277  int maxActiveBlocksPerCU;
278 
279  [[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
280 
281  e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
282  &maxActiveBlocksPerCU,
283  reinterpret_cast<void*>(
284  kentry<1,
287  block_size,
288  dync_smem_size);
289 
290  const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
291  const int total_work_tile_cnt = TilePartitioner::GridSize(kernelArgs.M, kernelArgs.N);
292 
293  // std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
294  // << ", persistent_block_size: " << persistent_block_size
295  // << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
296 
297  assert(kernelArgs.k_batch == 1);
298  return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kernelArgs.k_batch);
299  }
300 
301  template <class ScaleM = FlatmmScalePointer<-1>,
302  class ScaleN = FlatmmScalePointer<-1>,
303  index_t NumDTensor = 0>
305  [[maybe_unused]] const MaskedGroupedFlatmmHostArgs<ScaleM, ScaleN, NumDTensor>& kernelArgs)
306  {
307  hipDeviceProp_t prop;
308  int deviceId = 0; // default device
309 
310  constexpr int block_size = UnderlyingGemmKernel::BlockSize().x;
311  int dync_smem_size = 0;
312  int maxActiveBlocksPerCU;
313 
314  [[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
315 
316  e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
317  &maxActiveBlocksPerCU,
318  reinterpret_cast<void*>(
319  kentry<1,
322  block_size,
323  dync_smem_size);
324 
325  const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
326  // const int total_work_tile_cnt = TilePartitioner::GridSize(kernelArgs.M, kernelArgs.N);
327 
328  // std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
329  // << ", persistent_block_size: " << persistent_block_size << std::endl;
330 
331  assert(kernelArgs.k_batch == 1);
332  return dim3(persistent_block_size, 1, kernelArgs.k_batch);
333  }
334 
335  template <typename HostArgs>
336  CK_TILE_HOST static constexpr auto MakeKernelArgs(const HostArgs& hostArgs)
337  {
338  return hostArgs;
339  }
340  // CK_TILE_HOST static constexpr auto
341  // MakeKernelArgs(const ContiguousGroupedFlatmmHostArgs& hostArgs)
342  // {
343  // return hostArgs;
344  // }
345  // CK_TILE_HOST static constexpr auto
346  // MakeKernelArgs(const MaskedGroupedFlatmmHostArgs& hostArgs)
347  // {
348  // return hostArgs;
349  // }
350 
351  template <class ScaleM = FlatmmScalePointer<-1>,
352  class ScaleN = FlatmmScalePointer<-1>,
353  index_t NumDTensor = 0>
355  {
356  int group_idx = 0;
357  int block_linear_idx = blockIdx.x;
358  int total_block_cnt = gridDim.x;
359 
360  UnderlyingGemmKernel underlying_kernel{};
361  for(; group_idx < kargs.group_count; ++group_idx)
362  {
363  const index_t M = kargs.M[group_idx];
364  const index_t N = kargs.N[group_idx];
365  const index_t group_block_cnt = TilePartitioner::GridSize(M, N);
366 
367  while(block_linear_idx < group_block_cnt)
368  {
369  // Found the group this block belongs to
370  // create the kernel args for the underlying flatmm kernel
372  kargs.a_ptr[group_idx],
373  kargs.b_shuffle_ptr[group_idx],
374  kargs.ds_ptr,
375  kargs.c_ptr[group_idx],
376  kargs.M[group_idx],
377  kargs.N[group_idx],
378  kargs.K[group_idx],
379  kargs.stride_A[group_idx],
380  kargs.stride_B[group_idx],
381  kargs.stride_Ds,
382  kargs.stride_C[group_idx],
383  kargs.k_batch,
384  kargs.scale_m[group_idx],
385  kargs.scale_n[group_idx]};
386  // call the underlying flatmm kernel
387  underlying_kernel(impl_kargs, block_linear_idx);
388  block_linear_idx += total_block_cnt;
389  }
390  block_linear_idx -= group_block_cnt;
391  }
392  }
393 
394  template <class ScaleM = FlatmmScalePointer<-1>,
395  class ScaleN = FlatmmScalePointer<-1>,
396  index_t NumDTensor = 0>
397  CK_TILE_DEVICE void
399  {
400  int block_linear_idx = blockIdx.x;
401  int total_block_cnt = gridDim.x;
402  int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
403 
404  UnderlyingGemmKernel underlying_kernel{};
405  for(; block_linear_idx < total_work_tile_cnt; block_linear_idx += total_block_cnt)
406  {
407  auto [block_m_idx, block_n_idx] =
408  TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(block_linear_idx);
409  // get the group index from the M_indices
410  int group_idx = kargs.M_indices[block_m_idx * BlockGemmShape::kM];
411 
413  kargs.a_ptr,
414  static_cast<const BDataType*>(kargs.b_shuffle_ptr) + group_idx * kargs.N * kargs.K,
415  kargs.ds_ptr,
416  kargs.c_ptr,
417  kargs.M,
418  kargs.N,
419  kargs.K,
420  kargs.stride_A,
421  kargs.stride_B,
422  kargs.stride_Ds,
423  kargs.stride_C,
424  kargs.k_batch,
425  kargs.scale_m,
426  kargs.scale_n};
427  // call the underlying flatmm kernel
428  underlying_kernel(impl_kargs, block_linear_idx);
429  }
430  }
431 
432  template <class ScaleM = FlatmmScalePointer<-1>,
433  class ScaleN = FlatmmScalePointer<-1>,
434  index_t NumDTensor = 0>
435  CK_TILE_DEVICE void
437  {
438  int group_idx = 0;
439  int block_linear_idx = blockIdx.x;
440  int total_block_cnt = gridDim.x;
441 
442  UnderlyingGemmKernel underlying_kernel{};
443  for(; group_idx < kargs.group_count; ++group_idx)
444  {
445  const index_t valid_M = kargs.M_indices[group_idx];
446  const index_t N = kargs.N;
447  const index_t group_block_cnt = TilePartitioner::GridSize(valid_M, N);
448 
449  while(block_linear_idx < group_block_cnt)
450  {
451  // Found the group this block belongs to
452  // create the kernel args for the underlying flatmm kernel
454  static_cast<const ADataType*>(kargs.a_ptr) + group_idx * kargs.M * kargs.K,
455  static_cast<const BDataType*>(kargs.b_shuffle_ptr) +
456  group_idx * kargs.N * kargs.K,
457  kargs.ds_ptr,
458  static_cast<CDataType*>(kargs.c_ptr) + group_idx * kargs.M * kargs.N,
459  valid_M,
460  kargs.N,
461  kargs.K,
462  kargs.stride_A,
463  kargs.stride_B,
464  kargs.stride_Ds,
465  kargs.stride_C,
466  kargs.k_batch,
467  kargs.scale_m + group_idx * kargs.M,
468  kargs.scale_n + group_idx * kargs.N};
469  // call the underlying flatmm kernel
470  underlying_kernel(impl_kargs, block_linear_idx);
471  block_linear_idx += total_block_cnt;
472  }
473  block_linear_idx -= group_block_cnt;
474  }
475  }
476 };
477 
478 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
__global__ void kentry(Args... args)
Definition: kernel_launch.hpp:22
int32_t index_t
Definition: integer.hpp:9
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
Definition: grouped_flatmm_kernel.hpp:79
void * c_ptr
Definition: grouped_flatmm_kernel.hpp:128
void * e_ptr
Definition: grouped_flatmm_kernel.hpp:127
const std::array< index_t, NumDTensor > stride_Ds
Definition: grouped_flatmm_kernel.hpp:124
const std::array< const void *, NumDTensor > ds_ptr
Definition: grouped_flatmm_kernel.hpp:123
const void * b_shuffle_ptr
Definition: grouped_flatmm_kernel.hpp:121
CK_TILE_HOST ContiguousGroupedFlatmmHostArgs(index_t *M_indices_, index_t M_, index_t N_, index_t K_, const void *a_ptr_, index_t stride_A_, const void *b_shuffle_ptr_, index_t stride_B_, const std::array< const void *, NumDTensor > &ds_ptr_, const std::array< index_t, NumDTensor > &stride_Ds_, void *c_ptr_, index_t stride_C_, index_t k_batch_, ScaleM scale_m_=nullptr, ScaleN scale_n_=nullptr)
Definition: grouped_flatmm_kernel.hpp:81
const void * a_ptr
Definition: grouped_flatmm_kernel.hpp:119
index_t N
Definition: grouped_flatmm_kernel.hpp:117
index_t M
Definition: grouped_flatmm_kernel.hpp:116
index_t k_batch
Definition: grouped_flatmm_kernel.hpp:131
ScaleM scale_m
Definition: grouped_flatmm_kernel.hpp:132
CK_TILE_HOST ContiguousGroupedFlatmmHostArgs()=default
index_t group_count
Definition: grouped_flatmm_kernel.hpp:114
index_t K
Definition: grouped_flatmm_kernel.hpp:118
index_t stride_B
Definition: grouped_flatmm_kernel.hpp:122
index_t stride_A
Definition: grouped_flatmm_kernel.hpp:120
ScaleN scale_n
Definition: grouped_flatmm_kernel.hpp:133
index_t stride_C
Definition: grouped_flatmm_kernel.hpp:130
index_t * M_indices
Definition: grouped_flatmm_kernel.hpp:115
Definition: flatmm_kernel.hpp:229
Definition: flatmm_kernel.hpp:249
static constexpr CK_TILE_HOST auto BlockSize()
Definition: flatmm_kernel.hpp:330
remove_cvref_t< typename FlatmmPipeline::BlockGemmShape > BlockGemmShape
Definition: flatmm_kernel.hpp:253
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: flatmm_kernel.hpp:250
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition: flatmm_kernel.hpp:258
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: flatmm_kernel.hpp:259
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: flatmm_kernel.hpp:254
remove_cvref_t< FlatmmPipeline_ > FlatmmPipeline
Definition: flatmm_kernel.hpp:251
remove_cvref_t< typename FlatmmPipeline::ADataType > ADataType
Definition: flatmm_kernel.hpp:263
remove_cvref_t< typename FlatmmPipeline::BDataType > BDataType
Definition: flatmm_kernel.hpp:264
Definition: flatmm_kernel.hpp:33
Definition: grouped_flatmm_kernel.hpp:19
index_t * stride_A
Definition: grouped_flatmm_kernel.hpp:59
index_t * N
Definition: grouped_flatmm_kernel.hpp:56
const void ** b_shuffle_ptr
Definition: grouped_flatmm_kernel.hpp:60
CK_TILE_HOST GroupedFlatmmHostArgs(index_t group_count_, index_t *M_, index_t *N_, index_t *K_, const void **a_ptr_, index_t *stride_A_, const void **b_shuffle_ptr_, index_t *stride_B_, const std::array< const void *, NumDTensor > &ds_ptr_, const std::array< index_t, NumDTensor > &stride_Ds_, void **c_ptr_, index_t *stride_C_, index_t k_batch_, ScaleM *scale_m_=nullptr, ScaleN *scale_n_=nullptr)
Definition: grouped_flatmm_kernel.hpp:21
index_t * stride_B
Definition: grouped_flatmm_kernel.hpp:61
const std::array< index_t, NumDTensor > stride_Ds
Definition: grouped_flatmm_kernel.hpp:63
ScaleM * scale_m
Definition: grouped_flatmm_kernel.hpp:71
const std::array< const void *, NumDTensor > ds_ptr
Definition: grouped_flatmm_kernel.hpp:62
index_t k_batch
Definition: grouped_flatmm_kernel.hpp:70
index_t * stride_C
Definition: grouped_flatmm_kernel.hpp:69
CK_TILE_HOST GroupedFlatmmHostArgs()=default
const void ** a_ptr
Definition: grouped_flatmm_kernel.hpp:58
ScaleN * scale_n
Definition: grouped_flatmm_kernel.hpp:72
index_t group_count
Definition: grouped_flatmm_kernel.hpp:54
index_t * K
Definition: grouped_flatmm_kernel.hpp:57
index_t * M
Definition: grouped_flatmm_kernel.hpp:55
void ** e_ptr
Definition: grouped_flatmm_kernel.hpp:66
void ** c_ptr
Definition: grouped_flatmm_kernel.hpp:67
Definition: grouped_flatmm_kernel.hpp:201
static constexpr index_t NumDTensor
Definition: grouped_flatmm_kernel.hpp:217
static CK_TILE_HOST_DEVICE auto GridSize([[maybe_unused]] const GroupedFlatmmHostArgs< ScaleM, ScaleN, NumDTensor > &kernelArgs)
Definition: grouped_flatmm_kernel.hpp:238
static constexpr index_t kBlockSize
Definition: grouped_flatmm_kernel.hpp:218
static CK_TILE_HOST_DEVICE auto GridSize([[maybe_unused]] const MaskedGroupedFlatmmHostArgs< ScaleM, ScaleN, NumDTensor > &kernelArgs)
Definition: grouped_flatmm_kernel.hpp:304
static CK_TILE_HOST const std::string GetName()
Definition: grouped_flatmm_kernel.hpp:228
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition: grouped_flatmm_kernel.hpp:213
static constexpr CK_TILE_HOST auto MakeKernelArgs(const HostArgs &hostArgs)
Definition: grouped_flatmm_kernel.hpp:336
CK_TILE_DEVICE void operator()(ContiguousGroupedFlatmmHostArgs< ScaleM, ScaleN, NumDTensor > kargs) const
Definition: grouped_flatmm_kernel.hpp:398
CK_TILE_DEVICE void operator()(GroupedFlatmmHostArgs< ScaleM, ScaleN, NumDTensor > kargs) const
Definition: grouped_flatmm_kernel.hpp:354
static constexpr auto I1
Definition: grouped_flatmm_kernel.hpp:221
static constexpr auto I3
Definition: grouped_flatmm_kernel.hpp:223
static CK_TILE_HOST_DEVICE auto GridSize([[maybe_unused]] const ContiguousGroupedFlatmmHostArgs< ScaleM, ScaleN, NumDTensor > &kernelArgs)
Definition: grouped_flatmm_kernel.hpp:269
static constexpr auto I0
Definition: grouped_flatmm_kernel.hpp:220
CK_TILE_DEVICE void operator()(MaskedGroupedFlatmmHostArgs< ScaleM, ScaleN, NumDTensor > kargs) const
Definition: grouped_flatmm_kernel.hpp:436
static constexpr auto I2
Definition: grouped_flatmm_kernel.hpp:222
Definition: grouped_flatmm_kernel.hpp:140
index_t group_count
Definition: grouped_flatmm_kernel.hpp:178
void * e_ptr
Definition: grouped_flatmm_kernel.hpp:190
ScaleM scale_m
Definition: grouped_flatmm_kernel.hpp:195
ScaleN scale_n
Definition: grouped_flatmm_kernel.hpp:196
CK_TILE_HOST MaskedGroupedFlatmmHostArgs(index_t *M_indices_, index_t group_count_, index_t Max_M_, index_t N_, index_t K_, const void *a_ptr_, index_t stride_A_, const void *b_shuffle_ptr_, index_t stride_B_, const std::array< const void *, NumDTensor > &ds_ptr_, const std::array< index_t, NumDTensor > &stride_Ds_, void *c_ptr_, index_t stride_C_, index_t k_batch_, ScaleM scale_m_=nullptr, ScaleN scale_n_=nullptr)
Definition: grouped_flatmm_kernel.hpp:142
CK_TILE_HOST MaskedGroupedFlatmmHostArgs()=default
index_t * M_indices
Definition: grouped_flatmm_kernel.hpp:177
index_t N
Definition: grouped_flatmm_kernel.hpp:180
const std::array< const void *, NumDTensor > ds_ptr
Definition: grouped_flatmm_kernel.hpp:186
index_t k_batch
Definition: grouped_flatmm_kernel.hpp:194
index_t K
Definition: grouped_flatmm_kernel.hpp:181
const void * b_shuffle_ptr
Definition: grouped_flatmm_kernel.hpp:184
index_t stride_C
Definition: grouped_flatmm_kernel.hpp:193
const void * a_ptr
Definition: grouped_flatmm_kernel.hpp:182
index_t stride_A
Definition: grouped_flatmm_kernel.hpp:183
const std::array< index_t, NumDTensor > stride_Ds
Definition: grouped_flatmm_kernel.hpp:187
index_t M
Definition: grouped_flatmm_kernel.hpp:179
index_t stride_B
Definition: grouped_flatmm_kernel.hpp:185
void * c_ptr
Definition: grouped_flatmm_kernel.hpp:191
Definition: integral_constant.hpp:13