/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp Source File
gridwise_gemm_xdlops_v2r3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
17 
18 namespace ck {
19 
20 template <typename GridwiseGemm,
21  typename FloatAB,
22  typename FloatC,
23  typename AGridDesc_K0_M_K1,
24  typename BGridDesc_K0_N_K1,
25  typename CGridDesc_M_N,
26  bool HasMainKBlockLoop>
27 __global__ void
28 #if CK_USE_LAUNCH_BOUNDS
30 #endif
31 #if CK_USE_WAVES_PER_EU
32  __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU)))
33 #endif
34  kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
35  const FloatAB* __restrict__ p_b_grid,
36  FloatC* __restrict__ p_c_grid,
37  const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
38  const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
39  const CGridDesc_M_N c_grid_desc_m_n)
40 {
41 #if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__))
42  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
43 
44  GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
45  p_b_grid,
46  p_c_grid,
47  p_shared,
48  a_grid_desc_k0_m_k1,
49  b_grid_desc_k0_n_k1,
50  c_grid_desc_m_n);
51 #else
52  ignore = p_a_grid;
53  ignore = p_b_grid;
54  ignore = p_c_grid;
55  ignore = a_grid_desc_k0_m_k1;
56  ignore = b_grid_desc_k0_n_k1;
57  ignore = c_grid_desc_m_n;
58 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
59 }
60 
61 template <typename GridwiseGemm, bool HasMainKBlockLoop>
62 __global__ void
63 #if CK_USE_LAUNCH_BOUNDS
65 #endif
66 #if CK_USE_WAVES_PER_EU
67  __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU)))
68 #endif
69  kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg)
70 {
71 #if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__))
72  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
73 
74  const auto a_grid_desc_k0_m_k1 =
75  amd_wave_read_first_lane(GridwiseGemm::MakeAGridDescriptor_K0_M_K1(
76  karg.M, karg.MPadded, karg.K, karg.K0, karg.StrideA));
77  const auto b_grid_desc_k0_n_k1 =
78  amd_wave_read_first_lane(GridwiseGemm::MakeBGridDescriptor_K0_N_K1(
79  karg.K, karg.N, karg.NPadded, karg.K0, karg.StrideB));
80  const auto c_grid_desc_m_n = amd_wave_read_first_lane(GridwiseGemm::MakeCGridDescriptor_M_N(
81  karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC));
82 
83  GridwiseGemm::template Run<HasMainKBlockLoop>(karg.p_a_grid,
84  karg.p_b_grid,
85  karg.p_c_grid,
86  p_shared,
87  a_grid_desc_k0_m_k1,
88  b_grid_desc_k0_n_k1,
89  c_grid_desc_m_n);
90 #else
91  ignore = karg;
92 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
93 }
94 
95 template <index_t BlockSize,
96  typename FloatAB,
97  typename FloatAcc,
98  typename FloatC,
99  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
100  typename AElementwiseOperation,
101  typename BElementwiseOperation,
102  typename CElementwiseOperation,
103  index_t MPerBlock,
104  index_t NPerBlock,
105  index_t K0PerBlock,
106  index_t MPerXDL,
107  index_t NPerXDL,
108  index_t K1Value,
109  index_t MXdlPerWave,
110  index_t NXdlPerWave,
111  typename ABlockTransferThreadClusterLengths_K0_M_K1,
112  typename ABlockTransferThreadClusterArrangeOrder,
113  typename ABlockTransferSrcAccessOrder,
114  index_t ABlockTransferSrcVectorDim,
115  index_t ABlockTransferSrcScalarPerVector,
116  index_t ABlockTransferDstScalarPerVector_K1,
117  bool AThreadTransferSrcResetCoordinateAfterRun,
118  bool ABlockLdsExtraM,
119  typename BBlockTransferThreadClusterLengths_K0_N_K1,
120  typename BBlockTransferThreadClusterArrangeOrder,
121  typename BBlockTransferSrcAccessOrder,
122  index_t BBlockTransferSrcVectorDim,
123  index_t BBlockTransferSrcScalarPerVector,
124  index_t BBlockTransferDstScalarPerVector_K1,
125  bool BThreadTransferSrcResetCoordinateAfterRun,
126  bool BBlockLdsExtraN,
127  typename CThreadTransferSrcDstAccessOrder,
128  index_t CThreadTransferSrcDstVectorDim,
129  index_t CThreadTransferDstScalarPerVector,
130  index_t NumGemmKPrefetchStage = 1,
132  PipelineVersion PipelineVer = PipelineVersion::v1>
134 {
135  static constexpr auto I0 = Number<0>{};
136  static constexpr auto I1 = Number<1>{};
137  static constexpr auto I2 = Number<2>{};
138  static constexpr auto I3 = Number<3>{};
139  static constexpr auto I4 = Number<4>{};
140  static constexpr auto I5 = Number<5>{};
141  static constexpr auto I6 = Number<6>{};
142  static constexpr auto I7 = Number<7>{};
143 
144  // K1 should be Number<...>
145  static constexpr auto K1 = Number<K1Value>{};
146 
148 
149  __host__ static auto CalculateGridSize(index_t M, index_t N)
150  {
151  return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
152  }
153 
154  template <typename CGridDesc_M_N>
155  __host__ static auto CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
156  {
157  return std::make_tuple(Block2CTileMap::CalculateGridSize(c_grid_desc_m_n), 1, 1);
158  }
159 
160  template <typename>
161  __host__ static auto CalculateGridSize(index_t M, index_t N)
162  {
163  return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
164  }
165 
166  __host__ static auto CalculateMPadded(index_t M)
167  {
168  return math::integer_divide_ceil(M, MPerBlock) * MPerBlock;
169  }
170 
171  __host__ static auto CalculateNPadded(index_t N)
172  {
173  return math::integer_divide_ceil(N, NPerBlock) * NPerBlock;
174  }
175 
176  __host__ static auto CalculateK0(index_t K) { return math::integer_divide_ceil(K, K1Value); }
177 
178  // Argument
179  struct Problem
180  {
181  __host__ Problem(index_t M_,
182  index_t N_,
183  index_t K_,
184  index_t StrideA_,
185  index_t StrideB_,
186  index_t StrideC_)
187  : M{M_},
188  N{N_},
189  K{K_},
190  StrideA{StrideA_},
191  StrideB{StrideB_},
192  StrideC{StrideC_},
195  K0{CalculateK0(K_)}
196  {
197  }
198 
199  __host__ void Print() const
200  {
201  std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
202  << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
203  << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " << "K0:" << K0
204  << "}" << std::endl;
205  }
206 
216  };
217 
218  // Argument
220  {
221  __host__ Argument(const FloatAB* p_a_grid_,
222  const FloatAB* p_b_grid_,
223  FloatC* p_c_grid_,
224  index_t M_,
225  index_t N_,
226  index_t K_,
227  index_t StrideA_,
228  index_t StrideB_,
229  index_t StrideC_)
230  : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
231  p_a_grid{p_a_grid_},
232  p_b_grid{p_b_grid_},
233  p_c_grid{p_c_grid_}
234  {
235  }
236 
237  const FloatAB* p_a_grid;
238  const FloatAB* p_b_grid;
239  FloatC* p_c_grid;
240  };
241 
243  decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
244 
245  // denorm test fix, required to work around fp16 mfma issue
246  // we convert fp16->fp32->bf16 and execute bf16 mfma instruction
247  // when mfma if fixed, remove this section and update
248  // FloatABAdjusted -> FloatAB throughout this file
249 #if CK_GFX90A_DENORM_WORKAROUND
251 #else
252  using FloatABAdjusted = FloatAB;
253 #endif
254 
255  __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
256  {
257  constexpr auto max_lds_align = K1;
258 
259  // A matrix in LDS memory, dst of blockwise copy
260  constexpr auto a_block_desc_k0_m_k1 = [&]() {
261  if constexpr(ABlockLdsExtraM)
262  {
266  }
267  else
268  {
270  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
271  }
272  }();
273 
274  return a_block_desc_k0_m_k1;
275  }
276 
277  __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
278  {
279  constexpr auto max_lds_align = K1;
280 
281  // B matrix in LDS memory, dst of blockwise copy
282  constexpr auto b_block_desc_k0_n_k1 = [&]() {
283  if constexpr(BBlockLdsExtraN)
284  {
288  }
289  else
290  {
292  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
293  }
294  }();
295 
296  return b_block_desc_k0_n_k1;
297  }
298 
299  __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
300  {
301  // LDS allocation for A and B: be careful of alignment
302  constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
303 
304  constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
305 
306  constexpr auto max_lds_align = K1;
307 
308  constexpr auto a_block_space_size_aligned =
309  math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
310 
311  constexpr auto b_block_space_size_aligned =
312  math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
313 
314  return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB);
315  }
316 
317  template <typename AGridDesc_K0_M_K1, typename BGridDesc_K0_N_K1, typename CGridDesc_M_N>
318  __host__ __device__ static constexpr bool
319  CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
320  const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
321  const CGridDesc_M_N& c_grid_desc_m_n)
322  {
323  static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
324  "wrong! K1 need to be known at compile-time");
325 
326  static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) &&
327  (NPerBlock % (NXdlPerWave * NPerXDL)) == 0,
328  "Invalid tuning param!");
329 
330  const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
331  const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
332  const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
333 
334  if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
335  K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
336  K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
337  return false;
338 
339  if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
340  return false;
341 
342  // check gridwise gemm pipeline
343  const auto num_k_loop = K0 / K0PerBlock;
344 
345  if(!GridwiseGemmPipe::IsSupported(num_k_loop))
346  {
347  return false;
348  }
349 
350  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
351  return true;
352  }
353 
354  __host__ static constexpr bool CheckValidity(const Problem& problem)
355  {
356  static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
357  "wrong! K1 need to be known at compile-time");
358 
359  static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) &&
360  (NPerBlock % (NXdlPerWave * NPerXDL)) == 0,
361  "Invalid tuning param!");
362 
363  // check gridwise gemm pipeline
364  const auto num_k_loop = math::integer_divide_ceil(problem.K0, K0PerBlock);
365  if(!GridwiseGemmPipe::IsSupported(num_k_loop))
366  {
367  return false;
368  }
369 
370  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
371  return true;
372  }
373 
374  __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
375  {
376  const index_t num_loop = math::integer_divide_ceil(K, K0PerBlock * K1);
377 
378  return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
379  }
380 
381  template <typename CGridDesc>
382  __host__ __device__ static constexpr auto
383  MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc& c_grid_desc_m_n)
384  {
385  constexpr auto max_lds_align = K1;
386 
387  // A matrix in LDS memory, dst of blockwise copy
388  constexpr auto a_block_desc_k0_m_k1 = [&]() {
389  if constexpr(ABlockLdsExtraM)
390  {
394  }
395  else
396  {
398  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
399  }
400  }();
401 
402  // B matrix in LDS memory, dst of blockwise copy
403  constexpr auto b_block_desc_k0_n_k1 = [&]() {
404  if constexpr(BBlockLdsExtraN)
405  {
409  }
410  else
411  {
413  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
414  }
415  }();
416 
417  using BlockwiseGemm =
421  FloatAcc,
422  decltype(a_block_desc_k0_m_k1),
423  decltype(b_block_desc_k0_n_k1),
424  MPerXDL,
425  NPerXDL,
426  MXdlPerWave,
427  NXdlPerWave,
428  K1>;
429 
430  return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
431  }
432 
433  // return block_id to C matrix tile idx (m0, n0) mapping
435 
436  template <bool HasMainKBlockLoop,
437  typename AGridDesc_K0_M_K1,
438  typename BGridDesc_K0_N_K1,
439  typename CGridDesc_M_N>
440  __device__ static void Run(const FloatAB* p_a_grid,
441  const FloatAB* p_b_grid,
442  FloatC* p_c_grid,
443  void* __restrict__ p_shared,
444  const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
445  const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
446  const CGridDesc_M_N& c_grid_desc_m_n)
447  {
448  const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
450 
451  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
452  p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
453  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
454  p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
455  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
456  p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
457 
458  const AElementwiseOperation a_element_op{};
459  const BElementwiseOperation b_element_op{};
460  const CElementwiseOperation c_element_op{};
461 
462  const auto block_2_ctile_map =
463  Block2CTileMap{c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)};
464 
465  // divide block work by [M, N]
466  const auto block_work_idx =
467  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
468 
469  if(!block_2_ctile_map.ValidCTileIndex(
470  block_work_idx,
471  make_tuple(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0),
472  c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1))))
473  {
474  return;
475  }
476 
477  // HACK: this force m/n_block_data_idx_on_grid into SGPR
478  const index_t m_block_data_idx_on_grid =
479  __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
480 
481  const index_t n_block_data_idx_on_grid =
482  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
483 
484  // lds max alignment
485  constexpr auto max_lds_align = K1;
486 
487  // A matrix in LDS memory, dst of blockwise copy
488  constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
489 
490  // B matrix in LDS memory, dst of blockwise copy
491  constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
492 
493  // A matrix blockwise copy
494  auto a_blockwise_copy =
496  AElementwiseOperation,
500  ABlockTransferThreadClusterLengths_K0_M_K1,
501  ABlockTransferThreadClusterArrangeOrder,
502  FloatAB,
504  decltype(a_grid_desc_k0_m_k1),
505  decltype(a_block_desc_k0_m_k1),
506  ABlockTransferSrcAccessOrder,
508  ABlockTransferSrcVectorDim,
509  2,
510  ABlockTransferSrcScalarPerVector,
511  ABlockTransferDstScalarPerVector_K1,
512  1,
513  1,
514  AThreadTransferSrcResetCoordinateAfterRun,
515  true,
516  NumGemmKPrefetchStage>(
517  a_grid_desc_k0_m_k1,
518  make_multi_index(0, m_block_data_idx_on_grid, 0),
519  a_element_op,
520  a_block_desc_k0_m_k1,
521  make_multi_index(0, 0, 0),
523 
524  // B matrix blockwise copy
525  auto b_blockwise_copy =
527  BElementwiseOperation,
531  BBlockTransferThreadClusterLengths_K0_N_K1,
532  BBlockTransferThreadClusterArrangeOrder,
533  FloatAB,
535  decltype(b_grid_desc_k0_n_k1),
536  decltype(b_block_desc_k0_n_k1),
537  BBlockTransferSrcAccessOrder,
539  BBlockTransferSrcVectorDim,
540  2,
541  BBlockTransferSrcScalarPerVector,
542  BBlockTransferDstScalarPerVector_K1,
543  1,
544  1,
545  BThreadTransferSrcResetCoordinateAfterRun,
546  true,
547  NumGemmKPrefetchStage>(
548  b_grid_desc_k0_n_k1,
549  make_multi_index(0, n_block_data_idx_on_grid, 0),
550  b_element_op,
551  b_block_desc_k0_n_k1,
552  make_multi_index(0, 0, 0),
554 
555  // GEMM definition
556  // c_mtx += transpose(a_mtx) * b_mtx
557  // a_mtx[K0PerBlock, MPerBlock] is in LDS
558  // b_mtx[K0PerBlock, NPerBlock] is in LDS
559  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
560  // register
561  // sanity check
563  BlockSize,
566  FloatAcc,
567  decltype(a_block_desc_k0_m_k1),
568  decltype(b_block_desc_k0_n_k1),
569  MPerXDL,
570  NPerXDL,
571  MXdlPerWave,
572  NXdlPerWave,
573  K1,
574  LoopSched>();
575 
576  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
577 
578  // LDS allocation for A and B: be careful of alignment
579  constexpr auto a_block_space_size_aligned =
580  math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
581 
582  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
583  static_cast<FloatABAdjusted*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
584 
585  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
586  static_cast<FloatABAdjusted*>(p_shared) + a_block_space_size_aligned,
587  b_block_desc_k0_n_k1.GetElementSpaceSize());
588 
589  constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
590  constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
591 
592  // gridwise GEMM pipeline
593  const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
594  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
595 
596  GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1,
597  a_block_desc_k0_m_k1,
598  a_blockwise_copy,
599  a_grid_buf,
600  a_block_buf,
601  a_block_slice_copy_step,
602  b_grid_desc_k0_n_k1,
603  b_block_desc_k0_n_k1,
604  b_blockwise_copy,
605  b_grid_buf,
606  b_block_buf,
607  b_block_slice_copy_step,
608  blockwise_gemm,
609  c_thread_buf,
610  num_k_block_main_loop);
611 
612  // output: register to global memory
613  {
614  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
615  blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
616 
617  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
618  blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
619 
620  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0);
621  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
622  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2);
623  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
624  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4);
625  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
626  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
627  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
628 
629  // calculate origin of thread output tensor on global memory
630  // blockwise GEMM c matrix starting index
631  const auto c_thread_mtx_on_block =
632  blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
633 
634  const index_t m_thread_data_on_grid =
635  m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
636 
637  const index_t n_thread_data_on_grid =
638  n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
639 
640  const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
642  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
645 
646  const auto m_thread_data_on_grid_idx =
647  m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
648  make_multi_index(m_thread_data_on_grid));
649 
650  const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
654 
655  const auto n_thread_data_on_grid_idx =
656  n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
657  make_multi_index(n_thread_data_on_grid));
658 
659  auto c_thread_copy =
661  FloatC,
662  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
663  decltype(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2),
664  CElementwiseOperation,
666  CThreadTransferSrcDstAccessOrder,
667  CThreadTransferSrcDstVectorDim,
668  CThreadTransferDstScalarPerVector,
669  CGlobalMemoryDataOperation,
670  1,
671  true>{
672  c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
673  make_multi_index(m_thread_data_on_grid_idx[I0],
674  n_thread_data_on_grid_idx[I0],
675  m_thread_data_on_grid_idx[I1],
676  n_thread_data_on_grid_idx[I1],
677  m_thread_data_on_grid_idx[I2],
678  m_thread_data_on_grid_idx[I3],
679  m_thread_data_on_grid_idx[I4],
680  n_thread_data_on_grid_idx[I2]),
681  c_element_op};
682 
683  c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
684  make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
685  c_thread_buf,
686  c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
687  c_grid_buf);
688  }
689  }
690 };
691 
692 template <index_t BlockSize,
693  typename FloatAB,
694  typename FloatAcc,
695  typename FloatC,
696  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
697  typename ALayout,
698  typename BLayout,
699  typename CLayout,
700  typename AElementwiseOperation,
701  typename BElementwiseOperation,
702  typename CElementwiseOperation,
704  index_t MPerBlock,
705  index_t NPerBlock,
706  index_t K0PerBlock,
707  index_t MPerXDL,
708  index_t NPerXDL,
709  index_t K1Value,
710  index_t MXdlPerWave,
711  index_t NXdlPerWave,
712  typename ABlockTransferThreadClusterLengths_K0_M_K1,
713  typename ABlockTransferThreadClusterArrangeOrder,
714  typename ABlockTransferSrcAccessOrder,
715  index_t ABlockTransferSrcVectorDim,
716  index_t ABlockTransferSrcScalarPerVector,
717  index_t ABlockTransferDstScalarPerVector_K1,
718  bool AThreadTransferSrcResetCoordinateAfterRun,
719  bool ABlockLdsExtraM,
720  typename BBlockTransferThreadClusterLengths_K0_N_K1,
721  typename BBlockTransferThreadClusterArrangeOrder,
722  typename BBlockTransferSrcAccessOrder,
723  index_t BBlockTransferSrcVectorDim,
724  index_t BBlockTransferSrcScalarPerVector,
725  index_t BBlockTransferDstScalarPerVector_K1,
726  bool BThreadTransferSrcResetCoordinateAfterRun,
727  bool BBlockLdsExtraN,
728  typename CThreadTransferSrcDstAccessOrder,
729  index_t CThreadTransferSrcDstVectorDim,
730  index_t CThreadTransferDstScalarPerVector,
731  index_t NumGemmKPrefetchStage = 1,
733  PipelineVersion PipelineVer = PipelineVersion::v1>
736  FloatAB,
737  FloatAcc,
738  FloatC,
739  CGlobalMemoryDataOperation,
740  AElementwiseOperation,
741  BElementwiseOperation,
742  CElementwiseOperation,
743  MPerBlock,
744  NPerBlock,
745  K0PerBlock,
746  MPerXDL,
747  NPerXDL,
748  K1Value,
749  MXdlPerWave,
750  NXdlPerWave,
751  ABlockTransferThreadClusterLengths_K0_M_K1,
752  ABlockTransferThreadClusterArrangeOrder,
753  ABlockTransferSrcAccessOrder,
754  ABlockTransferSrcVectorDim,
755  ABlockTransferSrcScalarPerVector,
756  ABlockTransferDstScalarPerVector_K1,
757  AThreadTransferSrcResetCoordinateAfterRun,
758  ABlockLdsExtraM,
759  BBlockTransferThreadClusterLengths_K0_N_K1,
760  BBlockTransferThreadClusterArrangeOrder,
761  BBlockTransferSrcAccessOrder,
762  BBlockTransferSrcVectorDim,
763  BBlockTransferSrcScalarPerVector,
764  BBlockTransferDstScalarPerVector_K1,
765  BThreadTransferSrcResetCoordinateAfterRun,
766  BBlockLdsExtraN,
767  CThreadTransferSrcDstAccessOrder,
768  CThreadTransferSrcDstVectorDim,
769  CThreadTransferDstScalarPerVector,
770  NumGemmKPrefetchStage,
771  LoopSched,
772  PipelineVer>
773 {
774  using Parent =
776  FloatAB,
777  FloatAcc,
778  FloatC,
779  CGlobalMemoryDataOperation,
780  AElementwiseOperation,
781  BElementwiseOperation,
782  CElementwiseOperation,
783  MPerBlock,
784  NPerBlock,
785  K0PerBlock,
786  MPerXDL,
787  NPerXDL,
788  K1Value,
789  MXdlPerWave,
790  NXdlPerWave,
791  ABlockTransferThreadClusterLengths_K0_M_K1,
792  ABlockTransferThreadClusterArrangeOrder,
793  ABlockTransferSrcAccessOrder,
794  ABlockTransferSrcVectorDim,
795  ABlockTransferSrcScalarPerVector,
796  ABlockTransferDstScalarPerVector_K1,
797  AThreadTransferSrcResetCoordinateAfterRun,
798  ABlockLdsExtraM,
799  BBlockTransferThreadClusterLengths_K0_N_K1,
800  BBlockTransferThreadClusterArrangeOrder,
801  BBlockTransferSrcAccessOrder,
802  BBlockTransferSrcVectorDim,
803  BBlockTransferSrcScalarPerVector,
804  BBlockTransferDstScalarPerVector_K1,
805  BThreadTransferSrcResetCoordinateAfterRun,
806  BBlockLdsExtraN,
807  CThreadTransferSrcDstAccessOrder,
808  CThreadTransferSrcDstVectorDim,
809  CThreadTransferDstScalarPerVector,
810  NumGemmKPrefetchStage,
811  LoopSched,
812  PipelineVer>;
813 
814  using typename Parent::GridwiseGemmPipe;
815  using typename Parent::Problem;
816 
817  using Parent::I1;
818 
819  using Parent::K1;
820 
821  __device__ static auto
823  {
824  const auto a_grid_desc_m_k = [&]() {
826  {
827  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
828  }
830  {
831  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
832  }
833  }();
834 
836  {
837  const auto K0Pad = math::integer_divide_ceil(K0, K0PerBlock) * K0PerBlock;
838  const auto KPad = K0Pad * K1Value;
839 
840  const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
841  a_grid_desc_m_k,
845 
847  a_grid_desc_m_kpad,
849  make_right_pad_transform(M, MPad - M)),
852  }
853  else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
854  {
856  a_grid_desc_m_k,
858  make_right_pad_transform(M, MPad - M)),
861  }
862  else
863  {
865  a_grid_desc_m_k,
870  }
871  }
872 
873  __device__ static auto
875  {
876  const auto b_grid_desc_k_n = [&]() {
878  {
879  return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
880  }
882  {
883  return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
884  }
885  }();
886 
888  {
889  const auto K0Pad = math::integer_divide_ceil(K0, K0PerBlock) * K0PerBlock;
890  const auto KPad = K0Pad * K1Value;
891 
892  const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
893  b_grid_desc_k_n,
897 
899  b_grid_desc_kpad_n,
901  make_right_pad_transform(N, NPad - N)),
904  }
905 
906  else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
907  {
909  b_grid_desc_k_n,
911  make_right_pad_transform(N, NPad - N)),
914  }
915  else
916  {
918  b_grid_desc_k_n,
923  }
924  }
925 
926  __device__ static auto
928  {
929  const auto c_grid_desc_m_n = [&]() {
931  {
932  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
933  }
935  {
936  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
937  }
938  }();
939 
942  {
943  return transform_tensor_descriptor(c_grid_desc_m_n,
945  make_right_pad_transform(N, NPad - N)),
948  }
949  else
950  {
951 
953  c_grid_desc_m_n,
957  }
958  }
959 
960  __host__ static constexpr bool CheckValidity(const Problem& problem)
961  {
962  static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
963  "wrong! K1 need to be known at compile-time");
964 
965  static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) &&
966  (NPerBlock % (NXdlPerWave * NPerXDL)) == 0,
967  "Invalid tuning param!");
968 
969  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
973  {
974  if(!(problem.M % MPerBlock == 0))
975  {
976  return false;
977  }
978  }
979 
980  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
984  {
985  if(!(problem.N % NPerBlock == 0))
986  {
987  return false;
988  }
989  }
990 
991  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
995  {
996  if(!(problem.K0 % K0PerBlock == 0))
997  {
998  return false;
999  }
1000  }
1001 
1003  {
1004  if(problem.K % ABlockTransferSrcScalarPerVector != 0)
1005  {
1006  return false;
1007  }
1008  }
1009  else
1010  {
1011  if(problem.M % ABlockTransferSrcScalarPerVector != 0)
1012  {
1013  return false;
1014  }
1015  }
1016 
1018  {
1019  if(problem.N % BBlockTransferSrcScalarPerVector != 0)
1020  {
1021  return false;
1022  }
1023  }
1024  else
1025  {
1026  if(problem.K % BBlockTransferSrcScalarPerVector != 0)
1027  {
1028  return false;
1029  }
1030  }
1031 
1032  // check gridwise gemm pipeline
1033  const auto num_k_loop = math::integer_divide_ceil(problem.K0, K0PerBlock);
1034 
1035  if(!GridwiseGemmPipe::IsSupported(num_k_loop))
1036  {
1037  return false;
1038  }
1039 
1040  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1041  return true;
1042  }
1043 };
1044 
1045 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:30
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition: blockwise_gemm_xdlops.hpp:605
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:276
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
ushort bhalf_t
Definition: data_type.hpp:29
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition: amd_wave_read_first_lane.hpp:100
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__global__ void kernel_gemm_xdlops_v2r3(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const CGridDesc_M_N c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v2r3.hpp:34
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:298
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:18
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:295
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
Definition: gridwise_gemm_xdlops_v2r3.hpp:220
const FloatAB * p_a_grid
Definition: gridwise_gemm_xdlops_v2r3.hpp:237
FloatC * p_c_grid
Definition: gridwise_gemm_xdlops_v2r3.hpp:239
const FloatAB * p_b_grid
Definition: gridwise_gemm_xdlops_v2r3.hpp:238
__host__ Argument(const FloatAB *p_a_grid_, const FloatAB *p_b_grid_, FloatC *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_xdlops_v2r3.hpp:221
Definition: gridwise_gemm_xdlops_v2r3.hpp:180
index_t NPadded
Definition: gridwise_gemm_xdlops_v2r3.hpp:214
index_t StrideC
Definition: gridwise_gemm_xdlops_v2r3.hpp:212
index_t M
Definition: gridwise_gemm_xdlops_v2r3.hpp:207
index_t StrideA
Definition: gridwise_gemm_xdlops_v2r3.hpp:210
index_t N
Definition: gridwise_gemm_xdlops_v2r3.hpp:208
index_t K
Definition: gridwise_gemm_xdlops_v2r3.hpp:209
index_t StrideB
Definition: gridwise_gemm_xdlops_v2r3.hpp:211
index_t K0
Definition: gridwise_gemm_xdlops_v2r3.hpp:215
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_xdlops_v2r3.hpp:181
__host__ void Print() const
Definition: gridwise_gemm_xdlops_v2r3.hpp:199
index_t MPadded
Definition: gridwise_gemm_xdlops_v2r3.hpp:213
Definition: gridwise_gemm_xdlops_v2r3.hpp:773
static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdlops_v2r3.hpp:927
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_xdlops_v2r3.hpp:960
static constexpr auto I1
Definition: gridwise_gemm_xdlops_v2r3.hpp:136
static constexpr auto K1
Definition: gridwise_gemm_xdlops_v2r3.hpp:145
static __device__ auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t MPad, index_t K, index_t K0, index_t StrideA)
Definition: gridwise_gemm_xdlops_v2r3.hpp:822
static __device__ auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t NPad, index_t K0, index_t StrideB)
Definition: gridwise_gemm_xdlops_v2r3.hpp:874
Definition: gridwise_gemm_xdlops_v2r3.hpp:134
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdlops_v2r3.hpp:147
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdlops_v2r3.hpp:166
static __device__ void Run(const FloatAB *p_a_grid, const FloatAB *p_b_grid, FloatC *p_c_grid, void *__restrict__ p_shared, const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v2r3.hpp:440
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v2r3.hpp:319
__host__ static constexpr __device__ auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
Definition: gridwise_gemm_xdlops_v2r3.hpp:277
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdlops_v2r3.hpp:374
static constexpr auto I7
Definition: gridwise_gemm_xdlops_v2r3.hpp:142
static constexpr auto I2
Definition: gridwise_gemm_xdlops_v2r3.hpp:137
static __host__ auto CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v2r3.hpp:155
static constexpr auto I5
Definition: gridwise_gemm_xdlops_v2r3.hpp:140
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdlops_v2r3.hpp:171
FloatAB FloatABAdjusted
Definition: gridwise_gemm_xdlops_v2r3.hpp:252
static constexpr auto I0
Definition: gridwise_gemm_xdlops_v2r3.hpp:135
static constexpr auto I4
Definition: gridwise_gemm_xdlops_v2r3.hpp:139
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_xdlops_v2r3.hpp:161
static constexpr auto I1
Definition: gridwise_gemm_xdlops_v2r3.hpp:136
static constexpr auto K1
Definition: gridwise_gemm_xdlops_v2r3.hpp:145
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_xdlops_v2r3.hpp:354
static constexpr auto I6
Definition: gridwise_gemm_xdlops_v2r3.hpp:141
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc &c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v2r3.hpp:383
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage, LoopSched >())> GridwiseGemmPipe
Definition: gridwise_gemm_xdlops_v2r3.hpp:243
__host__ static constexpr __device__ auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
Definition: gridwise_gemm_xdlops_v2r3.hpp:255
static __host__ auto CalculateK0(index_t K)
Definition: gridwise_gemm_xdlops_v2r3.hpp:176
static constexpr auto I3
Definition: gridwise_gemm_xdlops_v2r3.hpp:138
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_v2r3.hpp:299
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_xdlops_v2r3.hpp:149
Definition: sequence.hpp:43
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: is_known_at_compile_time.hpp:14
Definition: type.hpp:177
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:334