@@ -839,6 +839,19 @@ def size(x: Array) -> int | None:
839
839
return None if math .isnan (out ) else out
840
840
841
841
842
+ @cache
843
+ def _is_writeable_cls (cls : type ) -> bool | None :
844
+ if (
845
+ _issubclass_fast (cls , "numpy" , "generic" )
846
+ or _issubclass_fast (cls , "jax" , "Array" )
847
+ or _issubclass_fast (cls , "sparse" , "SparseArray" )
848
+ ):
849
+ return False
850
+ if _is_array_api_cls (cls ):
851
+ return True
852
+ return None
853
+
854
+
842
855
def is_writeable_array (x : object ) -> bool :
843
856
"""
844
857
Return False if ``x.__setitem__`` is expected to raise; True otherwise.
@@ -849,11 +862,32 @@ def is_writeable_array(x: object) -> bool:
849
862
As there is no standard way to check if an array is writeable without actually
850
863
writing to it, this function blindly returns True for all unknown array types.
851
864
"""
852
- if is_numpy_array (x ):
865
+ cls = type (x )
866
+ if _issubclass_fast (cls , "numpy" , "ndarray" ):
853
867
return x .flags .writeable
854
- if is_jax_array (x ) or is_pydata_sparse_array (x ):
868
+ res = _is_writeable_cls (cls )
869
+ if res is not None :
870
+ return res
871
+ return hasattr (x , '__array_namespace__' )
872
+
873
+
874
+ @cache
875
+ def _is_lazy_cls (cls : type ) -> bool | None :
876
+ if (
877
+ _issubclass_fast (cls , "numpy" , "ndarray" )
878
+ or _issubclass_fast (cls , "numpy" , "generic" )
879
+ or _issubclass_fast (cls , "cupy" , "ndarray" )
880
+ or _issubclass_fast (cls , "torch" , "Tensor" )
881
+ or _issubclass_fast (cls , "sparse" , "SparseArray" )
882
+ ):
855
883
return False
856
- return is_array_api_obj (x )
884
+ if (
885
+ _issubclass_fast (cls , "jax" , "Array" )
886
+ or _issubclass_fast (cls , "dask.array" , "Array" )
887
+ or _issubclass_fast (cls , "ndonnx" , "Array" )
888
+ ):
889
+ return True
890
+ return None
857
891
858
892
859
893
def is_lazy_array (x : object ) -> bool :
@@ -869,14 +903,6 @@ def is_lazy_array(x: object) -> bool:
869
903
This function errs on the side of caution for array types that may or may not be
870
904
lazy, e.g. JAX arrays, by always returning True for them.
871
905
"""
872
- if (
873
- is_numpy_array (x )
874
- or is_cupy_array (x )
875
- or is_torch_array (x )
876
- or is_pydata_sparse_array (x )
877
- ):
878
- return False
879
-
880
906
# **JAX note:** while it is possible to determine if you're inside or outside
881
907
# jax.jit by testing the subclass of a jax.Array object, as well as testing bool()
882
908
# as we do below for unknown arrays, this is not recommended by JAX best practices.
@@ -886,10 +912,13 @@ def is_lazy_array(x: object) -> bool:
886
912
# compatibility, is highly detrimental to performance as the whole graph will end
887
913
# up being computed multiple times.
888
914
889
- if is_jax_array (x ) or is_dask_array (x ) or is_ndonnx_array (x ):
890
- return True
915
+ # Note: skipping reclassification of JAX zero gradient arrays, as one will
916
+ # exclusively get them once they leave a jax.grad JIT context.
917
+ res = _is_lazy_cls (type (x ))
918
+ if res is not None :
919
+ return res
891
920
892
- if not is_array_api_obj ( x ):
921
+ if not hasattr ( x , "__array_namespace__" ):
893
922
return False
894
923
895
924
# Unknown Array API compatible object. Note that this test may have dire consequences
0 commit comments