Skip to content

Commit 42f535d

Browse files
committed
Better tests for default copy param
1 parent 96105b3 commit 42f535d

File tree

1 file changed

+44
-26
lines changed

1 file changed

+44
-26
lines changed

Diff for: tests/test_at.py

+44-26
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections.abc import Callable, Generator
44
from contextlib import contextmanager
55
from types import ModuleType
6-
from typing import Any, cast
6+
from typing import cast
77

88
import numpy as np
99
import pytest
@@ -23,12 +23,13 @@
2323
]
2424

2525

26-
def at_op( # type: ignore[no-any-explicit]
26+
def at_op(
2727
x: Array,
2828
idx: Index,
2929
op: _AtOp,
3030
y: Array | object,
31-
**kwargs: Any, # Test the default copy=None
31+
copy: bool | None = None,
32+
xp: ModuleType | None = None,
3233
) -> Array:
3334
"""
3435
Wrapper around at(x, idx).op(y, copy=copy, xp=xp).
@@ -39,30 +40,33 @@ def at_op( # type: ignore[no-any-explicit]
3940
which is not a common use case.
4041
"""
4142
if isinstance(idx, (slice | tuple)):
42-
return _at_op(x, None, pickle.dumps(idx), op, y, **kwargs)
43-
return _at_op(x, idx, None, op, y, **kwargs)
43+
return _at_op(x, None, pickle.dumps(idx), op, y, copy=copy, xp=xp)
44+
return _at_op(x, idx, None, op, y, copy=copy, xp=xp)
4445

4546

46-
def _at_op( # type: ignore[no-any-explicit]
47+
def _at_op(
4748
x: Array,
4849
idx: Index | None,
4950
idx_pickle: bytes | None,
5051
op: _AtOp,
5152
y: Array | object,
52-
**kwargs: Any,
53+
copy: bool | None,
54+
xp: ModuleType | None = None,
5355
) -> Array:
5456
"""jitted helper of at_op"""
5557
if idx_pickle:
5658
idx = pickle.loads(idx_pickle)
5759
meth = cast(Callable[..., Array], getattr(at(x, idx), op.value)) # type: ignore[no-any-explicit]
58-
return meth(y, **kwargs)
60+
return meth(y, copy=copy, xp=xp)
5961

6062

6163
lazy_xp_function(_at_op, static_argnames=("op", "idx_pickle", "copy", "xp"))
6264

6365

6466
@contextmanager
65-
def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
67+
def assert_copy(
68+
array: Array, copy: bool | None, expect_copy: bool | None = None
69+
) -> Generator[None, None, None]:
6670
if copy is False and not is_writeable_array(array):
6771
with pytest.raises((TypeError, ValueError)):
6872
yield
@@ -72,28 +76,21 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
7276
array_orig = xp.asarray(array, copy=True)
7377
yield
7478

75-
if copy is True:
79+
if expect_copy is None:
80+
expect_copy = copy
81+
82+
if expect_copy:
7683
# Original has not been modified
7784
xp_assert_equal(array, array_orig)
78-
elif copy is False:
85+
elif expect_copy is False:
7986
# Original has been modified
8087
with pytest.raises(AssertionError):
8188
xp_assert_equal(array, array_orig)
8289
# Test nothing for copy=None. Dask changes behaviour depending on
8390
# whether it's a special case of a bool mask with scalar RHS or not.
8491

8592

86-
@pytest.mark.parametrize(
87-
("kwargs", "expect_copy"),
88-
[
89-
pytest.param({"copy": True}, True, id="copy=True"),
90-
pytest.param({"copy": False}, False, id="copy=False"),
91-
# Behavior is backend-specific
92-
pytest.param({"copy": None}, None, id="copy=None"),
93-
# Test that the copy parameter defaults to None
94-
pytest.param({}, None, id="no copy kwarg"),
95-
],
96-
)
93+
@pytest.mark.parametrize("copy", [False, True, None])
9794
@pytest.mark.parametrize(
9895
("op", "y", "expect_list"),
9996
[
@@ -130,8 +127,7 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
130127
)
131128
def test_update_ops(
132129
xp: ModuleType,
133-
kwargs: dict[str, bool | None],
134-
expect_copy: bool | None,
130+
copy: bool | None,
135131
op: _AtOp,
136132
y: float,
137133
expect_list: list[float],
@@ -156,12 +152,34 @@ def test_update_ops(
156152
if y_ndim == 1:
157153
y = xp.asarray([y, y])
158154

159-
with assert_copy(x, expect_copy):
160-
z = at_op(x, idx, op, y, **kwargs)
155+
with assert_copy(x, copy):
156+
z = at_op(x, idx, op, y, copy=copy)
161157
assert isinstance(z, type(x))
162158
xp_assert_equal(z, xp.asarray(expect))
163159

164160

161+
@pytest.mark.parametrize("op", list(_AtOp))
162+
def test_copy_default(xp: ModuleType, library: Backend, op: _AtOp):
163+
"""
164+
Test that the default copy behaviour is False for writeable arrays
165+
and True for read-only ones.
166+
"""
167+
x = xp.asarray([1.0, 10.0, 20.0])
168+
expect_copy = not is_writeable_array(x)
169+
meth = cast(Callable[..., Array], getattr(at(x)[:2], op.value)) # type: ignore[no-any-explicit]
170+
with assert_copy(x, None, expect_copy):
171+
_ = meth(2.0)
172+
173+
x = xp.asarray([1.0, 10.0, 20.0])
174+
# Dask's default copy value is True for bool masks,
175+
# even if the arrays are writeable.
176+
expect_copy = not is_writeable_array(x) or library is Backend.DASK
177+
idx = xp.asarray([True, True, False])
178+
meth = cast(Callable[..., Array], getattr(at(x, idx), op.value)) # type: ignore[no-any-explicit]
179+
with assert_copy(x, None, expect_copy):
180+
_ = meth(2.0)
181+
182+
165183
def test_copy_invalid():
166184
a = np.asarray([1, 2, 3])
167185
with pytest.raises(ValueError, match="copy"):

0 commit comments

Comments
 (0)