Skip to content

Commit 9ff932a

Browse files
Require to explicitly defining optional dimensions such as hue and markersize (#7277)
* Prioritize mpl kwargs when hue/size isn't defined. * Update dataarray_plot.py * rename vars for clarity * Handle int coords * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update dataarray_plot.py * Move funcs to utils and use in facetgrid, fix int coords in facetgrid * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update dataarray_plot.py * Update utils.py * Update utils.py * Update facetgrid.py * typing fixes * Only guess x-axis. * fix tests * rename function to a better name. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update whats-new.rst --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 7683442 commit 9ff932a

File tree

5 files changed

+179
-74
lines changed

5 files changed

+179
-74
lines changed

doc/whats-new.rst

+6-4
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,12 @@ Deprecations
3434

3535
Bug fixes
3636
~~~~~~~~~
37-
37+
- Require to explicitly defining optional dimensions such as hue
38+
and markersize for scatter plots. (:issue:`7314`, :pull:`7277`).
39+
By `Jimmy Westling <https://github.com/illviljan>`_.
40+
- Fix matplotlib raising a UserWarning when plotting a scatter plot
41+
with an unfilled marker (:issue:`7313`, :pull:`7318`).
42+
By `Jimmy Westling <https://github.com/illviljan>`_.
3843

3944
Documentation
4045
~~~~~~~~~~~~~
@@ -194,9 +199,6 @@ Bug fixes
194199
By `Michael Niklas <https://github.com/headtr1ck>`_.
195200
- Fix static typing of :py:meth:`xr.polyval` (:issue:`7312`, :pull:`7315`).
196201
By `Michael Niklas <https://github.com/headtr1ck>`_.
197-
- Fix matplotlib raising a UserWarning when plotting a scatter plot
198-
with an unfilled marker (:issue:`7313`, :pull:`7318`).
199-
By `Jimmy Westling <https://github.com/illviljan>`_.
200202
- Fix multiple reads on fsspec S3 files by resetting file pointer to 0 when reading file streams (:issue:`6813`, :pull:`7304`).
201203
By `David Hoese <https://github.com/djhoese>`_ and `Wei Ji Leong <https://github.com/weiji14>`_.
202204
- Fix :py:meth:`Dataset.assign_coords` resetting all dimension coordinates to default (pandas) index (:issue:`7346`, :pull:`7347`).

xarray/plot/dataarray_plot.py

+51-46
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
_assert_valid_xy,
2020
_determine_guide,
2121
_ensure_plottable,
22+
_guess_coords_to_plot,
2223
_infer_interval_breaks,
2324
_infer_xy_labels,
2425
_Normalize,
@@ -142,48 +143,45 @@ def _infer_line_data(
142143
return xplt, yplt, hueplt, huelabel
143144

144145

145-
def _infer_plot_dims(
146-
darray: DataArray,
147-
dims_plot: MutableMapping[str, Hashable],
148-
default_guess: Iterable[str] = ("x", "hue", "size"),
149-
) -> MutableMapping[str, Hashable]:
146+
def _prepare_plot1d_data(
147+
darray: T_DataArray,
148+
coords_to_plot: MutableMapping[str, Hashable],
149+
plotfunc_name: str | None = None,
150+
_is_facetgrid: bool = False,
151+
) -> dict[str, T_DataArray]:
150152
"""
151-
Guess what dims to plot if some of the values in dims_plot are None which
152-
happens when the user has not defined all available ways of visualizing
153-
the data.
153+
Prepare data for usage with plt.scatter.
154154
155155
Parameters
156156
----------
157-
darray : DataArray
158-
The DataArray to check.
159-
dims_plot : T_DimsPlot
160-
Dims defined by the user to plot.
161-
default_guess : Iterable[str], optional
162-
Default values and order to retrieve dims if values in dims_plot is
163-
missing, default: ("x", "hue", "size").
164-
"""
165-
dims_plot_exist = {k: v for k, v in dims_plot.items() if v is not None}
166-
dims_avail = tuple(v for v in darray.dims if v not in dims_plot_exist.values())
167-
168-
# If dims_plot[k] isn't defined then fill with one of the available dims:
169-
for k, v in zip(default_guess, dims_avail):
170-
if dims_plot.get(k, None) is None:
171-
dims_plot[k] = v
172-
173-
for k, v in dims_plot.items():
174-
_assert_valid_xy(darray, v, k)
175-
176-
return dims_plot
177-
157+
darray : T_DataArray
158+
Base DataArray.
159+
coords_to_plot : MutableMapping[str, Hashable]
160+
Coords that will be plotted.
161+
plotfunc_name : str | None
162+
Name of the plotting function that will be used.
178163
179-
def _infer_line_data2(
180-
darray: T_DataArray,
181-
dims_plot: MutableMapping[str, Hashable],
182-
plotfunc_name: None | str = None,
183-
) -> dict[str, T_DataArray]:
184-
# Guess what dims to use if some of the values in plot_dims are None:
185-
dims_plot = _infer_plot_dims(darray, dims_plot)
164+
Returns
165+
-------
166+
plts : dict[str, T_DataArray]
167+
Dict of DataArrays that will be sent to matplotlib.
186168
169+
Examples
170+
--------
171+
>>> # Make sure int coords are plotted:
172+
>>> a = xr.DataArray(
173+
... data=[1, 2],
174+
... coords={1: ("x", [0, 1], {"units": "s"})},
175+
... dims=("x",),
176+
... name="a",
177+
... )
178+
>>> plts = xr.plot.dataarray_plot._prepare_plot1d_data(
179+
... a, coords_to_plot={"x": 1, "z": None, "hue": None, "size": None}
180+
... )
181+
>>> # Check which coords to plot:
182+
>>> print({k: v.name for k, v in plts.items()})
183+
{'y': 'a', 'x': 1}
184+
"""
187185
# If there are more than 1 dimension in the array than stack all the
188186
# dimensions so the plotter can plot anything:
189187
if darray.ndim > 1:
@@ -193,11 +191,11 @@ def _infer_line_data2(
193191
dims_T = []
194192
if np.issubdtype(darray.dtype, np.floating):
195193
for v in ["z", "x"]:
196-
dim = dims_plot.get(v, None)
194+
dim = coords_to_plot.get(v, None)
197195
if (dim is not None) and (dim in darray.dims):
198196
darray_nan = np.nan * darray.isel({dim: -1})
199197
darray = concat([darray, darray_nan], dim=dim)
200-
dims_T.append(dims_plot[v])
198+
dims_T.append(coords_to_plot[v])
201199

202200
# Lines should never connect to the same coordinate when stacked,
203201
# transpose to avoid this as much as possible:
@@ -207,11 +205,13 @@ def _infer_line_data2(
207205
darray = darray.stack(_stacked_dim=darray.dims)
208206

209207
# Broadcast together all the chosen variables:
210-
out = dict(y=darray)
211-
out.update({k: darray[v] for k, v in dims_plot.items() if v is not None})
212-
out = dict(zip(out.keys(), broadcast(*(out.values()))))
208+
plts = dict(y=darray)
209+
plts.update(
210+
{k: darray.coords[v] for k, v in coords_to_plot.items() if v is not None}
211+
)
212+
plts = dict(zip(plts.keys(), broadcast(*(plts.values()))))
213213

214-
return out
214+
return plts
215215

216216

217217
# return type is Any due to the many different possibilities
@@ -938,15 +938,20 @@ def newplotfunc(
938938
_is_facetgrid = kwargs.pop("_is_facetgrid", False)
939939

940940
if plotfunc.__name__ == "scatter":
941-
size_ = markersize
941+
size_ = kwargs.pop("_size", markersize)
942942
size_r = _MARKERSIZE_RANGE
943943
else:
944-
size_ = linewidth
944+
size_ = kwargs.pop("_size", linewidth)
945945
size_r = _LINEWIDTH_RANGE
946946

947947
# Get data to plot:
948-
dims_plot = dict(x=x, z=z, hue=hue, size=size_)
949-
plts = _infer_line_data2(darray, dims_plot, plotfunc.__name__)
948+
coords_to_plot: MutableMapping[str, Hashable | None] = dict(
949+
x=x, z=z, hue=hue, size=size_
950+
)
951+
if not _is_facetgrid:
952+
# Guess what coords to use if some of the values in coords_to_plot are None:
953+
coords_to_plot = _guess_coords_to_plot(darray, coords_to_plot, kwargs)
954+
plts = _prepare_plot1d_data(darray, coords_to_plot, plotfunc.__name__)
950955
xplt = plts.pop("x", None)
951956
yplt = plts.pop("y", None)
952957
zplt = plts.pop("z", None)

xarray/plot/facetgrid.py

+29-14
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import functools
44
import itertools
55
import warnings
6-
from collections.abc import Hashable, Iterable
6+
from collections.abc import Hashable, Iterable, MutableMapping
77
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar, cast
88

99
import numpy as np
@@ -16,6 +16,7 @@
1616
_add_legend,
1717
_determine_guide,
1818
_get_nice_quiver_magnitude,
19+
_guess_coords_to_plot,
1920
_infer_xy_labels,
2021
_Normalize,
2122
_parse_size,
@@ -383,6 +384,11 @@ def map_plot1d(
383384
func: Callable,
384385
x: Hashable | None,
385386
y: Hashable | None,
387+
*,
388+
z: Hashable | None = None,
389+
hue: Hashable | None = None,
390+
markersize: Hashable | None = None,
391+
linewidth: Hashable | None = None,
386392
**kwargs: Any,
387393
) -> T_FacetGrid:
388394
"""
@@ -415,13 +421,25 @@ def map_plot1d(
415421
if kwargs.get("cbar_ax", None) is not None:
416422
raise ValueError("cbar_ax not supported by FacetGrid.")
417423

424+
if func.__name__ == "scatter":
425+
size_ = kwargs.pop("_size", markersize)
426+
size_r = _MARKERSIZE_RANGE
427+
else:
428+
size_ = kwargs.pop("_size", linewidth)
429+
size_r = _LINEWIDTH_RANGE
430+
431+
# Guess what coords to use if some of the values in coords_to_plot are None:
432+
coords_to_plot: MutableMapping[str, Hashable | None] = dict(
433+
x=x, z=z, hue=hue, size=size_
434+
)
435+
coords_to_plot = _guess_coords_to_plot(self.data, coords_to_plot, kwargs)
436+
418437
# Handle hues:
419-
hue = kwargs.get("hue", None)
420-
hueplt = self.data[hue] if hue else self.data
438+
hue = coords_to_plot["hue"]
439+
hueplt = self.data.coords[hue] if hue else None # TODO: _infer_line_data2 ?
421440
hueplt_norm = _Normalize(hueplt)
422441
self._hue_var = hueplt
423442
cbar_kwargs = kwargs.pop("cbar_kwargs", {})
424-
425443
if hueplt_norm.data is not None:
426444
if not hueplt_norm.data_is_numeric:
427445
# TODO: Ticks seems a little too hardcoded, since it will always
@@ -441,16 +459,11 @@ def map_plot1d(
441459
cmap_params = {}
442460

443461
# Handle sizes:
444-
_size_r = _MARKERSIZE_RANGE if func.__name__ == "scatter" else _LINEWIDTH_RANGE
445-
for _size in ("markersize", "linewidth"):
446-
size = kwargs.get(_size, None)
447-
448-
sizeplt = self.data[size] if size else None
449-
sizeplt_norm = _Normalize(data=sizeplt, width=_size_r)
450-
if size:
451-
self.data[size] = sizeplt_norm.values
452-
kwargs.update(**{_size: size})
453-
break
462+
size_ = coords_to_plot["size"]
463+
sizeplt = self.data.coords[size_] if size_ else None
464+
sizeplt_norm = _Normalize(data=sizeplt, width=size_r)
465+
if sizeplt_norm.data is not None:
466+
self.data[size_] = sizeplt_norm.values
454467

455468
# Add kwargs that are sent to the plotting function, # order is important ???
456469
func_kwargs = {
@@ -504,6 +517,8 @@ def map_plot1d(
504517
x=x,
505518
y=y,
506519
ax=ax,
520+
hue=hue,
521+
_size=size_,
507522
**func_kwargs,
508523
_is_facetgrid=True,
509524
)

xarray/plot/utils.py

+90-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import itertools
44
import textwrap
55
import warnings
6-
from collections.abc import Hashable, Iterable, Mapping, Sequence
6+
from collections.abc import Hashable, Iterable, Mapping, MutableMapping, Sequence
77
from datetime import datetime
88
from inspect import getfullargspec
99
from typing import TYPE_CHECKING, Any, Callable, overload
@@ -1735,3 +1735,92 @@ def _add_legend(
17351735
_adjust_legend_subtitles(legend)
17361736

17371737
return legend
1738+
1739+
1740+
def _guess_coords_to_plot(
1741+
darray: DataArray,
1742+
coords_to_plot: MutableMapping[str, Hashable | None],
1743+
kwargs: dict,
1744+
default_guess: tuple[str, ...] = ("x",),
1745+
# TODO: Can this be normalized, plt.cbook.normalize_kwargs?
1746+
ignore_guess_kwargs: tuple[tuple[str, ...], ...] = ((),),
1747+
) -> MutableMapping[str, Hashable]:
1748+
"""
1749+
Guess what coords to plot if some of the values in coords_to_plot are None which
1750+
happens when the user has not defined all available ways of visualizing
1751+
the data.
1752+
1753+
Parameters
1754+
----------
1755+
darray : DataArray
1756+
The DataArray to check for available coords.
1757+
coords_to_plot : MutableMapping[str, Hashable]
1758+
Coords defined by the user to plot.
1759+
kwargs : dict
1760+
Extra kwargs that will be sent to matplotlib.
1761+
default_guess : Iterable[str], optional
1762+
Default values and order to retrieve dims if values in dims_plot is
1763+
missing, default: ("x", "hue", "size").
1764+
ignore_guess_kwargs : tuple[tuple[str, ...], ...]
1765+
Matplotlib arguments to ignore.
1766+
1767+
Examples
1768+
--------
1769+
>>> ds = xr.tutorial.scatter_example_dataset(seed=42)
1770+
>>> # Only guess x by default:
1771+
>>> xr.plot.utils._guess_coords_to_plot(
1772+
... ds.A,
1773+
... coords_to_plot={"x": None, "z": None, "hue": None, "size": None},
1774+
... kwargs={},
1775+
... )
1776+
{'x': 'x', 'z': None, 'hue': None, 'size': None}
1777+
1778+
>>> # Guess all plot dims with other default values:
1779+
>>> xr.plot.utils._guess_coords_to_plot(
1780+
... ds.A,
1781+
... coords_to_plot={"x": None, "z": None, "hue": None, "size": None},
1782+
... kwargs={},
1783+
... default_guess=("x", "hue", "size"),
1784+
... ignore_guess_kwargs=((), ("c", "color"), ("s",)),
1785+
... )
1786+
{'x': 'x', 'z': None, 'hue': 'y', 'size': 'z'}
1787+
1788+
>>> # Don't guess ´size´, since the matplotlib kwarg ´s´ has been defined:
1789+
>>> xr.plot.utils._guess_coords_to_plot(
1790+
... ds.A,
1791+
... coords_to_plot={"x": None, "z": None, "hue": None, "size": None},
1792+
... kwargs={"s": 5},
1793+
... default_guess=("x", "hue", "size"),
1794+
... ignore_guess_kwargs=((), ("c", "color"), ("s",)),
1795+
... )
1796+
{'x': 'x', 'z': None, 'hue': 'y', 'size': None}
1797+
1798+
>>> # Prioritize ´size´ over ´s´:
1799+
>>> xr.plot.utils._guess_coords_to_plot(
1800+
... ds.A,
1801+
... coords_to_plot={"x": None, "z": None, "hue": None, "size": "x"},
1802+
... kwargs={"s": 5},
1803+
... default_guess=("x", "hue", "size"),
1804+
... ignore_guess_kwargs=((), ("c", "color"), ("s",)),
1805+
... )
1806+
{'x': 'y', 'z': None, 'hue': 'z', 'size': 'x'}
1807+
"""
1808+
coords_to_plot_exist = {k: v for k, v in coords_to_plot.items() if v is not None}
1809+
available_coords = tuple(
1810+
k for k in darray.coords.keys() if k not in coords_to_plot_exist.values()
1811+
)
1812+
1813+
# If dims_plot[k] isn't defined then fill with one of the available dims, unless
1814+
# one of related mpl kwargs has been used. This should have similiar behaviour as
1815+
# * plt.plot(x, y) -> Multple lines with different colors if y is 2d.
1816+
# * plt.plot(x, y, color="red") -> Multiple red lines if y is 2d.
1817+
for k, dim, ign_kws in zip(default_guess, available_coords, ignore_guess_kwargs):
1818+
if coords_to_plot.get(k, None) is None and all(
1819+
kwargs.get(ign_kw, None) is None for ign_kw in ign_kws
1820+
):
1821+
coords_to_plot[k] = dim
1822+
1823+
for k, dim in coords_to_plot.items():
1824+
_assert_valid_xy(darray, dim, k)
1825+
1826+
return coords_to_plot

xarray/tests/test_plot.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -2717,23 +2717,17 @@ def test_scatter(
27172717
def test_non_numeric_legend(self) -> None:
27182718
ds2 = self.ds.copy()
27192719
ds2["hue"] = ["a", "b", "c", "d"]
2720-
pc = ds2.plot.scatter(x="A", y="B", hue="hue")
2720+
pc = ds2.plot.scatter(x="A", y="B", markersize="hue")
27212721
# should make a discrete legend
27222722
assert pc.axes.legend_ is not None
27232723

27242724
def test_legend_labels(self) -> None:
27252725
# regression test for #4126: incorrect legend labels
27262726
ds2 = self.ds.copy()
27272727
ds2["hue"] = ["a", "a", "b", "b"]
2728-
pc = ds2.plot.scatter(x="A", y="B", hue="hue")
2728+
pc = ds2.plot.scatter(x="A", y="B", markersize="hue")
27292729
actual = [t.get_text() for t in pc.axes.get_legend().texts]
2730-
expected = [
2731-
"col [colunits]",
2732-
"$\\mathdefault{0}$",
2733-
"$\\mathdefault{1}$",
2734-
"$\\mathdefault{2}$",
2735-
"$\\mathdefault{3}$",
2736-
]
2730+
expected = ["hue", "a", "b"]
27372731
assert actual == expected
27382732

27392733
def test_legend_labels_facetgrid(self) -> None:

0 commit comments

Comments
 (0)