Skip to content

Commit 8d3f5d5

Browse files
authored
Merge pull request #147 from asmeurer/array_namespace-scalars
Allow Python scalars in array_namespace
2 parents a1bb036 + 0a2160b commit 8d3f5d5

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

array_api_compat/common/_helpers.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,8 @@ def array_namespace(*xs, api_version=None, use_compat=None):
428428
Parameters
429429
----------
430430
xs: arrays
431-
one or more arrays.
431+
one or more arrays. xs can also be Python scalars (bool, int, float,
432+
complex, or None), which are ignored.
432433
433434
api_version: str
434435
The newest version of the spec that you need support for (currently
@@ -491,7 +492,9 @@ def your_function(x, y):
491492

492493
namespaces = set()
493494
for x in xs:
494-
if is_numpy_array(x):
495+
if isinstance(x, (bool, int, float, complex, type(None))):
496+
continue
497+
elif is_numpy_array(x):
495498
from .. import numpy as numpy_namespace
496499
import numpy as np
497500
if use_compat is True:

tests/test_array_namespace.py

+16
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,19 @@ def test_api_version():
114114
def test_get_namespace():
115115
# Backwards compatible wrapper
116116
assert array_api_compat.get_namespace is array_api_compat.array_namespace
117+
118+
def test_python_scalars():
119+
a = torch.asarray([1, 2])
120+
xp = import_("torch", wrapper=True)
121+
122+
pytest.raises(TypeError, lambda: array_namespace(1))
123+
pytest.raises(TypeError, lambda: array_namespace(1.0))
124+
pytest.raises(TypeError, lambda: array_namespace(1j))
125+
pytest.raises(TypeError, lambda: array_namespace(True))
126+
pytest.raises(TypeError, lambda: array_namespace(None))
127+
128+
assert array_namespace(a, 1) == xp
129+
assert array_namespace(a, 1.0) == xp
130+
assert array_namespace(a, 1j) == xp
131+
assert array_namespace(a, True) == xp
132+
assert array_namespace(a, None) == xp

0 commit comments

Comments
 (0)