|
| 1 | +#include"loops_utils.h" |
| 2 | +#include"loops.h" |
| 3 | + |
| 4 | +#include<hwy/highway.h> |
| 5 | +#include"simd/simd.hpp" |
| 6 | + |
| 7 | +namespace { |
| 8 | +usingnamespacenp::simd; |
| 9 | + |
| 10 | +template<typename T>structOpCabs { |
| 11 | +#if NPY_HWY |
| 12 | +template<typename V,typename = std::enable_if_t<kSupportLane<T>>> |
| 13 | + HWY_INLINE HWY_ATTRautooperator()(const V& a,const V& b)const { |
| 14 | + V inf,nan; |
| 15 | +ifconstexpr (std::is_same_v<T,float>) { |
| 16 | + inf = Set<T>(NPY_INFINITYF); |
| 17 | +nan = Set<T>(NPY_NANF); |
| 18 | + } |
| 19 | +else { |
| 20 | + inf = Set<T>(NPY_INFINITY); |
| 21 | +nan = Set<T>(NPY_NAN); |
| 22 | + } |
| 23 | +auto re =hn::Abs(a), im =hn::Abs(b); |
| 24 | +/* |
| 25 | + * If real or imag = INF, then convert it to inf + j*inf |
| 26 | + * Handles: inf + j*nan, nan + j*inf |
| 27 | +*/ |
| 28 | +auto re_infmask =hn::IsInf(re), im_infmask =hn::IsInf(im); |
| 29 | + im =hn::IfThenElse(re_infmask, inf, im); |
| 30 | + re =hn::IfThenElse(im_infmask, inf, re); |
| 31 | +/* |
| 32 | + * If real or imag = NAN, then convert it to nan + j*nan |
| 33 | + * Handles: x + j*nan, nan + j*x |
| 34 | +*/ |
| 35 | +auto re_nanmask =hn::IsNaN(re), im_nanmask =hn::IsNaN(im); |
| 36 | + im =hn::IfThenElse(re_nanmask,nan, im); |
| 37 | + re =hn::IfThenElse(im_nanmask,nan, re); |
| 38 | + |
| 39 | +auto larger =hn::Max(re, im), smaller =hn::Min(im, re); |
| 40 | +/* |
| 41 | + * Calculate div_mask to prevent 0./0. and inf/inf operations in div |
| 42 | +*/ |
| 43 | +auto zeromask =hn::Eq(larger, Set<T>(static_cast<T>(0))); |
| 44 | +auto infmask =hn::IsInf(smaller); |
| 45 | +auto div_mask =hn::ExclusiveNeither(zeromask, infmask); |
| 46 | + |
| 47 | +auto ratio =hn::MaskedDiv(div_mask, smaller, larger); |
| 48 | +autohypot =hn::Sqrt(hn::MulAdd(ratio, ratio, Set<T>(static_cast<T>(1)))); |
| 49 | +returnhn::Mul(hypot, larger); |
| 50 | + } |
| 51 | +#endif |
| 52 | + |
| 53 | + NPY_INLINE Toperator()(T a, T b)const { |
| 54 | +ifconstexpr (std::is_same_v<T,float>) { |
| 55 | +returnnpy_hypotf(a, b); |
| 56 | + }else { |
| 57 | +returnnpy_hypot(a, b); |
| 58 | + } |
| 59 | + } |
| 60 | +}; |
| 61 | + |
| 62 | +#if NPY_HWY |
| 63 | +template<typename T> |
| 64 | +HWY_INLINE HWY_ATTRautoLoadWithStride(const T* src, npy_intp ssrc,size_t n = Lanes<T>(), T val = 0) { |
| 65 | + HWY_LANES_CONSTEXPRsize_t lanes = Lanes<T>(); |
| 66 | + std::vector<T>temp(lanes, val); |
| 67 | +for (size_t ii =0; ii < lanes && ii < n; ++ii) { |
| 68 | + temp[ii] = src[ii * ssrc]; |
| 69 | + } |
| 70 | +returnLoadU(temp.data()); |
| 71 | +} |
| 72 | + |
| 73 | +template<typename T> |
| 74 | +HWY_INLINE HWY_ATTRvoidStoreWithStride(Vec<T> vec, T* dst, npy_intp sdst,size_t n = Lanes<T>()) { |
| 75 | + HWY_LANES_CONSTEXPRsize_t lanes = Lanes<T>(); |
| 76 | + std::vector<T>temp(lanes); |
| 77 | +StoreU(vec, temp.data()); |
| 78 | +for (size_t ii =0; ii < lanes && ii < n; ++ii) { |
| 79 | + dst[ii * sdst] = temp[ii]; |
| 80 | + } |
| 81 | +} |
| 82 | +#endif// NPY_HWY |
| 83 | + |
| 84 | +template<typename T> |
| 85 | +HWY_INLINE HWY_ATTRvoid |
| 86 | +unary_complex(char **args, npy_intpconst *dimensions, npy_intpconst *steps) |
| 87 | +{ |
| 88 | +const OpCabs<T> op_func; |
| 89 | +constchar *src = args[0];char *dst = args[1]; |
| 90 | +const npy_intp src_step = steps[0]; |
| 91 | +const npy_intp dst_step = steps[1]; |
| 92 | + npy_intp len = dimensions[0]; |
| 93 | + |
| 94 | +#if NPY_HWY |
| 95 | +ifconstexpr (kSupportLane<T>) { |
| 96 | +if (!is_mem_overlap(src, src_step, dst, dst_step, len) &&alignof(T) ==sizeof(T) && |
| 97 | + src_step %sizeof(T) ==0 && dst_step %sizeof(T) ==0) { |
| 98 | +constint lsize =sizeof(T); |
| 99 | +const npy_intp ssrc = src_step / lsize; |
| 100 | +const npy_intp sdst = dst_step / lsize; |
| 101 | + |
| 102 | +constint vstep = Lanes<T>(); |
| 103 | +constint wstep = vstep *2; |
| 104 | + |
| 105 | +const T* src_T =reinterpret_cast<const T*>(src); |
| 106 | + T* dst_T =reinterpret_cast<T*>(dst); |
| 107 | + |
| 108 | +if (ssrc ==2 && sdst ==1) { |
| 109 | +for (; len >= vstep; len -= vstep, src_T += wstep, dst_T += vstep) { |
| 110 | + Vec<T> re, im; |
| 111 | +hn::LoadInterleaved2(_Tag<T>(), src_T, re, im); |
| 112 | +auto r =op_func(re, im); |
| 113 | +StoreU(r, dst_T); |
| 114 | + } |
| 115 | + } |
| 116 | +else { |
| 117 | +for (; len >= vstep; len -= vstep, src_T += ssrc*vstep, dst_T += sdst*vstep) { |
| 118 | +auto re =LoadWithStride(src_T, ssrc); |
| 119 | +auto im =LoadWithStride(src_T +1, ssrc); |
| 120 | +auto r =op_func(re, im); |
| 121 | +StoreWithStride(r, dst_T, sdst); |
| 122 | + } |
| 123 | + } |
| 124 | +if (len >0) { |
| 125 | +auto re =LoadWithStride(src_T, ssrc, len); |
| 126 | +auto im =LoadWithStride(src_T +1, ssrc, len); |
| 127 | +auto r =op_func(re, im); |
| 128 | +StoreWithStride(r, dst_T, sdst, len); |
| 129 | + } |
| 130 | +// clear the float status flags |
| 131 | +npy_clear_floatstatus_barrier((char*)&len); |
| 132 | +return; |
| 133 | + } |
| 134 | + } |
| 135 | +#endif |
| 136 | + |
| 137 | +// fallback to scalar implementation |
| 138 | +for (; len >0; --len, src += src_step, dst += dst_step) { |
| 139 | +const T src0 = *reinterpret_cast<const T*>(src); |
| 140 | +const T src1 = *(reinterpret_cast<const T*>(src) +1); |
| 141 | + *reinterpret_cast<T*>(dst) =op_func(src0, src1); |
| 142 | + } |
| 143 | +} |
| 144 | + |
| 145 | +}// anonymous namespace |
| 146 | + |
| 147 | +/******************************************************************************* |
| 148 | + ** Defining ufunc inner functions |
| 149 | + *******************************************************************************/ |
| 150 | +NPY_NO_EXPORTvoidNPY_CPU_DISPATCH_CURFX(CFLOAT_absolute) |
| 151 | +(char **args, npy_intpconst *dimensions, npy_intpconst *steps,void *NPY_UNUSED(func)) |
| 152 | +{ |
| 153 | + unary_complex<npy_float>(args, dimensions, steps); |
| 154 | +} |
| 155 | +NPY_NO_EXPORTvoidNPY_CPU_DISPATCH_CURFX(CDOUBLE_absolute) |
| 156 | +(char **args, npy_intpconst *dimensions, npy_intpconst *steps,void *NPY_UNUSED(func)) |
| 157 | +{ |
| 158 | + unary_complex<npy_double>(args, dimensions, steps); |
| 159 | +} |