@@ -74,7 +74,7 @@ def test_assert_close_tolerance(xp: ModuleType):
74
74
@param_assert_equal_close
75
75
@pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "index by sparse array" )
76
76
def test_assert_close_equal_none_shape (xp : ModuleType , func : Callable [..., None ]): # type: ignore[explicit-any]
77
- """On dask and other lazy backends, test that a shape with NaN's or None's
77
+ """On Dask and other lazy backends, test that a shape with NaN's or None's
78
78
can be compared to a real shape.
79
79
"""
80
80
a = xp .asarray ([1 , 2 ])
@@ -99,18 +99,18 @@ def test_assert_close_equal_none_shape(xp: ModuleType, func: Callable[..., None]
99
99
100
100
101
101
def good_lazy (x : Array ) -> Array :
102
- """A function that behaves well in dask and jax.jit"""
102
+ """A function that behaves well in Dask and jax.jit"""
103
103
return x * 2.0
104
104
105
105
106
106
def non_materializable (x : Array ) -> Array :
107
107
"""
108
108
This function materializes the input array, so it will fail when wrapped in jax.jit
109
- and it will trigger an expensive computation in dask .
109
+ and it will trigger an expensive computation in Dask .
110
110
"""
111
111
xp = array_namespace (x )
112
112
# Crashes inside jax.jit
113
- # On dask , this triggers two computations of the whole graph
113
+ # On Dask , this triggers two computations of the whole graph
114
114
if xp .any (x < 0.0 ) or xp .any (x > 10.0 ):
115
115
msg = "Values must be in the [0, 10] range"
116
116
raise ValueError (msg )
@@ -217,20 +217,20 @@ def test_lazy_xp_function_static_params(xp: ModuleType, func: Callable[..., Arra
217
217
erf = None
218
218
219
219
220
- @pytest .mark .filterwarnings ("ignore:__array_wrap__:DeprecationWarning" ) # torch
220
+ @pytest .mark .filterwarnings ("ignore:__array_wrap__:DeprecationWarning" ) # PyTorch
221
221
def test_lazy_xp_function_cython_ufuncs (xp : ModuleType , library : Backend ):
222
222
pytest .importorskip ("scipy" )
223
223
assert erf is not None
224
224
x = xp .asarray ([6.0 , 7.0 ])
225
225
if library in (Backend .ARRAY_API_STRICT , Backend .JAX ):
226
- # array-api-strict arrays are auto-converted to numpy
226
+ # array-api-strict arrays are auto-converted to NumPy
227
227
# which results in an assertion error for mismatched namespaces
228
- # eager jax arrays are auto-converted to numpy in eager jax
228
+ # eager JAX arrays are auto-converted to NumPy in eager JAX
229
229
# and fail in jax.jit (which lazy_xp_function tests here)
230
230
with pytest .raises ((TypeError , AssertionError )):
231
231
xp_assert_equal (cast (Array , erf (x )), xp .asarray ([1.0 , 1.0 ]))
232
232
else :
233
- # cupy, dask and sparse define __array_ufunc__ and dispatch accordingly
233
+ # CuPy, Dask and sparse define __array_ufunc__ and dispatch accordingly
234
234
# note that when sparse reduces to scalar it returns a np.generic, which
235
235
# would make xp_assert_equal fail.
236
236
xp_assert_equal (cast (Array , erf (x )), xp .asarray ([1.0 , 1.0 ]))
@@ -271,7 +271,7 @@ def test_lazy_xp_function_eagerly_raises(da: ModuleType):
271
271
272
272
def f (x : Array ) -> Array :
273
273
xp = array_namespace (x )
274
- # Crash in jax.jit and trigger compute() on dask
274
+ # Crash in jax.jit and trigger compute() on Dask
275
275
if not xp .all (x ):
276
276
msg = "Values must be non-zero"
277
277
raise ValueError (msg )
0 commit comments