|
3 | 3 | import functools
|
4 | 4 | import itertools
|
5 | 5 | import warnings
|
6 |
| -from typing import Iterable |
| 6 | +from typing import Callable, Iterable |
7 | 7 |
|
8 | 8 | import numpy as np
|
9 | 9 |
|
@@ -471,6 +471,74 @@ def add_quiverkey(self, u, v, **kwargs):
|
471 | 471 | # self._adjust_fig_for_guide(self.quiverkey.text)
|
472 | 472 | return self
|
473 | 473 |
|
| 474 | + def _get_largest_lims(self) -> dict[str, tuple[float, float]]: |
| 475 | + """ |
| 476 | + Get largest limits in the facetgrid. |
| 477 | +
|
| 478 | + Returns |
| 479 | + ------- |
| 480 | + lims_largest : dict[str, tuple[int | float, int | float]] |
| 481 | + Dictionary with the largest limits along each axis. |
| 482 | +
|
| 483 | + Examples |
| 484 | + -------- |
| 485 | + >>> ds = xr.tutorial.scatter_example_dataset(seed=42) |
| 486 | + >>> fg = ds.plot.scatter("A", "B", hue="y", row="x", col="w") |
| 487 | + >>> round(fg._get_largest_lims()["x"][0], 3) |
| 488 | + -0.334 |
| 489 | + """ |
| 490 | + lims_largest: dict[str, tuple[float, float]] = dict( |
| 491 | + x=(np.inf, -np.inf), y=(np.inf, -np.inf), z=(np.inf, -np.inf) |
| 492 | + ) |
| 493 | + for k in ("x", "y", "z"): |
| 494 | + # Find the plot with the largest xlim values: |
| 495 | + l0, l1 = lims_largest[k] |
| 496 | + for ax in self.axes.flat: |
| 497 | + get_lim: None | Callable[[], tuple[float, float]] = getattr( |
| 498 | + ax, f"get_{k}lim", None |
| 499 | + ) |
| 500 | + if get_lim: |
| 501 | + l0_new, l1_new = get_lim() |
| 502 | + l0, l1 = (min(l0, l0_new), max(l1, l1_new)) |
| 503 | + lims_largest[k] = (l0, l1) |
| 504 | + |
| 505 | + return lims_largest |
| 506 | + |
| 507 | + def _set_lims( |
| 508 | + self, |
| 509 | + x: None | tuple[float, float] = None, |
| 510 | + y: None | tuple[float, float] = None, |
| 511 | + z: None | tuple[float, float] = None, |
| 512 | + ) -> None: |
| 513 | + """ |
| 514 | + Set the same limits for all the subplots in the facetgrid. |
| 515 | +
|
| 516 | + Parameters |
| 517 | + ---------- |
| 518 | + x : None | tuple[int | float, int | float] |
| 519 | + x axis limits. |
| 520 | + y : None | tuple[int | float, int | float] |
| 521 | + y axis limits. |
| 522 | + z : None | tuple[int | float, int | float] |
| 523 | + z axis limits. |
| 524 | +
|
| 525 | + Examples |
| 526 | + -------- |
| 527 | + >>> ds = xr.tutorial.scatter_example_dataset(seed=42) |
| 528 | + >>> fg = ds.plot.scatter("A", "B", hue="y", row="x", col="w") |
| 529 | + >>> fg._set_lims(x=(-0.3, 0.3), y=(0, 2), z=(0, 4)) |
| 530 | + >>> fg.axes[0, 0].get_xlim(), fg.axes[0, 0].get_ylim() |
| 531 | + ((-0.3, 0.3), (0.0, 2.0)) |
| 532 | + """ |
| 533 | + lims_largest = self._get_largest_lims() |
| 534 | + |
| 535 | + # Set limits: |
| 536 | + for ax in self.axes.flat: |
| 537 | + for (k, v), vv in zip(lims_largest.items(), (x, y, z)): |
| 538 | + set_lim = getattr(ax, f"set_{k}lim", None) |
| 539 | + if set_lim: |
| 540 | + set_lim(v if vv is None else vv) |
| 541 | + |
474 | 542 | def set_axis_labels(self, *axlabels):
|
475 | 543 | """Set axis labels on the left column and bottom row of the grid."""
|
476 | 544 | from ..core.dataarray import DataArray
|
|
0 commit comments