|
| 1 | +#include "loops_utils.h" |
| 2 | +#include "loops.h" |
| 3 | + |
| 4 | +#include <hwy/highway.h> |
| 5 | +#include "simd/simd.hpp" |
| 6 | + |
| 7 | +namespace { |
| 8 | +using namespace np::simd; |
| 9 | + |
| 10 | +template <typename T> struct OpCabs { |
| 11 | +#if NPY_HWY |
| 12 | + template <typename V, typename = std::enable_if_t<kSupportLane<T>>> |
| 13 | + HWY_INLINE HWY_ATTR auto operator()(const V& a, const V& b) const { |
| 14 | + V inf, nan; |
| 15 | + if constexpr (std::is_same_v<T, float>) { |
| 16 | + inf = Set<T>(NPY_INFINITYF); |
| 17 | + nan = Set<T>(NPY_NANF); |
| 18 | + } |
| 19 | + else { |
| 20 | + inf = Set<T>(NPY_INFINITY); |
| 21 | + nan = Set<T>(NPY_NAN); |
| 22 | + } |
| 23 | + auto re = hn::Abs(a), im = hn::Abs(b); |
| 24 | + /* |
| 25 | + * If real or imag = INF, then convert it to inf + j*inf |
| 26 | + * Handles: inf + j*nan, nan + j*inf |
| 27 | + */ |
| 28 | + auto re_infmask = hn::IsInf(re), im_infmask = hn::IsInf(im); |
| 29 | + im = hn::IfThenElse(re_infmask, inf, im); |
| 30 | + re = hn::IfThenElse(im_infmask, inf, re); |
| 31 | + /* |
| 32 | + * If real or imag = NAN, then convert it to nan + j*nan |
| 33 | + * Handles: x + j*nan, nan + j*x |
| 34 | + */ |
| 35 | + auto re_nanmask = hn::IsNaN(re), im_nanmask = hn::IsNaN(im); |
| 36 | + im = hn::IfThenElse(re_nanmask, nan, im); |
| 37 | + re = hn::IfThenElse(im_nanmask, nan, re); |
| 38 | + |
| 39 | + auto larger = hn::Max(re, im), smaller = hn::Min(im, re); |
| 40 | + /* |
| 41 | + * Calculate div_mask to prevent 0./0. and inf/inf operations in div |
| 42 | + */ |
| 43 | + auto zeromask = hn::Eq(larger, Set<T>(static_cast<T>(0))); |
| 44 | + auto infmask = hn::IsInf(smaller); |
| 45 | + auto div_mask = hn::ExclusiveNeither(zeromask, infmask); |
| 46 | + |
| 47 | + auto ratio = hn::MaskedDiv(div_mask, smaller, larger); |
| 48 | + auto hypot = hn::Sqrt(hn::MulAdd(ratio, ratio, Set<T>(static_cast<T>(1)))); |
| 49 | + return hn::Mul(hypot, larger); |
| 50 | + } |
| 51 | +#endif |
| 52 | + |
| 53 | + NPY_INLINE T operator()(T a, T b) const { |
| 54 | + if constexpr (std::is_same_v<T, float>) { |
| 55 | + return npy_hypotf(a, b); |
| 56 | + } else { |
| 57 | + return npy_hypot(a, b); |
| 58 | + } |
| 59 | + } |
| 60 | +}; |
| 61 | + |
| 62 | +#if NPY_HWY |
| 63 | +template <typename T> |
| 64 | +HWY_INLINE HWY_ATTR auto LoadWithStride(const T* src, npy_intp ssrc, size_t n = Lanes<T>(), T val = 0) { |
| 65 | + HWY_LANES_CONSTEXPR size_t lanes = Lanes<T>(); |
| 66 | + std::vector<T> temp(lanes, val); |
| 67 | + for (size_t ii = 0; ii < lanes && ii < n; ++ii) { |
| 68 | + temp[ii] = src[ii * ssrc]; |
| 69 | + } |
| 70 | + return LoadU(temp.data()); |
| 71 | +} |
| 72 | + |
| 73 | +template <typename T> |
| 74 | +HWY_INLINE HWY_ATTR void StoreWithStride(Vec<T> vec, T* dst, npy_intp sdst, size_t n = Lanes<T>()) { |
| 75 | + HWY_LANES_CONSTEXPR size_t lanes = Lanes<T>(); |
| 76 | + std::vector<T> temp(lanes); |
| 77 | + StoreU(vec, temp.data()); |
| 78 | + for (size_t ii = 0; ii < lanes && ii < n; ++ii) { |
| 79 | + dst[ii * sdst] = temp[ii]; |
| 80 | + } |
| 81 | +} |
| 82 | +#endif // NPY_HWY |
| 83 | + |
| 84 | +template <typename T> |
| 85 | +HWY_INLINE HWY_ATTR void |
| 86 | +unary_complex(char **args, npy_intp const *dimensions, npy_intp const *steps) |
| 87 | +{ |
| 88 | + const OpCabs<T> op_func; |
| 89 | + const char *src = args[0]; char *dst = args[1]; |
| 90 | + const npy_intp src_step = steps[0]; |
| 91 | + const npy_intp dst_step = steps[1]; |
| 92 | + npy_intp len = dimensions[0]; |
| 93 | + |
| 94 | +#if NPY_HWY |
| 95 | + if constexpr (kSupportLane<T>) { |
| 96 | + if (!is_mem_overlap(src, src_step, dst, dst_step, len) && alignof(T) == sizeof(T) && |
| 97 | + src_step % sizeof(T) == 0 && dst_step % sizeof(T) == 0) { |
| 98 | + const int lsize = sizeof(T); |
| 99 | + const npy_intp ssrc = src_step / lsize; |
| 100 | + const npy_intp sdst = dst_step / lsize; |
| 101 | + |
| 102 | + const int vstep = Lanes<T>(); |
| 103 | + const int wstep = vstep * 2; |
| 104 | + |
| 105 | + const T* src_T = reinterpret_cast<const T*>(src); |
| 106 | + T* dst_T = reinterpret_cast<T*>(dst); |
| 107 | + |
| 108 | + if (ssrc == 2 && sdst == 1) { |
| 109 | + for (; len >= vstep; len -= vstep, src_T += wstep, dst_T += vstep) { |
| 110 | + Vec<T> re, im; |
| 111 | + hn::LoadInterleaved2(_Tag<T>(), src_T, re, im); |
| 112 | + auto r = op_func(re, im); |
| 113 | + StoreU(r, dst_T); |
| 114 | + } |
| 115 | + } |
| 116 | + else { |
| 117 | + for (; len >= vstep; len -= vstep, src_T += ssrc*vstep, dst_T += sdst*vstep) { |
| 118 | + auto re = LoadWithStride(src_T, ssrc); |
| 119 | + auto im = LoadWithStride(src_T + 1, ssrc); |
| 120 | + auto r = op_func(re, im); |
| 121 | + StoreWithStride(r, dst_T, sdst); |
| 122 | + } |
| 123 | + } |
| 124 | + if (len > 0) { |
| 125 | + auto re = LoadWithStride(src_T, ssrc, len); |
| 126 | + auto im = LoadWithStride(src_T + 1, ssrc, len); |
| 127 | + auto r = op_func(re, im); |
| 128 | + StoreWithStride(r, dst_T, sdst, len); |
| 129 | + } |
| 130 | + // clear the float status flags |
| 131 | + npy_clear_floatstatus_barrier((char*)&len); |
| 132 | + return; |
| 133 | + } |
| 134 | + } |
| 135 | +#endif |
| 136 | + |
| 137 | + // fallback to scalar implementation |
| 138 | + for (; len > 0; --len, src += src_step, dst += dst_step) { |
| 139 | + const T src0 = *reinterpret_cast<const T*>(src); |
| 140 | + const T src1 = *(reinterpret_cast<const T*>(src) + 1); |
| 141 | + *reinterpret_cast<T*>(dst) = op_func(src0, src1); |
| 142 | + } |
| 143 | +} |
| 144 | + |
| 145 | +} // anonymous namespace |
| 146 | + |
| 147 | +/******************************************************************************* |
| 148 | + ** Defining ufunc inner functions |
| 149 | + *******************************************************************************/ |
| 150 | +NPY_NO_EXPORT void NPY_CPU_DISPATCH_CURFX(CFLOAT_absolute) |
| 151 | +(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func)) |
| 152 | +{ |
| 153 | + unary_complex<npy_float>(args, dimensions, steps); |
| 154 | +} |
| 155 | +NPY_NO_EXPORT void NPY_CPU_DISPATCH_CURFX(CDOUBLE_absolute) |
| 156 | +(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func)) |
| 157 | +{ |
| 158 | + unary_complex<npy_double>(args, dimensions, steps); |
| 159 | +} |
0 commit comments