-
-
Notifications
You must be signed in to change notification settings - Fork 10.9k
ENH: Improve performance of np.linalg._linalg._commonType #28686
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eendebakpt - I'm wondering if this isn't just adding complexity for something where it might pay to just look better at what is being done. E.g., I'm rather confused why one cannot use promote_types
and cut this short a bit. A quick test on det
shows that the following passes all tests,
diff --git a/numpy/linalg/_linalg.py b/numpy/linalg/_linalg.py
index e181e1a5d8..67e36debb8 100644
--- a/numpy/linalg/_linalg.py
+++ b/numpy/linalg/_linalg.py
@@ -31,7 +31,7 @@
reciprocal, overrides, diagonal as _core_diagonal, trace as _core_trace,
cross as _core_cross, outer as _core_outer, tensordot as _core_tensordot,
matmul as _core_matmul, matrix_transpose as _core_matrix_transpose,
- transpose as _core_transpose, vecdot as _core_vecdot,
+ promote_types, transpose as _core_transpose, vecdot as _core_vecdot,
)
from numpy._globals import _NoValue
from numpy.lib._twodim_base_impl import triu, eye
@@ -2367,10 +2367,9 @@ def det(a):
"""
a = asarray(a)
_assert_stacked_square(a)
- t, result_t = _commonType(a)
- signature = 'D->D' if isComplexType(t) else 'd->d'
- r = _umath_linalg.det(a, signature=signature)
- r = r.astype(result_t, copy=False)
+ r = _umath_linalg.det(a, dtype=promote_types(a.dtype, double))
+ if r.dtype != a.dtype:
+ r = r.astype(promote_types(a.dtype, single), copy=False)
return r
With that, the test from your script on a 2x2 matrix
x22 = np.arange(4.).reshape( (2,2) ) + np.eye(2)
%timeit np.linalg.det(x22)
7.09 -> 4.48 us
The only possible downside is that this does not raise an error on f2
input -- but why should it anyway?
p.s. For det
at least, there is no need for _assert_stacked_square
either - the gufunc
will already check that there are at least 2 dimensions and that the last two are equal.
How about something like this: _DTYPE_RANK = dict(zip(map(dtype, "fdFD"), range(4)))
max_rank = -1
for dtype in dtypes:
if dtype.num < 11: # <: integer | bool
continue
if (rank := _DTYPE_RANK.get(dtype)) is None:
raise TypeError(...)
if rank == 3: # no need to go on
return cdouble, cdouble
if rank > max_rank:
max_rank = rank
if max_rank > 1:
return cdouble, (csingle, cdouble)[max_rank - 2]
else:
return double, (single, double)[max_rank] I didn't test it, but I expect this to be quite a bit faster (and it might even be correct, too). Anyway, even if not correct, I'm sure you get the idea. |
Ideally, we don't rely on implementation details like type numbers... Also, no real reason to exclude user dtypes that know how to convert to double, etc. Using |
The The main performance gain is from the |
@mhvk To avoid the copy on the scalar we can also check in main...eendebakpt:numpy:astype The advantage over the At this moment I have no strong preference for either options, so any arguments in either direction are welcome. |
@eendebakpt - I think your patch makes sense in principle, but is perhaps a bit orthogonal to the goals here? At least, I wrote the |
@mhvk Your patch looks good, I might end up refactoring this PR in that way. It would be nice to also refactor the other methods calling |
True, but |
We improve performance of
_commonType
which benefitsnp.linalg.det
and several other methods for small size arrays.isComplexType
so that the value can be reused by the calling methods.dtype
instead of the arrays we can use a cache.Benchmark:
Test script