/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp Source File
threadwise_tensor_slice_transfer.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
10 
12 
14 
15 namespace ck {
16 // Assume:
17 // 1. src:
18 // 1. SrcDesc is known at compile-time
19 // 2. SrcBuffer is StaticBuffer
20 // 3. SrcSliceOrginIdx is known at compile-time
21 // 2. dst:
22 // 1. DstDesc is not known at compile-time
23 // 2. DstBuffer is DynamicBuffer
24 // 3. DstSliceOrginIdx is not known at compile time
25 template <typename SrcData,
26  typename DstData,
27  typename SrcDesc,
28  typename DstDesc,
29  typename ElementwiseOperation,
30  typename SliceLengths,
31  typename DimAccessOrder,
32  index_t DstVectorDim,
33  index_t DstScalarPerVector,
34  InMemoryDataOperationEnum DstInMemOp,
35  index_t DstScalarStrideInVector,
36  bool DstResetCoordinateAfterRun,
37  typename enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
39 {
40  static constexpr index_t nDim = SliceLengths::Size();
41 
43 
44  using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
45 
46  using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
47 
48  __device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(const DstDesc& dst_desc,
49  const Index& dst_slice_origin_idx,
50  const ElementwiseOperation& element_op)
51  : dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)),
52  element_op_{element_op}
53  {
54  static_assert(SrcDesc::IsKnownAtCompileTime(),
55  "wrong! SrcDesc need to known at compile-time");
56  static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
57  "wrong! Not divisible");
58  }
59 
60  __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
61  {
62  dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
63  }
64 
65  template <typename SrcSliceOriginIdx, typename SrcBuffer, typename DstBuffer>
66  __device__ void Run(const SrcDesc&,
67  const SrcSliceOriginIdx&,
68  const SrcBuffer& src_buf,
69  const DstDesc& dst_desc,
70  DstBuffer& dst_buf)
71  {
72  static_assert(SrcDesc::IsKnownAtCompileTime(),
73  "wrong! SrcDesc need to known at compile-time");
74 
76  "wrong! SrcSliceOrigin need to known at compile-time");
77 
78  static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
79 
80  // SrcDesc and src_slice_origin_idx are known at compile-time
81  constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
82  constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
83 
84  // scalar per access on each dim
85  // TODO: don't use lambda_scalar_per_access
86  constexpr auto dst_scalar_per_access = generate_sequence(
88 
89  constexpr auto dst_scalar_step_in_vector =
91 
92  using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
93  DimAccessOrder,
94  remove_cv_t<decltype(dst_scalar_per_access)>>;
95 
96  // TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector?
97  static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
98  "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
101 
102  constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
103 
104  static_for<0, num_access, 1>{}([&](auto idx_1d) {
105  constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
106 
107  // copy data from src_buf into dst_vector
108  // TODO: It's a hack here to use \p dst_scalar_step_in_vector. Use SpaceFillingCurve?
110  constexpr index_t src_offset = src_desc.CalculateOffset(
111  src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
112 
113  DstData v;
114 
115  // apply element-wise operation
116  element_op_(v, src_buf[Number<src_offset>{}]);
117 
118  dst_vector.template AsType<DstData>()(i) = v;
119  });
120 
121  const bool is_dst_valid =
123 
124  // copy data from dst_vector into dst_buf
125  dst_buf.template Update<DstInMemOp, dst_vector_t>(
126  dst_coord_.GetOffset(),
127  is_dst_valid,
128  dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
129 
130  if constexpr(idx_1d.value != num_access - 1)
131  {
132  constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
133 
135  dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
136  }
137  });
138 
139  // move dst coordinate back to slice origin (or not)
140  if constexpr(DstResetCoordinateAfterRun)
141  {
142  const auto dst_reset_step =
144 
145  move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
146  }
147  }
148 
149  __device__ static constexpr auto GetDstCoordinateResetStep()
150  {
151  constexpr auto dst_scalar_per_access = generate_sequence(
153 
154  using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
155  DimAccessOrder,
156  remove_cv_t<decltype(dst_scalar_per_access)>>;
157 
158  constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
159  if constexpr(num_access == 0)
160  {
161  return typename SpaceFillingCurve::Index{};
162  }
163  else
164  {
165  constexpr auto reset_step =
167 
168  return reset_step;
169  }
170  }
171 
172  // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
173  __device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
174  const Index& dst_slice_origin_step_idx)
175  {
176  // if dst coord was not reset by Run(), then need to adjust the step here
177  const auto adjusted_step_idx =
178  DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
179  : dst_slice_origin_step_idx + GetDstCoordinateResetStep();
180 
181  // is it OK to construct a new step every time?
182  const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
183 
184  move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
185  }
186 
187  private:
188  DstCoord dst_coord_;
189  const ElementwiseOperation element_op_;
190 }; // namespace ThreadwiseTensorSliceTransfer_v1r3
191 
221 template <typename SrcData,
222  typename DstData,
223  typename SrcDesc,
224  typename DstDesc,
225  typename SliceLengths,
226  typename DimAccessOrder,
227  index_t SrcVectorDim,
228  index_t SrcScalarPerVector,
229  index_t SrcScalarStrideInVector,
230  bool SrcResetCoordinateAfterRun,
231  bool InvalidElementAsNaN = false,
232  typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
234 {
235  static_assert((InvalidElementAsNaN && !ck::is_integral<DstData>::value) ||
236  (!InvalidElementAsNaN),
237  "Filling invalid element as NaN is only for floating point types");
238 
239  static constexpr index_t nDim = SliceLengths::Size();
240 
242 
243  using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
244 
245  using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
246 
247  static constexpr index_t PackedSize = []() {
249  return 2;
250  else
251  return 1;
252  }();
253 
254  __device__ constexpr ThreadwiseTensorSliceTransfer_v2(const SrcDesc& src_desc,
255  const Index& src_slice_origin_idx)
256  : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx))
257  {
258  static_assert(DstDesc::IsKnownAtCompileTime(),
259  "wrong! SrcDesc need to known at compile-time");
260  static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
261  "wrong! Not divisible");
262 
263  if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> ||
265  {
266  static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1");
267  }
268  }
269 
270  __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
271  {
272  src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
273  }
274 
275  template <typename SrcBuffer, typename DstBuffer, typename DstSliceOriginIdx>
276  __device__ void Run(const SrcDesc& src_desc,
277  const SrcBuffer& src_buf,
278  const DstDesc&,
279  const DstSliceOriginIdx&,
280  DstBuffer& dst_buf)
281  {
282  static_assert(DstDesc::IsKnownAtCompileTime(),
283  "wrong! DstDesc need to known at compile-time");
284 
286  "wrong! DstSliceOrigin need to known at compile-time");
287 
288  static_assert(
290  "wrong! inconsistent type");
291 
292  // DstDesc and dst_slice_origin_idx are known at compile-time
293  constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
294  constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{};
295 
296  // scalar per access on each dim
297  // TODO: don't use lambda_scalar_per_access
298  constexpr auto src_scalar_per_access = generate_sequence(
300 
301  constexpr auto src_scalar_step_in_vector =
303 
304  using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
305  DimAccessOrder,
306  remove_cv_t<decltype(src_scalar_per_access)>>;
307 
308  // loop over tensor and copy
309  constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
310 
311  static_for<0, num_access, 1>{}([&](auto idx_1d) {
312  typename vector_type_maker<SrcData, SrcScalarPerVector / PackedSize>::type src_vector;
313 
314  using src_vector_t =
315  typename vector_type_maker<SrcData, SrcScalarPerVector / PackedSize>::type::type;
316  constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d);
317 
318  const bool is_src_valid =
320 
321  // copy data from src_buf into src_vector
322  src_vector.template AsType<src_vector_t>()(Number<0>{}) =
323  src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize,
324  is_src_valid);
325 
326  // copy data from src_vector into dst_buf
327  static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) {
328  constexpr index_t dst_offset =
329  dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
330  i * src_scalar_step_in_vector);
331 
332  if constexpr(InvalidElementAsNaN)
333  {
334  dst_buf(Number<dst_offset>{}) =
335  is_src_valid
336  ? type_convert<DstData>(src_vector.template AsType<SrcData>()[i])
338  }
339  else
340  {
341  dst_buf(Number<dst_offset>{}) =
342  type_convert<DstData>(src_vector.template AsType<SrcData>()[i]);
343  }
344  });
345 
346  if constexpr(idx_1d.value != num_access - 1)
347  {
348  constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
349 
351  src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step));
352  }
353  });
354 
355  // move src coordinate back to slice origin (or not)
356  if constexpr(SrcResetCoordinateAfterRun)
357  {
358  const auto src_reset_step =
360 
361  move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
362  }
363  }
364 
365  __device__ static constexpr auto GetSrcCoordinateResetStep()
366  {
367  constexpr auto src_scalar_per_access = generate_sequence(
369 
370  using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
371  DimAccessOrder,
372  remove_cv_t<decltype(src_scalar_per_access)>>;
373 
374  constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
375  if constexpr(num_access == 0)
376  {
377  return typename SpaceFillingCurve::Index{};
378  }
379  else
380  {
381  constexpr auto reset_step =
383 
384  return reset_step;
385  }
386  }
387 
388  // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
389  __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
390  const Index& src_slice_origin_step_idx)
391  {
392  // if src coord was not reset by Run(), then need to adjust the step here
393  const auto adjusted_step_idx =
394  SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
395  : src_slice_origin_step_idx + GetSrcCoordinateResetStep();
396 
397  // is it OK to construct a new step every time?
398  const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
399 
400  move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
401  }
402 
403  // src_slice_origin_step_idx need to be known at compile-time, for performance reason
404  template <typename SrcMoveSliceWindowStepHack>
405  __device__ void
406  MoveSrcSliceWindow(const SrcDesc& src_desc,
407  const Index& src_slice_origin_step_idx,
408  const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
409  {
410  // if src coord was not reset by RunRead(), then need to adjust the step here
411  const auto adjusted_step_idx =
412  SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
413  : src_slice_origin_step_idx + GetSrcCoordinateResetStep();
414 
415  // is it OK to construct a new step every time?
416  const auto adjusted_step = make_tensor_coordinate_step(
417  src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
418 
419  move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
420  }
421 
422  private:
423  SrcCoord src_coord_;
424 }; // namespace ck
425 
426 template <typename SrcData,
427  typename DstData,
428  typename SrcDesc,
429  typename DstDesc,
430  typename SliceLengths,
431  typename DimAccessOrder,
432  index_t SrcVectorDim,
433  index_t SrcScalarPerVector,
434  index_t SrcScalarStrideInVector,
435  bool SrcResetCoordinateAfterRun,
436  index_t scale_gather_num,
437  bool InvalidElementAsNaN = false,
438  typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
440 {
441  static_assert((InvalidElementAsNaN && !ck::is_integral<DstData>::value) ||
442  (!InvalidElementAsNaN),
443  "Filling invalid element as NaN is only for floating point types");
444 
445  static constexpr index_t nDim = SliceLengths::Size();
446 
448 
449  using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
450 
451  using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
452 
453  static constexpr index_t PackedSize = []() {
455  return 2;
456  else
457  return 1;
458  }();
459 
461  const SrcDesc& src_desc,
462  const Index& src_slice_origin_idx,
463  const StaticallyIndexedArray<index_t, scale_gather_num>& scale_gather_offsets)
464  : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx)),
465  scale_gather_offsets_(scale_gather_offsets)
466  {
467  static_assert(DstDesc::IsKnownAtCompileTime(),
468  "wrong! SrcDesc need to known at compile-time");
469  static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
470  "wrong! Not divisible");
471 
473  {
474  static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1");
475  }
476  }
477 
478  __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
479  {
480  auto adjusted_origin_idx = [&]() {
481  Index idx;
482 
484  [&](auto i) { idx(i) = i.value == 0 ? 0 : src_slice_origin_idx[Number<i>{}]; });
485 
486  return idx;
487  }();
488 
489  src_coord_ = make_tensor_coordinate(src_desc, adjusted_origin_idx);
490  }
491 
492  template <typename SrcBuffer, typename DstBuffer, typename DstSliceOriginIdx>
493  __device__ void Run(const SrcDesc& src_desc,
494  const SrcBuffer& src_buf,
495  const DstDesc&,
496  const DstSliceOriginIdx&,
497  DstBuffer& dst_buf)
498  {
499  static_assert(DstDesc::IsKnownAtCompileTime(),
500  "wrong! DstDesc need to known at compile-time");
501 
503  "wrong! DstSliceOrigin need to known at compile-time");
504 
505  static_assert(
507  "wrong! inconsistent type");
508 
509  // DstDesc and dst_slice_origin_idx are known at compile-time
510  constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
511  constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{};
512 
513  // scalar per access on each dim
514  // TODO: don't use lambda_scalar_per_access
515  constexpr auto src_scalar_per_access = generate_sequence(
517 
518  constexpr auto src_scalar_step_in_vector =
520 
521  using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
522  DimAccessOrder,
523  remove_cv_t<decltype(src_scalar_per_access)>>;
524 
525  // loop over tensor and copy
526  constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
527 
528  static_for<0, scale_gather_num, 1>{}([&](auto gather_idx) {
529  constexpr auto current_dst_origin =
530  to_multi_index(dst_slice_origin_idx) + make_multi_index(gather_idx, 0);
531 
532  static_for<0, num_access, 1>{}([&](auto idx_1d) {
533  typename vector_type_maker<SrcData, SrcScalarPerVector / PackedSize>::type
534  src_vector;
535 
536  using src_vector_t =
537  typename vector_type_maker<SrcData,
538  SrcScalarPerVector / PackedSize>::type::type;
539  constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d);
540 
541  const bool is_src_valid =
543  src_coord_);
544 
545  // copy data from src_buf into src_vector
546  src_vector.template AsType<src_vector_t>()(Number<0>{}) =
547  src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize +
548  scale_gather_offsets_(gather_idx),
549  is_src_valid);
550 
551  // copy data from src_vector into dst_buf
552  static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) {
553  constexpr index_t dst_offset =
554  dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) +
555  src_data_idx + i * src_scalar_step_in_vector);
556  constexpr auto full_dst_offset =
557  dst_desc.CalculateOffset(current_dst_origin) + dst_offset;
558 
559  if constexpr(InvalidElementAsNaN)
560  {
561  dst_buf(full_dst_offset) =
562  is_src_valid
563  ? type_convert<DstData>(src_vector.template AsType<SrcData>()[i])
565  }
566  else
567  {
568  dst_buf(Number<full_dst_offset>{}) =
569  type_convert<DstData>(src_vector.template AsType<SrcData>()[i]);
570  }
571  });
572 
573  if constexpr(idx_1d.value != num_access - 1)
574  {
575  constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
576 
578  src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step));
579  }
580  });
581  });
582 
583  // move src coordinate back to slice origin (or not)
584  if constexpr(SrcResetCoordinateAfterRun)
585  {
586  const auto src_reset_step =
588 
589  move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
590  }
591  }
592 
593  __device__ static constexpr auto GetSrcCoordinateResetStep()
594  {
595  constexpr auto src_scalar_per_access = generate_sequence(
597 
598  using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
599  DimAccessOrder,
600  remove_cv_t<decltype(src_scalar_per_access)>>;
601 
602  constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
603  if constexpr(num_access == 0)
604  {
605  return typename SpaceFillingCurve::Index{};
606  }
607  else
608  {
609  constexpr auto reset_step =
611 
612  return reset_step;
613  }
614  }
615 
616  // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
617  __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
618  const Index& src_slice_origin_step_idx)
619  {
620  // if src coord was not reset by Run(), then need to adjust the step here
621  const auto adjusted_step_idx =
622  SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
623  : src_slice_origin_step_idx + GetSrcCoordinateResetStep();
624 
625  // is it OK to construct a new step every time?
626  const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
627 
628  move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
629  }
630 
631  // src_slice_origin_step_idx need to be known at compile-time, for performance reason
632  template <typename SrcMoveSliceWindowStepHack>
633  __device__ void
634  MoveSrcSliceWindow(const SrcDesc& src_desc,
635  const Index& src_slice_origin_step_idx,
636  const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
637  {
638  // if src coord was not reset by RunRead(), then need to adjust the step here
639  const auto adjusted_step_idx =
640  SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
641  : src_slice_origin_step_idx + GetSrcCoordinateResetStep();
642 
643  // is it OK to construct a new step every time?
644  const auto adjusted_step = make_tensor_coordinate_step(
645  src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
646 
647  move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
648  }
649 
650  private:
651  SrcCoord src_coord_;
653 }; // namespace ck
654 
655 // Assume:
656 // 1. src_desc and dst_desc are not known at compile-time
657 // 2. SrcBuffer and DstBuffer are DynamicBuffer
658 // 3. src_slice_origin and dst_slice_origin are not known at compile-time,
659 // 4. Use thread buffer
660 template <typename SliceLengths,
661  InMemoryDataOperationEnum DstInMemOp,
662  typename SrcData,
663  typename DstData,
664  typename SrcDesc,
665  typename DstDesc,
666  typename SrcDimAccessOrder,
667  typename DstDimAccessOrder,
668  index_t SrcVectorDim,
669  index_t DstVectorDim,
670  index_t SrcScalarPerVector,
671  index_t DstScalarPerVector,
672  index_t SrcScalarStrideInVector,
673  index_t DstScalarStrideInVector,
674  bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
675  // RunRead(), will be fused with MoveSrcSliceWindow to
676  // save addr computation
677  bool DstResetCoordinateAfterRun> // control whether to move back dst coordinate after each
678  // RunWrite(), will be fused with MoveDstSliceWindow to
679  // save addr computation
681 {
682  static constexpr index_t nDim = SliceLengths::Size();
684 
685  using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
686  using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
687 
688  using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
689  using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
690 
691  __device__ constexpr ThreadwiseTensorSliceTransfer_v3(const SrcDesc& src_desc,
692  const Index& src_slice_origin,
693  const DstDesc& dst_desc,
694  const Index& dst_slice_origin)
695  : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
696  dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin))
697  {
698  static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
699  "wrong! Not divisible");
700  static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
701  "wrong! Not divisible");
702  }
703 
704  __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
705  {
706  src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
707  }
708 
709  __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
710  {
711  dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
712  }
713 
714  template <typename SrcBuffer, typename SrcStepHacks>
715  __device__ void
716  RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
717  {
718  static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
719  SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
720  "wrong!");
721 
722  static_assert(
724  "wrong! SrcBuffer and SrcData data type are inconsistent");
725 
726  constexpr auto I0 = Number<0>{};
727  constexpr auto I1 = Number<1>{};
728 
729  // scalar per access on each dim
730  // TODO: don't use lambda_scalar_per_access
731  constexpr auto src_scalar_per_access = generate_sequence(
733 
734  constexpr auto src_scalar_step_in_vector =
736 
737  constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
738 
739  constexpr auto src_dim_access_order = SrcDimAccessOrder{};
740 
741  constexpr auto ordered_src_access_lengths =
742  container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
743 
744  // make forward steps
745  const auto src_forward_steps = generate_tuple(
746  [&](auto i) {
747  Index forward_step_idx;
748 
749  static_for<0, nDim, 1>{}([&](auto j) {
750  forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
751  });
752 
754  src_desc, forward_step_idx, src_step_hacks[I0][i]);
755  },
756  Number<nDim>{});
757 
758  // make backward steps
759  const auto src_backward_steps = generate_tuple(
760  [&](auto i) {
761  Index backward_step_idx;
762 
763  static_for<0, nDim, 1>{}([&](auto j) {
764  backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
765  });
766 
768  src_desc, backward_step_idx, src_step_hacks[I1][i]);
769  },
770  Number<nDim>{});
771 
772  // loop over tensor and copy
773  static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
774  // judge move forward or move backward
775  constexpr auto forward_sweep = [&]() {
776  StaticallyIndexedArray<bool, nDim> forward_sweep_;
777 
778  forward_sweep_(I0) = true;
779 
780  static_for<1, nDim, 1>{}([&](auto i) {
781  index_t tmp = ordered_src_access_idx[I0];
782 
783  static_for<1, i, 1>{}([&](auto j) {
784  tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
785  });
786 
787  forward_sweep_(i) = tmp % 2 == 0;
788  });
789 
790  return forward_sweep_;
791  }();
792 
793  // calculate src data index
794  constexpr auto src_data_idx = [&]() {
795  Index ordered_idx;
796 
797  static_for<0, nDim, 1>{}([&](auto i) {
798  ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
799  : ordered_src_access_lengths[i] - 1 -
800  ordered_src_access_idx[i];
801  });
802 
803  return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
804  src_scalar_per_access;
805  }();
806 
808 
809  using src_vector_t = typename decltype(src_tmp_vector)::type;
810 
811  const bool is_src_valid =
813 
814  // copy data from src_buf to src_tmp_vector
815  src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
816  src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid);
817 
818  // copy data from src_tmp_vector to buffer_
820  constexpr index_t buffer_offset =
821  buffer_desc_.CalculateOffset(src_data_idx + i * src_scalar_step_in_vector);
822 
823  buffer_(Number<buffer_offset>{}) = src_tmp_vector.template AsType<SrcData>()[i];
824  });
825 
826  constexpr auto move_on_dim = [&]() constexpr {
828 
829  static_for<0, nDim, 1>{}([&](auto i) {
830  move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
831 
832  static_for<i + 1, nDim, 1>{}([&](auto j) {
833  move_on_dim_(i) &=
834  ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
835  });
836  });
837 
838  return move_on_dim_;
839  }();
840 
841  // move
842  static_for<0, nDim, 1>{}([&](auto i) {
843  if constexpr(move_on_dim[i])
844  {
845  if constexpr(forward_sweep[i])
846  {
848  src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
849  }
850  else
851  {
853  src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
854  }
855  }
856  });
857  });
858 
859  // move src coordinate back to slice origin (or not)
860  if constexpr(SrcResetCoordinateAfterRun)
861  {
862  const auto src_reset_step =
864 
865  move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
866  }
867  }
868 
869  template <typename DstBuffer, typename DstStepHacks>
870  __device__ void
871  RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks)
872  {
873  static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
874  DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
875  "wrong!");
876 
877  static_assert(
879  "wrong! SrcBuffer or DstBuffer data type is wrong");
880 
881  constexpr auto I0 = Number<0>{};
882  constexpr auto I1 = Number<1>{};
883 
884  // src scalar per access on each dim
885  // TODO: don't use this
886  constexpr auto dst_scalar_per_access = generate_sequence(
888 
889  constexpr auto dst_scalar_step_in_vector =
891 
892  constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
893 
894  constexpr auto dst_dim_access_order = DstDimAccessOrder{};
895 
896  constexpr auto ordered_dst_access_lengths =
897  container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
898 
899  // make forward steps
900  const auto dst_forward_steps = generate_tuple(
901  [&](auto i) {
902  Index forward_step_idx;
903 
904  static_for<0, nDim, 1>{}([&](auto j) {
905  forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
906  });
907 
909  dst_desc, forward_step_idx, dst_step_hacks[I0][i]);
910  },
911  Number<nDim>{});
912 
913  // make backward steps
914  const auto dst_backward_steps = generate_tuple(
915  [&](auto i) {
916  Index backward_step_idx;
917 
918  static_for<0, nDim, 1>{}([&](auto j) {
919  backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
920  });
921 
923  dst_desc, backward_step_idx, dst_step_hacks[I1][i]);
924  },
925  Number<nDim>{});
926 
927  // loop over tensor and copy
928  static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
929  // judge move forward or move backward
930  constexpr auto forward_sweep = [&]() {
931  StaticallyIndexedArray<bool, nDim> forward_sweep_;
932 
933  forward_sweep_(I0) = true;
934 
935  static_for<1, nDim, 1>{}([&](auto i) {
936  index_t tmp = ordered_dst_access_idx[I0];
937 
938  static_for<1, i, 1>{}([&](auto j) {
939  tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
940  });
941 
942  forward_sweep_(i) = tmp % 2 == 0;
943  });
944 
945  return forward_sweep_;
946  }();
947 
948  // calculate dst data index
949  constexpr auto dst_data_idx = [&]() {
950  Index ordered_idx;
951 
952  static_for<0, nDim, 1>{}([&](auto i) {
953  ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i]
954  : ordered_dst_access_lengths[i] - 1 -
955  ordered_dst_access_idx[i];
956  });
957 
958  return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
959  dst_scalar_per_access;
960  }();
961 
963 
964  // copy data from buffer_ to dst_tmp_vector
966  constexpr index_t buffer_offset =
967  buffer_desc_.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector);
968 
969  dst_tmp_vector.template AsType<DstData>()(i) =
970  type_convert<DstData>(buffer_[Number<buffer_offset>{}]);
971  });
972 
973  using dst_vector_t = typename decltype(dst_tmp_vector)::type;
974 
975  // copy data from dst_tmp_vector to dst_buf
976  const bool is_dst_valid =
978 
979  dst_buf.template Set<dst_vector_t>(
980  dst_coord_.GetOffset(),
981  is_dst_valid,
982  dst_tmp_vector.template AsType<dst_vector_t>()[Number<0>{}]);
983 
984  constexpr auto move_on_dim = [&]() constexpr {
986 
987  static_for<0, nDim, 1>{}([&](auto i) {
988  move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
989 
990  static_for<i + 1, nDim, 1>{}([&](auto j) {
991  move_on_dim_(i) &=
992  ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
993  });
994  });
995 
996  return move_on_dim_;
997  }();
998 
999  // move
1000  static_for<0, nDim, 1>{}([&](auto i) {
1001  if constexpr(move_on_dim[i])
1002  {
1003  if constexpr(forward_sweep[i])
1004  {
1006  dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
1007  }
1008  else
1009  {
1011  dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
1012  }
1013  }
1014  });
1015  });
1016 
1017  // move dst coordinate back to slice origin (or not)
1018  if constexpr(DstResetCoordinateAfterRun)
1019  {
1020  const auto dst_reset_step =
1022 
1023  move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
1024  }
1025  }
1026 
1027  template <typename SrcBuffer>
1028  __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
1029  {
1030  constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform();
1031 
1032  constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
1033 
1034  constexpr auto src_step_hacks =
1035  make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
1036  generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
1037 
1038  RunRead(src_desc, src_buf, src_step_hacks);
1039  }
1040 
1041  template <typename DstBuffer>
1042  __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
1043  {
1044  constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform();
1045 
1046  constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
1047 
1048  constexpr auto dst_step_hacks =
1049  make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
1050  generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
1051 
1052  RunWrite(dst_desc, dst_buf, dst_step_hacks);
1053  }
1054 
1055  __device__ static constexpr auto GetSrcCoordinateResetStep()
1056  {
1057  constexpr auto I0 = Number<0>{};
1058 
1059  // scalar per access on each dim
1060  // TODO: don't use lambda_scalar_per_access
1061  constexpr auto src_scalar_per_access = generate_sequence(
1063 
1064  constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
1065 
1066  constexpr auto src_dim_access_order = SrcDimAccessOrder{};
1067 
1068  constexpr auto ordered_src_access_lengths =
1069  container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
1070 
1071  // judge move forward or move backward during the last iteration
1072  constexpr auto forward_sweep = [&]() {
1073  StaticallyIndexedArray<bool, nDim> forward_sweep_;
1074 
1075  forward_sweep_(I0) = true;
1076 
1077  static_for<1, nDim, 1>{}([&](auto i) {
1078  index_t tmp = ordered_src_access_lengths[I0] - 1;
1079 
1080  static_for<1, i, 1>{}([&](auto j) {
1081  tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
1082  });
1083 
1084  forward_sweep_(i) = tmp % 2 == 0;
1085  });
1086 
1087  return forward_sweep_;
1088  }();
1089 
1090  // calculate src data index after last iteration in RunRead(), if it has not being reset by
1091  // RunRead()
1092  constexpr auto src_data_idx = [&]() {
1093  Index ordered_idx;
1094 
1095  static_for<0, nDim, 1>{}([&](auto i) {
1096  ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
1097  });
1098 
1099  return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
1100  src_scalar_per_access;
1101  }();
1102 
1103  //
1104  constexpr auto reset_src_data_step = [&]() {
1105  Index reset_src_data_step_;
1106 
1107  static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; });
1108 
1109  return reset_src_data_step_;
1110  }();
1111 
1112  return reset_src_data_step;
1113  }
1114 
1115  __device__ static constexpr auto GetDstCoordinateResetStep()
1116  {
1117  constexpr auto I0 = Number<0>{};
1118 
1119  // scalar per access on each dim
1120  // TODO: don't use lambda_scalar_per_access
1121  constexpr auto dst_scalar_per_access = generate_sequence(
1123 
1124  constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
1125 
1126  constexpr auto dst_dim_access_order = DstDimAccessOrder{};
1127 
1128  constexpr auto ordered_dst_access_lengths =
1129  container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
1130 
1131  // judge move forward or move backward during the last iteration
1132  constexpr auto forward_sweep = [&]() {
1133  StaticallyIndexedArray<bool, nDim> forward_sweep_;
1134 
1135  forward_sweep_(I0) = true;
1136 
1137  static_for<1, nDim, 1>{}([&](auto i) {
1138  index_t tmp = ordered_dst_access_lengths[I0] - 1;
1139 
1140  static_for<1, i, 1>{}([&](auto j) {
1141  tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
1142  });
1143 
1144  forward_sweep_(i) = tmp % 2 == 0;
1145  });
1146 
1147  return forward_sweep_;
1148  }();
1149 
1150  // calculate dst data index after last iteration in RunWrite(), if it has not being reset by
1151  // RunWrite()
1152  constexpr auto dst_data_idx = [&]() {
1153  Index ordered_idx;
1154 
1155  static_for<0, nDim, 1>{}([&](auto i) {
1156  ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
1157  });
1158 
1159  return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
1160  dst_scalar_per_access;
1161  }();
1162 
1163  //
1164  constexpr auto reset_dst_data_step = [&]() {
1165  Index reset_dst_data_step_;
1166 
1167  static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
1168 
1169  return reset_dst_data_step_;
1170  }();
1171 
1172  return reset_dst_data_step;
1173  }
1174 
1175  // src_slice_origin_step_idx need to be known at compile-time, for performance reason
1176  __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
1177  const Index& src_slice_origin_step_idx)
1178  {
1179  // if src coord was not reset by RunRead(), then need to adjust the step here
1180  const auto adjusted_step_idx =
1181  SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
1182  : src_slice_origin_step_idx + GetSrcCoordinateResetStep();
1183 
1184  // is it OK to construct a new step every time?
1185  const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
1186 
1187  move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
1188  }
1189 
1190  // src_slice_origin_step_idx need to be known at compile-time, for performance reason
1191  template <typename SrcMoveSliceWindowStepHack>
1192  __device__ void
1193  MoveSrcSliceWindow(const SrcDesc& src_desc,
1194  const Index& src_slice_origin_step_idx,
1195  const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
1196  {
1197  // if src coord was not reset by RunRead(), then need to adjust the step here
1198  const auto adjusted_step_idx =
1199  SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
1200  : src_slice_origin_step_idx + GetSrcCoordinateResetStep();
1201 
1202  // is it OK to construct a new step every time?
1203  const auto adjusted_step = make_tensor_coordinate_step(
1204  src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
1205 
1206  move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
1207  }
1208  // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
1209  __device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
1210  const Index& dst_slice_origin_step_idx)
1211  {
1212  // if dst coord was not reset by RunWrite(), then need to adjust the step here
1213  const auto adjusted_step_idx =
1214  DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
1215  : dst_slice_origin_step_idx + GetDstCoordinateResetStep();
1216 
1217  // is it OK to construct a new step every time?
1218  const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
1219 
1220  move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
1221  }
1222 
1223  private:
1224  static constexpr auto buffer_desc_ =
1226 
1227  static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
1228 
1229  StaticBuffer<AddressSpaceEnum::Vgpr, SrcData, buffer_size_, true> buffer_;
1230 
1231  SrcCoord src_coord_;
1232  DstCoord dst_coord_;
1233 };
1234 
1235 // Assume:
1236 // 1. src:
1237 // 1. SrcDesc is known at compile-time
1238 // 2. SrcBuffer is DynamicBuffer
1239 // 3. src_ref_idx is known at run-time
1240 // 4. SrcRefToOriginDisplacement is known at compile-time
1241 // 5. use #-step
1242 // 2. dst:
1243 // 1. DstDesc is known at compile-time
1244 // 2. DstBuffer is StaticBuffer
1245 // 3. DstOriginIdx is known at compile-time
1246 // 4. use direct address calculation
1247 // 3. vector access on src
1248 template <typename SrcData,
1249  typename DstData,
1250  typename SrcDesc,
1251  typename DstDesc,
1252  typename SliceLengths,
1253  typename DimAccessOrder,
1254  index_t SrcVectorDim,
1255  index_t SrcScalarPerVector,
1256  index_t SrcScalarStrideInVector,
1257  typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1258  bool>::type = false>
1260 {
1261  static constexpr index_t nDim = SliceLengths::Size();
1262 
1264 
1265  using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
1266 
1267  using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
1268 
1269  static constexpr index_t PackedSize = []() {
1270  if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
1271  return 2;
1272  else
1273  return 1;
1274  }();
1275 
1276  __device__ constexpr ThreadwiseTensorSliceTransfer_v4(const Index& src_ref_idx)
1277  : src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx))
1278  {
1279  static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1280  "wrong! SrcDesc and DstDesc need to known at compile-time");
1281 
1282  if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> ||
1284  {
1285  static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1");
1286  }
1287  }
1288 
1289  template <typename SrcRefToOriginDisplacement,
1290  typename DstOriginIdx,
1291  typename SrcBuffer,
1292  typename DstBuffer>
1293  __device__ void Run(const SrcDesc&,
1294  const SrcRefToOriginDisplacement&,
1295  const SrcBuffer& src_buf,
1296  const DstDesc&,
1297  const DstOriginIdx&,
1298  DstBuffer& dst_buf) const
1299  {
1300  static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1301  "wrong! SrcDesc and DstDesc need to known at compile-time");
1302 
1303  static_assert(
1306  "wrong! SrcBuffer or DstBuffer data type is wrong");
1307 
1308  static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
1309 
1312  "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
1313  "at compile-time");
1314 
1315  // SrcDesc and DstDesc are known at compile-time
1316  constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
1317  constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
1318 
1319  // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
1320  constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{});
1321  constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{});
1322 
1323  // scalar per access of each dim
1324  constexpr auto src_scalar_per_access = generate_sequence_v2(
1325  [&](auto i) constexpr {
1326  if constexpr(i == SrcVectorDim)
1327  {
1328  return Number<SrcScalarPerVector>{};
1329  }
1330  else
1331  {
1332  return Number<1>{};
1333  }
1334  },
1335  Number<nDim>{});
1336 
1337  // scalar step (if steping on SrcVectorDim) of each dim
1338  constexpr auto src_scalar_step_in_vector = generate_sequence_v2(
1339  [&](auto i) constexpr {
1340  if constexpr(i == SrcVectorDim)
1341  {
1342  return Number<1>{};
1343  }
1344  else
1345  {
1346  return Number<0>{};
1347  }
1348  },
1349  Number<nDim>{});
1350 
1351  constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access;
1352 
1353  constexpr auto dim_access_order = DimAccessOrder{};
1354 
1355  constexpr auto ordered_access_lengths =
1356  container_reorder_given_new2old(access_lengths, dim_access_order);
1357 
1358  static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
1359 #if 0
1360  // TODO: unable to compile
1361  // position in slice window
1362  constexpr auto data_to_origin_disp_idx =
1363  container_reorder_given_old2new(ordered_access_idx, dim_access_order) *
1364  src_scalar_per_access;
1365 #else
1366  // position in slice window
1367  constexpr auto data_to_origin_disp_idx =
1368  ordered_access_idx.ReorderGivenOld2New(dim_access_order) * src_scalar_per_access;
1369 #endif
1370  // src coordinate
1371  constexpr auto src_ref_to_data_disp_idx =
1372  src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
1373 
1374  constexpr auto src_ref_to_data_disp_coord_step =
1375  make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx);
1376 
1377  auto src_data_coord = src_ref_coord_;
1378 
1379  move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step);
1380 
1381  vector_type_maker_t<SrcData, SrcScalarPerVector / PackedSize> src_tmp_vector;
1382 
1383  using src_vector_t = typename decltype(src_tmp_vector)::type;
1384 
1386  src_desc, src_data_coord);
1387 
1388  // copy data from src_buf into src_tmp_vector
1389  if constexpr(SrcBuffer::IsDynamicBuffer())
1390  {
1391  src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
1392  src_buf.template Get<src_vector_t>(src_data_coord.GetOffset() / PackedSize,
1393  is_src_valid);
1394  }
1395  else if constexpr(SrcBuffer::IsStaticBuffer())
1396  {
1397  static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
1398  constexpr index_t src_offset = src_desc.CalculateOffset(
1399  src_ref_to_origin_disp_idx + data_to_origin_disp_idx +
1400  i * src_scalar_step_in_vector);
1401 
1402  src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset>{}];
1403  });
1404  }
1405 
1406  if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value)
1407  {
1408  // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
1409  // DstData)
1411 
1412  constexpr index_t pack_size = 8;
1413 
1414  static_assert(SrcScalarPerVector % pack_size == 0, "");
1415 
1416  using src_v_t = typename vector_type_maker_t<SrcData, pack_size / PackedSize>::type;
1417  using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
1418 
1419  static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) {
1421  dst_tmp_vector.template AsType<dst_v_t>()(i),
1422  src_tmp_vector.template AsType<src_v_t>()[i]);
1423  });
1424 
1425  // copy data from dst_tmp_vector into dst_buf
1426  static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
1427  constexpr index_t dst_offset = dst_desc.CalculateOffset(
1428  dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
1429 
1430  dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
1431  });
1432  }
1433  else if constexpr(is_same<remove_cvref_t<SrcData>, f8_t>::value &&
1435  SrcScalarPerVector % 2 == 0)
1436  {
1437  // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
1438  // DstData)
1440 
1441  constexpr index_t pack_size = 2;
1442 
1443  using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
1444  using src_v_t = typename vector_type_maker_t<SrcData, pack_size>::type;
1445  static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) {
1447  dst_tmp_vector.template AsType<dst_v_t>()(i),
1448  src_tmp_vector.template AsType<src_v_t>()[i]);
1449  });
1450 
1451  // copy data from dst_tmp_vector into dst_buf
1452  static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
1453  constexpr index_t dst_offset = dst_desc.CalculateOffset(
1454  dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
1455 
1456  dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
1457  });
1458  }
1459  else
1460  {
1461  // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
1462  // DstData)
1463  vector_type_maker_t<DstData, SrcScalarPerVector / PackedSize> dst_tmp_vector;
1464 
1465  // TODO: if SrcData and DstData are vetor type, then static_cast may not compile
1466  static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) {
1467  dst_tmp_vector.template AsType<DstData>()(i) =
1468  type_convert<DstData>(src_tmp_vector.template AsType<SrcData>()[i]);
1469  });
1470 
1471  // copy data from dst_tmp_vector into dst_buf
1472  static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) {
1473  constexpr index_t dst_offset = dst_desc.CalculateOffset(
1474  dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
1475 
1476  dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
1477  });
1478  }
1479  });
1480  }
1481 
1482  // Fuse scale
1483  template <typename SrcRefToOriginDisplacement,
1484  typename DstOriginIdx,
1485  typename SrcBuffer,
1486  typename DstBuffer>
1487  __device__ void Run(const SrcDesc&,
1488  const SrcRefToOriginDisplacement&,
1489  const SrcBuffer& src_buf,
1490  const DstData& scale,
1491  const DstDesc&,
1492  const DstOriginIdx&,
1493  DstBuffer& dst_buf) const
1494  {
1495  static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1496  "wrong! SrcDesc and DstDesc need to known at compile-time");
1497 
1498  static_assert(
1501  "wrong! SrcBuffer or DstBuffer data type is wrong");
1502 
1503  static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
1504 
1507  "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
1508  "at compile-time");
1509 
1510  // SrcDesc and DstDesc are known at compile-time
1511  constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
1512  constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
1513 
1514  // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
1515  constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{});
1516  constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{});
1517 
1518  // scalar per access of each dim
1519  constexpr auto src_scalar_per_access = generate_sequence_v2(
1520  [&](auto i) constexpr {
1521  if constexpr(i == SrcVectorDim)
1522  {
1523  return Number<SrcScalarPerVector>{};
1524  }
1525  else
1526  {
1527  return Number<1>{};
1528  }
1529  },
1530  Number<nDim>{});
1531 
1532  // scalar step (if steping on SrcVectorDim) of each dim
1533  constexpr auto src_scalar_step_in_vector = generate_sequence_v2(
1534  [&](auto i) constexpr {
1535  if constexpr(i == SrcVectorDim)
1536  {
1537  return Number<1>{};
1538  }
1539  else
1540  {
1541  return Number<0>{};
1542  }
1543  },
1544  Number<nDim>{});
1545 
1546  constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access;
1547 
1548  constexpr auto dim_access_order = DimAccessOrder{};
1549 
1550  constexpr auto ordered_access_lengths =
1551  container_reorder_given_new2old(access_lengths, dim_access_order);
1552 
1553  static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
1554 #if 0
1555  // TODO: unable to compile
1556  // position in slice window
1557  constexpr auto data_to_origin_disp_idx =
1558  container_reorder_given_old2new(ordered_access_idx, dim_access_order) *
1559  src_scalar_per_access;
1560 #else
1561  // position in slice window
1562  constexpr auto data_to_origin_disp_idx =
1563  ordered_access_idx.ReorderGivenOld2New(dim_access_order) * src_scalar_per_access;
1564 #endif
1565  // src coordinate
1566  constexpr auto src_ref_to_data_disp_idx =
1567  src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
1568 
1569  constexpr auto src_ref_to_data_disp_coord_step =
1570  make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx);
1571 
1572  auto src_data_coord = src_ref_coord_;
1573 
1574  move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step);
1575 
1576  vector_type_maker_t<SrcData, SrcScalarPerVector / PackedSize> src_tmp_vector;
1577 
1578  using src_vector_t = typename decltype(src_tmp_vector)::type;
1579 
1581  src_desc, src_data_coord);
1582 
1583  // copy data from src_buf into src_tmp_vector
1584  if constexpr(SrcBuffer::IsDynamicBuffer())
1585  {
1586  src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
1587  src_buf.template Get<src_vector_t>(src_data_coord.GetOffset() / PackedSize,
1588  is_src_valid);
1589  }
1590  else if constexpr(SrcBuffer::IsStaticBuffer())
1591  {
1592  static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
1593  constexpr index_t src_offset = src_desc.CalculateOffset(
1594  src_ref_to_origin_disp_idx + data_to_origin_disp_idx +
1595  i * src_scalar_step_in_vector);
1596 
1597  src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset>{}];
1598  });
1599  }
1600 
1601  if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value)
1602  {
1603  // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
1604  // DstData)
1606  vector_type<DstData, 2> scale_vector;
1607  scale_vector.template AsType<DstData>()(Number<0>{}) = scale;
1608  scale_vector.template AsType<DstData>()(Number<1>{}) = scale;
1609 
1610  constexpr index_t pack_size = 8;
1611 
1612  static_assert(SrcScalarPerVector % pack_size == 0, "");
1613 
1614  using src_v_t = typename vector_type_maker_t<SrcData, pack_size / PackedSize>::type;
1615  using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
1616  using scale_v_t = typename vector_type_maker_t<DstData, 2>::type;
1617 
1618  static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) {
1620  dst_tmp_vector.template AsType<dst_v_t>()(i),
1621  src_tmp_vector.template AsType<src_v_t>()[i],
1622  scale_vector.template AsType<scale_v_t>()[Number<0>{}]);
1623  });
1624 
1625  // copy data from dst_tmp_vector into dst_buf
1626  static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
1627  constexpr index_t dst_offset = dst_desc.CalculateOffset(
1628  dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
1629 
1630  dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
1631  });
1632  }
1633  else if constexpr(is_same<remove_cvref_t<SrcData>, f8_t>::value &&
1635  SrcScalarPerVector % 2 == 0)
1636  {
1637  // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
1638  // DstData)
1640 
1641  constexpr index_t pack_size = 2;
1642 
1643  using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
1644  using src_v_t = typename vector_type_maker_t<SrcData, pack_size>::type;
1645  static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) {
1647  dst_tmp_vector.template AsType<dst_v_t>()(i),
1648  src_tmp_vector.template AsType<src_v_t>()[i]);
1649  });
1650 
1651  // copy data from dst_tmp_vector into dst_buf
1652  static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
1653  constexpr index_t dst_offset = dst_desc.CalculateOffset(
1654  dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
1655 
1656  dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
1657  });
1658  }
1659  else
1660  {
1661  // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
1662  // DstData)
1664 
1665  // TODO: if SrcData and DstData are vetor type, then static_cast may not compile
1666  static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
1667  dst_tmp_vector.template AsType<DstData>()(i) =
1668  type_convert<DstData>(src_tmp_vector.template AsType<SrcData>()[i]);
1669  });
1670 
1671  // copy data from dst_tmp_vector into dst_buf
1672  static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
1673  constexpr index_t dst_offset = dst_desc.CalculateOffset(
1674  dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
1675 
1676  dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
1677  });
1678  }
1679  });
1680  }
1681 
1682  template <typename SrcSliceMoveStepIdx>
1683  __device__ void MoveSrcSliceWindow(const SrcDesc&,
1684  const SrcSliceMoveStepIdx& src_slice_move_step_idx)
1685  {
1686  constexpr auto src_desc = SrcDesc{};
1687 
1688  const auto src_slice_move_step_iter =
1689  make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx));
1690 
1691  move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter);
1692  }
1693  __device__ void SetSrcCoord(const Index& src_ref_idx)
1694  {
1695  src_ref_coord_ = make_tensor_coordinate(SrcDesc{}, src_ref_idx);
1696  }
1697 
1698  private:
1699  SrcCoord src_ref_coord_;
1700 };
1701 
1708 template <typename SrcData,
1709  typename DstData,
1710  typename SrcDesc,
1711  typename DstDesc,
1712  typename ElementwiseOperation,
1713  typename SliceLengths,
1714  typename DimAccessOrder,
1715  index_t DstVectorDim,
1716  index_t DstScalarPerVector,
1717  typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1718  bool>::type = false>
1720 {
1721  static constexpr index_t nDim = SliceLengths::Size();
1722 
1724 
1725  static constexpr index_t PackedSize = []() {
1726  if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
1727  return 2;
1728  else
1729  return 1;
1730  }();
1731 
1733  const ElementwiseOperation& element_op)
1734  : element_op_{element_op}
1735  {
1736  static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1737  "wrong! Desc need to known at compile-time");
1738 
1739  static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
1740  "wrong! Not divisible");
1741  }
1742 
1743  template <typename SrcSliceOriginIdx,
1744  typename DstSliceOriginIdx,
1745  typename SrcBuffer,
1746  typename DstBuffer>
1747  __device__ void Run(const SrcDesc&,
1748  const SrcSliceOriginIdx&,
1749  const SrcBuffer& src_buf,
1750  const DstDesc&,
1751  const DstSliceOriginIdx&,
1752  DstBuffer& dst_buf) const
1753  {
1754  static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1755  "wrong! Desc need to known at compile-time");
1756 
1759  "wrong! SliceOrigin need to known at compile-time");
1760 
1761  static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(),
1762  "wrong! Buffer need to be StaticBuffer");
1763 
1764  // SrcDesc and src_slice_origin_idx are known at compile-time
1765  constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
1766  constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
1767  constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
1768  constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{});
1769 
1770  // scalar per access on each dim
1771  constexpr auto dst_scalar_per_access = generate_sequence(
1773 
1774  constexpr auto dst_scalar_step_in_vector =
1776 
1777  using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
1778  DimAccessOrder,
1779  remove_cv_t<decltype(dst_scalar_per_access)>>;
1780 
1781  static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
1782  "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
1783 
1784  constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
1785 
1786  if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value)
1787  {
1788  static_for<0, num_access, 1>{}([&](auto idx_1d) {
1789  typename vector_type_maker<SrcData, DstScalarPerVector / PackedSize>::type
1790  src_tmp_vector;
1791 
1792  constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
1793 
1794  // copy data from src_buf into dst_vector
1795  static_for<0, DstScalarPerVector / PackedSize, 1>{}([&](auto i) {
1796  constexpr index_t src_offset = src_desc.CalculateOffset(
1797  src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
1798 
1799  src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset>{}];
1800  });
1801 
1802  // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
1803  // DstData)
1805 
1806  constexpr index_t pack_size = 8;
1807 
1808  static_assert(DstScalarPerVector % pack_size == 0, "");
1809 
1810  using src_v_t = typename vector_type_maker_t<SrcData, pack_size / PackedSize>::type;
1811  using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
1812 
1813  static_for<0, DstScalarPerVector / pack_size, 1>{}([&](auto i) {
1815  dst_tmp_vector.template AsType<dst_v_t>()(i),
1816  src_tmp_vector.template AsType<src_v_t>()[i]);
1817  });
1818 
1819  // copy data from dst_tmp_vector into dst_buf
1820  static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
1821  constexpr index_t dst_offset = dst_desc.CalculateOffset(
1822  dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
1823 
1824  dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
1825  });
1826  });
1827  }
1828  else
1829  {
1830  static_for<0, num_access, 1>{}([&](auto idx_1d) {
1831  constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
1832 
1833  // copy data from src_buf into dst_vector
1834  static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
1835  constexpr index_t src_offset = src_desc.CalculateOffset(
1836  src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
1837 
1838  constexpr index_t dst_offset = dst_desc.CalculateOffset(
1839  dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
1840 
1841  DstData v;
1842 
1843  // apply element-wise operation
1844  element_op_(v, src_buf[Number<src_offset>{}]);
1845 
1846  // apply type convert
1847  dst_buf(Number<dst_offset>{}) = v;
1848  });
1849  });
1850  }
1851  }
1852 
1853  ElementwiseOperation element_op_;
1854 };
1855 
1856 // Specialized for gfx11
1857 // A single Wave32 is composed by double row
1858 // Data exchange allowed between these two rows
1859 // This RowLane Dst buf will be filled from two Src buf
1860 // SrcA: From specific thread buffer hold by This RowLane on This Row
1861 // SrcB: From specific thread buffer hold by This RowLane on The other Row
1862 template <typename SrcData,
1863  typename DstData,
1864  typename SrcDesc,
1865  typename DstDesc,
1866  typename ElementwiseOperation,
1867  typename SliceLengths,
1868  typename DimAccessOrder,
1869  index_t DstVectorDim,
1870  index_t DstScalarPerVector,
1871  uint32_t LowEightRowlaneIdx,
1872  uint32_t HighEightRowLaneIdx,
1873  bool IntraRowSwizzlePerm,
1874  typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1875  bool>::type = false>
1877 {
1878  static constexpr index_t nDim = SliceLengths::Size();
1879 
1881 
1883  {
1884  static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1885  "wrong! Desc need to known at compile-time");
1886 
1887  static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
1888  "wrong! Not divisible");
1889  ignore = src_idx;
1890  }
1891 
1892  template <typename SrcSliceOriginIdx,
1893  typename DstSliceOriginIdx,
1894  typename SrcBuffer,
1895  typename DstBuffer>
1896  __device__ void Run(const SrcDesc&,
1897  const SrcSliceOriginIdx&,
1898  const SrcBuffer& src_buf,
1899  const DstDesc&,
1900  const DstSliceOriginIdx&,
1901  DstBuffer& dst_buf) const
1902  {
1903  static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1904  "wrong! Desc need to known at compile-time");
1905 
1908  "wrong! SliceOrigin need to known at compile-time");
1909 
1910  static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(),
1911  "wrong! Buffer need to be StaticBuffer");
1912 
1913  // SrcDesc and src_slice_origin_idx are known at compile-time
1914  constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
1915  constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
1916  constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
1917  constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{});
1918 
1919  // scalar per access on each dim
1920  constexpr auto dst_scalar_per_access = generate_sequence(
1922 
1923  constexpr auto dst_scalar_step_in_vector =
1925 
1926  using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
1927  DimAccessOrder,
1928  remove_cv_t<decltype(dst_scalar_per_access)>>;
1929 
1930  static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
1931  "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
1932 
1933  constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
1934 
1935  static_for<0, num_access, 1>{}([&](auto idx_1d) {
1936  constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
1937 
1938  // copy data from src_buf into dst_vector
1939  static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
1940  // src_desc error, non constexpr, caused by merge transform
1941  constexpr index_t src_offset = src_desc.CalculateOffset(
1942  src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
1943 
1944  constexpr index_t dst_offset = dst_desc.CalculateOffset(
1945  dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
1946 
1947  SrcData v_this_row, v_theother_row;
1948  // int type temp value due to intrinsic requirement
1949  int temp = 0;
1950 
1951  // apply element-wise operation
1952  element_op_(v_this_row, src_buf[Number<src_offset>{}]);
1953 
1954  // apply intra-row permute.
1955  if constexpr(IntraRowSwizzlePerm)
1956  {
1957  temp = __builtin_amdgcn_permlane16(
1958  temp, type_convert_sp<int>(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0);
1959  v_this_row = type_convert_sp<SrcData>(temp);
1960  }
1961 
1962  // apply inter-row permute.
1963  temp = __builtin_amdgcn_permlanex16(temp,
1964  type_convert_sp<int>(v_this_row),
1965  LowEightRowlaneIdx,
1966  HighEightRowLaneIdx,
1967  1,
1968  0);
1969  v_theother_row = type_convert_sp<SrcData>(temp);
1970 
1971  if(get_thread_local_1d_id() % 32 < 16)
1972  {
1973  // apply type convert
1974  dst_buf(Number<dst_offset>{}) = type_convert_sp<DstData>(v_this_row);
1976  type_convert_sp<DstData>(v_theother_row);
1977  }
1978  else
1979  {
1980  // apply type convert
1982  type_convert_sp<DstData>(v_this_row);
1983  dst_buf(Number<dst_offset>{}) = type_convert_sp<DstData>(v_theother_row);
1984  }
1985  });
1986  });
1987  }
1988  ElementwiseOperation element_op_{};
1989 };
1990 
1991 // Specialized for gfx12
1992 template <typename SrcData,
1993  typename DstData,
1994  typename SrcDesc,
1995  typename DstDesc,
1996  typename ElementwiseOperation,
1997  typename SliceLengths,
1998  typename DimAccessOrder,
1999  index_t DstVectorDim,
2000  index_t DstScalarPerVector,
2001  bool IntraRowSwizzlePerm,
2002  typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
2003  bool>::type = false>
2005 {
2006  static constexpr index_t nDim = SliceLengths::Size();
2007 
2009 
2011  {
2012  static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
2013  "wrong! Desc need to known at compile-time");
2014 
2015  static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
2016  "wrong! Not divisible");
2017  ignore = src_idx;
2018  }
2019 
2020  template <typename SrcSliceOriginIdx,
2021  typename DstSliceOriginIdx,
2022  typename SrcBuffer,
2023  typename DstBuffer>
2024  __device__ void Run(const SrcDesc&,
2025  const SrcSliceOriginIdx&,
2026  const SrcBuffer& src_buf,
2027  const DstDesc&,
2028  const DstSliceOriginIdx&,
2029  DstBuffer& dst_buf) const
2030  {
2031  static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
2032  "wrong! Desc need to known at compile-time");
2033 
2036  "wrong! SliceOrigin need to known at compile-time");
2037 
2038  static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(),
2039  "wrong! Buffer need to be StaticBuffer");
2040 
2041  // SrcDesc and src_slice_origin_idx are known at compile-time
2042  constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
2043  constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
2044  constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
2045  constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{});
2046 
2047  // scalar per access on each dim
2048  constexpr auto dst_scalar_per_access = generate_sequence(
2050 
2051  constexpr auto dst_scalar_step_in_vector =
2053 
2054  using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
2055  DimAccessOrder,
2056  remove_cv_t<decltype(dst_scalar_per_access)>>;
2057 
2058  static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
2059  "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
2060 
2061  constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
2062 
2063  static_for<0, num_access, 1>{}([&](auto idx_1d) {
2064  constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
2065 
2066  // copy data from src_buf into dst_vector
2067  static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
2068  // src_desc error, non constexpr, caused by merge transform
2069  constexpr index_t src_offset = src_desc.CalculateOffset(
2070  src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
2071 
2072  constexpr index_t dst_offset = dst_desc.CalculateOffset(
2073  dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
2074 
2075  SrcData v_this_row;
2076  // int type temp value due to intrinsic requirement
2077  int temp = 0;
2078 
2079  // apply element-wise operation
2080  element_op_(v_this_row, src_buf[Number<src_offset>{}]);
2081 
2082  // apply intra-row permute.
2083  if constexpr(IntraRowSwizzlePerm)
2084  {
2085  temp = __builtin_amdgcn_permlane16(
2086  temp, type_convert_sp<int>(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0);
2087  v_this_row = type_convert_sp<SrcData>(temp);
2088  }
2089 
2090  // apply type convert
2091  dst_buf(Number<dst_offset>{}) = type_convert_sp<DstData>(v_this_row);
2092  });
2093  });
2094  }
2095  ElementwiseOperation element_op_{};
2096 };
2097 
2098 } // namespace ck
Definition: ck.hpp:266
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ bool coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc &tensor_desc, const TensorCoord &coord)
Definition: tensor_descriptor.hpp:560
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:275
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1737
__host__ constexpr __device__ auto to_multi_index(const T &x)
Definition: array_multi_index.hpp:28
_Float16 half_t
Definition: data_type.hpp:30
__host__ constexpr __device__ auto make_tensor_coordinate(const TensorDesc &tensor_desc, const VisibleIndex &idx_visible)
Definition: tensor_descriptor.hpp:407
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__host__ constexpr __device__ auto generate_sequence(F, Number< N >)
Definition: sequence_helper.hpp:18
__host__ constexpr __device__ auto generate_sequence_v2(F &&f, Number< N >)
Definition: sequence_helper.hpp:25
__host__ constexpr __device__ auto sequence_to_tuple_of_number(Sequence< Is... >)
Definition: container_helper.hpp:380
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
constexpr bool is_same_v
Definition: type.hpp:283
__host__ constexpr __device__ auto container_reorder_given_new2old(const Array< TData, NSize > &old_array, Sequence< IRs... >)
Definition: container_helper.hpp:43
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:297
__host__ constexpr __device__ void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const TensorCoordStep &coord_step)
Definition: tensor_descriptor.hpp:508
__host__ constexpr __device__ auto make_tensor_coordinate_step(const TensorDesc &, const VisibleIndex &idx_diff_visible, UpdateLowerIndexHack)
Definition: tensor_descriptor.hpp:444
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:295
__host__ constexpr __device__ auto container_reorder_given_old2new(const Array< TData, NSize > &old_array, Sequence< IRs... > old2new)
Definition: container_helper.hpp:54
typename vector_type_maker< T, N >::type vector_type_maker_t
Definition: dtype_vector.hpp:54
Definition: array.hpp:14
__host__ static constexpr __device__ T QuietNaN()
Definition: numeric_limits.hpp:313
Definition: tensor_space_filling_curve.hpp:20
static __device__ constexpr __host__ auto GetForwardStep(Number< AccessIdx1d >)
Definition: tensor_space_filling_curve.hpp:66
__host__ static constexpr __device__ index_t GetNumOfAccess()
Definition: tensor_space_filling_curve.hpp:41
static constexpr index_t ScalarPerVector
Definition: tensor_space_filling_curve.hpp:25
static __device__ constexpr __host__ Index GetIndex(Number< AccessIdx1d >)
Definition: tensor_space_filling_curve.hpp:81
static __device__ constexpr __host__ auto GetStepBetween(Number< AccessIdx1dBegin >, Number< AccessIdx1dEnd >)
Definition: tensor_space_filling_curve.hpp:52
Definition: threadwise_tensor_slice_transfer.hpp:1877
static constexpr index_t nDim
Definition: threadwise_tensor_slice_transfer.hpp:1878
constexpr __device__ ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow(const Index &src_idx)
Definition: threadwise_tensor_slice_transfer.hpp:1882
ElementwiseOperation element_op_
Definition: threadwise_tensor_slice_transfer.hpp:1988
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf) const
Definition: threadwise_tensor_slice_transfer.hpp:1896
Definition: threadwise_tensor_slice_transfer.hpp:2005
static constexpr index_t nDim
Definition: threadwise_tensor_slice_transfer.hpp:2006
constexpr __device__ ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow(const Index &src_idx)
Definition: threadwise_tensor_slice_transfer.hpp:2010
ElementwiseOperation element_op_
Definition: threadwise_tensor_slice_transfer.hpp:2095
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf) const
Definition: threadwise_tensor_slice_transfer.hpp:2024
Threadwise data transfer.
Definition: threadwise_tensor_slice_transfer.hpp:1720
static constexpr index_t PackedSize
Definition: threadwise_tensor_slice_transfer.hpp:1725
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf) const
Definition: threadwise_tensor_slice_transfer.hpp:1747
static constexpr index_t nDim
Definition: threadwise_tensor_slice_transfer.hpp:1721
ElementwiseOperation element_op_
Definition: threadwise_tensor_slice_transfer.hpp:1853
constexpr __device__ ThreadwiseTensorSliceTransfer_StaticToStatic(const ElementwiseOperation &element_op)
Definition: threadwise_tensor_slice_transfer.hpp:1732
Definition: threadwise_tensor_slice_transfer.hpp:39
static constexpr __device__ auto GetDstCoordinateResetStep()
Definition: threadwise_tensor_slice_transfer.hpp:149
static constexpr index_t nDim
Definition: threadwise_tensor_slice_transfer.hpp:40
MultiIndex< nDim > Index
Definition: threadwise_tensor_slice_transfer.hpp:42
decltype(make_tensor_coordinate(DstDesc{}, Index{})) DstCoord
Definition: threadwise_tensor_slice_transfer.hpp:44
constexpr __device__ ThreadwiseTensorSliceTransfer_v1r3(const DstDesc &dst_desc, const Index &dst_slice_origin_idx, const ElementwiseOperation &element_op)
Definition: threadwise_tensor_slice_transfer.hpp:48
decltype(make_tensor_coordinate_step(DstDesc{}, Index{})) DstCoordStep
Definition: threadwise_tensor_slice_transfer.hpp:46
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &dst_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer.hpp:173
__device__ void SetDstSliceOrigin(const DstDesc &dst_desc, const Index &dst_slice_origin_idx)
Definition: threadwise_tensor_slice_transfer.hpp:60
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:66
Definition: threadwise_tensor_slice_transfer.hpp:440
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:493
static constexpr index_t PackedSize
Definition: threadwise_tensor_slice_transfer.hpp:453
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx, const SrcMoveSliceWindowStepHack &src_move_slice_window_step_hack)
Definition: threadwise_tensor_slice_transfer.hpp:634
__device__ void SetSrcSliceOrigin(const SrcDesc &src_desc, const Index &src_slice_origin_idx)
Definition: threadwise_tensor_slice_transfer.hpp:478
decltype(make_tensor_coordinate(SrcDesc{}, Index{})) SrcCoord
Definition: threadwise_tensor_slice_transfer.hpp:449
MultiIndex< nDim > Index
Definition: threadwise_tensor_slice_transfer.hpp:447
decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})) SrcCoordStep
Definition: threadwise_tensor_slice_transfer.hpp:451
static constexpr __device__ auto GetSrcCoordinateResetStep()
Definition: threadwise_tensor_slice_transfer.hpp:593
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer.hpp:617
constexpr __device__ ThreadwiseTensorSliceTransfer_v2_gather(const SrcDesc &src_desc, const Index &src_slice_origin_idx, const StaticallyIndexedArray< index_t, scale_gather_num > &scale_gather_offsets)
Definition: threadwise_tensor_slice_transfer.hpp:460
static constexpr index_t nDim
Definition: threadwise_tensor_slice_transfer.hpp:445
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
constexpr __device__ ThreadwiseTensorSliceTransfer_v2(const SrcDesc &src_desc, const Index &src_slice_origin_idx)
Definition: threadwise_tensor_slice_transfer.hpp:254
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:276
MultiIndex< nDim > Index
Definition: threadwise_tensor_slice_transfer.hpp:241
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer.hpp:389
static constexpr __device__ auto GetSrcCoordinateResetStep()
Definition: threadwise_tensor_slice_transfer.hpp:365
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx, const SrcMoveSliceWindowStepHack &src_move_slice_window_step_hack)
Definition: threadwise_tensor_slice_transfer.hpp:406
__device__ void SetSrcSliceOrigin(const SrcDesc &src_desc, const Index &src_slice_origin_idx)
Definition: threadwise_tensor_slice_transfer.hpp:270
static constexpr index_t nDim
Definition: threadwise_tensor_slice_transfer.hpp:239
decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})) SrcCoordStep
Definition: threadwise_tensor_slice_transfer.hpp:245
static constexpr index_t PackedSize
Definition: threadwise_tensor_slice_transfer.hpp:247
decltype(make_tensor_coordinate(SrcDesc{}, Index{})) SrcCoord
Definition: threadwise_tensor_slice_transfer.hpp:243
Definition: threadwise_tensor_slice_transfer.hpp:681
decltype(make_tensor_coordinate(DstDesc{}, Index{})) DstCoord
Definition: threadwise_tensor_slice_transfer.hpp:686
decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})) SrcCoordStep
Definition: threadwise_tensor_slice_transfer.hpp:688
MultiIndex< nDim > Index
Definition: threadwise_tensor_slice_transfer.hpp:683
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx, const SrcMoveSliceWindowStepHack &src_move_slice_window_step_hack)
Definition: threadwise_tensor_slice_transfer.hpp:1193
__device__ void RunRead(const SrcDesc &src_desc, const SrcBuffer &src_buf, const SrcStepHacks &src_step_hacks)
Definition: threadwise_tensor_slice_transfer.hpp:716
decltype(make_tensor_coordinate_step(DstDesc{}, Index{})) DstCoordStep
Definition: threadwise_tensor_slice_transfer.hpp:689
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &dst_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer.hpp:1209
__device__ void SetDstSliceOrigin(const DstDesc &dst_desc, const Index &dst_slice_origin_idx)
Definition: threadwise_tensor_slice_transfer.hpp:709
__device__ void RunWrite(const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:1042
__device__ void SetSrcSliceOrigin(const SrcDesc &src_desc, const Index &src_slice_origin_idx)
Definition: threadwise_tensor_slice_transfer.hpp:704
static constexpr __device__ auto GetSrcCoordinateResetStep()
Definition: threadwise_tensor_slice_transfer.hpp:1055
static constexpr __device__ auto GetDstCoordinateResetStep()
Definition: threadwise_tensor_slice_transfer.hpp:1115
constexpr __device__ ThreadwiseTensorSliceTransfer_v3(const SrcDesc &src_desc, const Index &src_slice_origin, const DstDesc &dst_desc, const Index &dst_slice_origin)
Definition: threadwise_tensor_slice_transfer.hpp:691
decltype(make_tensor_coordinate(SrcDesc{}, Index{})) SrcCoord
Definition: threadwise_tensor_slice_transfer.hpp:685
static constexpr index_t nDim
Definition: threadwise_tensor_slice_transfer.hpp:682
__device__ void RunRead(const SrcDesc &src_desc, const SrcBuffer &src_buf)
Definition: threadwise_tensor_slice_transfer.hpp:1028
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer.hpp:1176
__device__ void RunWrite(const DstDesc &dst_desc, DstBuffer &dst_buf, const DstStepHacks &dst_step_hacks)
Definition: threadwise_tensor_slice_transfer.hpp:871
Definition: threadwise_tensor_slice_transfer.hpp:1260
static constexpr index_t nDim
Definition: threadwise_tensor_slice_transfer.hpp:1261
__device__ void Run(const SrcDesc &, const SrcRefToOriginDisplacement &, const SrcBuffer &src_buf, const DstDesc &, const DstOriginIdx &, DstBuffer &dst_buf) const
Definition: threadwise_tensor_slice_transfer.hpp:1293
static constexpr index_t PackedSize
Definition: threadwise_tensor_slice_transfer.hpp:1269
decltype(make_tensor_coordinate(SrcDesc{}, Index{})) SrcCoord
Definition: threadwise_tensor_slice_transfer.hpp:1265
constexpr __device__ ThreadwiseTensorSliceTransfer_v4(const Index &src_ref_idx)
Definition: threadwise_tensor_slice_transfer.hpp:1276
decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})) SrcCoordStep
Definition: threadwise_tensor_slice_transfer.hpp:1267
__device__ void Run(const SrcDesc &, const SrcRefToOriginDisplacement &, const SrcBuffer &src_buf, const DstData &scale, const DstDesc &, const DstOriginIdx &, DstBuffer &dst_buf) const
Definition: threadwise_tensor_slice_transfer.hpp:1487
__device__ void SetSrcCoord(const Index &src_ref_idx)
Definition: threadwise_tensor_slice_transfer.hpp:1693
MultiIndex< nDim > Index
Definition: threadwise_tensor_slice_transfer.hpp:1263
__device__ void MoveSrcSliceWindow(const SrcDesc &, const SrcSliceMoveStepIdx &src_slice_move_step_idx)
Definition: threadwise_tensor_slice_transfer.hpp:1683
Definition: threadwise_tensor_slice_transfer_util.hpp:20
Definition: threadwise_tensor_slice_transfer_util.hpp:29
Definition: data_type.hpp:41
Definition: integral_constant.hpp:20
Definition: type.hpp:206
Definition: is_known_at_compile_time.hpp:14
Definition: type.hpp:177
Definition: data_type.hpp:186
Definition: functional2.hpp:33
Definition: functional3.hpp:97
Definition: unary_element_wise_operation.hpp:241
Definition: unary_element_wise_operation.hpp:277
Definition: unary_element_wise_operation.hpp:133
typename sequence_gen< NSize, F >::type type
Definition: sequence.hpp:295
Definition: dtype_vector.hpp:30
Definition: dtype_vector.hpp:10