/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/host_tensor.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/host_tensor.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/host_tensor.hpp Source File
host_tensor.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <algorithm>
7 #include <cassert>
8 #include <iostream>
9 #include <iomanip>
10 #include <numeric>
11 #include <utility>
12 #include <vector>
13 #include <functional>
14 #include <fstream>
15 
16 #include "ck_tile/core.hpp"
18 #include "ck_tile/host/ranges.hpp"
19 
20 namespace ck_tile {
21 
22 template <typename Range>
23 CK_TILE_HOST std::ostream& LogRange(std::ostream& os,
24  Range&& range,
25  std::string delim,
26  int precision = std::cout.precision(),
27  int width = 0)
28 {
29  bool first = true;
30  for(auto&& v : range)
31  {
32  if(first)
33  first = false;
34  else
35  os << delim;
36  os << std::setw(width) << std::setprecision(precision) << v;
37  }
38  return os;
39 }
40 
41 template <typename T, typename Range>
42 CK_TILE_HOST std::ostream& LogRangeAsType(std::ostream& os,
43  Range&& range,
44  std::string delim,
45  int precision = std::cout.precision(),
46  int width = 0)
47 {
48  bool first = true;
49  for(auto&& v : range)
50  {
51  if(first)
52  first = false;
53  else
54  os << delim;
55  os << std::setw(width) << std::setprecision(precision) << static_cast<T>(v);
56  }
57  return os;
58 }
59 
60 template <typename F, typename T, std::size_t... Is>
61 CK_TILE_HOST auto call_f_unpack_args_impl(F f, T args, std::index_sequence<Is...>)
62 {
63  return f(std::get<Is>(args)...);
64 }
65 
66 template <typename F, typename T>
68 {
69  constexpr std::size_t N = std::tuple_size<T>{};
70 
71  return call_f_unpack_args_impl(f, args, std::make_index_sequence<N>{});
72 }
73 
74 template <typename F, typename T, std::size_t... Is>
75 CK_TILE_HOST auto construct_f_unpack_args_impl(T args, std::index_sequence<Is...>)
76 {
77  return F(std::get<Is>(args)...);
78 }
79 
80 template <typename F, typename T>
82 {
83  constexpr std::size_t N = std::tuple_size<T>{};
84 
85  return construct_f_unpack_args_impl<F>(args, std::make_index_sequence<N>{});
86 }
87 
102 {
103  HostTensorDescriptor() = default;
104 
106  {
107  mStrides.clear();
108  mStrides.resize(mLens.size(), 0);
109  if(mStrides.empty())
110  return;
111 
112  mStrides.back() = 1;
113  std::partial_sum(mLens.rbegin(),
114  mLens.rend() - 1,
115  mStrides.rbegin() + 1,
116  std::multiplies<std::size_t>());
117  }
118 
119  template <typename X, typename = std::enable_if_t<std::is_convertible_v<X, std::size_t>>>
120  HostTensorDescriptor(const std::initializer_list<X>& lens) : mLens(lens.begin(), lens.end())
121  {
122  this->CalculateStrides();
123  }
124 
125  template <typename Lengths,
126  typename = std::enable_if_t<
127  std::is_convertible_v<ck_tile::ranges::range_value_t<Lengths>, std::size_t>>>
128  HostTensorDescriptor(const Lengths& lens) : mLens(lens.begin(), lens.end())
129  {
130  this->CalculateStrides();
131  }
132 
133  template <typename X,
134  typename Y,
135  typename = std::enable_if_t<std::is_convertible_v<X, std::size_t> &&
136  std::is_convertible_v<Y, std::size_t>>>
137  HostTensorDescriptor(const std::initializer_list<X>& lens,
138  const std::initializer_list<Y>& strides)
139  : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
140  {
141  }
142 
143  template <typename Lengths,
144  typename Strides,
145  typename = std::enable_if_t<
146  std::is_convertible_v<ck_tile::ranges::range_value_t<Lengths>, std::size_t> &&
147  std::is_convertible_v<ck_tile::ranges::range_value_t<Strides>, std::size_t>>>
148  HostTensorDescriptor(const Lengths& lens, const Strides& strides)
149  : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
150  {
151  }
152 
153  std::size_t get_num_of_dimension() const { return mLens.size(); }
165  std::size_t get_element_size() const
166  {
167  assert(mLens.size() == mStrides.size());
168  return std::accumulate(
169  mLens.begin(), mLens.end(), std::size_t{1}, std::multiplies<std::size_t>());
170  }
183  std::size_t get_element_space_size() const
184  {
185  std::size_t space = 1;
186  for(std::size_t i = 0; i < mLens.size(); ++i)
187  {
188  if(mLens[i] == 0)
189  continue;
190 
191  space += (mLens[i] - 1) * mStrides[i];
192  }
193  return space;
194  }
195 
196  std::size_t get_length(std::size_t dim) const { return mLens[dim]; }
197 
198  const std::vector<std::size_t>& get_lengths() const { return mLens; }
199 
200  std::size_t get_stride(std::size_t dim) const { return mStrides[dim]; }
201 
202  const std::vector<std::size_t>& get_strides() const { return mStrides; }
203 
216  template <typename... Is>
217  std::size_t GetOffsetFromMultiIndex(Is... is) const
218  {
219  assert(sizeof...(Is) == this->get_num_of_dimension());
220  std::initializer_list<std::size_t> iss{static_cast<std::size_t>(is)...};
221  return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
222  }
223 
233  std::size_t GetOffsetFromMultiIndex(const std::vector<std::size_t>& iss) const
234  {
235  return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
236  }
237 
238  friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc)
239  {
240  os << "dim " << desc.get_num_of_dimension() << ", ";
241 
242  os << "lengths {";
243  LogRange(os, desc.get_lengths(), ", ");
244  os << "}, ";
245 
246  os << "strides {";
247  LogRange(os, desc.get_strides(), ", ");
248  os << "}";
249 
250  return os;
251  }
252 
253  private:
254  std::vector<std::size_t> mLens;
255  std::vector<std::size_t> mStrides;
256 };
257 
258 template <typename New2Old>
260  const HostTensorDescriptor& a, const New2Old& new2old)
261 {
262  std::vector<std::size_t> new_lengths(a.get_num_of_dimension());
263  std::vector<std::size_t> new_strides(a.get_num_of_dimension());
264 
265  for(std::size_t i = 0; i < a.get_num_of_dimension(); i++)
266  {
267  new_lengths[i] = a.get_lengths()[new2old[i]];
268  new_strides[i] = a.get_strides()[new2old[i]];
269  }
270 
271  return HostTensorDescriptor(new_lengths, new_strides);
272 }
273 
274 template <typename F, typename... Xs>
276 {
277  F mF;
278  static constexpr std::size_t NDIM = sizeof...(Xs);
279  std::array<std::size_t, NDIM> mLens;
280  std::array<std::size_t, NDIM> mStrides;
281  std::size_t mN1d;
282 
283  ParallelTensorFunctor(F f, Xs... xs) : mF(f), mLens({static_cast<std::size_t>(xs)...})
284  {
285  mStrides.back() = 1;
286  std::partial_sum(mLens.rbegin(),
287  mLens.rend() - 1,
288  mStrides.rbegin() + 1,
289  std::multiplies<std::size_t>());
290  mN1d = mStrides[0] * mLens[0];
291  }
292 
293  std::array<std::size_t, NDIM> GetNdIndices(std::size_t i) const
294  {
295  std::array<std::size_t, NDIM> indices;
296 
297  for(std::size_t idim = 0; idim < NDIM; ++idim)
298  {
299  indices[idim] = i / mStrides[idim];
300  i -= indices[idim] * mStrides[idim];
301  }
302 
303  return indices;
304  }
305 
306  void operator()(std::size_t num_thread = 1) const
307  {
308  std::size_t work_per_thread = (mN1d + num_thread - 1) / num_thread;
309 
310  std::vector<joinable_thread> threads(num_thread);
311 
312  for(std::size_t it = 0; it < num_thread; ++it)
313  {
314  std::size_t iw_begin = it * work_per_thread;
315  std::size_t iw_end = std::min((it + 1) * work_per_thread, mN1d);
316 
317  auto f = [this, iw_begin, iw_end] {
318  for(std::size_t iw = iw_begin; iw < iw_end; ++iw)
319  {
320  call_f_unpack_args(this->mF, this->GetNdIndices(iw));
321  }
322  };
323  threads[it] = joinable_thread(f);
324  }
325  }
326 };
327 
328 template <typename F, typename... Xs>
330 {
331  return ParallelTensorFunctor<F, Xs...>(f, xs...);
332 }
333 
334 template <typename T>
336 {
338  using Data = std::vector<T>;
339 
340  template <typename X>
341  HostTensor(std::initializer_list<X> lens) : mDesc(lens), mData(get_element_space_size())
342  {
343  }
344 
345  template <typename X, typename Y>
346  HostTensor(std::initializer_list<X> lens, std::initializer_list<Y> strides)
347  : mDesc(lens, strides), mData(get_element_space_size())
348  {
349  }
350 
351  template <typename Lengths>
352  HostTensor(const Lengths& lens) : mDesc(lens), mData(get_element_space_size())
353  {
354  }
355 
356  template <typename Lengths, typename Strides>
357  HostTensor(const Lengths& lens, const Strides& strides)
358  : mDesc(lens, strides), mData(get_element_space_size())
359  {
360  }
361 
363 
364  template <typename OutT>
366  {
367  HostTensor<OutT> ret(mDesc);
368  std::transform(mData.cbegin(), mData.cend(), ret.mData.begin(), [](auto value) {
369  return ck_tile::type_convert<OutT>(value);
370  });
371  return ret;
372  }
373 
374  HostTensor() = delete;
375  HostTensor(const HostTensor&) = default;
376  HostTensor(HostTensor&&) = default;
377 
378  ~HostTensor() = default;
379 
380  HostTensor& operator=(const HostTensor&) = default;
382 
383  template <typename FromT>
384  explicit HostTensor(const HostTensor<FromT>& other) : HostTensor(other.template CopyAsType<T>())
385  {
386  }
387 
388  std::size_t get_length(std::size_t dim) const { return mDesc.get_length(dim); }
389 
390  decltype(auto) get_lengths() const { return mDesc.get_lengths(); }
391 
392  std::size_t get_stride(std::size_t dim) const { return mDesc.get_stride(dim); }
393 
394  decltype(auto) get_strides() const { return mDesc.get_strides(); }
395 
396  std::size_t get_num_of_dimension() const { return mDesc.get_num_of_dimension(); }
397 
398  std::size_t get_element_size() const { return mDesc.get_element_size(); }
399 
400  std::size_t get_element_space_size() const
401  {
402  constexpr index_t PackedSize = ck_tile::numeric_traits<remove_cvref_t<T>>::PackedSize;
403  return mDesc.get_element_space_size() / PackedSize;
404  }
405 
407  {
408  return sizeof(T) * get_element_space_size();
409  }
410 
411  // void SetZero() { ck_tile::ranges::fill<T>(mData, 0); }
412  void SetZero()
413  {
414  if constexpr(std::is_same_v<T, e8m0_t>)
415  std::fill(mData.begin(), mData.end(), e8m0_t{1.f});
416  else
417  std::fill(mData.begin(), mData.end(), 0);
418  }
419 
420  template <typename F>
421  void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank)
422  {
424  {
425  f(*this, idx);
426  return;
427  }
428  // else
429  for(size_t i = 0; i < mDesc.get_lengths()[rank]; i++)
430  {
431  idx[rank] = i;
432  ForEach_impl(std::forward<F>(f), idx, rank + 1);
433  }
434  }
435 
436  template <typename F>
437  void ForEach(F&& f)
438  {
439  std::vector<size_t> idx(mDesc.get_num_of_dimension(), 0);
440  ForEach_impl(std::forward<F>(f), idx, size_t(0));
441  }
442 
443  template <typename F>
444  void ForEach_impl(const F&& f, std::vector<size_t>& idx, size_t rank) const
445  {
447  {
448  f(*this, idx);
449  return;
450  }
451  // else
452  for(size_t i = 0; i < mDesc.get_lengths()[rank]; i++)
453  {
454  idx[rank] = i;
455  ForEach_impl(std::forward<const F>(f), idx, rank + 1);
456  }
457  }
458 
459  template <typename F>
460  void ForEach(const F&& f) const
461  {
462  std::vector<size_t> idx(mDesc.get_num_of_dimension(), 0);
463  ForEach_impl(std::forward<const F>(f), idx, size_t(0));
464  }
465 
466  template <typename G>
467  void GenerateTensorValue(G g, std::size_t num_thread = 1)
468  {
469  switch(mDesc.get_num_of_dimension())
470  {
471  case 1: {
472  auto f = [&](auto i) { (*this)(i) = g(i); };
473  make_ParallelTensorFunctor(f, mDesc.get_lengths()[0])(num_thread);
474  break;
475  }
476  case 2: {
477  auto f = [&](auto i0, auto i1) { (*this)(i0, i1) = g(i0, i1); };
479  num_thread);
480  break;
481  }
482  case 3: {
483  auto f = [&](auto i0, auto i1, auto i2) { (*this)(i0, i1, i2) = g(i0, i1, i2); };
485  mDesc.get_lengths()[0],
486  mDesc.get_lengths()[1],
487  mDesc.get_lengths()[2])(num_thread);
488  break;
489  }
490  case 4: {
491  auto f = [&](auto i0, auto i1, auto i2, auto i3) {
492  (*this)(i0, i1, i2, i3) = g(i0, i1, i2, i3);
493  };
495  mDesc.get_lengths()[0],
496  mDesc.get_lengths()[1],
497  mDesc.get_lengths()[2],
498  mDesc.get_lengths()[3])(num_thread);
499  break;
500  }
501  case 5: {
502  auto f = [&](auto i0, auto i1, auto i2, auto i3, auto i4) {
503  (*this)(i0, i1, i2, i3, i4) = g(i0, i1, i2, i3, i4);
504  };
506  mDesc.get_lengths()[0],
507  mDesc.get_lengths()[1],
508  mDesc.get_lengths()[2],
509  mDesc.get_lengths()[3],
510  mDesc.get_lengths()[4])(num_thread);
511  break;
512  }
513  case 6: {
514  auto f = [&](auto i0, auto i1, auto i2, auto i3, auto i4, auto i5) {
515  (*this)(i0, i1, i2, i3, i4, i5) = g(i0, i1, i2, i3, i4, i5);
516  };
518  mDesc.get_lengths()[0],
519  mDesc.get_lengths()[1],
520  mDesc.get_lengths()[2],
521  mDesc.get_lengths()[3],
522  mDesc.get_lengths()[4],
523  mDesc.get_lengths()[5])(num_thread);
524  break;
525  }
526  default: throw std::runtime_error("unspported dimension");
527  }
528  }
529 
530  template <typename... Is>
531  std::size_t GetOffsetFromMultiIndex(Is... is) const
532  {
533  constexpr index_t PackedSize = ck_tile::numeric_traits<remove_cvref_t<T>>::PackedSize;
534  return mDesc.GetOffsetFromMultiIndex(is...) / PackedSize;
535  }
536 
537  template <typename... Is>
538  T& operator()(Is... is)
539  {
540  return mData[GetOffsetFromMultiIndex(is...)];
541  }
542 
543  template <typename... Is>
544  const T& operator()(Is... is) const
545  {
546  return mData[GetOffsetFromMultiIndex(is...)];
547  }
548 
549  T& operator()(const std::vector<std::size_t>& idx)
550  {
551  return mData[GetOffsetFromMultiIndex(idx)];
552  }
553 
554  const T& operator()(const std::vector<std::size_t>& idx) const
555  {
556  return mData[GetOffsetFromMultiIndex(idx)];
557  }
558 
559  HostTensor<T> transpose(std::vector<size_t> axes = {}) const
560  {
561  if(axes.empty())
562  {
563  axes.resize(this->get_num_of_dimension());
564  std::iota(axes.rbegin(), axes.rend(), 0);
565  }
566  if(axes.size() != mDesc.get_num_of_dimension())
567  {
568  throw std::runtime_error(
569  "HostTensor::transpose(): size of axes must match tensor dimension");
570  }
571  std::vector<size_t> tlengths, tstrides;
572  for(const auto& axis : axes)
573  {
574  tlengths.push_back(get_lengths()[axis]);
575  tstrides.push_back(get_strides()[axis]);
576  }
577  HostTensor<T> ret(*this);
578  ret.mDesc = HostTensorDescriptor(tlengths, tstrides);
579  return ret;
580  }
581 
582  HostTensor<T> transpose(std::vector<size_t> axes = {})
583  {
584  return const_cast<HostTensor<T> const*>(this)->transpose(axes);
585  }
586 
587  typename Data::iterator begin() { return mData.begin(); }
588 
589  typename Data::iterator end() { return mData.end(); }
590 
591  typename Data::pointer data() { return mData.data(); }
592 
593  typename Data::const_iterator begin() const { return mData.begin(); }
594 
595  typename Data::const_iterator end() const { return mData.end(); }
596 
597  typename Data::const_pointer data() const { return mData.data(); }
598 
599  typename Data::size_type size() const { return mData.size(); }
600 
601  // return a slice of this tensor
602  // for simplicity we just copy the data and return a new tensor
603  auto slice(std::vector<size_t> s_begin, std::vector<size_t> s_end) const
604  {
605  assert(s_begin.size() == s_end.size());
606  assert(s_begin.size() == get_num_of_dimension());
607 
608  std::vector<size_t> s_len(s_begin.size());
610  s_end.begin(), s_end.end(), s_begin.begin(), s_len.begin(), std::minus<size_t>{});
611  HostTensor<T> sliced_tensor(s_len);
612 
613  sliced_tensor.ForEach([&](auto& self, auto idx) {
614  std::vector<size_t> src_idx(idx.size());
616  idx.begin(), idx.end(), s_begin.begin(), src_idx.begin(), std::plus<size_t>{});
617  self(idx) = operator()(src_idx);
618  });
619 
620  return sliced_tensor;
621  }
622 
623  template <typename U = T>
624  auto AsSpan() const
625  {
626  constexpr std::size_t FromSize = sizeof(T);
627  constexpr std::size_t ToSize = sizeof(U);
628 
629  using Element = std::add_const_t<std::remove_reference_t<U>>;
630  return ck_tile::span<Element>{reinterpret_cast<Element*>(data()),
631  size() * FromSize / ToSize};
632  }
633 
634  template <typename U = T>
635  auto AsSpan()
636  {
637  constexpr std::size_t FromSize = sizeof(T);
638  constexpr std::size_t ToSize = sizeof(U);
639 
640  using Element = std::remove_reference_t<U>;
641  return ck_tile::span<Element>{reinterpret_cast<Element*>(data()),
642  size() * FromSize / ToSize};
643  }
644 
652  std::ostream& print_first_n(std::ostream& os, std::size_t n = 5) const
653  {
654  os << mDesc;
655  os << "[";
656  for(typename Data::size_type idx = 0; idx < std::min(n, mData.size()); ++idx)
657  {
658  if(0 < idx)
659  {
660  os << ", ";
661  }
662  if constexpr(std::is_same_v<T, bf16_t> || std::is_same_v<T, fp16_t> ||
663  std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>)
664  {
665  os << type_convert<float>(mData[idx]) << " #### ";
666  }
667  else if constexpr(std::is_same_v<T, ck_tile::pk_int4_t>)
668  {
669  auto unpacked = pk_int4_t_to_int8x2_t(mData[idx]);
670  os << "pk(" << static_cast<int>(unpacked[0]) << ", "
671  << static_cast<int>(unpacked[1]) << ") #### ";
672  }
673  else if constexpr(std::is_same_v<T, int8_t>)
674  {
675  os << static_cast<int>(mData[idx]);
676  }
677  else
678  {
679  os << mData[idx];
680  }
681  }
682  if(mData.size() > n)
683  {
684  os << ", ...";
685  }
686  os << "]";
687  return os;
688  }
689 
690  friend std::ostream& operator<<(std::ostream& os, const HostTensor<T>& t)
691  {
692  os << t.mDesc;
693  os << "[";
694  for(typename Data::size_type idx = 0; idx < t.mData.size(); ++idx)
695  {
696  if(0 < idx)
697  {
698  os << ", ";
699  }
700  if constexpr(std::is_same_v<T, bf16_t> || std::is_same_v<T, fp16_t> ||
701  std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>)
702  {
703  os << type_convert<float>(t.mData[idx]) << " #### ";
704  }
705  else if constexpr(std::is_same_v<T, ck_tile::pk_int4_t>)
706  {
707  auto unpacked = pk_int4_t_to_int8x2_t(t.mData[idx]);
708  os << "pk(" << static_cast<int>(unpacked[0]) << ", "
709  << static_cast<int>(unpacked[1]) << ") #### ";
710  }
711  else
712  {
713  os << t.mData[idx];
714  }
715  }
716  os << "]";
717  return os;
718  }
719 
720  // read data from a file, as dtype
721  // the file could dumped from torch as (targeting tensor is t here)
722  // numpy.savetxt("f.txt", t.view(-1).numpy())
723  // numpy.savetxt("f.txt", t.cpu().view(-1).numpy()) # from cuda to cpu to save
724  // numpy.savetxt("f.txt", t.cpu().view(-1).numpy(), fmt="%d") # save as int
725  // will output f.txt, each line is a value
726  // dtype=float or int, internally will cast to real type
727  void loadtxt(std::string file_name, std::string dtype = "float")
728  {
729  std::ifstream file(file_name);
730 
731  if(file.is_open())
732  {
733  std::string line;
734 
735  index_t cnt = 0;
736  while(std::getline(file, line))
737  {
738  if(cnt >= static_cast<index_t>(mData.size()))
739  {
740  throw std::runtime_error(std::string("data read from file:") + file_name +
741  " is too big");
742  }
743 
744  if(dtype == "float")
745  {
746  mData[cnt] = type_convert<T>(std::stof(line));
747  }
748  else if(dtype == "int" || dtype == "int32")
749  {
750  mData[cnt] = type_convert<T>(std::stoi(line));
751  }
752  cnt++;
753  }
754  file.close();
755  if(cnt < static_cast<index_t>(mData.size()))
756  {
757  std::cerr << "Warning! reading from file:" << file_name
758  << ", does not match the size of this tensor" << std::endl;
759  }
760  }
761  else
762  {
763  // Print an error message to the standard error
764  // stream if the file cannot be opened.
765  throw std::runtime_error(std::string("unable to open file:") + file_name);
766  }
767  }
768 
769  // can save to a txt file and read from torch as:
770  // torch.from_numpy(np.loadtxt('f.txt', dtype=np.int32/np.float32...)).view([...]).contiguous()
771  void savetxt(std::string file_name, std::string dtype = "float")
772  {
773  std::ofstream file(file_name);
774 
775  if(file.is_open())
776  {
777  for(auto& itm : mData)
778  {
779  if(dtype == "float")
780  file << type_convert<float>(itm) << std::endl;
781  else if(dtype == "int")
782  file << type_convert<int>(itm) << std::endl;
783  else if(dtype == "int8_t")
784  file << static_cast<int>(type_convert<ck_tile::int8_t>(itm)) << std::endl;
785  else
786  // TODO: we didn't implement operator<< for all custom
787  // data types, here fall back to float in case compile error
788  file << type_convert<float>(itm) << std::endl;
789  }
790  file.close();
791  }
792  else
793  {
794  // Print an error message to the standard error
795  // stream if the file cannot be opened.
796  throw std::runtime_error(std::string("unable to open file:") + file_name);
797  }
798  }
799 
802 };
803 
822 template <bool is_row_major>
823 auto host_tensor_descriptor(std::size_t row,
824  std::size_t col,
825  std::size_t stride,
827 {
828  using namespace ck_tile::literals;
829 
830  if constexpr(is_row_major)
831  {
832  return HostTensorDescriptor({row, col}, {stride, 1_uz});
833  }
834  else
835  {
836  return HostTensorDescriptor({row, col}, {1_uz, stride});
837  }
838 }
839 
840 template <bool is_row_major>
841 auto get_default_stride(std::size_t row,
842  std::size_t col,
843  std::size_t stride,
845 {
846  if(stride == 0)
847  {
848  if constexpr(is_row_major)
849  {
850  return col;
851  }
852  else
853  {
854  return row;
855  }
856  }
857  else
858  return stride;
859 }
860 } // namespace ck_tile
Definition: span.hpp:18
#define CK_TILE_HOST
Definition: config.hpp:40
__host__ constexpr __device__ auto rank([[maybe_unused]] const Layout< Shape, UnrolledDescriptorType > &layout)
Get layout rank (num elements in shape).
Definition: layout_utils.hpp:310
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
auto fill(OutputRange &&range, const T &init) -> std::void_t< decltype(std::fill(std::begin(std::forward< OutputRange >(range)), std::end(std::forward< OutputRange >(range)), init))>
Definition: algorithm.hpp:25
auto transform(InputRange &&range, OutputIterator iter, UnaryOperation unary_op) -> decltype(std::transform(std::begin(range), std::end(range), iter, unary_op))
Definition: algorithm.hpp:36
Definition: literals.hpp:9
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
CK_TILE_HOST auto call_f_unpack_args(F f, T args)
Definition: host_tensor.hpp:67
CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old(const HostTensorDescriptor &a, const New2Old &new2old)
Definition: host_tensor.hpp:259
CK_TILE_HOST auto call_f_unpack_args_impl(F f, T args, std::index_sequence< Is... >)
Definition: host_tensor.hpp:61
auto host_tensor_descriptor(std::size_t row, std::size_t col, std::size_t stride, bool_constant< is_row_major >)
Creates a host tensor descriptor with specified dimensions and layout.
Definition: host_tensor.hpp:823
CK_TILE_HOST std::ostream & LogRangeAsType(std::ostream &os, Range &&range, std::string delim, int precision=std::cout.precision(), int width=0)
Definition: host_tensor.hpp:42
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST std::ostream & LogRange(std::ostream &os, Range &&range, std::string delim, int precision=std::cout.precision(), int width=0)
Definition: host_tensor.hpp:23
CK_TILE_HOST auto construct_f_unpack_args(F, T args)
Definition: host_tensor.hpp:81
CK_TILE_HOST_DEVICE int8x2_t pk_int4_t_to_int8x2_t(const pk_int4_t &x)
Definition: pk_int4.hpp:169
CK_TILE_HOST auto construct_f_unpack_args_impl(T args, std::index_sequence< Is... >)
Definition: host_tensor.hpp:75
auto get_default_stride(std::size_t row, std::size_t col, std::size_t stride, bool_constant< is_row_major >)
Definition: host_tensor.hpp:841
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
__device__ void inner_product(const TA &a, const TB &b, TC &c)
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
const GenericPointer< typename T::ValueType > & pointer
Definition: pointer.h:1249
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1249
Definition: host_tensor.hpp:101
Descriptor for tensors in host memory.
Definition: host_tensor.hpp:102
std::size_t get_stride(std::size_t dim) const
Definition: host_tensor.hpp:200
std::size_t GetOffsetFromMultiIndex(Is... is) const
Calculates the linear offset from multi-dimensional indices.
Definition: host_tensor.hpp:217
std::size_t get_element_size() const
Calculates the total number of elements in the tensor.
Definition: host_tensor.hpp:165
void CalculateStrides()
Definition: host_tensor.hpp:105
std::size_t get_num_of_dimension() const
Definition: host_tensor.hpp:153
HostTensorDescriptor(const std::initializer_list< X > &lens, const std::initializer_list< Y > &strides)
Definition: host_tensor.hpp:137
std::size_t get_element_space_size() const
Calculates the total element space required for the tensor in memory.
Definition: host_tensor.hpp:183
const std::vector< std::size_t > & get_strides() const
Definition: host_tensor.hpp:202
const std::vector< std::size_t > & get_lengths() const
Definition: host_tensor.hpp:198
std::size_t get_length(std::size_t dim) const
Definition: host_tensor.hpp:196
HostTensorDescriptor(const Lengths &lens, const Strides &strides)
Definition: host_tensor.hpp:148
HostTensorDescriptor(const std::initializer_list< X > &lens)
Definition: host_tensor.hpp:120
std::size_t GetOffsetFromMultiIndex(const std::vector< std::size_t > &iss) const
Calculates the linear memory offset from a multi-dimensional index.
Definition: host_tensor.hpp:233
HostTensorDescriptor(const Lengths &lens)
Definition: host_tensor.hpp:128
friend std::ostream & operator<<(std::ostream &os, const HostTensorDescriptor &desc)
Definition: host_tensor.hpp:238
Definition: host_tensor.hpp:336
void ForEach(F &&f)
Definition: host_tensor.hpp:437
std::size_t get_stride(std::size_t dim) const
Definition: host_tensor.hpp:392
void ForEach(const F &&f) const
Definition: host_tensor.hpp:460
HostTensor(HostTensor &&)=default
Data::size_type size() const
Definition: host_tensor.hpp:599
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
HostTensor(std::initializer_list< X > lens, std::initializer_list< Y > strides)
Definition: host_tensor.hpp:346
HostTensor & operator=(HostTensor &&)=default
friend std::ostream & operator<<(std::ostream &os, const HostTensor< T > &t)
Definition: host_tensor.hpp:690
HostTensor(std::initializer_list< X > lens)
Definition: host_tensor.hpp:341
HostTensor & operator=(const HostTensor &)=default
std::size_t get_element_space_size_in_bytes() const
Definition: host_tensor.hpp:406
decltype(auto) get_strides() const
Definition: host_tensor.hpp:394
HostTensor(const HostTensor &)=default
Data::iterator end()
Definition: host_tensor.hpp:589
void GenerateTensorValue(G g, std::size_t num_thread=1)
Definition: host_tensor.hpp:467
void SetZero()
Definition: host_tensor.hpp:412
Descriptor mDesc
Definition: host_tensor.hpp:800
const T & operator()(Is... is) const
Definition: host_tensor.hpp:544
HostTensor(const Lengths &lens)
Definition: host_tensor.hpp:352
std::size_t GetOffsetFromMultiIndex(Is... is) const
Definition: host_tensor.hpp:531
Data::pointer data()
Definition: host_tensor.hpp:591
T & operator()(Is... is)
Definition: host_tensor.hpp:538
HostTensor< OutT > CopyAsType() const
Definition: host_tensor.hpp:365
auto AsSpan() const
Definition: host_tensor.hpp:624
auto slice(std::vector< size_t > s_begin, std::vector< size_t > s_end) const
Definition: host_tensor.hpp:603
std::vector< T > Data
Definition: host_tensor.hpp:338
auto AsSpan()
Definition: host_tensor.hpp:635
Data::const_iterator begin() const
Definition: host_tensor.hpp:593
std::size_t get_num_of_dimension() const
Definition: host_tensor.hpp:396
std::size_t get_element_space_size() const
Definition: host_tensor.hpp:400
HostTensor(const Lengths &lens, const Strides &strides)
Definition: host_tensor.hpp:357
void loadtxt(std::string file_name, std::string dtype="float")
Definition: host_tensor.hpp:727
Data::const_pointer data() const
Definition: host_tensor.hpp:597
std::ostream & print_first_n(std::ostream &os, std::size_t n=5) const
Print only the first N elements of the tensor.
Definition: host_tensor.hpp:652
void ForEach_impl(const F &&f, std::vector< size_t > &idx, size_t rank) const
Definition: host_tensor.hpp:444
HostTensor(const Descriptor &desc)
Definition: host_tensor.hpp:362
HostTensor< T > transpose(std::vector< size_t > axes={})
Definition: host_tensor.hpp:582
Data::iterator begin()
Definition: host_tensor.hpp:587
const T & operator()(const std::vector< std::size_t > &idx) const
Definition: host_tensor.hpp:554
void savetxt(std::string file_name, std::string dtype="float")
Definition: host_tensor.hpp:771
HostTensor(const HostTensor< FromT > &other)
Definition: host_tensor.hpp:384
HostTensor< T > transpose(std::vector< size_t > axes={}) const
Definition: host_tensor.hpp:559
std::size_t get_length(std::size_t dim) const
Definition: host_tensor.hpp:388
std::size_t get_element_size() const
Definition: host_tensor.hpp:398
T & operator()(const std::vector< std::size_t > &idx)
Definition: host_tensor.hpp:549
void ForEach_impl(F &&f, std::vector< size_t > &idx, size_t rank)
Definition: host_tensor.hpp:421
Data::const_iterator end() const
Definition: host_tensor.hpp:595
Data mData
Definition: host_tensor.hpp:801
Definition: host_tensor.hpp:276
void operator()(std::size_t num_thread=1) const
Definition: host_tensor.hpp:306
ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:283
std::size_t mN1d
Definition: host_tensor.hpp:281
std::array< std::size_t, NDIM > mLens
Definition: host_tensor.hpp:279
std::array< std::size_t, NDIM > mStrides
Definition: host_tensor.hpp:280
static constexpr std::size_t NDIM
Definition: host_tensor.hpp:278
F mF
Definition: host_tensor.hpp:277
std::array< std::size_t, NDIM > GetNdIndices(std::size_t i) const
Definition: host_tensor.hpp:293
Definition: integral_constant.hpp:13
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:27
Definition: joinable_thread.hpp:12
Definition: numeric.hpp:81