Skip to content

Commit 92cb751

Browse files
Plots get labels from pint arrays (#5561)
* test labels come from pint units * values demotes pint arrays before returning * plot labels look for pint units first * pre-commit * added to_numpy() and as_numpy() methods * remove special-casing of cupy arrays in .values in favour of using .to_numpy() * .values -> .to_numpy() * lint * Fix mypy (I think?) * added Dataset.as_numpy() * improved docstrings * add what's new * add to API docs * linting * fix failures by only importing pint when needed * refactor pycompat into class * pycompat import changes applied to plotting code * what's new * compute instead of load * added tests * fixed sparse test * tests and fixes for ds.as_numpy() * fix sparse tests * fix linting * tests for Variable * test IndexVariable too * use numpy.asarray to avoid a copy * also convert coords * Force tests again after #5600 Co-authored-by: Maximilian Roos <[email protected]>
1 parent c5ee050 commit 92cb751

File tree

6 files changed

+63
-13
lines changed

6 files changed

+63
-13
lines changed

doc/whats-new.rst

+2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ New Features
5858
By `Elle Smith <https://github.com/ellesmith88>`_.
5959
- Added :py:meth:`DataArray.to_numpy`, :py:meth:`DataArray.as_numpy`, and :py:meth:`Dataset.as_numpy`. (:pull:`5568`).
6060
By `Tom Nicholas <https://github.com/TomNicholas>`_.
61+
- Units in plot labels are now automatically inferred from wrapped :py:meth:`pint.Quantity` arrays. (:pull:`5561`).
62+
By `Tom Nicholas <https://github.com/TomNicholas>`_.
6163

6264
Breaking changes
6365
~~~~~~~~~~~~~~~~

xarray/core/dataarray.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2784,7 +2784,7 @@ def to_masked_array(self, copy: bool = True) -> np.ma.MaskedArray:
27842784
result : MaskedArray
27852785
Masked where invalid values (nan or inf) occur.
27862786
"""
2787-
values = self.values # only compute lazy arrays once
2787+
values = self.to_numpy() # only compute lazy arrays once
27882788
isnull = pd.isnull(values)
27892789
return np.ma.MaskedArray(data=values, mask=isnull, copy=copy)
27902790

xarray/core/variable.py

+1
Original file line numberDiff line numberDiff line change
@@ -1075,6 +1075,7 @@ def to_numpy(self) -> np.ndarray:
10751075
"""Coerces wrapped data to numpy and returns a numpy.ndarray"""
10761076
# TODO an entrypoint so array libraries can choose coercion method?
10771077
data = self.data
1078+
10781079
# TODO first attempt to call .to_numpy() once some libraries implement it
10791080
if isinstance(data, dask_array_type):
10801081
data = data.compute()

xarray/plot/plot.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def line(
430430

431431
# Remove pd.Intervals if contained in xplt.values and/or yplt.values.
432432
xplt_val, yplt_val, x_suffix, y_suffix, kwargs = _resolve_intervals_1dplot(
433-
xplt.values, yplt.values, kwargs
433+
xplt.to_numpy(), yplt.to_numpy(), kwargs
434434
)
435435
xlabel = label_from_attrs(xplt, extra=x_suffix)
436436
ylabel = label_from_attrs(yplt, extra=y_suffix)
@@ -449,7 +449,7 @@ def line(
449449
ax.set_title(darray._title_for_slice())
450450

451451
if darray.ndim == 2 and add_legend:
452-
ax.legend(handles=primitive, labels=list(hueplt.values), title=hue_label)
452+
ax.legend(handles=primitive, labels=list(hueplt.to_numpy()), title=hue_label)
453453

454454
# Rotate dates on xlabels
455455
# Do this without calling autofmt_xdate so that x-axes ticks
@@ -551,7 +551,7 @@ def hist(
551551
"""
552552
ax = get_axis(figsize, size, aspect, ax)
553553

554-
no_nan = np.ravel(darray.values)
554+
no_nan = np.ravel(darray.to_numpy())
555555
no_nan = no_nan[pd.notnull(no_nan)]
556556

557557
primitive = ax.hist(no_nan, **kwargs)
@@ -1153,8 +1153,8 @@ def newplotfunc(
11531153
dims = (yval.dims[0], xval.dims[0])
11541154

11551155
# better to pass the ndarrays directly to plotting functions
1156-
xval = xval.values
1157-
yval = yval.values
1156+
xval = xval.to_numpy()
1157+
yval = yval.to_numpy()
11581158

11591159
# May need to transpose for correct x, y labels
11601160
# xlab may be the name of a coord, we have to check for dim names

xarray/plot/utils.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pandas as pd
1010

1111
from ..core.options import OPTIONS
12+
from ..core.pycompat import DuckArrayModule
1213
from ..core.utils import is_scalar
1314

1415
try:
@@ -474,12 +475,20 @@ def label_from_attrs(da, extra=""):
474475
else:
475476
name = ""
476477

477-
if da.attrs.get("units"):
478-
units = " [{}]".format(da.attrs["units"])
479-
elif da.attrs.get("unit"):
480-
units = " [{}]".format(da.attrs["unit"])
478+
def _get_units_from_attrs(da):
479+
if da.attrs.get("units"):
480+
units = " [{}]".format(da.attrs["units"])
481+
elif da.attrs.get("unit"):
482+
units = " [{}]".format(da.attrs["unit"])
483+
else:
484+
units = ""
485+
return units
486+
487+
pint_array_type = DuckArrayModule("pint").type
488+
if isinstance(da.data, pint_array_type):
489+
units = " [{}]".format(str(da.data.units))
481490
else:
482-
units = ""
491+
units = _get_units_from_attrs(da)
483492

484493
return "\n".join(textwrap.wrap(name + extra + units, 30))
485494

@@ -896,7 +905,7 @@ def _get_nice_quiver_magnitude(u, v):
896905
import matplotlib as mpl
897906

898907
ticker = mpl.ticker.MaxNLocator(3)
899-
mean = np.mean(np.hypot(u.values, v.values))
908+
mean = np.mean(np.hypot(u.to_numpy(), v.to_numpy()))
900909
magnitude = ticker.tick_values(0, mean)[-2]
901910
return magnitude
902911

xarray/tests/test_units.py

+39-1
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,22 @@
55
import pandas as pd
66
import pytest
77

8+
try:
9+
import matplotlib.pyplot as plt
10+
except ImportError:
11+
pass
12+
813
import xarray as xr
914
from xarray.core import dtypes, duck_array_ops
1015

11-
from . import assert_allclose, assert_duckarray_allclose, assert_equal, assert_identical
16+
from . import (
17+
assert_allclose,
18+
assert_duckarray_allclose,
19+
assert_equal,
20+
assert_identical,
21+
requires_matplotlib,
22+
)
23+
from .test_plot import PlotTestCase
1224
from .test_variable import _PAD_XR_NP_ARGS
1325

1426
pint = pytest.importorskip("pint")
@@ -5564,3 +5576,29 @@ def test_merge(self, variant, unit, error, dtype):
55645576

55655577
assert_units_equal(expected, actual)
55665578
assert_equal(expected, actual)
5579+
5580+
5581+
@requires_matplotlib
5582+
class TestPlots(PlotTestCase):
5583+
def test_units_in_line_plot_labels(self):
5584+
arr = np.linspace(1, 10, 3) * unit_registry.Pa
5585+
# TODO make coord a Quantity once unit-aware indexes supported
5586+
x_coord = xr.DataArray(
5587+
np.linspace(1, 3, 3), dims="x", attrs={"units": "meters"}
5588+
)
5589+
da = xr.DataArray(data=arr, dims="x", coords={"x": x_coord}, name="pressure")
5590+
5591+
da.plot.line()
5592+
5593+
ax = plt.gca()
5594+
assert ax.get_ylabel() == "pressure [pascal]"
5595+
assert ax.get_xlabel() == "x [meters]"
5596+
5597+
def test_units_in_2d_plot_labels(self):
5598+
arr = np.ones((2, 3)) * unit_registry.Pa
5599+
da = xr.DataArray(data=arr, dims=["x", "y"], name="pressure")
5600+
5601+
fig, (ax, cax) = plt.subplots(1, 2)
5602+
ax = da.plot.contourf(ax=ax, cbar_ax=cax, add_colorbar=True)
5603+
5604+
assert cax.get_ylabel() == "pressure [pascal]"

0 commit comments

Comments
 (0)