-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BUG: isclose: fix multidevice for equal_nan=True
#177
BUG: isclose: fix multidevice for equal_nan=True
#177
Conversation
@lucascolley |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we already have
array-api-extra/tests/conftest.py
Lines 148 to 161 in acbdf72
@pytest.fixture | |
def device( | |
library: Backend, xp: ModuleType | |
) -> Device: # numpydoc ignore=PR01,RT01,RT03 | |
""" | |
Return a valid device for the backend. | |
Where possible, return a device that is not the default one. | |
""" | |
if library == Backend.ARRAY_API_STRICT: | |
d = xp.Device("device1") | |
assert get_device(xp.empty(0)) != d | |
return d | |
return get_device(xp.empty(0)) |
it looks like we just forgot to add a test_device
for isclose
, see e.g.
array-api-extra/tests/test_funcs.py
Lines 435 to 437 in acbdf72
def test_device(self, xp: ModuleType, device: Device): | |
x = xp.asarray([1, 2, 3], device=device) | |
assert get_device(cov(x)) == device |
757996d
to
b1e6639
Compare
Wow, totally missed that! Thanks, that's a lot cleaner. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks Thomas!
equal_nan=True
looks like there is a pre-commit failure |
Sorry, sent in a fix! |
Unfortunately the I'd be happy for you to modify the test to be just a device check with |
fc26a74
to
bdc41bd
Compare
I added a conversion to the CPU device in array-api-strict which is what scikit-learn does as well I think. |
src/array_api_extra/_lib/_testing.py
Outdated
if is_array_api_strict_namespace(xp): | ||
# __array__ doesn't work on array-api-strict device arrays | ||
# We need to convert to the CPU device first | ||
actual = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE"))) | ||
desired = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE"))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, yes, this is a good call, thanks!
bdc41bd
to
1236e3a
Compare
Thanks for the review and for fixing typing! |
* BUG: Fix isclose multidevice * test the right way * fix pre-commit * convert to CPU in xp_assert_equal * fixes * fix tests --------- Co-authored-by: Lucas Colley <[email protected]>
xp.asarray(True)
was missing a device argument, which causes array-api-strict with a device argument to fail in where.Since as of array api 2024.12 we can pass in scalars directly to where, I've replaced asarray(True) with just True.