Skip to content

Commit 968443c

Browse files
Illviljanmathause
andauthored
Add helper for setting axis limits in facetgrid (#7046)
* Add helper for setting axis limits in facetgrid * z argument isn't available yet * Update facetgrid.py * Update xarray/plot/facetgrid.py Co-authored-by: Mathias Hauser <[email protected]> * Update facetgrid.py * use float only Co-authored-by: Mathias Hauser <[email protected]>
1 parent bda0a2f commit 968443c

File tree

1 file changed

+69
-1
lines changed

1 file changed

+69
-1
lines changed

xarray/plot/facetgrid.py

+69-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import functools
44
import itertools
55
import warnings
6-
from typing import Iterable
6+
from typing import Callable, Iterable
77

88
import numpy as np
99

@@ -471,6 +471,74 @@ def add_quiverkey(self, u, v, **kwargs):
471471
# self._adjust_fig_for_guide(self.quiverkey.text)
472472
return self
473473

474+
def _get_largest_lims(self) -> dict[str, tuple[float, float]]:
475+
"""
476+
Get largest limits in the facetgrid.
477+
478+
Returns
479+
-------
480+
lims_largest : dict[str, tuple[int | float, int | float]]
481+
Dictionary with the largest limits along each axis.
482+
483+
Examples
484+
--------
485+
>>> ds = xr.tutorial.scatter_example_dataset(seed=42)
486+
>>> fg = ds.plot.scatter("A", "B", hue="y", row="x", col="w")
487+
>>> round(fg._get_largest_lims()["x"][0], 3)
488+
-0.334
489+
"""
490+
lims_largest: dict[str, tuple[float, float]] = dict(
491+
x=(np.inf, -np.inf), y=(np.inf, -np.inf), z=(np.inf, -np.inf)
492+
)
493+
for k in ("x", "y", "z"):
494+
# Find the plot with the largest xlim values:
495+
l0, l1 = lims_largest[k]
496+
for ax in self.axes.flat:
497+
get_lim: None | Callable[[], tuple[float, float]] = getattr(
498+
ax, f"get_{k}lim", None
499+
)
500+
if get_lim:
501+
l0_new, l1_new = get_lim()
502+
l0, l1 = (min(l0, l0_new), max(l1, l1_new))
503+
lims_largest[k] = (l0, l1)
504+
505+
return lims_largest
506+
507+
def _set_lims(
508+
self,
509+
x: None | tuple[float, float] = None,
510+
y: None | tuple[float, float] = None,
511+
z: None | tuple[float, float] = None,
512+
) -> None:
513+
"""
514+
Set the same limits for all the subplots in the facetgrid.
515+
516+
Parameters
517+
----------
518+
x : None | tuple[int | float, int | float]
519+
x axis limits.
520+
y : None | tuple[int | float, int | float]
521+
y axis limits.
522+
z : None | tuple[int | float, int | float]
523+
z axis limits.
524+
525+
Examples
526+
--------
527+
>>> ds = xr.tutorial.scatter_example_dataset(seed=42)
528+
>>> fg = ds.plot.scatter("A", "B", hue="y", row="x", col="w")
529+
>>> fg._set_lims(x=(-0.3, 0.3), y=(0, 2), z=(0, 4))
530+
>>> fg.axes[0, 0].get_xlim(), fg.axes[0, 0].get_ylim()
531+
((-0.3, 0.3), (0.0, 2.0))
532+
"""
533+
lims_largest = self._get_largest_lims()
534+
535+
# Set limits:
536+
for ax in self.axes.flat:
537+
for (k, v), vv in zip(lims_largest.items(), (x, y, z)):
538+
set_lim = getattr(ax, f"set_{k}lim", None)
539+
if set_lim:
540+
set_lim(v if vv is None else vv)
541+
474542
def set_axis_labels(self, *axlabels):
475543
"""Set axis labels on the left column and bottom row of the grid."""
476544
from ..core.dataarray import DataArray

0 commit comments

Comments
 (0)