-
Notifications
You must be signed in to change notification settings - Fork 10
Add "multi device" support #59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Having more than one device is useful during testing to allow you to find bugs related to how arrays on different devices are handled.
+1! Another question is what the default should be (technically |
I think the CPU device should be the default. That way code that exists today should keep working and the only people who notice any changes are those who use the pony device. |
This looks good so far. We need to make sure the semantics specified at https://data-apis.org/array-api/latest/design_topics/device_support.html#semantics are followed, namely, disallowing combining arrays from different devices, and making sure that if a function creates a new array based on an existing array that it uses the same device. For tests, ideally this would be tested in array-api-tests, but right now device support is not tested at all there. If you just want to add some basic tests here for now, that is fien. Finally, there is the |
I've rebooted my work. The long pause is because I went on holiday :D I think we have to limit ourselves to a fixed number of devices, otherwise we can't full fill the requirement that the info extension can provide a list of devices. So now you can use the Slowly making progress towards the creation functions and "array combination" functions respecting the device |
Do you use |
It looks like it would be quite tricky to add device testing to |
There's no autoformatting on this repo.
I think it would have to use the It's also possible to do some basic testing using the default device, like that The annoying thing for the test suite is making sure every function everywhere is passing |
I think what we need here are just some big parameterized tests combining basic example arrays with different devices across all the APIs. For instance, there's an existing test that checks type promotion and the "no mixing devices" test could look very similar to that. |
|
||
|
||
def logaddexp(x1: Array, x2: Array) -> Array: | ||
def logaddexp(x1: Array, x2: Array, /) -> Array: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this was missing/typo.
@@ -19,8 +19,16 @@ | |||
|
|||
import pytest | |||
|
|||
import array_api_strict | |||
|
|||
|
|||
def nargs(func): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I modified this so that it works with decorated/wrapped functions as well. len(getfullargspec(f).args)
returns zero for functions that are decorated with the array API version decorator. From what I can tell from the Python docs this is kind of on purpose/to preserve existing behaviour. signature()
does the right thing for wrapped functions, but it needs slightly more explicit work to count the arguments.
I think the intention/rule is that nargs()
counts the number of positional only arguments, which is basically the "number of arrays you need to pass to a elementwise function". I went with the very explicit way of counting the args partially as a way to make it easier for people from the future to understand what nargs
is meant to do (even if it contains a bug and doesn't actually do what it is meant to do).
@@ -91,12 +99,57 @@ def nargs(func): | |||
"trunc": "real numeric", | |||
} | |||
|
|||
|
|||
def test_nargs(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A short test to make sure nargs
works but also that all of the functions that we look at have "the right signature" - I found logaddexp
was missing that trailing /
when working on nargs
. So it seems useful to have a "all functions have a reasonable number of arguments" test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's good. The array-api test suite doesn't check for positional-only, though it probably could now that it uses inspect.signature.
The default device should continue to convert, but other arrays from other devices should error.
9245481
to
ff37de7
Compare
760c230
to
9c5436c
Compare
Does someone know more about the failure? It looks like it is not to do with the actual code but with computing the expected shape and that overflowing because the array is of dtype |
Sorry, that is from a new test that I added in the test suite. I guess I didn't catch all the corner cases. You can ignore it for now. |
In that case, I think, this is ready?! I tried to modify all the functions that return an array to take into account the |
@@ -309,6 +309,9 @@ | |||
|
|||
__all__ += ["all", "any"] | |||
|
|||
from ._array_object import Device | |||
__all__ += ["Device"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather not add this to the __init__.py
since it isn't part of the array API. If it is necessary to have some public APIs to create device objects we should make APIs that are more obviously array-api-strict specific (similar to the flags APIs).
@@ -625,19 +661,21 @@ def __getitem__( | |||
""" | |||
Performs the operation __getitem__. | |||
""" | |||
# XXX Does key have to be on the same device? Is there an exception for CPU_DEVICE? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this is underspecified. I would need to test what PyTorch and others do, but I would suspect that an implicit cross-device array key
is not something that's intended to be supported, since that still would require an implicit device transfer.
@@ -121,7 +121,7 @@ def __hash__(self): | |||
"integer": _integer_dtypes, | |||
"integer or boolean": _integer_or_boolean_dtypes, | |||
"boolean": _boolean_dtypes, | |||
"real floating-point": _floating_dtypes, | |||
"real floating-point": _real_floating_dtypes, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. Is this covered by one of the new tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not explicitly. I found this because a test was failing, looked into this and found it. But also it is too long ago for me to remember what exactly it was :-/
It's hard to tell just from the diff if you missed anything. Here are all the places in the code that call
So it would be worth double checking all of those. |
As you mentioned, we should indeed be testing most of this in the test suite. However, I'm not really sure how soon that will happen. There's quite a backlog of things to do in the test suite right now, and my current priority is implementing tests for new functions added in 2023.12 or 2024.12 versions of the standard. So some very simple tests here would not hurt. The tests already have a list of two-argument functions which could be reused. |
Were you wanting to add support for devices that don't support certain dtypes, or is that something that we should add in a later pull request? |
Actually, I think we should make |
I'd do that in a separate PR. If only because this one is already quite long and hard to check by looking at the diff.
That is a good idea. I like it |
Nice. I feel much better about this after the latest commit making It looks like another PR I just merged has created a small conflict here, but other than that, I am +1 to merging this. |
Thanks for taking this over the finish line! |
I opened #70 for ideas for further work here. |
Yes that is the default device (I've not followed the scipy discussion, so I have no idea if this is a good thing?) |
Having more than one device is useful during testing to allow you to find bugs related to how arrays on different devices are handled. Closes #56
With scikit-learn we run into the frustrating situation were contributors execute tests locally, they all pass but then see failures on the CI related to the fact that e.g. PyTorch has several devices and some things work on the CPU device but not on the CUDA/MPS device. However, if you have neither of those on your local machine you can't really test this upfront and to debug it you need to rely on the CI.
The idea of this PR is to add support for multiple devices to
array-api-strict
to make testing easier. The default device continues to be the CPU device and for arrays that use it nothing should change. However, you can now place an array on a different device witharray_api_strict.Device("pony")
(or some other string, each string is a new device). For arrays on a device that isn't the CPU device calls likenp.asarray(some_strict_array)
will raise an error. This mirrors how PyTorch treats arrays on the CPU and MPS device.What isn't yet implemented in this PR is raising an error if you try to operate on arrays that are not on the same device.
I wanted to open this PR already now after just a short amount of effort to get feedback what people think about this before putting in the time to update all the tests, etc.