Skip to content

Commit 5f57f5c

Browse files
committed
Merge branch 'main' into Illviljan-dataset_line_plot
2 parents fe5bece + 92cb751 commit 5f57f5c

15 files changed

+428
-49
lines changed

ci/requirements/environment-windows.yml

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ dependencies:
1010
- cftime
1111
- dask
1212
- distributed
13+
- fsspec!=2021.7.0
1314
- h5netcdf
1415
- h5py
1516
- hdf5

ci/requirements/environment.yml

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ dependencies:
1212
- cftime
1313
- dask
1414
- distributed
15+
- fsspec!=2021.7.0
1516
- h5netcdf
1617
- h5py
1718
- hdf5

doc/api.rst

+3
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,7 @@ Dataset methods
686686
open_zarr
687687
Dataset.to_netcdf
688688
Dataset.to_pandas
689+
Dataset.as_numpy
689690
Dataset.to_zarr
690691
save_mfdataset
691692
Dataset.to_array
@@ -716,6 +717,8 @@ DataArray methods
716717
DataArray.to_pandas
717718
DataArray.to_series
718719
DataArray.to_dataframe
720+
DataArray.to_numpy
721+
DataArray.as_numpy
719722
DataArray.to_index
720723
DataArray.to_masked_array
721724
DataArray.to_cdms2

doc/whats-new.rst

+4
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ New Features
5656
- Allow removal of the coordinate attribute ``coordinates`` on variables by setting ``.attrs['coordinates']= None``
5757
(:issue:`5510`).
5858
By `Elle Smith <https://github.com/ellesmith88>`_.
59+
- Added :py:meth:`DataArray.to_numpy`, :py:meth:`DataArray.as_numpy`, and :py:meth:`Dataset.as_numpy`. (:pull:`5568`).
60+
By `Tom Nicholas <https://github.com/TomNicholas>`_.
61+
- Units in plot labels are now automatically inferred from wrapped :py:meth:`pint.Quantity` arrays. (:pull:`5561`).
62+
By `Tom Nicholas <https://github.com/TomNicholas>`_.
5963

6064
Breaking changes
6165
~~~~~~~~~~~~~~~~

xarray/core/dataarray.py

+48-6
Original file line numberDiff line numberDiff line change
@@ -426,12 +426,12 @@ def __init__(
426426
self._close = None
427427

428428
def _replace(
429-
self,
429+
self: T_DataArray,
430430
variable: Variable = None,
431431
coords=None,
432432
name: Union[Hashable, None, Default] = _default,
433433
indexes=None,
434-
) -> "DataArray":
434+
) -> T_DataArray:
435435
if variable is None:
436436
variable = self.variable
437437
if coords is None:
@@ -623,7 +623,16 @@ def __len__(self) -> int:
623623

624624
@property
625625
def data(self) -> Any:
626-
"""The array's data as a dask or numpy array"""
626+
"""
627+
The DataArray's data as an array. The underlying array type
628+
(e.g. dask, sparse, pint) is preserved.
629+
630+
See Also
631+
--------
632+
DataArray.to_numpy
633+
DataArray.as_numpy
634+
DataArray.values
635+
"""
627636
return self.variable.data
628637

629638
@data.setter
@@ -632,13 +641,46 @@ def data(self, value: Any) -> None:
632641

633642
@property
634643
def values(self) -> np.ndarray:
635-
"""The array's data as a numpy.ndarray"""
644+
"""
645+
The array's data as a numpy.ndarray.
646+
647+
If the array's data is not a numpy.ndarray this will attempt to convert
648+
it naively using np.array(), which will raise an error if the array
649+
type does not support coercion like this (e.g. cupy).
650+
"""
636651
return self.variable.values
637652

638653
@values.setter
639654
def values(self, value: Any) -> None:
640655
self.variable.values = value
641656

657+
def to_numpy(self) -> np.ndarray:
658+
"""
659+
Coerces wrapped data to numpy and returns a numpy.ndarray.
660+
661+
See also
662+
--------
663+
DataArray.as_numpy : Same but returns the surrounding DataArray instead.
664+
Dataset.as_numpy
665+
DataArray.values
666+
DataArray.data
667+
"""
668+
return self.variable.to_numpy()
669+
670+
def as_numpy(self: T_DataArray) -> T_DataArray:
671+
"""
672+
Coerces wrapped data and coordinates into numpy arrays, returning a DataArray.
673+
674+
See also
675+
--------
676+
DataArray.to_numpy : Same but returns only the data as a numpy.ndarray object.
677+
Dataset.as_numpy : Converts all variables in a Dataset.
678+
DataArray.values
679+
DataArray.data
680+
"""
681+
coords = {k: v.as_numpy() for k, v in self._coords.items()}
682+
return self._replace(self.variable.as_numpy(), coords, indexes=self._indexes)
683+
642684
@property
643685
def _in_memory(self) -> bool:
644686
return self.variable._in_memory
@@ -931,7 +973,7 @@ def persist(self, **kwargs) -> "DataArray":
931973
ds = self._to_temp_dataset().persist(**kwargs)
932974
return self._from_temp_dataset(ds)
933975

934-
def copy(self, deep: bool = True, data: Any = None) -> "DataArray":
976+
def copy(self: T_DataArray, deep: bool = True, data: Any = None) -> T_DataArray:
935977
"""Returns a copy of this array.
936978
937979
If `deep=True`, a deep copy is made of the data array.
@@ -2742,7 +2784,7 @@ def to_masked_array(self, copy: bool = True) -> np.ma.MaskedArray:
27422784
result : MaskedArray
27432785
Masked where invalid values (nan or inf) occur.
27442786
"""
2745-
values = self.values # only compute lazy arrays once
2787+
values = self.to_numpy() # only compute lazy arrays once
27462788
isnull = pd.isnull(values)
27472789
return np.ma.MaskedArray(data=values, mask=isnull, copy=copy)
27482790

xarray/core/dataset.py

+12
Original file line numberDiff line numberDiff line change
@@ -1323,6 +1323,18 @@ def copy(self, deep: bool = False, data: Mapping = None) -> "Dataset":
13231323

13241324
return self._replace(variables, attrs=attrs)
13251325

1326+
def as_numpy(self: "Dataset") -> "Dataset":
1327+
"""
1328+
Coerces wrapped data and coordinates into numpy arrays, returning a Dataset.
1329+
1330+
See also
1331+
--------
1332+
DataArray.as_numpy
1333+
DataArray.to_numpy : Returns only the data as a numpy.ndarray object.
1334+
"""
1335+
numpy_variables = {k: v.as_numpy() for k, v in self.variables.items()}
1336+
return self._replace(variables=numpy_variables)
1337+
13261338
@property
13271339
def _level_coords(self) -> Dict[str, Hashable]:
13281340
"""Return a mapping of all MultiIndex levels and their corresponding

xarray/core/pycompat.py

+46-30
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,63 @@
11
from distutils.version import LooseVersion
2+
from importlib import import_module
23

34
import numpy as np
45

56
from .utils import is_duck_array
67

78
integer_types = (int, np.integer)
89

9-
try:
10-
import dask
11-
import dask.array
12-
from dask.base import is_dask_collection
1310

14-
dask_version = LooseVersion(dask.__version__)
11+
class DuckArrayModule:
12+
"""
13+
Solely for internal isinstance and version checks.
1514
16-
# solely for isinstance checks
17-
dask_array_type = (dask.array.Array,)
15+
Motivated by having to only import pint when required (as pint currently imports xarray)
16+
https://github.com/pydata/xarray/pull/5561#discussion_r664815718
17+
"""
1818

19-
def is_duck_dask_array(x):
20-
return is_duck_array(x) and is_dask_collection(x)
19+
def __init__(self, mod):
20+
try:
21+
duck_array_module = import_module(mod)
22+
duck_array_version = LooseVersion(duck_array_module.__version__)
23+
24+
if mod == "dask":
25+
duck_array_type = (import_module("dask.array").Array,)
26+
elif mod == "pint":
27+
duck_array_type = (duck_array_module.Quantity,)
28+
elif mod == "cupy":
29+
duck_array_type = (duck_array_module.ndarray,)
30+
elif mod == "sparse":
31+
duck_array_type = (duck_array_module.SparseArray,)
32+
else:
33+
raise NotImplementedError
34+
35+
except ImportError: # pragma: no cover
36+
duck_array_module = None
37+
duck_array_version = LooseVersion("0.0.0")
38+
duck_array_type = ()
2139

40+
self.module = duck_array_module
41+
self.version = duck_array_version
42+
self.type = duck_array_type
43+
self.available = duck_array_module is not None
2244

23-
except ImportError: # pragma: no cover
24-
dask_version = LooseVersion("0.0.0")
25-
dask_array_type = ()
26-
is_duck_dask_array = lambda _: False
27-
is_dask_collection = lambda _: False
2845

29-
try:
30-
# solely for isinstance checks
31-
import sparse
46+
def is_duck_dask_array(x):
47+
if DuckArrayModule("dask").available:
48+
from dask.base import is_dask_collection
49+
50+
return is_duck_array(x) and is_dask_collection(x)
51+
else:
52+
return False
53+
3254

33-
sparse_version = LooseVersion(sparse.__version__)
34-
sparse_array_type = (sparse.SparseArray,)
35-
except ImportError: # pragma: no cover
36-
sparse_version = LooseVersion("0.0.0")
37-
sparse_array_type = ()
55+
dsk = DuckArrayModule("dask")
56+
dask_version = dsk.version
57+
dask_array_type = dsk.type
3858

39-
try:
40-
# solely for isinstance checks
41-
import cupy
59+
sp = DuckArrayModule("sparse")
60+
sparse_array_type = sp.type
61+
sparse_version = sp.version
4262

43-
cupy_version = LooseVersion(cupy.__version__)
44-
cupy_array_type = (cupy.ndarray,)
45-
except ImportError: # pragma: no cover
46-
cupy_version = LooseVersion("0.0.0")
47-
cupy_array_type = ()
63+
cupy_array_type = DuckArrayModule("cupy").type

xarray/core/variable.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,12 @@
2929
from .indexing import BasicIndexer, OuterIndexer, VectorizedIndexer, as_indexable
3030
from .options import _get_keep_attrs
3131
from .pycompat import (
32+
DuckArrayModule,
3233
cupy_array_type,
3334
dask_array_type,
3435
integer_types,
3536
is_duck_dask_array,
37+
sparse_array_type,
3638
)
3739
from .utils import (
3840
NdimSizeLenMixin,
@@ -259,7 +261,7 @@ def _as_array_or_item(data):
259261
260262
TODO: remove this (replace with np.asarray) once these issues are fixed
261263
"""
262-
data = data.get() if isinstance(data, cupy_array_type) else np.asarray(data)
264+
data = np.asarray(data)
263265
if data.ndim == 0:
264266
if data.dtype.kind == "M":
265267
data = np.datetime64(data, "ns")
@@ -1069,6 +1071,30 @@ def chunk(self, chunks={}, name=None, lock=False):
10691071

10701072
return self._replace(data=data)
10711073

1074+
def to_numpy(self) -> np.ndarray:
1075+
"""Coerces wrapped data to numpy and returns a numpy.ndarray"""
1076+
# TODO an entrypoint so array libraries can choose coercion method?
1077+
data = self.data
1078+
1079+
# TODO first attempt to call .to_numpy() once some libraries implement it
1080+
if isinstance(data, dask_array_type):
1081+
data = data.compute()
1082+
if isinstance(data, cupy_array_type):
1083+
data = data.get()
1084+
# pint has to be imported dynamically as pint imports xarray
1085+
pint_array_type = DuckArrayModule("pint").type
1086+
if isinstance(data, pint_array_type):
1087+
data = data.magnitude
1088+
if isinstance(data, sparse_array_type):
1089+
data = data.todense()
1090+
data = np.asarray(data)
1091+
1092+
return data
1093+
1094+
def as_numpy(self: VariableType) -> VariableType:
1095+
"""Coerces wrapped data into a numpy array, returning a Variable."""
1096+
return self._replace(data=self.to_numpy())
1097+
10721098
def _as_sparse(self, sparse_format=_default, fill_value=dtypes.NA):
10731099
"""
10741100
use sparse-array as backend.

xarray/plot/plot.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def line(
430430

431431
# Remove pd.Intervals if contained in xplt.values and/or yplt.values.
432432
xplt_val, yplt_val, x_suffix, y_suffix, kwargs = _resolve_intervals_1dplot(
433-
xplt.values, yplt.values, kwargs
433+
xplt.to_numpy(), yplt.to_numpy(), kwargs
434434
)
435435
xlabel = label_from_attrs(xplt, extra=x_suffix)
436436
ylabel = label_from_attrs(yplt, extra=y_suffix)
@@ -449,7 +449,7 @@ def line(
449449
ax.set_title(darray._title_for_slice())
450450

451451
if darray.ndim == 2 and add_legend:
452-
ax.legend(handles=primitive, labels=list(hueplt.values), title=hue_label)
452+
ax.legend(handles=primitive, labels=list(hueplt.to_numpy()), title=hue_label)
453453

454454
# Rotate dates on xlabels
455455
# Do this without calling autofmt_xdate so that x-axes ticks
@@ -551,7 +551,7 @@ def hist(
551551
"""
552552
ax = get_axis(figsize, size, aspect, ax)
553553

554-
no_nan = np.ravel(darray.values)
554+
no_nan = np.ravel(darray.to_numpy())
555555
no_nan = no_nan[pd.notnull(no_nan)]
556556

557557
primitive = ax.hist(no_nan, **kwargs)
@@ -1153,8 +1153,8 @@ def newplotfunc(
11531153
dims = (yval.dims[0], xval.dims[0])
11541154

11551155
# better to pass the ndarrays directly to plotting functions
1156-
xval = xval.values
1157-
yval = yval.values
1156+
xval = xval.to_numpy()
1157+
yval = yval.to_numpy()
11581158

11591159
# May need to transpose for correct x, y labels
11601160
# xlab may be the name of a coord, we have to check for dim names

xarray/plot/utils.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pandas as pd
1010

1111
from ..core.options import OPTIONS
12+
from ..core.pycompat import DuckArrayModule
1213
from ..core.utils import is_scalar
1314

1415
try:
@@ -474,12 +475,20 @@ def label_from_attrs(da, extra=""):
474475
else:
475476
name = ""
476477

477-
if da.attrs.get("units"):
478-
units = " [{}]".format(da.attrs["units"])
479-
elif da.attrs.get("unit"):
480-
units = " [{}]".format(da.attrs["unit"])
478+
def _get_units_from_attrs(da):
479+
if da.attrs.get("units"):
480+
units = " [{}]".format(da.attrs["units"])
481+
elif da.attrs.get("unit"):
482+
units = " [{}]".format(da.attrs["unit"])
483+
else:
484+
units = ""
485+
return units
486+
487+
pint_array_type = DuckArrayModule("pint").type
488+
if isinstance(da.data, pint_array_type):
489+
units = " [{}]".format(str(da.data.units))
481490
else:
482-
units = ""
491+
units = _get_units_from_attrs(da)
483492

484493
return "\n".join(textwrap.wrap(name + extra + units, 30))
485494

@@ -896,7 +905,7 @@ def _get_nice_quiver_magnitude(u, v):
896905
import matplotlib as mpl
897906

898907
ticker = mpl.ticker.MaxNLocator(3)
899-
mean = np.mean(np.hypot(u.values, v.values))
908+
mean = np.mean(np.hypot(u.to_numpy(), v.to_numpy()))
900909
magnitude = ticker.tick_values(0, mean)[-2]
901910
return magnitude
902911

xarray/tests/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def LooseVersion(vstring):
8383
has_numbagg, requires_numbagg = _importorskip("numbagg")
8484
has_seaborn, requires_seaborn = _importorskip("seaborn")
8585
has_sparse, requires_sparse = _importorskip("sparse")
86+
has_cupy, requires_cupy = _importorskip("cupy")
8687
has_cartopy, requires_cartopy = _importorskip("cartopy")
8788
# Need Pint 0.15 for __dask_tokenize__ tests for Quantity wrapped Dask Arrays
8889
has_pint_0_15, requires_pint_0_15 = _importorskip("pint", minversion="0.15")

0 commit comments

Comments
 (0)