/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp Source File
mixed_prec_flatmm_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 
13 
14 namespace ck_tile {
15 
16 template <typename TilePartitioner_, typename FlatmmPipeline_, typename EpiloguePipeline_>
17 struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>
18 {
20 
31  static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize;
32  static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel;
33 
36  // Below type is actually accumulation data type - the output of block GEMM.
38 
40  static constexpr int N_Pack = 2;
41 
42  static constexpr index_t NumDTensor = DsDataType::size();
43 
44  static constexpr auto I0 = number<0>();
45  static constexpr auto I1 = number<1>();
46  static constexpr auto I2 = number<2>();
47  static constexpr auto I3 = number<3>();
48  static constexpr auto I4 = number<4>();
49 
50  static_assert(DsLayout::size() == DsDataType::size(),
51  "The size of DsLayout and DsDataType should be the same");
52  // using KernelArgs = FlatmmKernelArgs<DsLayout::size()>;
53 
54  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
55  {
56  // clang-format off
57  return concat('_', "mixed_prec_gemm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
58  // clang-format on
59  }
60 
61  template <class ScaleM, class ScaleN>
62  CK_TILE_HOST static constexpr auto
63  GridSize(const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs)
64  {
65  if constexpr(UsePersistentKernel)
66  {
67  hipDeviceProp_t prop;
68  int deviceId = 0; // default device
69 
70  constexpr int block_size = F16xMXF4FlatmmKernel::BlockSize().x;
71  int dync_smem_size = 0;
72  int maxActiveBlocksPerCU = 0;
73 
74  [[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
75 
76  e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
77  &maxActiveBlocksPerCU,
78  reinterpret_cast<void*>(
79  kentry<1,
81  FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>>),
82  block_size,
83  dync_smem_size);
84 
85  const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
86  const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
87 
88  // std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
89  // << ", persistent_block_size: " << persistent_block_size
90  // << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
91 
92  assert(kargs.k_batch == 1);
93  return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
94  }
95  else
96  {
97  return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
98  }
99  }
100 
102 
103  template <typename KernelArgs>
104  CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr,
105  const KernelArgs& kargs,
106  const index_t k_size,
107  const index_t block_idx_m)
108  {
109  // Step 1: Create tensor view
110  const auto& a_tensor_view = [&]() {
111  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
112  {
113  return make_naive_tensor_view<address_space_enum::global>(
114  a_ptr,
115  make_tuple(kargs.M, k_size),
116  make_tuple(kargs.stride_A, 1),
117  number<FlatmmPipeline::GetVectorSizeA()>{},
118  number<1>{});
119  }
120  else
121  {
122  return make_naive_tensor_view<address_space_enum::global>(
123  a_ptr,
124  make_tuple(k_size, kargs.M),
125  make_tuple(kargs.stride_A, 1),
126  number<FlatmmPipeline::GetVectorSizeA()>{},
127  number<1>{});
128  }
129  }();
130 
131  // Step 2: Create padded view
132  const auto& a_pad_view = [&]() {
133  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
134  {
135  return pad_tensor_view(a_tensor_view,
139  }
140  else
141  {
142  return pad_tensor_view(a_tensor_view,
146  }
147  }();
148 
149  // Step 3: Create tile window
150  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
151  {
152  return make_tile_window(a_pad_view,
155  {block_idx_m, 0});
156  }
157  else
158  {
159  return make_tile_window(a_pad_view,
162  {0, block_idx_m});
163  }
164  }
165 
166  template <typename KernelArgs>
167  CK_TILE_DEVICE static auto MakeBFlatBlockWindow(const BDataType* b_flat_ptr,
168  const KernelArgs& kargs,
169  const index_t block_idx_n)
170  {
171  // Step 1: Create tensor view
172  index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1);
173  index_t kFlatN = kargs.N * kargs.K / kFlatK;
174 
175  const auto& b_flat_tensor_view = make_naive_tensor_view<address_space_enum::global>(
176  b_flat_ptr,
177  make_tuple(kFlatN, kFlatK),
178  make_tuple(kFlatK, 1),
179  number<FlatmmPipeline::GetVectorSizeB()>{},
180  number<1>{});
181 
182  // Step 2: No padding needed for b_flat
183  // Step 3: Create tile window
184  return make_tile_window(
185  b_flat_tensor_view,
188  {static_cast<int>(block_idx_n / BlockGemmShape::WarpTile::at(I1)), 0});
189  }
190 
191  template <typename KernelArgs>
192  CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
193  const KernelArgs& kargs,
194  const index_t block_idx_m,
195  const index_t block_idx_n)
196  {
197  // Step 1: Create tensor views
198  const auto& ds_tensor_view = generate_tuple(
199  [&](auto i) {
200  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
201  using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
202  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
203  {
204  return make_naive_tensor_view<address_space_enum::global>(
205  static_cast<const DDataType_*>(ds_ptr[i]),
206  make_tuple(kargs.M, kargs.N),
207  make_tuple(kargs.stride_Ds[i], 1),
208  number<EpiloguePipeline::GetVectorSizeD(i)>{},
209  number<1>{});
210  }
211  else
212  {
213  return make_naive_tensor_view<address_space_enum::global>(
214  static_cast<const DDataType_*>(ds_ptr[i]),
215  make_tuple(kargs.N, kargs.M),
216  make_tuple(kargs.stride_Ds[i], 1),
217  number<EpiloguePipeline::GetVectorSizeD(i)>{},
218  number<1>{});
219  }
220  },
222 
223  // Step 2: Create padded views
224  const auto& ds_pad_view = generate_tuple(
225  [&](auto i) {
226  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
227  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
228  {
229  return pad_tensor_view(ds_tensor_view[i],
233  }
234  else
235  {
236  return pad_tensor_view(ds_tensor_view[i],
240  }
241  },
243 
244  // Step 3: Create tile windows
245  return generate_tuple(
246  [&](auto i) {
247  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
248  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
249  {
250  return make_tile_window(ds_pad_view[i],
253  {block_idx_m, block_idx_n});
254  }
255  else
256  {
257  return make_tile_window(ds_pad_view[i],
260  {block_idx_n, block_idx_m});
261  }
262  },
264  }
265 
266  template <memory_operation_enum DstInMemOp = memory_operation_enum::set, typename KernelArgs>
268  const KernelArgs& kargs,
269  const index_t block_idx_m,
270  const index_t block_idx_n)
271  {
272  // Step 1: Create tensor view
273  const auto& e_tensor_view = [&]() {
274  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
275  {
276  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
277  e_ptr,
278  make_tuple(kargs.M, kargs.N),
279  make_tuple(kargs.stride_E, 1),
280  number<EpiloguePipeline::GetVectorSizeC()>{},
281  number<1>{});
282  }
283  else
284  {
285  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
286  e_ptr,
287  make_tuple(kargs.N, kargs.M),
288  make_tuple(kargs.stride_E, 1),
289  number<1>{},
290  number<1>{});
291  }
292  }();
293 
294  // Step 2: Create padded view
295  const auto& e_pad_view = [&]() {
296  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
297  {
298  return pad_tensor_view(e_tensor_view,
302  }
303  else
304  {
305  return pad_tensor_view(e_tensor_view,
309  }
310  }();
311 
312  // Step 3: Create tile window
313  return make_tile_window(
314  e_pad_view,
316  {block_idx_m, block_idx_n});
317  }
318 
319  template <typename KernelArgs>
320  CK_TILE_DEVICE static auto MakeScaleBBlockWindow(const KernelArgs& kargs,
321  const index_t block_idx_n)
322  {
323  auto scale_n = kargs.scale_n_ptr;
324 
325  // Step 1: Create tensor view
326  index_t FlatScaleK =
327  (kargs.K / decltype(scale_n)::GranularityK) * N_Pack * BlockGemmShape::WarpTile::at(I1);
328  index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
329 
330  const auto scale_b_flat_view = make_naive_tensor_view<address_space_enum::global>(
331  reinterpret_cast<const e8m0_t*>(scale_n.ptr),
332  make_tuple(FlatScaleN, FlatScaleK),
333  make_tuple(FlatScaleK, 1),
334  number<8>{},
335  number<1>{});
336 
337  // Step 2: Create tile window
338  return make_tile_window(
339  scale_b_flat_view,
341  number<FlatmmPipeline::flatKPerWarp * N_Pack * 4 / 32>{}),
342  {block_idx_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
343  }
344 
345  template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
346  CK_TILE_DEVICE static void
347  RunFlatmm(const ADataType* a_ptr,
348  const BDataType* b_flat_ptr,
349  const std::array<const void*, NumDTensor>& ds_ptr,
350  EDataType* e_ptr,
351  void* smem_ptr_ping,
352  void* smem_ptr_pong,
353  const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs,
354  const SplitKBatchOffset& splitk_batch_offset,
355  const index_t block_idx_m,
356  const index_t block_idx_n)
357  {
358  // Create block windows using specialized methods
359  const auto& a_block_window =
360  MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
361  const auto& b_flat_block_window = MakeBFlatBlockWindow(b_flat_ptr, kargs, block_idx_n);
362  const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
363  const auto& scale_block_window = MakeScaleBBlockWindow(kargs, block_idx_n);
364 
365  const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
366 
367  static_assert(ScaleM::GranularityK == ScaleN::GranularityK // have the same granK
368  || ScaleM::GranularityMN == -1 // or ScaleA is disable
369  || ScaleN::GranularityMN == -1, // or ScaleB is disable
370  "ScaleM and ScaleN should have the same GranularityK");
371  constexpr bool DoEpiScale =
372  (ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) || // per token
373  (ScaleN::GranularityMN != -1 && ScaleN::GranularityK == 0); // per channel
374 
375  // Run GEMM cooperatively by whole workgroup.
376  auto a_block_window_with_distr =
377  ck_tile::make_tile_window(a_block_window.get_bottom_tensor_view(),
378  a_block_window.get_window_lengths(),
379  a_block_window.get_window_origin(),
380  FlatmmPipeline::GetADramTileDistribution());
381  const auto& c_block_tile = FlatmmPipeline{}(a_block_window_with_distr,
382  b_flat_block_window,
383  scale_block_window,
384  num_loop,
385  smem_ptr_ping,
386  smem_ptr_pong);
387 
388  // Run Epilogue Pipeline with k_batch dispatching
389  if constexpr(DoEpiScale)
390  {
391  if(kargs.k_batch == 1)
392  {
393  auto e_block_window = MakeEBlockWindow<memory_operation_enum::set>(
394  e_ptr, kargs, block_idx_m, block_idx_n);
395  EpiloguePipeline{}(e_block_window,
396  c_block_tile,
397  ds_block_window,
398  smem_ptr_ping,
399  kargs.scale_m_ptr + block_idx_m,
400  kargs.scale_n_ptr + block_idx_n);
401  }
402  else
403  {
404  auto e_block_window = MakeEBlockWindow<memory_operation_enum::atomic_add>(
405  e_ptr, kargs, block_idx_m, block_idx_n);
406  EpiloguePipeline{}(e_block_window,
407  c_block_tile,
408  ds_block_window,
409  smem_ptr_ping,
410  kargs.scale_m_ptr + block_idx_m,
411  kargs.scale_n_ptr + block_idx_n);
412  }
413  }
414  else if(UseDefaultScheduler || (get_warp_id() == 0))
415  {
416  if(kargs.k_batch == 1)
417  {
418  auto e_block_window = MakeEBlockWindow<memory_operation_enum::set>(
419  e_ptr, kargs, block_idx_m, block_idx_n);
420  EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_window, smem_ptr_ping);
421  }
422  else
423  {
424  auto e_block_window = MakeEBlockWindow<memory_operation_enum::atomic_add>(
425  e_ptr, kargs, block_idx_m, block_idx_n);
426  EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_window, smem_ptr_ping);
427  }
428  }
429  }
430 
431  template <class ScaleM, class ScaleN>
432  CK_TILE_DEVICE void operator()(FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()> kargs,
433  int partition_idx = blockIdx.x) const
434  {
435  int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
436 
437  do
438  {
439  const auto [iM, iN] =
440  TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
441  const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
442  const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
443 
444  const SplitKBatchOffset splitk_batch_offset(kargs);
445  // options
446  const ADataType* a_ptr =
447  static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
448  const BDataType* b_flat_ptr = static_cast<const BDataType*>(kargs.b_ptr) +
449  splitk_batch_offset.b_k_split_offset / QuantPackedSize;
450  EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
451 
452  // allocate LDS
453  __shared__ char smem_ptr_ping[Underlying::GetSmemPingSize()];
454  __shared__ char smem_ptr_pong[Underlying::GetSmemPongSize()];
455 
456  if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
458  {
459  constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
460  RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
461  b_flat_ptr,
462  kargs.ds_ptr,
463  e_ptr,
464  smem_ptr_ping,
465  smem_ptr_pong,
466  kargs,
467  splitk_batch_offset,
468  i_m,
469  i_n);
470  }
471  partition_idx += gridDim.x;
472  } while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
473  }
474 };
475 
476 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
Definition: cluster_descriptor.hpp:13
__global__ void kentry(Args... args)
Definition: kernel_launch.hpp:22
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:206
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
Definition: mixed_prec_flatmm_kernel.hpp:18
static constexpr int N_Pack
Definition: mixed_prec_flatmm_kernel.hpp:40
static constexpr auto I4
Definition: mixed_prec_flatmm_kernel.hpp:48
static CK_TILE_DEVICE auto MakeScaleBBlockWindow(const KernelArgs &kargs, const index_t block_idx_n)
Definition: mixed_prec_flatmm_kernel.hpp:320
static CK_TILE_DEVICE auto MakeDBlockWindows(const std::array< const void *, NumDTensor > &ds_ptr, const KernelArgs &kargs, const index_t block_idx_m, const index_t block_idx_n)
Definition: mixed_prec_flatmm_kernel.hpp:192
static constexpr index_t KernelBlockSize
Definition: mixed_prec_flatmm_kernel.hpp:31
static CK_TILE_DEVICE auto MakeEBlockWindow(EDataType *e_ptr, const KernelArgs &kargs, const index_t block_idx_m, const index_t block_idx_n)
Definition: mixed_prec_flatmm_kernel.hpp:267
static CK_TILE_HOST const std::string GetName()
Definition: mixed_prec_flatmm_kernel.hpp:54
static constexpr auto I0
Definition: mixed_prec_flatmm_kernel.hpp:44
static constexpr auto I1
Definition: mixed_prec_flatmm_kernel.hpp:45
static constexpr auto I2
Definition: mixed_prec_flatmm_kernel.hpp:46
CK_TILE_DEVICE void operator()(FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> kargs, int partition_idx=blockIdx.x) const
Definition: mixed_prec_flatmm_kernel.hpp:432
static constexpr int QuantPackedSize
Definition: mixed_prec_flatmm_kernel.hpp:39
static constexpr bool UsePersistentKernel
Definition: mixed_prec_flatmm_kernel.hpp:32
static CK_TILE_DEVICE auto MakeABlockWindow(const ADataType *a_ptr, const KernelArgs &kargs, const index_t k_size, const index_t block_idx_m)
Definition: mixed_prec_flatmm_kernel.hpp:104
static constexpr CK_TILE_HOST auto GridSize(const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs)
Definition: mixed_prec_flatmm_kernel.hpp:63
typename Underlying::SplitKBatchOffset SplitKBatchOffset
Definition: mixed_prec_flatmm_kernel.hpp:101
static constexpr auto I3
Definition: mixed_prec_flatmm_kernel.hpp:47
static CK_TILE_DEVICE void RunFlatmm(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_ping, void *smem_ptr_pong, const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Definition: mixed_prec_flatmm_kernel.hpp:347
static constexpr index_t NumDTensor
Definition: mixed_prec_flatmm_kernel.hpp:42
static CK_TILE_DEVICE auto MakeBFlatBlockWindow(const BDataType *b_flat_ptr, const KernelArgs &kargs, const index_t block_idx_n)
Definition: mixed_prec_flatmm_kernel.hpp:167
Definition: flatmm_kernel.hpp:365
Definition: flatmm_kernel.hpp:232
Definition: flatmm_kernel.hpp:252
static constexpr CK_TILE_HOST auto BlockSize()
Definition: flatmm_kernel.hpp:333
remove_cvref_t< typename FlatmmPipeline::BlockGemmShape > BlockGemmShape
Definition: flatmm_kernel.hpp:256
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: flatmm_kernel.hpp:253
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition: flatmm_kernel.hpp:261
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: flatmm_kernel.hpp:262
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: flatmm_kernel.hpp:269
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: flatmm_kernel.hpp:257
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPingSize()
Definition: flatmm_kernel.hpp:355
remove_cvref_t< FlatmmPipeline_ > FlatmmPipeline
Definition: flatmm_kernel.hpp:254
remove_cvref_t< typename FlatmmPipeline::CLayout > ELayout
Definition: flatmm_kernel.hpp:260
remove_cvref_t< typename FlatmmPipeline::ADataType > ADataType
Definition: flatmm_kernel.hpp:266
remove_cvref_t< typename FlatmmPipeline::BLayout > BLayout
Definition: flatmm_kernel.hpp:259
remove_cvref_t< typename FlatmmPipeline::ALayout > ALayout
Definition: flatmm_kernel.hpp:258
remove_cvref_t< typename FlatmmPipeline::BDataType > BDataType
Definition: flatmm_kernel.hpp:267
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPongSize()
Definition: flatmm_kernel.hpp:359
Definition: integral_constant.hpp:13
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:27
Definition: type_traits.hpp:115
Definition: numeric.hpp:81
Definition: sequence.hpp:49