Skip to content

Commit fc6b56b

Browse files
committed
is_lazy_array, is_writeable_array
1 parent 632f081 commit fc6b56b

File tree

1 file changed

+43
-14
lines changed

1 file changed

+43
-14
lines changed

array_api_compat/common/_helpers.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,19 @@ def size(x: Array) -> int | None:
839839
return None if math.isnan(out) else out
840840

841841

842+
@cache
843+
def _is_writeable_cls(cls: type) -> bool | None:
844+
if (
845+
_issubclass_fast(cls, "numpy", "generic")
846+
or _issubclass_fast(cls, "jax", "Array")
847+
or _issubclass_fast(cls, "sparse", "SparseArray")
848+
):
849+
return False
850+
if _is_array_api_cls(cls):
851+
return True
852+
return None
853+
854+
842855
def is_writeable_array(x: object) -> bool:
843856
"""
844857
Return False if ``x.__setitem__`` is expected to raise; True otherwise.
@@ -849,11 +862,32 @@ def is_writeable_array(x: object) -> bool:
849862
As there is no standard way to check if an array is writeable without actually
850863
writing to it, this function blindly returns True for all unknown array types.
851864
"""
852-
if is_numpy_array(x):
865+
cls = type(x)
866+
if _issubclass_fast(cls, "numpy", "ndarray"):
853867
return x.flags.writeable
854-
if is_jax_array(x) or is_pydata_sparse_array(x):
868+
res = _is_writeable_cls(cls)
869+
if res is not None:
870+
return res
871+
return hasattr(x, '__array_namespace__')
872+
873+
874+
@cache
875+
def _is_lazy_cls(cls: type) -> bool | None:
876+
if (
877+
_issubclass_fast(cls, "numpy", "ndarray")
878+
or _issubclass_fast(cls, "numpy", "generic")
879+
or _issubclass_fast(cls, "cupy", "ndarray")
880+
or _issubclass_fast(cls, "torch", "Tensor")
881+
or _issubclass_fast(cls, "sparse", "SparseArray")
882+
):
855883
return False
856-
return is_array_api_obj(x)
884+
if (
885+
_issubclass_fast(cls, "jax", "Array")
886+
or _issubclass_fast(cls, "dask.array", "Array")
887+
or _issubclass_fast(cls, "ndonnx", "Array")
888+
):
889+
return True
890+
return None
857891

858892

859893
def is_lazy_array(x: object) -> bool:
@@ -869,14 +903,6 @@ def is_lazy_array(x: object) -> bool:
869903
This function errs on the side of caution for array types that may or may not be
870904
lazy, e.g. JAX arrays, by always returning True for them.
871905
"""
872-
if (
873-
is_numpy_array(x)
874-
or is_cupy_array(x)
875-
or is_torch_array(x)
876-
or is_pydata_sparse_array(x)
877-
):
878-
return False
879-
880906
# **JAX note:** while it is possible to determine if you're inside or outside
881907
# jax.jit by testing the subclass of a jax.Array object, as well as testing bool()
882908
# as we do below for unknown arrays, this is not recommended by JAX best practices.
@@ -886,10 +912,13 @@ def is_lazy_array(x: object) -> bool:
886912
# compatibility, is highly detrimental to performance as the whole graph will end
887913
# up being computed multiple times.
888914

889-
if is_jax_array(x) or is_dask_array(x) or is_ndonnx_array(x):
890-
return True
915+
# Note: skipping reclassification of JAX zero gradient arrays, as one will
916+
# exclusively get them once they leave a jax.grad JIT context.
917+
res = _is_lazy_cls(type(x))
918+
if res is not None:
919+
return res
891920

892-
if not is_array_api_obj(x):
921+
if not hasattr(x, "__array_namespace__"):
893922
return False
894923

895924
# Unknown Array API compatible object. Note that this test may have dire consequences

0 commit comments

Comments
 (0)