3
3
from collections .abc import Callable , Generator
4
4
from contextlib import contextmanager
5
5
from types import ModuleType
6
- from typing import Any , cast
6
+ from typing import cast
7
7
8
8
import numpy as np
9
9
import pytest
23
23
]
24
24
25
25
26
- def at_op ( # type: ignore[no-any-explicit]
26
+ def at_op (
27
27
x : Array ,
28
28
idx : Index ,
29
29
op : _AtOp ,
30
30
y : Array | object ,
31
- ** kwargs : Any , # Test the default copy=None
31
+ copy : bool | None = None ,
32
+ xp : ModuleType | None = None ,
32
33
) -> Array :
33
34
"""
34
35
Wrapper around at(x, idx).op(y, copy=copy, xp=xp).
@@ -39,30 +40,33 @@ def at_op( # type: ignore[no-any-explicit]
39
40
which is not a common use case.
40
41
"""
41
42
if isinstance (idx , (slice | tuple )):
42
- return _at_op (x , None , pickle .dumps (idx ), op , y , ** kwargs )
43
- return _at_op (x , idx , None , op , y , ** kwargs )
43
+ return _at_op (x , None , pickle .dumps (idx ), op , y , copy = copy , xp = xp )
44
+ return _at_op (x , idx , None , op , y , copy = copy , xp = xp )
44
45
45
46
46
- def _at_op ( # type: ignore[no-any-explicit]
47
+ def _at_op (
47
48
x : Array ,
48
49
idx : Index | None ,
49
50
idx_pickle : bytes | None ,
50
51
op : _AtOp ,
51
52
y : Array | object ,
52
- ** kwargs : Any ,
53
+ copy : bool | None ,
54
+ xp : ModuleType | None = None ,
53
55
) -> Array :
54
56
"""jitted helper of at_op"""
55
57
if idx_pickle :
56
58
idx = pickle .loads (idx_pickle )
57
59
meth = cast (Callable [..., Array ], getattr (at (x , idx ), op .value )) # type: ignore[no-any-explicit]
58
- return meth (y , ** kwargs )
60
+ return meth (y , copy = copy , xp = xp )
59
61
60
62
61
63
lazy_xp_function (_at_op , static_argnames = ("op" , "idx_pickle" , "copy" , "xp" ))
62
64
63
65
64
66
@contextmanager
65
- def assert_copy (array : Array , copy : bool | None ) -> Generator [None , None , None ]:
67
+ def assert_copy (
68
+ array : Array , copy : bool | None , expect_copy : bool | None = None
69
+ ) -> Generator [None , None , None ]:
66
70
if copy is False and not is_writeable_array (array ):
67
71
with pytest .raises ((TypeError , ValueError )):
68
72
yield
@@ -72,28 +76,21 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
72
76
array_orig = xp .asarray (array , copy = True )
73
77
yield
74
78
75
- if copy is True :
79
+ if expect_copy is None :
80
+ expect_copy = copy
81
+
82
+ if expect_copy :
76
83
# Original has not been modified
77
84
xp_assert_equal (array , array_orig )
78
- elif copy is False :
85
+ elif expect_copy is False :
79
86
# Original has been modified
80
87
with pytest .raises (AssertionError ):
81
88
xp_assert_equal (array , array_orig )
82
89
# Test nothing for copy=None. Dask changes behaviour depending on
83
90
# whether it's a special case of a bool mask with scalar RHS or not.
84
91
85
92
86
- @pytest .mark .parametrize (
87
- ("kwargs" , "expect_copy" ),
88
- [
89
- pytest .param ({"copy" : True }, True , id = "copy=True" ),
90
- pytest .param ({"copy" : False }, False , id = "copy=False" ),
91
- # Behavior is backend-specific
92
- pytest .param ({"copy" : None }, None , id = "copy=None" ),
93
- # Test that the copy parameter defaults to None
94
- pytest .param ({}, None , id = "no copy kwarg" ),
95
- ],
96
- )
93
+ @pytest .mark .parametrize ("copy" , [False , True , None ])
97
94
@pytest .mark .parametrize (
98
95
("op" , "y" , "expect_list" ),
99
96
[
@@ -130,8 +127,7 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
130
127
)
131
128
def test_update_ops (
132
129
xp : ModuleType ,
133
- kwargs : dict [str , bool | None ],
134
- expect_copy : bool | None ,
130
+ copy : bool | None ,
135
131
op : _AtOp ,
136
132
y : float ,
137
133
expect_list : list [float ],
@@ -156,12 +152,34 @@ def test_update_ops(
156
152
if y_ndim == 1 :
157
153
y = xp .asarray ([y , y ])
158
154
159
- with assert_copy (x , expect_copy ):
160
- z = at_op (x , idx , op , y , ** kwargs )
155
+ with assert_copy (x , copy ):
156
+ z = at_op (x , idx , op , y , copy = copy )
161
157
assert isinstance (z , type (x ))
162
158
xp_assert_equal (z , xp .asarray (expect ))
163
159
164
160
161
+ @pytest .mark .parametrize ("op" , list (_AtOp ))
162
+ def test_copy_default (xp : ModuleType , library : Backend , op : _AtOp ):
163
+ """
164
+ Test that the default copy behaviour is False for writeable arrays
165
+ and True for read-only ones.
166
+ """
167
+ x = xp .asarray ([1.0 , 10.0 , 20.0 ])
168
+ expect_copy = not is_writeable_array (x )
169
+ meth = cast (Callable [..., Array ], getattr (at (x )[:2 ], op .value )) # type: ignore[no-any-explicit]
170
+ with assert_copy (x , None , expect_copy ):
171
+ _ = meth (2.0 )
172
+
173
+ x = xp .asarray ([1.0 , 10.0 , 20.0 ])
174
+ # Dask's default copy value is True for bool masks,
175
+ # even if the arrays are writeable.
176
+ expect_copy = not is_writeable_array (x ) or library is Backend .DASK
177
+ idx = xp .asarray ([True , True , False ])
178
+ meth = cast (Callable [..., Array ], getattr (at (x , idx ), op .value )) # type: ignore[no-any-explicit]
179
+ with assert_copy (x , None , expect_copy ):
180
+ _ = meth (2.0 )
181
+
182
+
165
183
def test_copy_invalid ():
166
184
a = np .asarray ([1 , 2 , 3 ])
167
185
with pytest .raises (ValueError , match = "copy" ):
0 commit comments