Skip to content

Add typing to plot methods #7052

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 59 commits into from
Oct 16, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
f22ffc7
add plot methods statically and add typing to plot tests
headtr1ck Sep 18, 2022
bef6b7a
whats-new update
headtr1ck Sep 18, 2022
fd12592
fix copy-paste typo
headtr1ck Sep 18, 2022
85b59fc
correct plot signatures
headtr1ck Sep 19, 2022
5a0b8dc
add *some* typing to plot methods
headtr1ck Sep 20, 2022
7edde8f
annotate darray in plot tests
headtr1ck Sep 20, 2022
5f5fffc
correct typing of plot returns
headtr1ck Sep 22, 2022
89d033b
Merge branch 'main' into plotaccessor
headtr1ck Sep 22, 2022
60c6a70
fix plotting overloads
headtr1ck Sep 23, 2022
481565c
add correct overloads to dataset_plot
headtr1ck Sep 23, 2022
ad5e363
Merge branch 'main' into plotaccessor
headtr1ck Sep 23, 2022
27fa07f
update whats-new
headtr1ck Sep 23, 2022
47ef0ce
rename xr.plot.plot module since it shadows the xr.plot.plot method
headtr1ck Sep 25, 2022
45b154a
move accessor to its own module
headtr1ck Sep 25, 2022
e158bf3
move DSPlotAccessor to accessor module
headtr1ck Sep 25, 2022
cadc6de
fix DSPlotAccessor import
headtr1ck Sep 25, 2022
44a0317
add explanation to import statement
headtr1ck Sep 25, 2022
f68a5da
add breaking change to whats-new
headtr1ck Sep 25, 2022
39cb308
Merge branch 'main' into plotaccessor
headtr1ck Sep 25, 2022
5f38366
remove unused `rtol` argument from plot
headtr1ck Sep 25, 2022
e4f792b
make most arguments of plotmethods kwargs only
headtr1ck Sep 25, 2022
226f0e7
fix wrong return types
headtr1ck Sep 25, 2022
84a9ae9
add breaking kwarg change to whats-new
headtr1ck Sep 25, 2022
c1979f5
Merge branch 'main' into plotaccessor
headtr1ck Oct 3, 2022
0758b3b
support for aspect='auto' or 'equal
headtr1ck Oct 3, 2022
f9ce21b
typing support for Dataset FacetGrid
headtr1ck Oct 3, 2022
70c4771
deprecate positional arguments for all plot methods
headtr1ck Oct 3, 2022
f1df41c
add deprecation to whats-new
headtr1ck Oct 3, 2022
bf0ffd1
add FacetGrid generic type
headtr1ck Oct 4, 2022
c61ef0a
fix mypy 0.981 complaints
headtr1ck Oct 4, 2022
a4c7795
fix index errors in plots
headtr1ck Oct 4, 2022
8a557dd
Merge branch 'main' into plotaccessor
headtr1ck Oct 9, 2022
870d5a5
add overloads to scatter
headtr1ck Oct 9, 2022
9d0c859
deprecate scatter args
headtr1ck Oct 10, 2022
1b7e7db
add scatter to accessors and fix docstrings
headtr1ck Oct 10, 2022
ebec845
undo some breaking changes
headtr1ck Oct 11, 2022
90aded4
fix the docstrings and some typing
headtr1ck Oct 11, 2022
3da72e5
fix typing of scatter accessor funcs
headtr1ck Oct 11, 2022
8342f6c
align docstrings with signature and complete typing
headtr1ck Oct 11, 2022
9145f11
add remaining typing
headtr1ck Oct 11, 2022
2f01a17
align more docstrings
headtr1ck Oct 11, 2022
6be4352
re add ValueError for scatter plots with u, v
headtr1ck Oct 11, 2022
9d6a804
fix whats-new conflict
headtr1ck Oct 11, 2022
9686ce6
Merge branch 'main' into plotaccessor
headtr1ck Oct 12, 2022
f61c3d7
fix some typing errors
headtr1ck Oct 12, 2022
1bf0165
more typing fixes
headtr1ck Oct 12, 2022
a62a9a6
fix last mypy complaints
headtr1ck Oct 12, 2022
43b4e7e
try fixing facetgrid examples
headtr1ck Oct 12, 2022
48c9248
fix py3.8 problems
headtr1ck Oct 13, 2022
d101b8b
update plotting.rst
headtr1ck Oct 13, 2022
534c09a
update api
headtr1ck Oct 13, 2022
a0c6b14
update plot docstring
headtr1ck Oct 13, 2022
75f1425
add a tip about yincrease in imshow
headtr1ck Oct 13, 2022
0c25767
set default for x/yincrease in docstring
headtr1ck Oct 13, 2022
a514530
simplify typing
headtr1ck Oct 14, 2022
92462d9
add deprecation date as comment
headtr1ck Oct 14, 2022
8761264
Merge branch 'main' into plotaccessor
headtr1ck Oct 14, 2022
381f00f
update whats-new to new release
headtr1ck Oct 14, 2022
f621ef2
fix whats-new
headtr1ck Oct 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ Bug fixes
By `Deepak Cherian <https://github.com/dcherian>`_.
- ``Dataset.encoding['source']`` now exists when reading from a Path object (:issue:`5888`, :pull:`6974`)
By `Thomas Coleman <https://github.com/ColemanTom>`_.
- Add initial static typing to plot accessors (:issue:`6949`, :pull:`7052`).
By `Michael Niklas <https://github.com/headtr1ck>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ ignore_missing_imports = True
ignore_missing_imports = True
[mypy-matplotlib.*]
ignore_missing_imports = True
[mypy-mpl_toolkits.*]
ignore_missing_imports = True
[mypy-Nio.*]
ignore_missing_imports = True
[mypy-nc_time_axis.*]
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from ..coding.calendar_ops import convert_calendar, interp_calendar
from ..coding.cftimeindex import CFTimeIndex
from ..plot.plot import _PlotMethods
from ..plot.plot import DataArrayPlotAccessor
from ..plot.utils import _get_units_from_attrs
from . import alignment, computation, dtypes, indexing, ops, utils
from ._reductions import DataArrayReductions
Expand Down Expand Up @@ -3575,7 +3575,7 @@ def _inplace_binary_op(self: T_DataArray, other: Any, f: Callable) -> T_DataArra
def _copy_attrs_from(self, other: DataArray | Dataset | Variable) -> None:
self.attrs = other.attrs

plot = utils.UncachedAccessor(_PlotMethods)
plot = utils.UncachedAccessor(DataArrayPlotAccessor)

def _title_for_slice(self, truncate: int = 50) -> str:
"""
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

from ..coding.calendar_ops import convert_calendar, interp_calendar
from ..coding.cftimeindex import CFTimeIndex, _parse_array_of_cftime_strings
from ..plot.dataset_plot import _Dataset_PlotMethods
from ..plot.dataset_plot import DatasetPlotAccessor
from . import alignment
from . import dtypes as xrdtypes
from . import duck_array_ops, formatting, formatting_html, ops, utils
Expand Down Expand Up @@ -7198,7 +7198,7 @@ def real(self: T_Dataset) -> T_Dataset:
def imag(self: T_Dataset) -> T_Dataset:
return self.map(lambda x: x.imag, keep_attrs=True)

plot = utils.UncachedAccessor(_Dataset_PlotMethods)
plot = utils.UncachedAccessor(DatasetPlotAccessor)

def filter_by_attrs(self: T_Dataset, **kwargs) -> T_Dataset:
"""Returns a ``Dataset`` with variables that match specific conditions.
Expand Down
2 changes: 2 additions & 0 deletions xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@
CoarsenBoundaryOptions = Literal["exact", "trim", "pad"]
SideOptions = Literal["left", "right"]

MPLScaleOptions = Literal["linear", "symlog", "log", "logit", None]

# TODO: Wait until mypy supports recursive objects in combination with typevars
_T = TypeVar("_T")
NestedSequence = Union[
Expand Down
103 changes: 38 additions & 65 deletions xarray/plot/dataset_plot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import functools
from typing import TYPE_CHECKING, NoReturn

import numpy as np
import pandas as pd
Expand All @@ -16,6 +17,9 @@
get_axis,
)

if TYPE_CHECKING:
from ..core.dataset import Dataset


def _infer_scatter_data(ds, x, y, hue, markersize, size_norm, size_mapping=None):

Expand Down Expand Up @@ -48,22 +52,6 @@ def _infer_scatter_data(ds, x, y, hue, markersize, size_norm, size_mapping=None)
return data


class _Dataset_PlotMethods:
"""
Enables use of xarray.plot functions as attributes on a Dataset.
For example, Dataset.plot.scatter
"""

def __init__(self, dataset):
self._ds = dataset

def __call__(self, *args, **kwargs):
raise ValueError(
"Dataset.plot cannot be called directly. Use "
"an explicit plot method, e.g. ds.plot.scatter(...)"
)


def _dsplot(plotfunc):
commondoc = """
Parameters
Expand Down Expand Up @@ -299,55 +287,10 @@ def newplotfunc(

return primitive

@functools.wraps(newplotfunc)
def plotmethod(
_PlotMethods_obj,
x=None,
y=None,
u=None,
v=None,
hue=None,
hue_style=None,
col=None,
row=None,
ax=None,
figsize=None,
col_wrap=None,
sharex=True,
sharey=True,
aspect=None,
size=None,
subplot_kws=None,
add_guide=None,
cbar_kwargs=None,
cbar_ax=None,
vmin=None,
vmax=None,
norm=None,
infer_intervals=None,
center=None,
levels=None,
robust=None,
colors=None,
extend=None,
cmap=None,
**kwargs,
):
"""
The method should have the same signature as the function.

This just makes the method work on Plotmethods objects,
and passes all the other arguments straight through.
"""
allargs = locals()
allargs["ds"] = _PlotMethods_obj._ds
allargs.update(kwargs)
for arg in ["_PlotMethods_obj", "newplotfunc", "kwargs"]:
del allargs[arg]
return newplotfunc(**allargs)

# Add to class _PlotMethods
setattr(_Dataset_PlotMethods, plotmethod.__name__, plotmethod)
# we want to actually expose the signature of newplotfunc
# and not the copied **kwargs from the plotfunc which
# functools.wraps adds, so delete the wrapped attr
del newplotfunc.__wrapped__

return newplotfunc

Expand Down Expand Up @@ -497,3 +440,33 @@ def streamplot(ds, x, y, ax, u, v, **kwargs):

# Return .lines so colorbar creation works properly
return hdl.lines


class DatasetPlotAccessor:
"""
Enables use of xarray.plot functions as attributes on a Dataset.
For example, Dataset.plot.scatter
"""

_ds: Dataset

def __init__(self, dataset: Dataset) -> None:
self._ds = dataset

def __call__(self, *args, **kwargs) -> NoReturn:
raise ValueError(
"Dataset.plot cannot be called directly. Use "
"an explicit plot method, e.g. ds.plot.scatter(...)"
)

@functools.wraps(scatter)
def scatter(self, *args, **kwargs):
return scatter(self._ds, *args, **kwargs)

@functools.wraps(quiver)
def quiver(self, *args, **kwargs):
return quiver(self._ds, *args, **kwargs)

@functools.wraps(streamplot)
def streamplot(self, *args, **kwargs):
return streamplot(self._ds, *args, **kwargs)
Loading