/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/library/utility/host_tensor.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/library/utility/host_tensor.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/library/utility/host_tensor.hpp Source File
host_tensor.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <algorithm>
7 #include <cassert>
8 #include <iostream>
9 #include <fstream>
10 #include <numeric>
11 #include <random>
12 #include <thread>
13 #include <utility>
14 #include <vector>
15 
16 #include "ck/utility/data_type.hpp"
17 #include "ck/utility/span.hpp"
19 
23 
24 template <typename Range>
25 std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim)
26 {
27  bool first = true;
28  for(auto&& v : range)
29  {
30  if(first)
31  first = false;
32  else
33  os << delim;
34  os << v;
35  }
36  return os;
37 }
38 
39 template <typename T, typename Range>
40 std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim)
41 {
42  bool first = true;
43  for(auto&& v : range)
44  {
45  if(first)
46  first = false;
47  else
48  os << delim;
49 
50  using RangeType = ck::remove_cvref_t<decltype(v)>;
51  if constexpr(std::is_same_v<RangeType, ck::f8_t> || std::is_same_v<RangeType, ck::bf8_t> ||
52  std::is_same_v<RangeType, ck::bhalf_t>)
53  {
54  os << ck::type_convert<float>(v);
55  }
56  else if constexpr(std::is_same_v<RangeType, ck::pk_i4_t> ||
57  std::is_same_v<RangeType, ck::f4x2_pk_t>)
58  {
59  const auto packed_floats = ck::type_convert<ck::float2_t>(v);
60  const ck::vector_type<float, 2> vector_of_floats{packed_floats};
61  os << vector_of_floats.template AsType<float>()[ck::Number<0>{}] << delim
62  << vector_of_floats.template AsType<float>()[ck::Number<1>{}];
63  }
64  else
65  {
66  os << static_cast<T>(v);
67  }
68  }
69  return os;
70 }
71 
72 template <typename F, typename T, std::size_t... Is>
73 auto call_f_unpack_args_impl(F f, T args, std::index_sequence<Is...>)
74 {
75  return f(std::get<Is>(args)...);
76 }
77 
78 template <typename F, typename T>
79 auto call_f_unpack_args(F f, T args)
80 {
81  constexpr std::size_t N = std::tuple_size<T>{};
82 
83  return call_f_unpack_args_impl(f, args, std::make_index_sequence<N>{});
84 }
85 
86 template <typename F, typename T, std::size_t... Is>
87 auto construct_f_unpack_args_impl(T args, std::index_sequence<Is...>)
88 {
89  return F(std::get<Is>(args)...);
90 }
91 
92 template <typename F, typename T>
93 auto construct_f_unpack_args(F, T args)
94 {
95  constexpr std::size_t N = std::tuple_size<T>{};
96 
97  return construct_f_unpack_args_impl<F>(args, std::make_index_sequence<N>{});
98 }
99 
101 {
102  HostTensorDescriptor() = default;
103 
105 
106  template <typename X, typename = std::enable_if_t<std::is_convertible_v<X, std::size_t>>>
107  HostTensorDescriptor(const std::initializer_list<X>& lens) : mLens(lens.begin(), lens.end())
108  {
109  this->CalculateStrides();
110  }
111 
112  HostTensorDescriptor(const std::initializer_list<ck::long_index_t>& lens)
113  : mLens(lens.begin(), lens.end())
114  {
115  this->CalculateStrides();
116  }
117 
118  template <typename Lengths,
119  typename = std::enable_if_t<
120  std::is_convertible_v<ck::ranges::range_value_t<Lengths>, std::size_t> ||
121  std::is_convertible_v<ck::ranges::range_value_t<Lengths>, ck::long_index_t>>>
122  HostTensorDescriptor(const Lengths& lens) : mLens(lens.begin(), lens.end())
123  {
124  this->CalculateStrides();
125  }
126 
127  template <typename X,
128  typename Y,
129  typename = std::enable_if_t<std::is_convertible_v<X, std::size_t> &&
130  std::is_convertible_v<Y, std::size_t>>>
131  HostTensorDescriptor(const std::initializer_list<X>& lens,
132  const std::initializer_list<Y>& strides)
133  : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
134  {
135  }
136 
137  HostTensorDescriptor(const std::initializer_list<ck::long_index_t>& lens,
138  const std::initializer_list<ck::long_index_t>& strides)
139  : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
140  {
141  }
142 
143  template <typename Lengths,
144  typename Strides,
145  typename = std::enable_if_t<
146  (std::is_convertible_v<ck::ranges::range_value_t<Lengths>, std::size_t> &&
147  std::is_convertible_v<ck::ranges::range_value_t<Strides>, std::size_t>) ||
148  (std::is_convertible_v<ck::ranges::range_value_t<Lengths>, ck::long_index_t> &&
149  std::is_convertible_v<ck::ranges::range_value_t<Strides>, ck::long_index_t>)>>
150  HostTensorDescriptor(const Lengths& lens, const Strides& strides)
151  : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
152  {
153  }
154 
155  std::size_t GetNumOfDimension() const;
156  std::size_t GetElementSize() const;
157  std::size_t GetElementSpaceSize() const;
158 
159  const std::vector<std::size_t>& GetLengths() const;
160  const std::vector<std::size_t>& GetStrides() const;
161 
162  template <typename... Is>
163  std::size_t GetOffsetFromMultiIndex(Is... is) const
164  {
165  assert(sizeof...(Is) == this->GetNumOfDimension());
166  std::initializer_list<std::size_t> iss{static_cast<std::size_t>(is)...};
167  return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
168  }
169 
170  std::size_t GetOffsetFromMultiIndex(const std::vector<std::size_t>& iss) const
171  {
172  return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
173  }
174 
175  friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc);
176 
177  private:
178  std::vector<std::size_t> mLens;
179  std::vector<std::size_t> mStrides;
180 };
181 
182 template <typename New2Old>
184  const New2Old& new2old)
185 {
186  std::vector<std::size_t> new_lengths(a.GetNumOfDimension());
187  std::vector<std::size_t> new_strides(a.GetNumOfDimension());
188 
189  for(std::size_t i = 0; i < a.GetNumOfDimension(); i++)
190  {
191  new_lengths[i] = a.GetLengths()[new2old[i]];
192  new_strides[i] = a.GetStrides()[new2old[i]];
193  }
194 
195  return HostTensorDescriptor(new_lengths, new_strides);
196 }
197 
198 struct joinable_thread : std::thread
199 {
200  template <typename... Xs>
201  joinable_thread(Xs&&... xs) : std::thread(std::forward<Xs>(xs)...)
202  {
203  }
204 
207 
209  {
210  if(this->joinable())
211  this->join();
212  }
213 };
214 
215 template <typename F, typename... Xs>
217 {
218  F mF;
219  static constexpr std::size_t NDIM = sizeof...(Xs);
220  std::array<std::size_t, NDIM> mLens;
221  std::array<std::size_t, NDIM> mStrides;
222  std::size_t mN1d;
223 
224  ParallelTensorFunctor(F f, Xs... xs) : mF(f), mLens({static_cast<std::size_t>(xs)...})
225  {
226  mStrides.back() = 1;
227  std::partial_sum(mLens.rbegin(),
228  mLens.rend() - 1,
229  mStrides.rbegin() + 1,
230  std::multiplies<std::size_t>());
231  mN1d = mStrides[0] * mLens[0];
232  }
233 
234  std::array<std::size_t, NDIM> GetNdIndices(std::size_t i) const
235  {
236  std::array<std::size_t, NDIM> indices;
237 
238  for(std::size_t idim = 0; idim < NDIM; ++idim)
239  {
240  indices[idim] = i / mStrides[idim];
241  i -= indices[idim] * mStrides[idim];
242  }
243 
244  return indices;
245  }
246 
247  void operator()(std::size_t num_thread = 1) const
248  {
249  std::size_t work_per_thread = (mN1d + num_thread - 1) / num_thread;
250 
251  std::vector<joinable_thread> threads(num_thread);
252 
253  for(std::size_t it = 0; it < num_thread; ++it)
254  {
255  std::size_t iw_begin = it * work_per_thread;
256  std::size_t iw_end = std::min((it + 1) * work_per_thread, mN1d);
257 
258  auto f = [=, *this] {
259  for(std::size_t iw = iw_begin; iw < iw_end; ++iw)
260  {
262  }
263  };
264  threads[it] = joinable_thread(f);
265  }
266  }
267 };
268 
269 template <typename F, typename... Xs>
270 auto make_ParallelTensorFunctor(F f, Xs... xs)
271 {
272  return ParallelTensorFunctor<F, Xs...>(f, xs...);
273 }
274 
275 template <typename T>
276 struct Tensor
277 {
279  using Data = std::vector<T>;
280 
281  template <typename X>
282  Tensor(std::initializer_list<X> lens) : mDesc(lens), mData(GetElementSpaceSize())
283  {
284  }
285 
286  template <typename X, typename Y>
287  Tensor(std::initializer_list<X> lens, std::initializer_list<Y> strides)
288  : mDesc(lens, strides), mData(GetElementSpaceSize())
289  {
290  }
291 
292  template <typename Lengths>
293  Tensor(const Lengths& lens) : mDesc(lens), mData(GetElementSpaceSize())
294  {
295  }
296 
297  template <typename Lengths, typename Strides>
298  Tensor(const Lengths& lens, const Strides& strides)
299  : mDesc(lens, strides), mData(GetElementSpaceSize())
300  {
301  }
302 
303  Tensor(const Descriptor& desc) : mDesc(desc), mData(GetElementSpaceSize()) {}
304 
305  template <typename OutT>
307  {
308  Tensor<OutT> ret(mDesc);
309 
311  mData, ret.mData.begin(), [](auto value) { return ck::type_convert<OutT>(value); });
312 
313  return ret;
314  }
315 
316  Tensor() = delete;
317  Tensor(const Tensor&) = default;
318  Tensor(Tensor&&) = default;
319 
320  ~Tensor() = default;
321 
322  Tensor& operator=(const Tensor&) = default;
323  Tensor& operator=(Tensor&&) = default;
324 
325  template <typename FromT>
326  explicit Tensor(const Tensor<FromT>& other) : Tensor(other.template CopyAsType<T>())
327  {
328  }
329  void savetxt(std::string file_name, std::string dtype = "float")
330  {
331  std::ofstream file(file_name);
332 
333  if(file.is_open())
334  {
335  for(auto& itm : mData)
336  {
337  if(dtype == "float")
338  file << ck::type_convert<float>(itm) << std::endl;
339  else if(dtype == "int")
340  file << ck::type_convert<int>(itm) << std::endl;
341  else
342  // TODO: we didn't implement operator<< for all custom
343  // data types, here fall back to float in case compile error
344  file << ck::type_convert<float>(itm) << std::endl;
345  }
346  file.close();
347  }
348  else
349  {
350  // Print an error message to the standard error
351  // stream if the file cannot be opened.
352  throw std::runtime_error(std::string("unable to open file:") + file_name);
353  }
354  }
355  decltype(auto) GetLengths() const { return mDesc.GetLengths(); }
356 
357  decltype(auto) GetStrides() const { return mDesc.GetStrides(); }
358 
359  std::size_t GetNumOfDimension() const { return mDesc.GetNumOfDimension(); }
360 
361  std::size_t GetElementSize() const { return mDesc.GetElementSize(); }
362 
363  std::size_t GetElementSpaceSize() const
364  {
366  {
367  return (mDesc.GetElementSpaceSize() + 1) / ck::packed_size_v<ck::remove_cvref_t<T>>;
368  }
369  else
370  {
371  return mDesc.GetElementSpaceSize();
372  }
373  }
374 
375  std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); }
376 
377  void SetZero() { ck::ranges::fill<T>(mData, T{0}); }
378 
379  template <typename F>
380  void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank)
381  {
382  if(rank == mDesc.GetNumOfDimension())
383  {
384  f(*this, idx);
385  return;
386  }
387  // else
388  for(size_t i = 0; i < mDesc.GetLengths()[rank]; i++)
389  {
390  idx[rank] = i;
391  ForEach_impl(std::forward<F>(f), idx, rank + 1);
392  }
393  }
394 
395  template <typename F>
396  void ForEach(F&& f)
397  {
398  std::vector<size_t> idx(mDesc.GetNumOfDimension(), 0);
399  ForEach_impl(std::forward<F>(f), idx, size_t(0));
400  }
401 
402  template <typename F>
403  void ForEach_impl(const F&& f, std::vector<size_t>& idx, size_t rank) const
404  {
405  if(rank == mDesc.GetNumOfDimension())
406  {
407  f(*this, idx);
408  return;
409  }
410  // else
411  for(size_t i = 0; i < mDesc.GetLengths()[rank]; i++)
412  {
413  idx[rank] = i;
414  ForEach_impl(std::forward<const F>(f), idx, rank + 1);
415  }
416  }
417 
418  template <typename F>
419  void ForEach(const F&& f) const
420  {
421  std::vector<size_t> idx(mDesc.GetNumOfDimension(), 0);
422  ForEach_impl(std::forward<const F>(f), idx, size_t(0));
423  }
424 
425  template <typename G>
426  void GenerateTensorValue(G g, std::size_t num_thread = 1)
427  {
428  switch(mDesc.GetNumOfDimension())
429  {
430  case 1: {
431  auto f = [&](auto i) { (*this)(i) = g(i); };
432  make_ParallelTensorFunctor(f, mDesc.GetLengths()[0])(num_thread);
433  break;
434  }
435  case 2: {
436  auto f = [&](auto i0, auto i1) { (*this)(i0, i1) = g(i0, i1); };
437  make_ParallelTensorFunctor(f, mDesc.GetLengths()[0], mDesc.GetLengths()[1])(num_thread);
438  break;
439  }
440  case 3: {
441  auto f = [&](auto i0, auto i1, auto i2) { (*this)(i0, i1, i2) = g(i0, i1, i2); };
443  f, mDesc.GetLengths()[0], mDesc.GetLengths()[1], mDesc.GetLengths()[2])(num_thread);
444  break;
445  }
446  case 4: {
447  auto f = [&](auto i0, auto i1, auto i2, auto i3) {
448  (*this)(i0, i1, i2, i3) = g(i0, i1, i2, i3);
449  };
451  mDesc.GetLengths()[0],
452  mDesc.GetLengths()[1],
453  mDesc.GetLengths()[2],
454  mDesc.GetLengths()[3])(num_thread);
455  break;
456  }
457  case 5: {
458  auto f = [&](auto i0, auto i1, auto i2, auto i3, auto i4) {
459  (*this)(i0, i1, i2, i3, i4) = g(i0, i1, i2, i3, i4);
460  };
462  mDesc.GetLengths()[0],
463  mDesc.GetLengths()[1],
464  mDesc.GetLengths()[2],
465  mDesc.GetLengths()[3],
466  mDesc.GetLengths()[4])(num_thread);
467  break;
468  }
469  case 6: {
470  auto f = [&](auto i0, auto i1, auto i2, auto i3, auto i4, auto i5) {
471  (*this)(i0, i1, i2, i3, i4, i5) = g(i0, i1, i2, i3, i4, i5);
472  };
474  mDesc.GetLengths()[0],
475  mDesc.GetLengths()[1],
476  mDesc.GetLengths()[2],
477  mDesc.GetLengths()[3],
478  mDesc.GetLengths()[4],
479  mDesc.GetLengths()[5])(num_thread);
480  break;
481  }
482  case 12: {
483  auto f = [&](auto i0,
484  auto i1,
485  auto i2,
486  auto i3,
487  auto i4,
488  auto i5,
489  auto i6,
490  auto i7,
491  auto i8,
492  auto i9,
493  auto i10,
494  auto i11) {
495  (*this)(i0, i1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11) =
496  g(i0, i1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11);
497  };
499  mDesc.GetLengths()[0],
500  mDesc.GetLengths()[1],
501  mDesc.GetLengths()[2],
502  mDesc.GetLengths()[3],
503  mDesc.GetLengths()[4],
504  mDesc.GetLengths()[5],
505  mDesc.GetLengths()[6],
506  mDesc.GetLengths()[7],
507  mDesc.GetLengths()[8],
508  mDesc.GetLengths()[9],
509  mDesc.GetLengths()[10],
510  mDesc.GetLengths()[11])(num_thread);
511  break;
512  }
513  default: throw std::runtime_error("unspported dimension");
514  }
515  }
516 
517  // Generate random values with multiple threads. Guaranteed to give the same sequence with any
518  // number of threads provided.
519  template <typename Distribution = std::uniform_real_distribution<float>,
520  typename Mapping = ck::identity,
521  typename Generator = std::minstd_rand>
522  void GenerateTensorDistr(Distribution dis = {0.f, 1.f},
523  Mapping fn = {},
524  const Generator g = Generator(0), // default seed 0
525  std::size_t num_thread = -1)
526  {
528  using ck::math::min;
529  if(num_thread == -1ULL)
530  num_thread = min(ck::get_available_cpu_cores(), 80U); // max 80 threads
531  // At least 2MB per thread
532  num_thread = min(num_thread, integer_divide_ceil(this->GetElementSpaceSize(), 0x200000));
533  constexpr std::size_t BLOCK_BYTES = 64;
534  constexpr std::size_t BLOCK_SIZE = BLOCK_BYTES / sizeof(T);
535 
536  const std::size_t num_blocks = integer_divide_ceil(this->GetElementSpaceSize(), BLOCK_SIZE);
537  const std::size_t blocks_per_thread = integer_divide_ceil(num_blocks, num_thread);
538 
539  std::vector<std::thread> threads;
540  threads.reserve(num_thread - 1);
541  const auto dst = const_cast<T*>(this->mData.data());
542  const auto element_space_size = this->GetElementSpaceSize();
543  for(int it = num_thread - 1; it >= 0; --it)
544  {
545  std::size_t ib_begin = it * blocks_per_thread;
546  std::size_t ib_end = min(ib_begin + blocks_per_thread, num_blocks);
547 
548  auto job = [=]() {
549  auto g_ = g; // copy
550  auto dis_ = dis; // copy
551  g_.discard(ib_begin * BLOCK_SIZE * ck::packed_size_v<T>);
552  auto t_fn = [&]() {
553  // As user can pass integer distribution in dis, we must ensure that the correct
554  // constructor/converter is called at all times. For f4/f6/f8 types, to ensure
555  // correct results, we convert from float to the target type. In these cases
556  // integer constructors are interpreted as direct initialization of the internal
557  // storage with binary values instead of treating integers as subset of floats.
558  if constexpr(ck::is_same_v<T, ck::f8_t> || ck::is_same_v<T, ck::bf8_t>)
559  return ck::type_convert<T>(static_cast<float>(fn(dis_(g_))));
560  else if constexpr(ck::packed_size_v<T> == 1)
561  return ck::type_convert<T>(fn(dis_(g_)));
562  else if constexpr(ck::is_same_v<T, ck::f4x2_pk_t>)
563  return ck::f4x2_pk_t{ck::type_convert<ck::f4x2_t>(
564  ck::float2_t{ck::type_convert<float>(fn(dis_(g_))),
565  ck::type_convert<float>(fn(dis_(g_)))})};
566  else if constexpr(ck::is_same_v<T, ck::f6x32_pk_t> ||
567  ck::is_same_v<T, ck::bf6x32_pk_t>)
568  {
569  return ck::type_convert<T>(
570  ck::float32_t{ck::type_convert<float>(fn(dis_(g_))),
571  ck::type_convert<float>(fn(dis_(g_))),
572  ck::type_convert<float>(fn(dis_(g_))),
573  ck::type_convert<float>(fn(dis_(g_))),
574  ck::type_convert<float>(fn(dis_(g_))),
575  ck::type_convert<float>(fn(dis_(g_))),
576  ck::type_convert<float>(fn(dis_(g_))),
577  ck::type_convert<float>(fn(dis_(g_))),
578  ck::type_convert<float>(fn(dis_(g_))),
579  ck::type_convert<float>(fn(dis_(g_))),
580  ck::type_convert<float>(fn(dis_(g_))),
581  ck::type_convert<float>(fn(dis_(g_))),
582  ck::type_convert<float>(fn(dis_(g_))),
583  ck::type_convert<float>(fn(dis_(g_))),
584  ck::type_convert<float>(fn(dis_(g_))),
585  ck::type_convert<float>(fn(dis_(g_))),
586  ck::type_convert<float>(fn(dis_(g_))),
587  ck::type_convert<float>(fn(dis_(g_))),
588  ck::type_convert<float>(fn(dis_(g_))),
589  ck::type_convert<float>(fn(dis_(g_))),
590  ck::type_convert<float>(fn(dis_(g_))),
591  ck::type_convert<float>(fn(dis_(g_))),
592  ck::type_convert<float>(fn(dis_(g_))),
593  ck::type_convert<float>(fn(dis_(g_))),
594  ck::type_convert<float>(fn(dis_(g_))),
595  ck::type_convert<float>(fn(dis_(g_))),
596  ck::type_convert<float>(fn(dis_(g_))),
597  ck::type_convert<float>(fn(dis_(g_))),
598  ck::type_convert<float>(fn(dis_(g_))),
599  ck::type_convert<float>(fn(dis_(g_))),
600  ck::type_convert<float>(fn(dis_(g_))),
601  ck::type_convert<float>(fn(dis_(g_)))});
602  }
603  else if constexpr(ck::is_same_v<T, ck::f6x16_pk_t> ||
604  ck::is_same_v<T, ck::bf6x16_pk_t>)
605  {
606  return ck::type_convert<T>(
607  ck::float16_t{ck::type_convert<float>(fn(dis_(g_))),
608  ck::type_convert<float>(fn(dis_(g_))),
609  ck::type_convert<float>(fn(dis_(g_))),
610  ck::type_convert<float>(fn(dis_(g_))),
611  ck::type_convert<float>(fn(dis_(g_))),
612  ck::type_convert<float>(fn(dis_(g_))),
613  ck::type_convert<float>(fn(dis_(g_))),
614  ck::type_convert<float>(fn(dis_(g_))),
615  ck::type_convert<float>(fn(dis_(g_))),
616  ck::type_convert<float>(fn(dis_(g_))),
617  ck::type_convert<float>(fn(dis_(g_))),
618  ck::type_convert<float>(fn(dis_(g_))),
619  ck::type_convert<float>(fn(dis_(g_))),
620  ck::type_convert<float>(fn(dis_(g_))),
621  ck::type_convert<float>(fn(dis_(g_))),
622  ck::type_convert<float>(fn(dis_(g_)))});
623  }
624  else
625  static_assert(false, "Unsupported packed size for T");
626  };
627 
628  std::size_t ib = ib_begin;
629  for(; ib < ib_end - 1; ++ib)
630  ck::static_for<0, BLOCK_SIZE, 1>{}([&](auto iw_) {
631  constexpr size_t iw = iw_.value;
632  dst[ib * BLOCK_SIZE + iw] = t_fn();
633  });
634  for(std::size_t iw = 0; iw < BLOCK_SIZE; ++iw)
635  if(ib * BLOCK_SIZE + iw < element_space_size)
636  dst[ib * BLOCK_SIZE + iw] = t_fn();
637  };
638 
639  if(it > 0)
640  threads.emplace_back(std::move(job));
641  else
642  job(); // last job run in the main thread
643  }
644  for(auto& t : threads)
645  t.join();
646  }
647 
648  template <typename... Is>
649  std::size_t GetOffsetFromMultiIndex(Is... is) const
650  {
651  return mDesc.GetOffsetFromMultiIndex(is...) / ck::packed_size_v<ck::remove_cvref_t<T>>;
652  }
653 
654  template <typename... Is>
655  T& operator()(Is... is)
656  {
657  return mData[mDesc.GetOffsetFromMultiIndex(is...) /
658  ck::packed_size_v<ck::remove_cvref_t<T>>];
659  }
660 
661  template <typename... Is>
662  const T& operator()(Is... is) const
663  {
664  return mData[mDesc.GetOffsetFromMultiIndex(is...) /
665  ck::packed_size_v<ck::remove_cvref_t<T>>];
666  }
667 
668  T& operator()(const std::vector<std::size_t>& idx)
669  {
670  return mData[mDesc.GetOffsetFromMultiIndex(idx) / ck::packed_size_v<ck::remove_cvref_t<T>>];
671  }
672 
673  const T& operator()(const std::vector<std::size_t>& idx) const
674  {
675  return mData[mDesc.GetOffsetFromMultiIndex(idx) / ck::packed_size_v<ck::remove_cvref_t<T>>];
676  }
677 
678  typename Data::iterator begin() { return mData.begin(); }
679 
680  typename Data::iterator end() { return mData.end(); }
681 
682  typename Data::pointer data() { return mData.data(); }
683 
684  typename Data::const_iterator begin() const { return mData.begin(); }
685 
686  typename Data::const_iterator end() const { return mData.end(); }
687 
688  typename Data::const_pointer data() const { return mData.data(); }
689 
690  typename Data::size_type size() const { return mData.size(); }
691 
692  template <typename U = T>
693  auto AsSpan() const
694  {
695  constexpr std::size_t FromSize = sizeof(T);
696  constexpr std::size_t ToSize = sizeof(U);
697 
698  using Element = std::add_const_t<std::remove_reference_t<U>>;
699  return ck::span<Element>{reinterpret_cast<Element*>(data()), size() * FromSize / ToSize};
700  }
701 
702  template <typename U = T>
703  auto AsSpan()
704  {
705  constexpr std::size_t FromSize = sizeof(T);
706  constexpr std::size_t ToSize = sizeof(U);
707 
708  using Element = std::remove_reference_t<U>;
709  return ck::span<Element>{reinterpret_cast<Element*>(data()), size() * FromSize / ToSize};
710  }
711 
714 };
Definition: span.hpp:14
__host__ constexpr __device__ auto rank([[maybe_unused]] const Layout< Shape, UnrolledDescriptorType > &layout)
Get layout rank (num elements in shape).
Definition: layout_utils.hpp:310
auto call_f_unpack_args_impl(F f, T args, std::index_sequence< Is... >)
Definition: host_tensor.hpp:73
std::ostream & LogRangeAsType(std::ostream &os, Range &&range, std::string delim)
Definition: host_tensor.hpp:40
auto construct_f_unpack_args_impl(T args, std::index_sequence< Is... >)
Definition: host_tensor.hpp:87
auto call_f_unpack_args(F f, T args)
Definition: host_tensor.hpp:79
auto construct_f_unpack_args(F, T args)
Definition: host_tensor.hpp:93
auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:270
std::ostream & LogRange(std::ostream &os, Range &&range, std::string delim)
Definition: host_tensor.hpp:25
HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old(const HostTensorDescriptor &a, const New2Old &new2old)
Definition: host_tensor.hpp:183
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
auto transform(InputRange &&range, OutputIterator iter, UnaryOperation unary_op) -> decltype(std::transform(std::begin(range), std::end(range), iter, unary_op))
Definition: algorithm.hpp:36
iter_value_t< ranges::iterator_t< R > > range_value_t
Definition: ranges.hpp:28
Definition: ck.hpp:267
typename vector_type< float, 16 >::type float16_t
Definition: dtype_vector.hpp:2134
unsigned int get_available_cpu_cores()
Definition: thread.hpp:11
int64_t long_index_t
Definition: ck.hpp:299
typename vector_type< float, 2 >::type float2_t
Definition: dtype_vector.hpp:2131
__host__ constexpr __device__ Y type_convert(X x)
Definition: type_convert.hpp:98
constexpr bool is_same_v
Definition: type.hpp:283
constexpr bool is_packed_type_v
Definition: data_type.hpp:411
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
__device__ void inner_product(const TA &a, const TB &b, TC &c)
typename vector_type< float, 32 >::type float32_t
Definition: dtype_vector.hpp:2135
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
const GenericPointer< typename T::ValueType > & pointer
Definition: pointer.h:1249
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1249
Definition: host_tensor.hpp:101
HostTensorDescriptor(const Lengths &lens)
Definition: host_tensor.hpp:122
const std::vector< std::size_t > & GetStrides() const
HostTensorDescriptor(const std::initializer_list< X > &lens)
Definition: host_tensor.hpp:107
std::size_t GetElementSize() const
const std::vector< std::size_t > & GetLengths() const
std::size_t GetOffsetFromMultiIndex(Is... is) const
Definition: host_tensor.hpp:163
HostTensorDescriptor(const std::initializer_list< X > &lens, const std::initializer_list< Y > &strides)
Definition: host_tensor.hpp:131
HostTensorDescriptor(const std::initializer_list< ck::long_index_t > &lens, const std::initializer_list< ck::long_index_t > &strides)
Definition: host_tensor.hpp:137
std::size_t GetNumOfDimension() const
std::size_t GetElementSpaceSize() const
HostTensorDescriptor()=default
HostTensorDescriptor(const Lengths &lens, const Strides &strides)
Definition: host_tensor.hpp:150
std::size_t GetOffsetFromMultiIndex(const std::vector< std::size_t > &iss) const
Definition: host_tensor.hpp:170
friend std::ostream & operator<<(std::ostream &os, const HostTensorDescriptor &desc)
HostTensorDescriptor(const std::initializer_list< ck::long_index_t > &lens)
Definition: host_tensor.hpp:112
Definition: host_tensor.hpp:217
std::array< std::size_t, NDIM > GetNdIndices(std::size_t i) const
Definition: host_tensor.hpp:234
F mF
Definition: host_tensor.hpp:218
std::size_t mN1d
Definition: host_tensor.hpp:222
ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:224
std::array< std::size_t, NDIM > mLens
Definition: host_tensor.hpp:220
std::array< std::size_t, NDIM > mStrides
Definition: host_tensor.hpp:221
void operator()(std::size_t num_thread=1) const
Definition: host_tensor.hpp:247
static constexpr std::size_t NDIM
Definition: host_tensor.hpp:219
Tensor wrapper that performs static and dynamic buffer logic. The tensor is based on a descriptor sto...
Definition: host_tensor.hpp:277
auto AsSpan() const
Definition: host_tensor.hpp:693
Tensor(const Lengths &lens, const Strides &strides)
Definition: host_tensor.hpp:298
Tensor()=delete
std::size_t GetNumOfDimension() const
Definition: host_tensor.hpp:359
T & operator()(const std::vector< std::size_t > &idx)
Definition: host_tensor.hpp:668
void ForEach(const F &&f) const
Definition: host_tensor.hpp:419
decltype(auto) GetLengths() const
Definition: host_tensor.hpp:355
Data::const_iterator end() const
Definition: host_tensor.hpp:686
std::size_t GetOffsetFromMultiIndex(Is... is) const
Definition: host_tensor.hpp:649
Tensor< OutT > CopyAsType() const
Definition: host_tensor.hpp:306
const T & operator()(const std::vector< std::size_t > &idx) const
Definition: host_tensor.hpp:673
void ForEach(F &&f)
Definition: host_tensor.hpp:396
Data::pointer data()
Definition: host_tensor.hpp:682
void ForEach_impl(F &&f, std::vector< size_t > &idx, size_t rank)
Definition: host_tensor.hpp:380
std::size_t GetElementSpaceSizeInBytes() const
Definition: host_tensor.hpp:375
void ForEach_impl(const F &&f, std::vector< size_t > &idx, size_t rank) const
Definition: host_tensor.hpp:403
Tensor & operator=(const Tensor &)=default
std::vector< T > Data
Definition: host_tensor.hpp:279
Data mData
Definition: host_tensor.hpp:713
Data::iterator end()
Definition: host_tensor.hpp:680
void GenerateTensorDistr(Distribution dis={0.f, 1.f}, Mapping fn={}, const Generator g=Generator(0), std::size_t num_thread=-1)
Definition: host_tensor.hpp:522
std::size_t GetElementSize() const
Definition: host_tensor.hpp:361
~Tensor()=default
void SetZero()
Definition: host_tensor.hpp:377
Tensor(const Lengths &lens)
Definition: host_tensor.hpp:293
void savetxt(std::string file_name, std::string dtype="float")
Definition: host_tensor.hpp:329
Tensor(Tensor &&)=default
const T & operator()(Is... is) const
Definition: host_tensor.hpp:662
Data::const_pointer data() const
Definition: host_tensor.hpp:688
auto AsSpan()
Definition: host_tensor.hpp:703
Data::iterator begin()
Definition: host_tensor.hpp:678
Tensor(std::initializer_list< X > lens, std::initializer_list< Y > strides)
Definition: host_tensor.hpp:287
Tensor(const Tensor &)=default
Tensor(const Descriptor &desc)
Definition: host_tensor.hpp:303
Descriptor mDesc
Definition: host_tensor.hpp:712
Tensor & operator=(Tensor &&)=default
Data::const_iterator begin() const
Definition: host_tensor.hpp:684
std::size_t GetElementSpaceSize() const
Definition: host_tensor.hpp:363
Tensor(const Tensor< FromT > &other)
Definition: host_tensor.hpp:326
Data::size_type size() const
Definition: host_tensor.hpp:690
void GenerateTensorValue(G g, std::size_t num_thread=1)
Definition: host_tensor.hpp:426
decltype(auto) GetStrides() const
Definition: host_tensor.hpp:357
T & operator()(Is... is)
Definition: host_tensor.hpp:655
Tensor(std::initializer_list< X > lens)
Definition: host_tensor.hpp:282
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10
Definition: host_tensor.hpp:199
joinable_thread(joinable_thread &&)=default
joinable_thread(Xs &&... xs)
Definition: host_tensor.hpp:201
~joinable_thread()
Definition: host_tensor.hpp:208
joinable_thread & operator=(joinable_thread &&)=default