Skip to content

Commit ea06c6f

Browse files
Update __array__ signatures with copy (#9529)
* Update __array__ with copy * Update common.py * Update indexing.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * copy only available from np2 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Raise if copy=false * Update groupby.py * Update test_namedarray.py * Update pyproject.toml --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 52f13d4 commit ea06c6f

10 files changed

+65
-31
lines changed

pyproject.toml

-1
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,6 @@ filterwarnings = [
323323
"default:Using a non-tuple sequence for multidimensional indexing is deprecated:FutureWarning",
324324
"default:Duplicate dimension names present:UserWarning:xarray.namedarray.core",
325325
"default:::xarray.tests.test_strategies", # TODO: remove once we know how to deal with a changed signature in protocols
326-
"ignore:__array__ implementation doesn't accept a copy keyword, so passing copy=False failed.",
327326
]
328327

329328
log_cli_level = "INFO"

xarray/core/common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def __complex__(self: Any) -> complex:
163163
return complex(self.values)
164164

165165
def __array__(
166-
self: Any, dtype: DTypeLike | None = None, copy: bool | None = None
166+
self: Any, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
167167
) -> np.ndarray:
168168
if not copy:
169169
if np.lib.NumpyVersion(np.__version__) >= "2.0.0":

xarray/core/datatree.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
from xarray.core.dataset import calculate_dimensions
5656

5757
if TYPE_CHECKING:
58+
import numpy as np
5859
import pandas as pd
5960

6061
from xarray.core.datatree_io import T_DataTreeNetcdfEngine, T_DataTreeNetcdfTypes
@@ -737,7 +738,9 @@ def __bool__(self) -> bool:
737738
def __iter__(self) -> Iterator[str]:
738739
return itertools.chain(self._data_variables, self._children) # type: ignore[arg-type]
739740

740-
def __array__(self, dtype=None, copy=None):
741+
def __array__(
742+
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
743+
) -> np.ndarray:
741744
raise TypeError(
742745
"cannot directly convert a DataTree into a "
743746
"numpy array. Instead, create an xarray.DataArray "

xarray/core/groupby.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,11 @@ def values(self) -> range:
193193
def data(self) -> range:
194194
return range(self.size)
195195

196-
def __array__(self) -> np.ndarray:
196+
def __array__(
197+
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
198+
) -> np.ndarray:
199+
if copy is False:
200+
raise NotImplementedError(f"An array copy is necessary, got {copy = }.")
197201
return np.arange(self.size)
198202

199203
@property

xarray/core/indexing.py

+27-16
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import numpy as np
1515
import pandas as pd
16+
from packaging.version import Version
1617

1718
from xarray.core import duck_array_ops
1819
from xarray.core.nputils import NumpyVIndexAdapter
@@ -505,9 +506,14 @@ class ExplicitlyIndexed:
505506

506507
__slots__ = ()
507508

508-
def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray:
509+
def __array__(
510+
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
511+
) -> np.ndarray:
509512
# Leave casting to an array up to the underlying array type.
510-
return np.asarray(self.get_duck_array(), dtype=dtype)
513+
if Version(np.__version__) >= Version("2.0.0"):
514+
return np.asarray(self.get_duck_array(), dtype=dtype, copy=copy)
515+
else:
516+
return np.asarray(self.get_duck_array(), dtype=dtype)
511517

512518
def get_duck_array(self):
513519
return self.array
@@ -520,11 +526,6 @@ def get_duck_array(self):
520526
key = BasicIndexer((slice(None),) * self.ndim)
521527
return self[key]
522528

523-
def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray:
524-
# This is necessary because we apply the indexing key in self.get_duck_array()
525-
# Note this is the base class for all lazy indexing classes
526-
return np.asarray(self.get_duck_array(), dtype=dtype)
527-
528529
def _oindex_get(self, indexer: OuterIndexer):
529530
raise NotImplementedError(
530531
f"{self.__class__.__name__}._oindex_get method should be overridden"
@@ -570,8 +571,13 @@ def __init__(self, array, indexer_cls: type[ExplicitIndexer] = BasicIndexer):
570571
self.array = as_indexable(array)
571572
self.indexer_cls = indexer_cls
572573

573-
def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray:
574-
return np.asarray(self.get_duck_array(), dtype=dtype)
574+
def __array__(
575+
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
576+
) -> np.ndarray:
577+
if Version(np.__version__) >= Version("2.0.0"):
578+
return np.asarray(self.get_duck_array(), dtype=dtype, copy=copy)
579+
else:
580+
return np.asarray(self.get_duck_array(), dtype=dtype)
575581

576582
def get_duck_array(self):
577583
return self.array.get_duck_array()
@@ -830,9 +836,6 @@ def __init__(self, array):
830836
def _ensure_cached(self):
831837
self.array = as_indexable(self.array.get_duck_array())
832838

833-
def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray:
834-
return np.asarray(self.get_duck_array(), dtype=dtype)
835-
836839
def get_duck_array(self):
837840
self._ensure_cached()
838841
return self.array.get_duck_array()
@@ -1674,15 +1677,21 @@ def __init__(self, array: pd.Index, dtype: DTypeLike = None):
16741677
def dtype(self) -> np.dtype:
16751678
return self._dtype
16761679

1677-
def __array__(self, dtype: DTypeLike = None) -> np.ndarray:
1680+
def __array__(
1681+
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
1682+
) -> np.ndarray:
16781683
if dtype is None:
16791684
dtype = self.dtype
16801685
array = self.array
16811686
if isinstance(array, pd.PeriodIndex):
16821687
with suppress(AttributeError):
16831688
# this might not be public API
16841689
array = array.astype("object")
1685-
return np.asarray(array.values, dtype=dtype)
1690+
1691+
if Version(np.__version__) >= Version("2.0.0"):
1692+
return np.asarray(array.values, dtype=dtype, copy=copy)
1693+
else:
1694+
return np.asarray(array.values, dtype=dtype)
16861695

16871696
def get_duck_array(self) -> np.ndarray:
16881697
return np.asarray(self)
@@ -1831,15 +1840,17 @@ def __init__(
18311840
super().__init__(array, dtype)
18321841
self.level = level
18331842

1834-
def __array__(self, dtype: DTypeLike = None) -> np.ndarray:
1843+
def __array__(
1844+
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
1845+
) -> np.ndarray:
18351846
if dtype is None:
18361847
dtype = self.dtype
18371848
if self.level is not None:
18381849
return np.asarray(
18391850
self.array.get_level_values(self.level).values, dtype=dtype
18401851
)
18411852
else:
1842-
return super().__array__(dtype)
1853+
return super().__array__(dtype, copy=copy)
18431854

18441855
def _convert_scalar(self, item):
18451856
if isinstance(item, tuple) and self.level is not None:

xarray/namedarray/_typing.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -153,15 +153,15 @@ def __getitem__(
153153

154154
@overload
155155
def __array__(
156-
self, dtype: None = ..., /, *, copy: None | bool = ...
156+
self, dtype: None = ..., /, *, copy: bool | None = ...
157157
) -> np.ndarray[Any, _DType_co]: ...
158158
@overload
159159
def __array__(
160-
self, dtype: _DType, /, *, copy: None | bool = ...
160+
self, dtype: _DType, /, *, copy: bool | None = ...
161161
) -> np.ndarray[Any, _DType]: ...
162162

163163
def __array__(
164-
self, dtype: _DType | None = ..., /, *, copy: None | bool = ...
164+
self, dtype: _DType | None = ..., /, *, copy: bool | None = ...
165165
) -> np.ndarray[Any, _DType] | np.ndarray[Any, _DType_co]: ...
166166

167167
# TODO: Should return the same subclass but with a new dtype generic.

xarray/tests/arrays.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ def __init__(self, array):
2424
def get_duck_array(self):
2525
raise UnexpectedDataAccess("Tried accessing data")
2626

27-
def __array__(self, dtype: np.typing.DTypeLike = None):
27+
def __array__(
28+
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
29+
) -> np.ndarray:
2830
raise UnexpectedDataAccess("Tried accessing data")
2931

3032
def __getitem__(self, key):
@@ -49,7 +51,9 @@ def __init__(self, array: np.ndarray):
4951
def __getitem__(self, key):
5052
return type(self)(self.array[key])
5153

52-
def __array__(self, dtype: np.typing.DTypeLike = None):
54+
def __array__(
55+
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
56+
) -> np.ndarray:
5357
raise UnexpectedDataAccess("Tried accessing data")
5458

5559
def __array_namespace__(self):
@@ -140,7 +144,9 @@ def __repr__(self: Any) -> str:
140144
def get_duck_array(self):
141145
raise UnexpectedDataAccess("Tried accessing data")
142146

143-
def __array__(self, dtype: np.typing.DTypeLike = None):
147+
def __array__(
148+
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
149+
) -> np.ndarray:
144150
raise UnexpectedDataAccess("Tried accessing data")
145151

146152
def __getitem__(self, key) -> "ConcatenatableArray":

xarray/tests/test_assertions.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,11 @@ def dims(self):
173173
warnings.warn("warning in test", stacklevel=2)
174174
return super().dims
175175

176-
def __array__(self, dtype=None, copy=None):
176+
def __array__(
177+
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
178+
) -> np.ndarray:
177179
warnings.warn("warning in test", stacklevel=2)
178-
return super().__array__()
180+
return super().__array__(dtype, copy=copy)
179181

180182
a = WarningVariable("x", [1])
181183
b = WarningVariable("x", [2])

xarray/tests/test_formatting.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -942,7 +942,9 @@ def test_lazy_array_wont_compute() -> None:
942942
from xarray.core.indexing import LazilyIndexedArray
943943

944944
class LazilyIndexedArrayNotComputable(LazilyIndexedArray):
945-
def __array__(self, dtype=None, copy=None):
945+
def __array__(
946+
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
947+
) -> np.ndarray:
946948
raise NotImplementedError("Computing this array is not possible.")
947949

948950
arr = LazilyIndexedArrayNotComputable(np.array([1, 2]))

xarray/tests/test_namedarray.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import numpy as np
1010
import pytest
11+
from packaging.version import Version
1112

1213
from xarray.core.indexing import ExplicitlyIndexed
1314
from xarray.namedarray._typing import (
@@ -53,8 +54,14 @@ def shape(self) -> _Shape:
5354
class CustomArray(
5455
CustomArrayBase[_ShapeType_co, _DType_co], Generic[_ShapeType_co, _DType_co]
5556
):
56-
def __array__(self) -> np.ndarray[Any, np.dtype[np.generic]]:
57-
return np.array(self.array)
57+
def __array__(
58+
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
59+
) -> np.ndarray[Any, np.dtype[np.generic]]:
60+
61+
if Version(np.__version__) >= Version("2.0.0"):
62+
return np.asarray(self.array, dtype=dtype, copy=copy)
63+
else:
64+
return np.asarray(self.array, dtype=dtype)
5865

5966

6067
class CustomArrayIndexable(

0 commit comments

Comments
 (0)