-
-
Notifications
You must be signed in to change notification settings - Fork 10.9k
ENH: Check for floating point exceptions in dot #28442
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Currently dot does not check for floating point exceptions, while similar functions (such as multiply and matmul) do. Add these checks for consistency. Note that certain types of floating point exceptions are still not caught: - Multiplications of 0 with nan or inf are ignored (see numpy#27902) - Integer overflows are ignored. These aren't true floating point exceptions, but the behavior is inconsistent with scalar multiply, i.e. ``` >>> np.int16(32000) * np.int16(3) <python-input-7>:1: RuntimeWarning: overflow encountered in scalar multiply np.int16(32000) * np.int16(3) np.int16(30464) >>> np.dot(np.int16(32000), np.int16(3)) np.int16(30464) ``` Fixes numpy#14925
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for looking into it, seems like a good small improvement. Could note it as a very brief "change" in the relase note, although I don't think it mattre much.
} | ||
|
||
return 0; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use PyUFunc_GiveFloatingpointErrors
, you can find the rest of the pattern elsewhere (e.g. in scalar math).
@@ -1008,11 +1022,17 @@ PyArray_MatrixProduct2(PyObject *op1, PyObject *op2, PyArrayObject* out) | |||
return NULL; | |||
} | |||
|
|||
npy_clear_floatstatus(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer pairing this with the GIL release (i.e. very targeted around the place of interest).
That probably means moving it into cblas_matrixproduct
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved this to cblas_matrixproduct
but I'm not sure if it's quite what you want. Do you want this immediately before (or after?) the GIL release, or just closer to where the actual work is being done?
return cblas_matrixproduct(typenum, ap1, ap2, out); | ||
PyObject *res = cblas_matrixproduct(typenum, ap1, ap2, out); | ||
if (check_fperr("dot") < 0) { | ||
return NULL; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to note, this is also incorrect of course: You would have to throw away res
here, but you are leaking it.
I'm not sure what's causing the CI failure, but it doesn't seem to be related to my PR. Or am I missing something? |
I restarted the s390x build. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, LGTM! A small nit that we should extend the test a bit, I think.
Not super important here, but you could add a change
or improvement
release note if you like (see docs/release/upcoming_changes/README.rst
).
(Unrelated... I am half wondering if we couldn't fold most of this to use the matmul
ufunc as a whole or just its loops. That would also have the information that no floating point errors can happen. But I am not worried about it much, matmuls on integers isn't all that common.)
numpy/_core/tests/test_multiarray.py
Outdated
@@ -3345,6 +3345,29 @@ def test_dot(self): | |||
a.dot(b=b, out=c) | |||
assert_equal(c, np.dot(a, b)) | |||
|
|||
@pytest.mark.skipif(IS_WASM, reason="no wasm fp exception support") | |||
def test_dot_errstate(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small nit: Could you parametrize this to include half
and/or longdouble
(just because they currently take the other code paths and can set the warnings).
The point is that double/float should effectively always take the first branch, that way we should cover both error paths.
I can look into this, if it would be of interest. I picked this bug as a good way to get started with numpy development, and I was thinking of applying similar fixes to some of the other functions in the multiarray module (assuming they have similar bugs, which I haven't looked into yet). Do you have a preference on what would be more useful? |
Thanks for the follow-up, let's give this a shot! |
* Fixes scipygh-22720. * This patch adds a few floating point exception handling shims to compensate for the upstream change for `np.dot` at numpy/numpy#28442, merged five days ago. This allows the Scipy testsuite to pass again locally on x86_64 Linux against NumPy `main`.
Note: required a small number of shims downstream, but not too bad I don't think. |
* BUG: Check for floating point exceptions in dot Currently dot does not check for floating point exceptions, while similar functions (such as multiply and matmul) do. Add these checks for consistency. Note that certain types of floating point exceptions are still not caught: - Multiplications of 0 with nan or inf are ignored (see numpy#27902) - Integer overflows are ignored. These aren't true floating point exceptions, but the behavior is inconsistent with scalar multiply, i.e. ``` >>> np.int16(32000) * np.int16(3) <python-input-7>:1: RuntimeWarning: overflow encountered in scalar multiply np.int16(32000) * np.int16(3) np.int16(30464) >>> np.dot(np.int16(32000), np.int16(3)) np.int16(30464) ``` Fixes numpy#14925 * pr comments * parametrize test, add release note * use generic typenames
* Fixes scipygh-22720. * This patch adds a few floating point exception handling shims to compensate for the upstream change for `np.dot` at numpy/numpy#28442, merged five days ago. This allows the Scipy testsuite to pass again locally on x86_64 Linux against NumPy `main`.
Currently dot does not check for floating point exceptions, while similar functions (such as multiply and matmul) do. Add these checks for consistency.
Note that certain types of floating point exceptions are still not caught:
np.dot
result changes based on shape when the input containsnan
#27902)Fixes #14925