/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp Source File
smfmac_xdlops_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/math.hpp"
9 
10 namespace ck {
11 
12 enum struct SmfmacInstr
13 {
18 };
19 
20 template <SmfmacInstr instr>
21 struct smfmac_type;
22 
23 template <>
25 {
26  static constexpr index_t group_size = 4;
27  static constexpr index_t num_groups_per_blk = 1;
28  static constexpr index_t num_regs_per_blk = 4;
29  static constexpr index_t num_threads_per_blk = 16;
30  static constexpr index_t wave_size = 64;
31  static constexpr index_t num_input_blks = 4;
32  static constexpr index_t num_output_blks = 1;
33  static constexpr index_t m_per_blk = 16;
34  static constexpr index_t n_per_blk = 16;
35  static constexpr index_t k_per_blk = 8;
36  static constexpr bool is_k_reduction = true;
37 
38  template <index_t MPerXdlops,
39  index_t NPerXdlops,
40  index_t idx_part,
41  class FloatA,
42  class FloatB,
43  class FloatC>
44  __device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const
45  {
47  a, b, idx, reg_c);
48  }
49 };
50 
51 template <>
53 {
54  static constexpr index_t group_size = 4;
55  static constexpr index_t num_groups_per_blk = 4;
56  static constexpr index_t num_regs_per_blk = 16;
57  static constexpr index_t num_threads_per_blk = 32;
58  static constexpr index_t wave_size = 64;
59  static constexpr index_t num_input_blks = 2;
60  static constexpr index_t num_output_blks = 1;
61  static constexpr index_t m_per_blk = 32;
62  static constexpr index_t n_per_blk = 32;
63  static constexpr index_t k_per_blk = 16;
64  static constexpr bool is_k_reduction = true;
65 
66  template <index_t MPerXdlops,
67  index_t NPerXdlops,
68  index_t idx_part,
69  class FloatA,
70  class FloatB,
71  class FloatC>
72  __device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const
73  {
75  a, b, idx, reg_c);
76  }
77 };
78 
79 template <>
81 {
82  static constexpr index_t group_size = 4;
83  static constexpr index_t num_groups_per_blk = 1;
84  static constexpr index_t num_regs_per_blk = 4;
85  static constexpr index_t num_threads_per_blk = 16;
86  static constexpr index_t wave_size = 64;
87  static constexpr index_t num_input_blks = 4;
88  static constexpr index_t num_output_blks = 1;
89  static constexpr index_t m_per_blk = 16;
90  static constexpr index_t n_per_blk = 16;
91  static constexpr index_t k_per_blk = 8;
92  static constexpr bool is_k_reduction = true;
93 
94  template <index_t MPerXdlops,
95  index_t NPerXdlops,
96  index_t idx_part,
97  class FloatA,
98  class FloatB,
99  class FloatC>
100  __device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const
101  {
103  a, b, idx, reg_c);
104  }
105 };
106 
107 template <>
109 {
110  static constexpr index_t group_size = 4;
111  static constexpr index_t num_groups_per_blk = 4;
112  static constexpr index_t num_regs_per_blk = 16;
113  static constexpr index_t num_threads_per_blk = 32;
114  static constexpr index_t wave_size = 64;
115  static constexpr index_t num_input_blks = 2;
116  static constexpr index_t num_output_blks = 1;
117  static constexpr index_t m_per_blk = 32;
118  static constexpr index_t n_per_blk = 32;
119  static constexpr index_t k_per_blk = 16;
120  static constexpr bool is_k_reduction = true;
121 
122  template <index_t MPerXdlops,
123  index_t NPerXdlops,
124  index_t idx_part,
125  class FloatA,
126  class FloatB,
127  class FloatC>
128  __device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const
129  {
131  a, b, idx, reg_c);
132  }
133 };
134 
135 template <typename base_type,
136  index_t MPerXdlops,
137  index_t NPerXdlops,
138  typename additional_type = base_type>
140 {
141  template <typename base_type_,
142  index_t MPerXdlops_,
143  index_t NPerXdlops_,
144  typename additional_type_ = base_type_>
145  static constexpr auto GetSmfmac();
146 
147  template <>
148  static constexpr auto GetSmfmac<half_t, 16, 16>()
149  {
151  }
152 
153  template <>
154  static constexpr auto GetSmfmac<half_t, 32, 32>()
155  {
157  }
158 
159  template <>
160  static constexpr auto GetSmfmac<bhalf_t, 16, 16>()
161  {
163  }
164 
165  template <>
166  static constexpr auto GetSmfmac<bhalf_t, 32, 32>()
167  {
169  }
170 
171  static constexpr auto selected_smfmac =
173 
174  __host__ __device__ constexpr SmfmacSelector()
175  {
176  static_assert(selected_smfmac.group_size * selected_smfmac.num_groups_per_blk ==
177  selected_smfmac.num_regs_per_blk,
178  "wrong! num_regs_per_blk");
179 
180  static_assert(selected_smfmac.num_threads_per_blk == selected_smfmac.n_per_blk,
181  "n_per_blk != num_threads_per_blk");
182 
183  static_assert(selected_smfmac.num_regs_per_blk * selected_smfmac.num_input_blks ==
184  selected_smfmac.m_per_blk,
185  "m_per_blk != num_input_blks * num_regs_per_blk");
186 
187  static_assert(selected_smfmac.num_output_blks == selected_smfmac.num_input_blks ||
188  selected_smfmac.num_output_blks == 1,
189  "incorrect num_output_blks");
190 
191  static_assert(selected_smfmac.num_regs_per_blk * selected_smfmac.wave_size ==
192  selected_smfmac.m_per_blk * selected_smfmac.n_per_blk,
193  "num_regs_per_blk incorrect");
194 
195  static_assert(selected_smfmac.is_k_reduction ||
196  (selected_smfmac.num_input_blks == selected_smfmac.num_output_blks),
197  "is_k_reduction wrong!");
198  }
199 
200  static constexpr index_t GetKPerXdlops()
201  {
202  return (selected_smfmac.is_k_reduction ? selected_smfmac.num_input_blks : 1) *
203  selected_smfmac.k_per_blk;
204  }
205 
206  static constexpr index_t GetK1PerXdlops() { return selected_smfmac.k_per_blk; }
207 };
208 
209 template <typename base_type,
210  index_t MPerXdlops,
211  index_t NPerXdlops,
212  index_t KPack,
213  typename additional_type = base_type>
215 {
216  static constexpr auto I0 = Number<0>{};
217  static constexpr auto I1 = Number<1>{};
218  static constexpr auto I2 = Number<2>{};
219  static constexpr auto I3 = Number<3>{};
220  static constexpr auto I4 = Number<4>{};
221  static constexpr auto I5 = Number<5>{};
222 
225 
226  __device__ static constexpr index_t GetNumBlks() { return smfmac_instr.num_output_blks; }
227 
228  __device__ static constexpr index_t GetNumXdlops()
229  {
230  return MPerXdlops * NPerXdlops /
231  (smfmac_instr.m_per_blk * smfmac_instr.n_per_blk * smfmac_instr.num_output_blks);
232  }
233 
234  __host__ __device__ constexpr SparseXdlopsGemm()
235  {
236  static_assert(NPerXdlops == 16 || NPerXdlops == 32,
237  "Only support GemmNPerXdlops == 16 or 32 for smfmac xdlops");
238 
239  static_assert(MPerXdlops == 16 || MPerXdlops == 32,
240  "Only support GemmMPerXdlops == 16 or 32 for smfmac xdlops");
241 
242  static_assert(KPack % smfmac_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk");
243  }
244 
245  // XDL output supporting C = A * B
246  // M2_N2 -> M2_M3_M4_N2
247  template <typename CDesc_M0_N0_M1_N1_M2_N2>
248  __host__ __device__ static constexpr auto
249  MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
250  {
251  const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
252  const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
253  const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
254  const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
255 
257  c_desc_m0_n0_m1_n1_m2_n2,
263  Number<smfmac_instr.num_input_blks>{},
264  Number<smfmac_instr.group_size>{})),
267  Sequence<1>{},
268  Sequence<2>{},
269  Sequence<3>{},
270  Sequence<4>{},
271  Sequence<5>{}),
273  Sequence<1>{},
274  Sequence<2>{},
275  Sequence<3>{},
277  Sequence<7>{}));
278  }
279 
280  template <typename CDesc_G_M0_N0_M1_N1_M2_N2>
281  __host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
282  const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2)
283  {
284  const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0);
285  const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1);
286  const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2);
287  const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3);
288  const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4);
289 
291  c_desc_g_m0_n0_m1_n1_m2_n2,
298  smfmac_instr.num_input_blks,
299  smfmac_instr.group_size)),
300  make_pass_through_transform(smfmac_instr.num_threads_per_blk)),
302  Sequence<1>{},
303  Sequence<2>{},
304  Sequence<3>{},
305  Sequence<4>{},
306  Sequence<5>{},
307  Sequence<6>{}),
309  Sequence<1>{},
310  Sequence<2>{},
311  Sequence<3>{},
312  Sequence<4>{},
314  Sequence<8>{}));
315  }
316 
317  __device__ static constexpr index_t GetRegSizePerXdlops()
318  {
319  return MPerXdlops * NPerXdlops / smfmac_instr.wave_size;
320  }
321 
322  __device__ static constexpr index_t GetWaveSize() { return smfmac_instr.wave_size; }
323 
324  template <class FloatA, class FloatB, class Idx, class FloatC>
325  __device__ void
326  Run(const FloatA& p_a_wave, const FloatB& p_b_wave, const Idx& idx, FloatC& p_c_thread) const
327  {
329  "base base_type must be half or bfloat16!");
330 
331  static_for<0, KPack / smfmac_instr.k_per_blk, 1>{}([&](auto k) {
332  smfmac_instr.template run<MPerXdlops, NPerXdlops, k % 4>(
333  p_a_wave[k], p_b_wave[k], idx[k / 4], p_c_thread);
334  });
335  }
336 
337  __device__ static auto GetLaneId() { return get_thread_local_1d_id() % smfmac_instr.wave_size; }
338 
339  __device__ static auto GetBlkIdx()
340  {
341  const auto laneId = GetLaneId();
342 
343  constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor(
345  make_tuple(1, smfmac_instr.num_input_blks, smfmac_instr.num_threads_per_blk))),
348 
349  const auto blk_idx =
350  threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId));
351 
352  const auto blk_id = blk_idx[I1];
353  const auto blk_td = blk_idx[I2];
354 
355  return make_tuple(blk_id, blk_td);
356  }
357 
358  __host__ __device__ static auto CalculateAThreadOriginDataIndex()
359  {
360  const auto laneId = GetLaneId();
361  const auto blk_idx = GetBlkIdx();
362 
363  const auto blk_id = blk_idx[I0];
364  const auto blk_td = blk_idx[I1];
365 
366  if constexpr(smfmac_instr.is_k_reduction)
367  {
368  return make_tuple(blk_id, blk_td);
369  }
370  else
371  {
372  return make_tuple(0, laneId);
373  }
374  }
375 
376  __host__ __device__ static auto CalculateBThreadOriginDataIndex()
377  {
378  const auto laneId = GetLaneId();
379  const auto blk_idx = GetBlkIdx();
380 
381  const auto blk_id = blk_idx[I0];
382  const auto blk_td = blk_idx[I1];
383 
384  if constexpr(smfmac_instr.is_k_reduction)
385  {
386  return make_tuple(blk_id, blk_td);
387  }
388  else
389  {
390  return make_tuple(0, laneId);
391  }
392  }
393 
394  __device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
395  {
396  const auto blk_idx = GetBlkIdx();
397 
398  const auto blk_id = blk_idx[I0];
399  const auto blk_td = blk_idx[I1];
400 
401  index_t n_offset = blk_i * smfmac_instr.n_per_blk + blk_td;
402  index_t m_offset = xdlops_i * smfmac_instr.m_per_blk + blk_id * smfmac_instr.group_size;
403 
404  return CIndex{m_offset, n_offset};
405  }
406 
407  __device__ static CIndex4D GetBeginOfThreadBlk4D(index_t /* xdlops_i */, index_t /* blk_i */)
408  {
409  const auto blk_idx = GetBlkIdx();
410 
411  const auto blk_id = blk_idx[I0];
412  const auto blk_td = blk_idx[I1];
413 
414  return CIndex4D{I0, blk_id, I0, blk_td};
415  }
416 
417  static constexpr auto smfmac =
419 
420  static constexpr auto smfmac_instr = smfmac.selected_smfmac;
421 
422  static constexpr auto KPerXdlops = smfmac.GetKPerXdlops();
423  static constexpr auto K1PerXdlops = smfmac.GetK1PerXdlops();
424  static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
425 
426  __host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths()
427  {
428  return make_tuple(
430  }
431 };
432 
433 } // namespace ck
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
SmfmacInstr
Definition: smfmac_xdlops_gemm.hpp:13
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1249
Definition: array.hpp:14
Definition: sequence.hpp:43
Definition: smfmac_xdlops_gemm.hpp:140
static constexpr index_t GetKPerXdlops()
Definition: smfmac_xdlops_gemm.hpp:200
__host__ constexpr __device__ SmfmacSelector()
Definition: smfmac_xdlops_gemm.hpp:174
static constexpr auto GetSmfmac()
static constexpr auto selected_smfmac
Definition: smfmac_xdlops_gemm.hpp:171
static constexpr index_t GetK1PerXdlops()
Definition: smfmac_xdlops_gemm.hpp:206
Definition: smfmac_xdlops_gemm.hpp:215
static __device__ auto GetLaneId()
Definition: smfmac_xdlops_gemm.hpp:337
static constexpr __device__ index_t GetRegSizePerXdlops()
Definition: smfmac_xdlops_gemm.hpp:317
static constexpr auto K0PerXdlops
Definition: smfmac_xdlops_gemm.hpp:424
static constexpr auto I2
Definition: smfmac_xdlops_gemm.hpp:218
static __device__ CIndex4D GetBeginOfThreadBlk4D(index_t, index_t)
Definition: smfmac_xdlops_gemm.hpp:407
static constexpr auto I0
Definition: smfmac_xdlops_gemm.hpp:216
static constexpr auto I5
Definition: smfmac_xdlops_gemm.hpp:221
__device__ void Run(const FloatA &p_a_wave, const FloatB &p_b_wave, const Idx &idx, FloatC &p_c_thread) const
Definition: smfmac_xdlops_gemm.hpp:326
static constexpr auto K1PerXdlops
Definition: smfmac_xdlops_gemm.hpp:423
static constexpr __device__ index_t GetWaveSize()
Definition: smfmac_xdlops_gemm.hpp:322
static constexpr auto I1
Definition: smfmac_xdlops_gemm.hpp:217
static constexpr __device__ index_t GetNumBlks()
Definition: smfmac_xdlops_gemm.hpp:226
static __device__ auto GetBlkIdx()
Definition: smfmac_xdlops_gemm.hpp:339
__host__ static constexpr __device__ auto GetCM0M1M2NThreadBlkLengths()
Definition: smfmac_xdlops_gemm.hpp:426
static constexpr auto I4
Definition: smfmac_xdlops_gemm.hpp:220
static __device__ CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
Definition: smfmac_xdlops_gemm.hpp:394
__host__ static __device__ auto CalculateAThreadOriginDataIndex()
Definition: smfmac_xdlops_gemm.hpp:358
static constexpr auto smfmac_instr
Definition: smfmac_xdlops_gemm.hpp:420
static constexpr auto smfmac
Definition: smfmac_xdlops_gemm.hpp:417
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition: smfmac_xdlops_gemm.hpp:249
static constexpr __device__ index_t GetNumXdlops()
Definition: smfmac_xdlops_gemm.hpp:228
__host__ constexpr __device__ SparseXdlopsGemm()
Definition: smfmac_xdlops_gemm.hpp:234
__host__ static constexpr __device__ auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_G_M0_N0_M1_N1_M2_N2 &c_desc_g_m0_n0_m1_n1_m2_n2)
Definition: smfmac_xdlops_gemm.hpp:281
static constexpr auto KPerXdlops
Definition: smfmac_xdlops_gemm.hpp:422
__host__ static __device__ auto CalculateBThreadOriginDataIndex()
Definition: smfmac_xdlops_gemm.hpp:376
static constexpr auto I3
Definition: smfmac_xdlops_gemm.hpp:219
Definition: integral_constant.hpp:20
Definition: amd_smfmac.hpp:34
Definition: amd_smfmac.hpp:10
Definition: amd_smfmac.hpp:78
Definition: amd_smfmac.hpp:56
Definition: type.hpp:177
__device__ void run(const FloatA &a, const FloatB &b, const index_t &idx, FloatC &reg_c) const
Definition: smfmac_xdlops_gemm.hpp:100
__device__ void run(const FloatA &a, const FloatB &b, const index_t &idx, FloatC &reg_c) const
Definition: smfmac_xdlops_gemm.hpp:44
__device__ void run(const FloatA &a, const FloatB &b, const index_t &idx, FloatC &reg_c) const
Definition: smfmac_xdlops_gemm.hpp:128
__device__ void run(const FloatA &a, const FloatB &b, const index_t &idx, FloatC &reg_c) const
Definition: smfmac_xdlops_gemm.hpp:72
Definition: smfmac_xdlops_gemm.hpp:21
Definition: functional2.hpp:33