4 #ifndef CK_THREADWISE_GEMM_DLOPS_V3_HPP 
    5 #define CK_THREADWISE_GEMM_DLOPS_V3_HPP 
   18 template <
typename FloatA,
 
   21           typename AThreadDesc_E1_K_E2,
 
   22           typename BThreadDesc_E1_N_Ho_Wo_E2,
 
   23           typename CThreadDesc_K_N_Ho_Wo,
 
   24           typename enable_if<AThreadDesc_E1_K_E2::IsKnownAtCompileTime() &&
 
   25                                  BThreadDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() &&
 
   26                                  CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(),
 
   31     template <
typename ABuffer,
 
   37     __device__ 
static void Run(
const ABuffer& a_buf,
 
   45         static_assert(AThreadDesc_E1_K_E2::IsKnownAtCompileTime() &&
 
   46                           BThreadDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() &&
 
   47                           CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(),
 
   48                       "wrong! Desc should be known at compile-time");
 
   53                       "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
 
   59             "wrong! inconsistent type");
 
   66         constexpr 
auto E1 = AThreadDesc_E1_K_E2{}.GetLength(I0);
 
   67         constexpr 
auto K  = AThreadDesc_E1_K_E2{}.GetLength(I1);
 
   68         constexpr 
auto E2 = AThreadDesc_E1_K_E2{}.GetLength(I2);
 
   70         constexpr 
auto Ho = BThreadDesc_E1_N_Ho_Wo_E2{}.GetLength(I2);
 
   71         constexpr 
auto Wo = BThreadDesc_E1_N_Ho_Wo_E2{}.GetLength(I3);
 
   77         if constexpr((Ho % 2 == 0) && (Wo % 2 == 0))
 
   79             constexpr 
auto SubHW = 2;
 
   86                                 constexpr 
index_t a_offset = AThreadDesc_E1_K_E2{}.CalculateOffset(
 
   90                                     BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
 
   94                                     BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
 
   95                                         b_origin_idx + 
make_tuple(e1, 0, h, w + 1, e2));
 
   98                                     BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
 
   99                                         b_origin_idx + 
make_tuple(e1, 0, h + 1, w, e2));
 
  102                                     BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
 
  103                                         b_origin_idx + 
make_tuple(e1, 0, h + 1, w + 1, e2));
 
  106                                     CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx +
 
  110                                     CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
 
  114                                     CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
 
  118                                     CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
 
  119                                         c_origin_idx + 
make_tuple(k, 0, h + 1, w + 1));
 
  144                                 constexpr 
index_t a_offset = AThreadDesc_E1_K_E2{}.CalculateOffset(
 
  148                                     BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
 
  152                                     CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx +
 
__device__ void amd_assembly_outer_product_1x4(float a, float b0, float b1, float b2, float b3, float &c0, float &c1, float &c2, float &c3)
Definition: amd_inline_asm.hpp:106
 
__host__ constexpr __device__ auto to_multi_index(const T &x)
Definition: array_multi_index.hpp:28
 
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
 
int32_t index_t
Definition: ck.hpp:299
 
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1697
 
Definition: threadwise_gemm_dlops_v3.hpp:29
 
static __device__ void Run(const ABuffer &a_buf, AOriginIdx, const BBuffer &b_buf, BOriginIdx, CBuffer &c_buf, COriginIdx)
Definition: threadwise_gemm_dlops_v3.hpp:37
 
Definition: integral_constant.hpp:20
 
Definition: is_known_at_compile_time.hpp:14
 
Definition: functional2.hpp:33