diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index aeb53126265..b74d47b5712 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -1,179 +1,19 @@ from __future__ import annotations import functools - -import numpy as np -import pandas as pd +import inspect from ..core.alignment import broadcast from .facetgrid import _easy_facetgrid +from .plot import _PlotMethods from .utils import ( _add_colorbar, _get_nice_quiver_magnitude, - _is_numeric, + _infer_meta_data, _process_cmap_cbar_kwargs, get_axis, - label_from_attrs, ) -# copied from seaborn -_MARKERSIZE_RANGE = np.array([18.0, 72.0]) - - -def _infer_meta_data(ds, x, y, hue, hue_style, add_guide, funcname): - dvars = set(ds.variables.keys()) - error_msg = " must be one of ({:s})".format(", ".join(dvars)) - - if x not in dvars: - raise ValueError("x" + error_msg) - - if y not in dvars: - raise ValueError("y" + error_msg) - - if hue is not None and hue not in dvars: - raise ValueError("hue" + error_msg) - - if hue: - hue_is_numeric = _is_numeric(ds[hue].values) - - if hue_style is None: - hue_style = "continuous" if hue_is_numeric else "discrete" - - if not hue_is_numeric and (hue_style == "continuous"): - raise ValueError( - f"Cannot create a colorbar for a non numeric coordinate: {hue}" - ) - - if add_guide is None or add_guide is True: - add_colorbar = True if hue_style == "continuous" else False - add_legend = True if hue_style == "discrete" else False - else: - add_colorbar = False - add_legend = False - else: - if add_guide is True and funcname not in ("quiver", "streamplot"): - raise ValueError("Cannot set add_guide when hue is None.") - add_legend = False - add_colorbar = False - - if (add_guide or add_guide is None) and funcname == "quiver": - add_quiverkey = True - if hue: - add_colorbar = True - if not hue_style: - hue_style = "continuous" - elif hue_style != "continuous": - raise ValueError( - "hue_style must be 'continuous' or None for .plot.quiver or " - ".plot.streamplot" - ) - else: - add_quiverkey = False - - if (add_guide or add_guide is None) and funcname == "streamplot": - if hue: - add_colorbar = True - if not hue_style: - hue_style = "continuous" - elif hue_style != "continuous": - raise ValueError( - "hue_style must be 'continuous' or None for .plot.quiver or " - ".plot.streamplot" - ) - - if hue_style is not None and hue_style not in ["discrete", "continuous"]: - raise ValueError("hue_style must be either None, 'discrete' or 'continuous'.") - - if hue: - hue_label = label_from_attrs(ds[hue]) - hue = ds[hue] - else: - hue_label = None - hue = None - - return { - "add_colorbar": add_colorbar, - "add_legend": add_legend, - "add_quiverkey": add_quiverkey, - "hue_label": hue_label, - "hue_style": hue_style, - "xlabel": label_from_attrs(ds[x]), - "ylabel": label_from_attrs(ds[y]), - "hue": hue, - } - - -def _infer_scatter_data(ds, x, y, hue, markersize, size_norm, size_mapping=None): - - broadcast_keys = ["x", "y"] - to_broadcast = [ds[x], ds[y]] - if hue: - to_broadcast.append(ds[hue]) - broadcast_keys.append("hue") - if markersize: - to_broadcast.append(ds[markersize]) - broadcast_keys.append("size") - - broadcasted = dict(zip(broadcast_keys, broadcast(*to_broadcast))) - - data = {"x": broadcasted["x"], "y": broadcasted["y"], "hue": None, "sizes": None} - - if hue: - data["hue"] = broadcasted["hue"] - - if markersize: - size = broadcasted["size"] - - if size_mapping is None: - size_mapping = _parse_size(size, size_norm) - - data["sizes"] = size.copy( - data=np.reshape(size_mapping.loc[size.values.ravel()].values, size.shape) - ) - - return data - - -# copied from seaborn -def _parse_size(data, norm): - - import matplotlib as mpl - - if data is None: - return None - - data = data.values.flatten() - - if not _is_numeric(data): - levels = np.unique(data) - numbers = np.arange(1, 1 + len(levels))[::-1] - else: - levels = numbers = np.sort(np.unique(data)) - - min_width, max_width = _MARKERSIZE_RANGE - # width_range = min_width, max_width - - if norm is None: - norm = mpl.colors.Normalize() - elif isinstance(norm, tuple): - norm = mpl.colors.Normalize(*norm) - elif not isinstance(norm, mpl.colors.Normalize): - err = "``size_norm`` must be None, tuple, or Normalize object." - raise ValueError(err) - - norm.clip = True - if not norm.scaled(): - norm(np.asarray(numbers)) - # limits = norm.vmin, norm.vmax - - scl = norm(numbers) - widths = np.asarray(min_width + scl * (max_width - min_width)) - if scl.mask.any(): - widths[scl.mask] = 0 - sizes = dict(zip(levels, widths)) - - return pd.Series(sizes) - class _Dataset_PlotMethods: """ @@ -479,67 +319,6 @@ def plotmethod( return newplotfunc -@_dsplot -def scatter(ds, x, y, ax, **kwargs): - """ - Scatter Dataset data variables against each other. - - Wraps :py:func:`matplotlib:matplotlib.pyplot.scatter`. - """ - - if "add_colorbar" in kwargs or "add_legend" in kwargs: - raise ValueError( - "Dataset.plot.scatter does not accept " - "'add_colorbar' or 'add_legend'. " - "Use 'add_guide' instead." - ) - - cmap_params = kwargs.pop("cmap_params") - hue = kwargs.pop("hue") - hue_style = kwargs.pop("hue_style") - markersize = kwargs.pop("markersize", None) - size_norm = kwargs.pop("size_norm", None) - size_mapping = kwargs.pop("size_mapping", None) # set by facetgrid - - # Remove `u` and `v` so they don't get passed to `ax.scatter` - kwargs.pop("u", None) - kwargs.pop("v", None) - - # need to infer size_mapping with full dataset - data = _infer_scatter_data(ds, x, y, hue, markersize, size_norm, size_mapping) - - if hue_style == "discrete": - primitive = [] - # use pd.unique instead of np.unique because that keeps the order of the labels, - # which is important to keep them in sync with the ones used in - # FacetGrid.add_legend - for label in pd.unique(data["hue"].values.ravel()): - mask = data["hue"] == label - if data["sizes"] is not None: - kwargs.update(s=data["sizes"].where(mask, drop=True).values.flatten()) - - primitive.append( - ax.scatter( - data["x"].where(mask, drop=True).values.flatten(), - data["y"].where(mask, drop=True).values.flatten(), - label=label, - **kwargs, - ) - ) - - elif hue is None or hue_style == "continuous": - if data["sizes"] is not None: - kwargs.update(s=data["sizes"].values.ravel()) - if data["hue"] is not None: - kwargs.update(c=data["hue"].values.ravel()) - - primitive = ax.scatter( - data["x"].values.ravel(), data["y"].values.ravel(), **cmap_params, **kwargs - ) - - return primitive - - @_dsplot def quiver(ds, x, y, ax, u, v, **kwargs): """Quiver plot of Dataset variables. @@ -624,3 +403,111 @@ def streamplot(ds, x, y, ax, u, v, **kwargs): # Return .lines so colorbar creation works properly return hdl.lines + + +def _attach_to_plot_class(plotfunc): + """ + Set the function to the plot class and add a common docstring. + + Use this decorator when relying on DataArray.plot methods for + creating the Dataset plot. + + TODO: Reduce code duplication. + + * The goal is to reduce code duplication by moving all Dataset + specific plots to the DataArray side and use this thin wrapper to + handle the conversion between Dataset and DataArray. + * Improve docstring handling, maybe reword the DataArray versions to + explain Datasets better. + * Consider automatically adding all _PlotMethods to + _Dataset_PlotMethods. + + Parameters + ---------- + plotfunc : function + Function that returns a finished plot primitive. + """ + # Build on the original docstring: + original_doc = getattr(_PlotMethods, plotfunc.__name__, None) + commondoc = original_doc.__doc__ + if commondoc is not None: + doc_warning = ( + f"This docstring was copied from xr.DataArray.plot.{original_doc.__name__}." + " Some inconsistencies may exist." + ) + # Add indentation so it matches the original doc: + commondoc = f"\n\n {doc_warning}\n\n {commondoc}" + else: + commondoc = "" + plotfunc.__doc__ = ( + f" {plotfunc.__doc__}\n\n" + " The `y` DataArray will be used as base," + " any other variables are added as coords.\n\n" + f"{commondoc}" + ) + + @functools.wraps(plotfunc) + def plotmethod(self, *args, **kwargs): + return plotfunc(self._ds, *args, **kwargs) + + # Add to class _PlotMethods + setattr(_Dataset_PlotMethods, plotmethod.__name__, plotmethod) + + +def _normalize_args(plotmethod, args, kwargs): + from ..core.dataarray import DataArray + + # Determine positional arguments keyword by inspecting the + # signature of the plotmethod: + locals_ = dict( + inspect.signature(getattr(DataArray().plot, plotmethod)) + .bind(*args, **kwargs) + .arguments.items() + ) + locals_.update(locals_.pop("kwargs", {})) + + return locals_ + + +def _temp_dataarray(ds, y, locals_): + """Create a temporary datarray with extra coords.""" + from ..core.dataarray import DataArray + + # Base coords: + coords = dict(ds.coords) + + # Add extra coords to the DataArray from valid kwargs, if using all + # kwargs there is a risk that we add unneccessary dataarrays as + # coords straining RAM further for example: + # ds.both and extend="both" would add ds.both to the coords: + valid_coord_kwargs = {"x", "z", "markersize", "hue", "row", "col", "u", "v"} + for k in locals_.keys() & valid_coord_kwargs: + key = locals_[k] + if ds.data_vars.get(key) is not None: + coords[key] = ds[key] + + # The dataarray has to include all the dims. Broadcast to that shape + # and add the additional coords: + _y = ds[y].broadcast_like(ds) + + return DataArray(_y, coords=coords) + + +@_attach_to_plot_class +def line(ds, x, y, *args, **kwargs): + """Line plot Dataset data variables against each other.""" + kwargs.update(x=x) + locals_ = _normalize_args("line", args, kwargs) + da = _temp_dataarray(ds, y, locals_) + + return da.plot.line(*locals_.pop("args", ()), **locals_) + + +@_attach_to_plot_class +def scatter(ds, x, y, *args, **kwargs): + """Line plot Dataset data variables against each other.""" + kwargs.update(x=x) + locals_ = _normalize_args("scatter", args, kwargs) + da = _temp_dataarray(ds, y, locals_) + + return da.plot.scatter(*locals_.pop("args", ()), **locals_) diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 01ba00b2b94..b59219457f2 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -8,8 +8,15 @@ from ..core.formatting import format_item from .utils import ( + _LINEWIDTH_RANGE, + _MARKERSIZE_RANGE, + _add_legend, + _determine_guide, _get_nice_quiver_magnitude, + _infer_meta_data, _infer_xy_labels, + _Normalize, + _parse_size, _process_cmap_cbar_kwargs, import_matplotlib_pyplot, label_from_attrs, @@ -296,6 +303,150 @@ def map_dataarray(self, func, x, y, **kwargs): return self + def map_plot1d(self, func, x, y, **kwargs): + """ + Apply a plotting function to a 2d facet's subset of the data. + + This is more convenient and less general than ``FacetGrid.map`` + + Parameters + ---------- + func : callable + A plotting function with the same signature as a 2d xarray + plotting method such as `xarray.plot.imshow` + x, y : string + Names of the coordinates to plot on x, y axes + **kwargs + additional keyword arguments to func + + Returns + ------- + self : FacetGrid object + + """ + # Copy data to allow converting categoricals to integers and storing + # them in self.data. It is not possible to copy in the init + # unfortunately as there are tests that relies on self.data being + # mutable (test_names_appear_somewhere()). Maybe something to deprecate + # not sure how much that is used outside these tests. + self.data = self.data.copy() + + if kwargs.get("cbar_ax", None) is not None: + raise ValueError("cbar_ax not supported by FacetGrid.") + + # Handle hues: + hue = kwargs.get("hue", None) + hueplt = self.data[hue] if hue else self.data + hueplt_norm = _Normalize(hueplt) + self._hue_var = hueplt + cbar_kwargs = kwargs.pop("cbar_kwargs", {}) + if not hueplt_norm.data_is_numeric: + # TODO: Ticks seems a little too hardcoded, since it will always show + # all the values. But maybe it's ok, since plotting hundreds of + # categorical data isn't that meaningful anyway. + cbar_kwargs.update(format=hueplt_norm.format, ticks=hueplt_norm.ticks) + kwargs.update(levels=hueplt_norm.levels) + if "label" not in cbar_kwargs: + cbar_kwargs["label"] = label_from_attrs(hueplt_norm.data) + cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( + func, hueplt_norm.values.to_numpy(), cbar_kwargs=cbar_kwargs, **kwargs + ) + self._cmap_extend = cmap_params.get("extend") + + # Handle sizes: + _size_r = _MARKERSIZE_RANGE if func.__name__ == "scatter" else _LINEWIDTH_RANGE + for _size in ("markersize", "linewidth"): + size = kwargs.get(_size, None) + + sizeplt = self.data[size] if size else None + sizeplt_norm = _Normalize(sizeplt, _size_r) + if size: + self.data[size] = sizeplt_norm.values + kwargs.update(**{_size: size}) + break + + # Add kwargs that are sent to the plotting function, # order is important ??? + func_kwargs = { + k: v + for k, v in kwargs.items() + if k not in {"cmap", "colors", "cbar_kwargs", "levels"} + } + func_kwargs.update(cmap_params) + func_kwargs["add_colorbar"] = False + func_kwargs["add_legend"] = False + func_kwargs["add_title"] = False + + # Subplots should have labels on the left and bottom edges only: + add_labels_ = np.zeros(self.axes.shape + (3,), dtype=bool) + add_labels_[-1, :, 0] = True # x + add_labels_[:, 0, 1] = True # y + # add_labels_[:, :, 2] = True # z + + # + if self._single_group: + full = [ + {self._single_group: x} + for x in range(0, self.data[self._single_group].size) + ] + empty = [None for x in range(self._nrow * self._ncol - len(full))] + name_dicts = full + empty + else: + rowcols = itertools.product( + range(0, self.data[self._row_var].size), + range(0, self.data[self._col_var].size), + ) + name_dicts = [{self._row_var: r, self._col_var: c} for r, c in rowcols] + name_dicts = np.array(name_dicts).reshape(self._nrow, self._ncol) + + # Plot the data for each subplot: + for i, (d, ax) in enumerate(zip(name_dicts.flat, self.axes.flat)): + func_kwargs["add_labels"] = add_labels_.ravel()[3 * i : 3 * i + 3] + # None is the sentinel value + if d is not None: + subset = self.data.isel(d) + mappable = func( + subset, + x=x, + y=y, + ax=ax, + **func_kwargs, + _is_facetgrid=True, + ) + self._mappables.append(mappable) + + # Add titles and some touch ups: + self._finalize_grid() + + add_colorbar, add_legend = _determine_guide( + hueplt_norm, + sizeplt_norm, + kwargs.get("add_colorbar", None), + kwargs.get("add_legend", None), + # kwargs.get("add_guide", None), + # kwargs.get("hue_style", None), + ) + + if add_legend: + use_legend_elements = False if func.__name__ == "hist" else True + if use_legend_elements: + self.add_legend( + use_legend_elements=use_legend_elements, + hueplt_norm=hueplt_norm if not add_colorbar else _Normalize(None), + sizeplt_norm=sizeplt_norm, + primitive=self._mappables, + ax=ax, + legend_ax=self.fig, + plotfunc=func.__name__, + ) + else: + self.add_legend(use_legend_elements=use_legend_elements) + + if add_colorbar: + # Colorbar is after legend so it correctly fits the plot: + self.add_colorbar(**cbar_kwargs) + + return self + def map_dataarray_line( self, func, x, y, hue, add_legend=True, _labels=None, **kwargs ): @@ -324,19 +475,17 @@ def map_dataarray_line( ylabel = label_from_attrs(yplt) self._hue_var = hueplt - self._hue_label = huelabel + # self._hue_label = huelabel self._finalize_grid(xlabel, ylabel) if add_legend and hueplt is not None and huelabel is not None: - self.add_legend() + self.add_legend(label=huelabel) return self def map_dataset( self, func, x=None, y=None, hue=None, hue_style=None, add_guide=None, **kwargs ): - from .dataset_plot import _infer_meta_data, _parse_size - kwargs["add_guide"] = False if kwargs.get("markersize", None): @@ -376,12 +525,13 @@ def map_dataset( self._finalize_grid(meta_data["xlabel"], meta_data["ylabel"]) if hue: - self._hue_label = meta_data.pop("hue_label", None) + hue_label = meta_data.pop("hue_label", None) + self._hue_label = hue_label if meta_data["add_legend"]: self._hue_var = meta_data["hue"] - self.add_legend() + self.add_legend(label=hue_label) elif meta_data["add_colorbar"]: - self.add_colorbar(label=self._hue_label, **cbar_kwargs) + self.add_colorbar(label=hue_label, **cbar_kwargs) if meta_data["add_quiverkey"]: self.add_quiverkey(kwargs["u"], kwargs["v"]) @@ -409,14 +559,15 @@ def _adjust_fig_for_guide(self, guide): # Calculate and set the new width of the figure so the legend fits guide_width = guide.get_window_extent(renderer).width / self.fig.dpi figure_width = self.fig.get_figwidth() - self.fig.set_figwidth(figure_width + guide_width) + total_width = figure_width + guide_width + self.fig.set_figwidth(total_width) # Draw the plot again to get the new transformations self.fig.draw(renderer) # Now calculate how much space we need on the right side guide_width = guide.get_window_extent(renderer).width / self.fig.dpi - space_needed = guide_width / (figure_width + guide_width) + 0.02 + space_needed = guide_width / (total_width) + 0.02 # margin = .01 # _space_needed = margin + space_needed right = 1 - space_needed @@ -424,14 +575,17 @@ def _adjust_fig_for_guide(self, guide): # Place the subplot axes to give space for the legend self.fig.subplots_adjust(right=right) - def add_legend(self, **kwargs): - self.figlegend = self.fig.legend( - handles=self._mappables[-1], - labels=list(self._hue_var.to_numpy()), - title=self._hue_label, - loc="center right", - **kwargs, - ) + def add_legend(self, *, label=None, use_legend_elements: bool, **kwargs): + if use_legend_elements: + self.figlegend = _add_legend(**kwargs) + else: + self.figlegend = self.fig.legend( + handles=self._mappables[-1], + labels=list(self._hue_var.to_numpy()), + title=label if label is not None else label_from_attrs(self._hue_var), + loc="center right", + **kwargs, + ) self._adjust_fig_for_guide(self.figlegend) def add_colorbar(self, **kwargs): @@ -469,39 +623,37 @@ def add_quiverkey(self, u, v, **kwargs): # self._adjust_fig_for_guide(self.quiverkey.text) return self - def set_axis_labels(self, x_var=None, y_var=None): + def set_axis_labels(self, *axlabels): """Set axis labels on the left column and bottom row of the grid.""" - if x_var is not None: - if x_var in self.data.coords: - self._x_var = x_var - self.set_xlabels(label_from_attrs(self.data[x_var])) - else: - # x_var is a string - self.set_xlabels(x_var) + from ..core.dataarray import DataArray + + for var, xyz in zip(axlabels, ["x", "y", "z"]): + if var is not None: + if isinstance(var, DataArray): + getattr(self, f"set_{xyz}labels")(label_from_attrs(var)) + else: + getattr(self, f"set_{xyz}labels")(var) - if y_var is not None: - if y_var in self.data.coords: - self._y_var = y_var - self.set_ylabels(label_from_attrs(self.data[y_var])) - else: - self.set_ylabels(y_var) return self - def set_xlabels(self, label=None, **kwargs): - """Label the x axis on the bottom row of the grid.""" + def _set_labels(self, axis, axes, label=None, **kwargs): if label is None: - label = label_from_attrs(self.data[self._x_var]) - for ax in self._bottom_axes: - ax.set_xlabel(label, **kwargs) + label = label_from_attrs(self.data[getattr(self, f"_{axis}_var")]) + for ax in axes: + getattr(ax, f"set_{axis}label")(label, **kwargs) return self + def set_xlabels(self, label=None, **kwargs): + """Label the x axis on the bottom row of the grid.""" + self._set_labels("x", self._bottom_axes, label, **kwargs) + def set_ylabels(self, label=None, **kwargs): """Label the y axis on the left column of the grid.""" - if label is None: - label = label_from_attrs(self.data[self._y_var]) - for ax in self._left_axes: - ax.set_ylabel(label, **kwargs) - return self + self._set_labels("y", self._left_axes, label, **kwargs) + + def set_zlabels(self, label=None, **kwargs): + """Label the y axis on the left column of the grid.""" + self._set_labels("z", self._left_axes, label, **kwargs) def set_titles(self, template="{coord} = {value}", maxchar=30, size=None, **kwargs): """ @@ -690,5 +842,8 @@ def _easy_facetgrid( if kind == "dataarray": return g.map_dataarray(plotfunc, x, y, **kwargs) + if kind == "plot1d": + return g.map_plot1d(plotfunc, x, y, **kwargs) + if kind == "dataset": return g.map_dataset(plotfunc, x, y, **kwargs) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index ae0adfff00b..6c94732e47e 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -9,22 +9,28 @@ from __future__ import annotations import functools +from typing import Hashable, Iterable, Sequence import numpy as np import pandas as pd from packaging.version import Version from ..core.alignment import broadcast +from ..core.concat import concat +from ..core.types import T_DataArray from .facetgrid import _easy_facetgrid from .utils import ( + _LINEWIDTH_RANGE, + _MARKERSIZE_RANGE, _add_colorbar, - _adjust_legend_subtitles, + _add_legend, _assert_valid_xy, + _determine_guide, _ensure_plottable, _infer_interval_breaks, _infer_xy_labels, - _is_numeric, - _legend_add_subtitle, + _line, + _Normalize, _process_cmap_cbar_kwargs, _rescale_imshow_rgb, _resolve_intervals_1dplot, @@ -33,214 +39,142 @@ get_axis, import_matplotlib_pyplot, label_from_attrs, - legend_elements, ) -# copied from seaborn -_MARKERSIZE_RANGE = np.array([18.0, 72.0]) +def _infer_plot_dims( + darray, dims_plot: dict, default_guesser: Iterable[str] = ("x", "hue", "size") +) -> dict: + dims_plot_exist = {k: v for k, v in dims_plot.items() if v is not None} + dims_avail = tuple(v for v in darray.dims if v not in dims_plot_exist.values()) -def _infer_scatter_metadata(darray, x, z, hue, hue_style, size): - def _determine_array(darray, name, array_style): - """Find and determine what type of array it is.""" - array = darray[name] - array_is_numeric = _is_numeric(array.values) + # If dims_plot[k] isn't defined then fill with one of the available dims: + for k, v in zip(default_guesser, dims_avail): + if dims_plot.get(k, None) is None: + dims_plot[k] = v - if array_style is None: - array_style = "continuous" if array_is_numeric else "discrete" - elif array_style not in ["discrete", "continuous"]: - raise ValueError( - f"The style '{array_style}' is not valid, " - "valid options are None, 'discrete' or 'continuous'." - ) - - array_label = label_from_attrs(array) - - return array, array_style, array_label - - # Add nice looking labels: - out = dict(ylabel=label_from_attrs(darray)) - out.update( - { - k: label_from_attrs(darray[v]) if v in darray.coords else None - for k, v in [("xlabel", x), ("zlabel", z)] - } - ) - - # Add styles and labels for the dataarrays: - for type_, a, style in [("hue", hue, hue_style), ("size", size, None)]: - tp, stl, lbl = f"{type_}", f"{type_}_style", f"{type_}_label" - if a: - out[tp], out[stl], out[lbl] = _determine_array(darray, a, style) - else: - out[tp], out[stl], out[lbl] = None, None, None - - return out - - -# copied from seaborn -def _parse_size(data, norm, width): - """ - Determine what type of data it is. Then normalize it to width. - - If the data is categorical, normalize it to numbers. - """ - plt = import_matplotlib_pyplot() + for k, v in dims_plot.items(): + _assert_valid_xy(darray, v, k) - if data is None: - return None + return dims_plot - data = data.values.ravel() - - if not _is_numeric(data): - # Data is categorical. - # Use pd.unique instead of np.unique because that keeps - # the order of the labels: - levels = pd.unique(data) - numbers = np.arange(1, 1 + len(levels)) - else: - levels = numbers = np.sort(np.unique(data)) - min_width, max_width = width - # width_range = min_width, max_width +def _infer_line_data(darray, dims_plot: dict, plotfunc_name: str = None) -> dict: + # Guess what dims to use if some of the values in plot_dims are None: + dims_plot = _infer_plot_dims(darray, dims_plot) - if norm is None: - norm = plt.Normalize() - elif isinstance(norm, tuple): - norm = plt.Normalize(*norm) - elif not isinstance(norm, plt.Normalize): - err = "``size_norm`` must be None, tuple, or Normalize object." - raise ValueError(err) + # If there are more than 1 dimension in the array than stack all the + # dimensions so the plotter can plot anything: + if darray.ndim > 1: + # When stacking dims the lines will continue connecting. For floats + # this can be solved by adding a nan element inbetween the flattening + # points: + dims_T = [] + if np.issubdtype(darray.dtype, np.floating): + for v in ["z", "x"]: + dim = dims_plot.get(v, None) + if (dim is not None) and (dim in darray.dims): + darray_nan = np.nan * darray.isel(**{dim: -1}) + darray = concat([darray, darray_nan], dim=dim) + dims_T.append(dims_plot[v]) - norm.clip = True - if not norm.scaled(): - norm(np.asarray(numbers)) - # limits = norm.vmin, norm.vmax + # Lines should never connect to the same coordinate when stacked, + # transpose to avoid this as much as possible: + darray = darray.transpose(..., *dims_T) - scl = norm(numbers) - widths = np.asarray(min_width + scl * (max_width - min_width)) - if scl.mask.any(): - widths[scl.mask] = 0 - sizes = dict(zip(levels, widths)) + # Array is now ready to be stacked: + darray = darray.stack(_stacked_dim=darray.dims) - return pd.Series(sizes) - - -def _infer_scatter_data( - darray, x, z, hue, size, size_norm, size_mapping=None, size_range=(1, 10) -): # Broadcast together all the chosen variables: - to_broadcast = dict(y=darray) - to_broadcast.update( - {k: darray[v] for k, v in dict(x=x, z=z).items() if v is not None} - ) - to_broadcast.update( - {k: darray[v] for k, v in dict(hue=hue, size=size).items() if v in darray.dims} - ) - broadcasted = dict(zip(to_broadcast.keys(), broadcast(*(to_broadcast.values())))) - - # Normalize hue and size and create lookup tables: - for type_, mapping, norm, width in [ - ("hue", None, None, [0, 1]), - ("size", size_mapping, size_norm, size_range), - ]: - broadcasted_type = broadcasted.get(type_, None) - if broadcasted_type is not None: - if mapping is None: - mapping = _parse_size(broadcasted_type, norm, width) - - broadcasted[type_] = broadcasted_type.copy( - data=np.reshape( - mapping.loc[broadcasted_type.values.ravel()].values, - broadcasted_type.shape, - ) - ) - broadcasted[f"{type_}_to_label"] = pd.Series(mapping.index, index=mapping) - - return broadcasted - - -def _infer_line_data(darray, x, y, hue): + out = dict(y=darray) + out.update({k: darray[v] for k, v in dims_plot.items() if v is not None}) + out = dict(zip(out.keys(), broadcast(*(out.values())))) - ndims = len(darray.dims) - - if x is not None and y is not None: - raise ValueError("Cannot specify both x and y kwargs for line plots.") - - if x is not None: - _assert_valid_xy(darray, x, "x") - - if y is not None: - _assert_valid_xy(darray, y, "y") - - if ndims == 1: - huename = None - hueplt = None - huelabel = "" - - if x is not None: - xplt = darray[x] - yplt = darray - - elif y is not None: - xplt = darray - yplt = darray[y] - - else: # Both x & y are None - dim = darray.dims[0] - xplt = darray[dim] - yplt = darray - - else: - if x is None and y is None and hue is None: - raise ValueError("For 2D inputs, please specify either hue, x or y.") - - if y is None: - if hue is not None: - _assert_valid_xy(darray, hue, "hue") - xname, huename = _infer_xy_labels(darray=darray, x=x, y=hue) - xplt = darray[xname] - if xplt.ndim > 1: - if huename in darray.dims: - otherindex = 1 if darray.dims.index(huename) == 0 else 0 - otherdim = darray.dims[otherindex] - yplt = darray.transpose(otherdim, huename, transpose_coords=False) - xplt = xplt.transpose(otherdim, huename, transpose_coords=False) - else: - raise ValueError( - "For 2D inputs, hue must be a dimension" - " i.e. one of " + repr(darray.dims) - ) - - else: - (xdim,) = darray[xname].dims - (huedim,) = darray[huename].dims - yplt = darray.transpose(xdim, huedim) - - else: - yname, huename = _infer_xy_labels(darray=darray, x=y, y=hue) - yplt = darray[yname] - if yplt.ndim > 1: - if huename in darray.dims: - otherindex = 1 if darray.dims.index(huename) == 0 else 0 - otherdim = darray.dims[otherindex] - xplt = darray.transpose(otherdim, huename, transpose_coords=False) - yplt = yplt.transpose(otherdim, huename, transpose_coords=False) - else: - raise ValueError( - "For 2D inputs, hue must be a dimension" - " i.e. one of " + repr(darray.dims) - ) - - else: - (ydim,) = darray[yname].dims - (huedim,) = darray[huename].dims - xplt = darray.transpose(ydim, huedim) + return out - huelabel = label_from_attrs(darray[huename]) - hueplt = darray[huename] - return xplt, yplt, hueplt, huelabel +# def _infer_line_data(darray, x, y, hue): + +# ndims = len(darray.dims) + +# if x is not None and y is not None: +# raise ValueError("Cannot specify both x and y kwargs for line plots.") + +# if x is not None: +# _assert_valid_xy(darray, x, "x") + +# if y is not None: +# _assert_valid_xy(darray, y, "y") + +# if ndims == 1: +# huename = None +# hueplt = None +# huelabel = "" + +# if x is not None: +# xplt = darray[x] +# yplt = darray + +# elif y is not None: +# xplt = darray +# yplt = darray[y] + +# else: # Both x & y are None +# dim = darray.dims[0] +# xplt = darray[dim] +# yplt = darray + +# else: +# if x is None and y is None and hue is None: +# raise ValueError("For 2D inputs, please specify either hue, x or y.") + +# if y is None: +# if hue is not None: +# _assert_valid_xy(darray, hue, "hue") +# xname, huename = _infer_xy_labels(darray=darray, x=x, y=hue) +# xplt = darray[xname] +# if xplt.ndim > 1: +# if huename in darray.dims: +# otherindex = 1 if darray.dims.index(huename) == 0 else 0 +# otherdim = darray.dims[otherindex] +# yplt = darray.transpose(otherdim, huename, transpose_coords=False) +# xplt = xplt.transpose(otherdim, huename, transpose_coords=False) +# else: +# raise ValueError( +# "For 2D inputs, hue must be a dimension" +# " i.e. one of " + repr(darray.dims) +# ) + +# else: +# (xdim,) = darray[xname].dims +# (huedim,) = darray[huename].dims +# yplt = darray.transpose(xdim, huedim) + +# else: +# yname, huename = _infer_xy_labels(darray=darray, x=y, y=hue) +# yplt = darray[yname] +# if yplt.ndim > 1: +# if huename in darray.dims: +# otherindex = 1 if darray.dims.index(huename) == 0 else 0 +# otherdim = darray.dims[otherindex] +# xplt = darray.transpose(otherdim, huename, transpose_coords=False) +# yplt = yplt.transpose(otherdim, huename, transpose_coords=False) +# else: +# raise ValueError( +# "For 2D inputs, hue must be a dimension" +# " i.e. one of " + repr(darray.dims) +# ) + +# else: +# (ydim,) = darray[yname].dims +# (huedim,) = darray[huename].dims +# xplt = darray.transpose(ydim, huedim) + +# huelabel = label_from_attrs(darray[huename]) +# hueplt = darray[huename] + +# return xplt, yplt, hueplt, huelabel +# # return dict(x=xplt, y=yplt, hue=hueplt, hue_label = huelabel, z=zplt) def plot( @@ -334,140 +268,7 @@ def plot( return plotfunc(darray, **kwargs) -# This function signature should not change so that it can use -# matplotlib format strings -def line( - darray, - *args, - row=None, - col=None, - figsize=None, - aspect=None, - size=None, - ax=None, - hue=None, - x=None, - y=None, - xincrease=None, - yincrease=None, - xscale=None, - yscale=None, - xticks=None, - yticks=None, - xlim=None, - ylim=None, - add_legend=True, - _labels=True, - **kwargs, -): - """ - Line plot of DataArray values. - - Wraps :py:func:`matplotlib:matplotlib.pyplot.plot`. - - Parameters - ---------- - darray : DataArray - Either 1D or 2D. If 2D, one of ``hue``, ``x`` or ``y`` must be provided. - figsize : tuple, optional - A tuple (width, height) of the figure in inches. - Mutually exclusive with ``size`` and ``ax``. - aspect : scalar, optional - Aspect ratio of plot, so that ``aspect * size`` gives the *width* in - inches. Only used if a ``size`` is provided. - size : scalar, optional - If provided, create a new figure for the plot with the given size: - *height* (in inches) of each plot. See also: ``aspect``. - ax : matplotlib axes object, optional - Axes on which to plot. By default, the current is used. - Mutually exclusive with ``size`` and ``figsize``. - hue : str, optional - Dimension or coordinate for which you want multiple lines plotted. - If plotting against a 2D coordinate, ``hue`` must be a dimension. - x, y : str, optional - Dimension, coordinate or multi-index level for *x*, *y* axis. - Only one of these may be specified. - The other will be used for values from the DataArray on which this - plot method is called. - xscale, yscale : {'linear', 'symlog', 'log', 'logit'}, optional - Specifies scaling for the *x*- and *y*-axis, respectively. - xticks, yticks : array-like, optional - Specify tick locations for *x*- and *y*-axis. - xlim, ylim : array-like, optional - Specify *x*- and *y*-axis limits. - xincrease : None, True, or False, optional - Should the values on the *x* axis be increasing from left to right? - if ``None``, use the default for the Matplotlib function. - yincrease : None, True, or False, optional - Should the values on the *y* axis be increasing from top to bottom? - if ``None``, use the default for the Matplotlib function. - add_legend : bool, optional - Add legend with *y* axis coordinates (2D inputs only). - *args, **kwargs : optional - Additional arguments to :py:func:`matplotlib:matplotlib.pyplot.plot`. - """ - # Handle facetgrids first - if row or col: - allargs = locals().copy() - allargs.update(allargs.pop("kwargs")) - allargs.pop("darray") - return _easy_facetgrid(darray, line, kind="line", **allargs) - - ndims = len(darray.dims) - if ndims > 2: - raise ValueError( - "Line plots are for 1- or 2-dimensional DataArrays. " - "Passed DataArray has {ndims} " - "dimensions".format(ndims=ndims) - ) - - # The allargs dict passed to _easy_facetgrid above contains args - if args == (): - args = kwargs.pop("args", ()) - else: - assert "args" not in kwargs - - ax = get_axis(figsize, size, aspect, ax) - xplt, yplt, hueplt, hue_label = _infer_line_data(darray, x, y, hue) - - # Remove pd.Intervals if contained in xplt.values and/or yplt.values. - xplt_val, yplt_val, x_suffix, y_suffix, kwargs = _resolve_intervals_1dplot( - xplt.to_numpy(), yplt.to_numpy(), kwargs - ) - xlabel = label_from_attrs(xplt, extra=x_suffix) - ylabel = label_from_attrs(yplt, extra=y_suffix) - - _ensure_plottable(xplt_val, yplt_val) - - primitive = ax.plot(xplt_val, yplt_val, *args, **kwargs) - - if _labels: - if xlabel is not None: - ax.set_xlabel(xlabel) - - if ylabel is not None: - ax.set_ylabel(ylabel) - - ax.set_title(darray._title_for_slice()) - - if darray.ndim == 2 and add_legend: - ax.legend(handles=primitive, labels=list(hueplt.to_numpy()), title=hue_label) - - # Rotate dates on xlabels - # Do this without calling autofmt_xdate so that x-axes ticks - # on other subplots (if any) are not deleted. - # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots - if np.issubdtype(xplt.dtype, np.datetime64): - for xlabels in ax.get_xticklabels(): - xlabels.set_rotation(30) - xlabels.set_ha("right") - - _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim) - - return primitive - - -def step(darray, *args, where="pre", drawstyle=None, ds=None, **kwargs): +def step_(darray, *args, where="pre", drawstyle=None, ds=None, **kwargs): """ Step plot of DataArray values. @@ -508,7 +309,7 @@ def step(darray, *args, where="pre", drawstyle=None, ds=None, **kwargs): return line(darray, *args, drawstyle=drawstyle, **kwargs) -def hist( +def hist_old( darray, figsize=None, size=None, @@ -566,336 +367,622 @@ def hist( return primitive -def scatter( - darray, - *args, - row=None, - col=None, - figsize=None, - aspect=None, - size=None, - ax=None, - hue=None, - hue_style=None, - x=None, - z=None, - xincrease=None, - yincrease=None, - xscale=None, - yscale=None, - xticks=None, - yticks=None, - xlim=None, - ylim=None, - add_legend=None, - add_colorbar=None, - cbar_kwargs=None, - cbar_ax=None, - vmin=None, - vmax=None, - norm=None, - infer_intervals=None, - center=None, - levels=None, - robust=None, - colors=None, - extend=None, - cmap=None, - _labels=True, - **kwargs, -): +# MUST run before any 2d plotting functions are defined since +# _plot2d decorator adds them as methods here. +class _PlotMethods: + """ + Enables use of xarray.plot functions as attributes on a DataArray. + For example, DataArray.plot.imshow """ - Scatter plot a DataArray along some coordinates. + __slots__ = ("_da",) + + def __init__(self, darray): + self._da = darray + + def __call__(self, **kwargs): + return plot(self._da, **kwargs) + + # we can't use functools.wraps here since that also modifies the name / qualname + __doc__ = __call__.__doc__ = plot.__doc__ + __call__.__wrapped__ = plot # type: ignore[attr-defined] + __call__.__annotations__ = plot.__annotations__ + + # @functools.wraps(hist) + # def hist(self, ax=None, **kwargs): + # return hist(self._da, ax=ax, **kwargs) + + # @functools.wraps(line) + # def line(self, *args, **kwargs): + # return line(self._da, *args, **kwargs) + + # @functools.wraps(step) + # def step(self, *args, **kwargs): + # return step(self._da, *args, **kwargs) + + # @functools.wraps(scatter) + # def _scatter(self, *args, **kwargs): + # return scatter(self._da, *args, **kwargs) + + +def override_signature(f): + def wrapper(func): + func.__wrapped__ = f + + return func + + return wrapper + + +def _plot1d(plotfunc): + """ + Decorator for common 1d plotting logic. + + Also adds the 1d plot method to class _PlotMethods. + """ + commondoc = """ Parameters ---------- darray : DataArray - Dataarray to plot. - x, y : str - Variable names for x, y axis. - hue: str, optional - Variable by which to color scattered points - hue_style: str, optional - Can be either 'discrete' (legend) or 'continuous' (color bar). - markersize: str, optional - scatter only. Variable by which to vary size of scattered points. - size_norm: optional - Either None or 'Norm' instance to normalize the 'markersize' variable. - add_guide: bool, optional - Add a guide that depends on hue_style - - for "discrete", build a legend. - This is the default for non-numeric `hue` variables. - - for "continuous", build a colorbar - row : str, optional - If passed, make row faceted plots on this dimension name - col : str, optional - If passed, make column faceted plots on this dimension name - col_wrap : int, optional - Use together with ``col`` to wrap faceted plots - ax : matplotlib axes object, optional - If None, uses the current axis. Not applicable when using facets. - subplot_kws : dict, optional - Dictionary of keyword arguments for matplotlib subplots. Only applies - to FacetGrid plotting. + Must be 2 dimensional, unless creating faceted plots + x : string, optional + Coordinate for x axis. If None use darray.dims[1] + y : string, optional + Coordinate for y axis. If None use darray.dims[0] + hue : string, optional + Dimension or coordinate for which you want multiple lines plotted. + figsize : tuple, optional + A tuple (width, height) of the figure in inches. + Mutually exclusive with ``size`` and ``ax``. aspect : scalar, optional Aspect ratio of plot, so that ``aspect * size`` gives the width in inches. Only used if a ``size`` is provided. size : scalar, optional If provided, create a new figure for the plot with the given size. Height (in inches) of each plot. See also: ``aspect``. - norm : ``matplotlib.colors.Normalize`` instance, optional - If the ``norm`` has vmin or vmax specified, the corresponding kwarg - must be None. - vmin, vmax : float, optional - Values to anchor the colormap, otherwise they are inferred from the - data and other keyword arguments. When a diverging dataset is inferred, - setting one of these values will fix the other by symmetry around - ``center``. Setting both values prevents use of a diverging colormap. - If discrete levels are provided as an explicit list, both of these - values are ignored. - cmap : str or colormap, optional - The mapping from data values to color space. Either a - matplotlib colormap name or object. If not provided, this will - be either ``viridis`` (if the function infers a sequential - dataset) or ``RdBu_r`` (if the function infers a diverging - dataset). When `Seaborn` is installed, ``cmap`` may also be a - `seaborn` color palette. If ``cmap`` is seaborn color palette - and the plot type is not ``contour`` or ``contourf``, ``levels`` - must also be specified. - colors : color-like or list of color-like, optional - A single color or a list of colors. If the plot type is not ``contour`` - or ``contourf``, the ``levels`` argument is required. - center : float, optional - The value at which to center the colormap. Passing this value implies - use of a diverging colormap. Setting it to ``False`` prevents use of a - diverging colormap. - robust : bool, optional - If True and ``vmin`` or ``vmax`` are absent, the colormap range is - computed with 2nd and 98th percentiles instead of the extreme values. - extend : {"neither", "both", "min", "max"}, optional - How to draw arrows extending the colorbar beyond its limits. If not - provided, extend is inferred from vmin, vmax and the data limits. - levels : int or list-like object, optional - Split the colormap (cmap) into discrete color intervals. If an integer - is provided, "nice" levels are chosen based on the data range: this can - imply that the final number of levels is not exactly the expected one. - Setting ``vmin`` and/or ``vmax`` with ``levels=N`` is equivalent to - setting ``levels=np.linspace(vmin, vmax, N)``. + ax : matplotlib.axes.Axes, optional + Axis on which to plot this figure. By default, use the current axis. + Mutually exclusive with ``size`` and ``figsize``. + row : string, optional + If passed, make row faceted plots on this dimension name + col : string, optional + If passed, make column faceted plots on this dimension name + col_wrap : int, optional + Use together with ``col`` to wrap faceted plots + xscale, yscale : 'linear', 'symlog', 'log', 'logit', optional + Specifies scaling for the x- and y-axes respectively + xticks, yticks : Specify tick locations for x- and y-axes + xlim, ylim : Specify x- and y-axes limits + xincrease : None, True, or False, optional + Should the values on the x axes be increasing from left to right? + if None, use the default for the matplotlib function. + yincrease : None, True, or False, optional + Should the values on the y axes be increasing from top to bottom? + if None, use the default for the matplotlib function. + add_labels : bool, optional + Use xarray metadata to label axes + subplot_kws : dict, optional + Dictionary of keyword arguments for matplotlib subplots. Only used + for FacetGrid plots. **kwargs : optional - Additional keyword arguments to matplotlib + Additional arguments to wrapped matplotlib function + + Returns + ------- + artist : + The same type of primitive artist that the wrapped matplotlib + function returns """ - plt = import_matplotlib_pyplot() - # Handle facetgrids first - if row or col: - allargs = locals().copy() - allargs.update(allargs.pop("kwargs")) - allargs.pop("darray") - subplot_kws = dict(projection="3d") if z is not None else None - return _easy_facetgrid( - darray, scatter, kind="dataarray", subplot_kws=subplot_kws, **allargs + # Build on the original docstring + plotfunc.__doc__ = f"{plotfunc.__doc__}\n{commondoc}" + + # plotfunc and newplotfunc have different signatures: + # - plotfunc: (x, y, z, ax, **kwargs) + # - newplotfunc: (darray, *args, x, y, **kwargs) + # where plotfunc accepts numpy arrays, while newplotfunc accepts a DataArray + # and variable names. newplotfunc also explicitly lists most kwargs, so we + # need to shorten it + def signature(darray, *args, x, **kwargs): + pass + + @override_signature(signature) + @functools.wraps(plotfunc) + def newplotfunc( + darray, + *args, + x: Hashable = None, + y: Hashable = None, + z: Hashable = None, + hue: Hashable = None, + hue_style=None, + markersize: Hashable = None, + linewidth: Hashable = None, + figsize=None, + size=None, + aspect=None, + ax=None, + row: Hashable = None, + col: Hashable = None, + col_wrap=None, + xincrease=True, + yincrease=True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool = True, + add_title: bool = True, + subplot_kws: dict | None = None, + xscale=None, + yscale=None, + xticks=None, + yticks=None, + xlim=None, + ylim=None, + cmap=None, + vmin=None, + vmax=None, + norm=None, + extend=None, + levels=None, + **kwargs, + ): + # All 1d plots in xarray share this function signature. + # Method signature below should be consistent. + + if subplot_kws is None: + subplot_kws = dict() + + # Handle facetgrids first + if row or col: + if z is not None: + subplot_kws.update(projection="3d") + + allargs = locals().copy() + allargs.update(allargs.pop("kwargs")) + allargs.pop("darray") + allargs["plotfunc"] = globals()[plotfunc.__name__] + + return _easy_facetgrid(darray, kind="plot1d", **allargs) + + # The allargs dict passed to _easy_facetgrid above contains args + if args == (): + args = kwargs.pop("args", ()) + else: + assert "args" not in kwargs + + if markersize is not None: + size_ = markersize + size_r = _MARKERSIZE_RANGE + else: + size_ = linewidth + size_r = _LINEWIDTH_RANGE + + _is_facetgrid = kwargs.pop("_is_facetgrid", False) + + plts = _infer_line_data( + darray, dict(x=x, z=z, hue=hue, size=size_), plotfunc.__name__ ) + xplt = plts.pop("x", None) + yplt = plts.pop("y", None) + zplt = plts.pop("z", None) + kwargs.update(zplt=zplt) + hueplt = plts.pop("hue", None) + sizeplt = plts.pop("size", None) + + hueplt_norm = _Normalize(hueplt) + kwargs.update(hueplt=hueplt_norm.values) + sizeplt_norm = _Normalize(sizeplt, size_r, _is_facetgrid) + kwargs.update(sizeplt=sizeplt_norm.values) + cmap_params_subset = kwargs.pop("cmap_params_subset", {}) + cbar_kwargs = kwargs.pop("cbar_kwargs", {}) + + if hueplt_norm.data is not None: + if not hueplt_norm.data_is_numeric: + # Map hue values back to its original value: + cbar_kwargs.update(format=hueplt_norm.format, ticks=hueplt_norm.ticks) + levels = kwargs.get("levels", hueplt_norm.levels) + + cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( + plotfunc, + hueplt_norm.values.data, + **locals(), + ) - # Further - _is_facetgrid = kwargs.pop("_is_facetgrid", False) - if _is_facetgrid: - # Why do I need to pop these here? - kwargs.pop("y", None) - kwargs.pop("args", None) - kwargs.pop("add_labels", None) - - _sizes = kwargs.pop("markersize", kwargs.pop("linewidth", None)) - size_norm = kwargs.pop("size_norm", None) - size_mapping = kwargs.pop("size_mapping", None) # set by facetgrid - cmap_params = kwargs.pop("cmap_params", None) - - figsize = kwargs.pop("figsize", None) - subplot_kws = dict() - if z is not None and ax is None: - # TODO: Importing Axes3D is not necessary in matplotlib >= 3.2. - # Remove when minimum requirement of matplotlib is 3.2: - from mpl_toolkits.mplot3d import Axes3D # type: ignore # noqa - - subplot_kws.update(projection="3d") - ax = get_axis(figsize, size, aspect, ax, **subplot_kws) - # Using 30, 30 minimizes rotation of the plot. Making it easier to - # build on your intuition from 2D plots: - if Version(plt.matplotlib.__version__) < Version("3.5.0"): - ax.view_init(azim=30, elev=30) + # subset that can be passed to scatter, hist2d + if not cmap_params_subset: + ckw = {vv: cmap_params[vv] for vv in ("vmin", "vmax", "norm", "cmap")} + cmap_params_subset.update(**ckw) + + if z is not None and ax is None: + subplot_kws.update(projection="3d") + ax = get_axis(figsize, size, aspect, ax, **subplot_kws) + # Using 30, 30 minimizes rotation of the plot. Making it easier to + # build on your intuition from 2D plots: + plt = import_matplotlib_pyplot() + if Version(plt.matplotlib.__version__) < Version("3.5.0"): + ax.view_init(azim=30, elev=30) + else: + # https://github.com/matplotlib/matplotlib/pull/19873 + ax.view_init(azim=30, elev=30, vertical_axis="y") else: - # https://github.com/matplotlib/matplotlib/pull/19873 - ax.view_init(azim=30, elev=30, vertical_axis="y") - else: - ax = get_axis(figsize, size, aspect, ax, **subplot_kws) + ax = get_axis(figsize, size, aspect, ax, **subplot_kws) - _data = _infer_scatter_metadata(darray, x, z, hue, hue_style, _sizes) + primitive = plotfunc( + xplt, + yplt, + *args, + ax=ax, + add_labels=add_labels, + **cmap_params_subset, + **kwargs, + ) - add_guide = kwargs.pop("add_guide", None) - if add_legend is not None: - pass - elif add_guide is None or add_guide is True: - add_legend = True if _data["hue_style"] == "discrete" else False - elif add_legend is None: - add_legend = False + if np.any(add_labels) and add_title: + ax.set_title(darray._title_for_slice()) - if add_colorbar is not None: - pass - elif add_guide is None or add_guide is True: - add_colorbar = True if _data["hue_style"] == "continuous" else False - else: - add_colorbar = False - - # need to infer size_mapping with full dataset - _data.update( - _infer_scatter_data( - darray, - x, - z, - hue, - _sizes, - size_norm, - size_mapping, - _MARKERSIZE_RANGE, + add_colorbar_, add_legend_ = _determine_guide( + hueplt_norm, + sizeplt_norm, + add_colorbar, + add_legend, + plotfunc_name=plotfunc.__name__, ) - ) - cmap_params_subset = {} - if _data["hue"] is not None: - kwargs.update(c=_data["hue"].values.ravel()) - cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( - scatter, _data["hue"].values, **locals() + if add_colorbar_: + if "label" not in cbar_kwargs: + cbar_kwargs["label"] = label_from_attrs(hueplt_norm.data) + + _add_colorbar( + primitive, ax, kwargs.get("cbar_ax", None), cbar_kwargs, cmap_params + ) + + if add_legend_: + if plotfunc.__name__ == "hist": + ax.legend( + handles=primitive[-1], + labels=list(hueplt_norm.values.to_numpy()), + title=label_from_attrs(hueplt_norm.data), + ) + elif plotfunc.__name__ in ["scatter", "line"]: + _add_legend( + hueplt_norm + if add_legend or not add_colorbar_ + else _Normalize(None), + sizeplt_norm, + primitive, + ax=ax, + legend_ax=ax, + plotfunc=plotfunc.__name__, + ) + else: + ax.legend( + handles=primitive, + labels=list(hueplt_norm.values.to_numpy()), + title=label_from_attrs(hueplt_norm.data), + ) + + _update_axes( + ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim ) - # subset that can be passed to scatter, hist2d - cmap_params_subset = { - vv: cmap_params[vv] for vv in ["vmin", "vmax", "norm", "cmap"] - } + return primitive + + # For use as DataArray.plot.plotmethod + @functools.wraps(newplotfunc) + def plotmethod( + _PlotMethods_obj, + *args, + x: Hashable = None, + y: Hashable = None, + z: Hashable = None, + hue: Hashable = None, + hue_style=None, + markersize: Hashable = None, + linewidth: Hashable = None, + figsize=None, + size=None, + aspect=None, + ax=None, + row: Hashable = None, + col: Hashable = None, + col_wrap=None, + xincrease=True, + yincrease=True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | None = True, + subplot_kws=None, + xscale=None, + yscale=None, + xticks=None, + yticks=None, + xlim=None, + ylim=None, + cmap=None, + vmin=None, + vmax=None, + norm=None, + extend=None, + levels=None, + **kwargs, + ): + """ + The method should have the same signature as the function. + + This just makes the method work on Plotmethods objects, + and passes all the other arguments straight through. + """ + allargs = locals().copy() + allargs["darray"] = _PlotMethods_obj._da + allargs.update(kwargs) + for arg in ["_PlotMethods_obj", "newplotfunc", "kwargs"]: + del allargs[arg] + return newplotfunc(**allargs) + + # Add to class _PlotMethods + setattr(_PlotMethods, plotmethod.__name__, plotmethod) + + return newplotfunc + + +def _add_labels( + add_labels: bool | Iterable[bool], + darrays: Sequence[T_DataArray], + suffixes: Iterable[str], + rotate_labels: Iterable[bool], + ax, +): + # Set x, y, z labels: + xyz = ("x", "y", "z") + add_labels = [add_labels] * len(xyz) if isinstance(add_labels, bool) else add_labels + for i, (add_label, darray, suffix, rotate_label) in enumerate( + zip(add_labels, darrays, suffixes, rotate_labels) + ): + if darray is None: + continue + + lbl = xyz[i] + if add_label: + label = label_from_attrs(darray, extra=suffix) + if label is not None: + getattr(ax, f"set_{lbl}label")(label) + + if rotate_label and np.issubdtype(darray.dtype, np.datetime64): + # Rotate dates on xlabels + # Do this without calling autofmt_xdate so that x-axes ticks + # on other subplots (if any) are not deleted. + # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots + for labels in getattr(ax, f"get_{lbl}ticklabels")(): + labels.set_rotation(30) + labels.set_ha("right") + + +# # This function signature should not change so that it can use +# # matplotlib format strings +# @_plot1d +# def line2d(xplt, yplt, *args, ax, add_labels=True, **kwargs): +# """ +# Line plot of DataArray index against values +# Wraps :func:`matplotlib:matplotlib.pyplot.plot` +# """ +# plt = import_matplotlib_pyplot() + +# zplt = kwargs.pop("zplt", None) +# hueplt = kwargs.pop("hueplt", None) +# sizeplt = kwargs.pop("sizeplt", None) + +# vmin = kwargs.pop("vmin", None) +# vmax = kwargs.pop("vmax", None) +# kwargs["clim"] = [vmin, vmax] +# # norm = kwargs["norm"] = kwargs.pop( +# # "norm", plt.matplotlib.colors.Normalize(vmin=vmin, vmax=vmax) +# # ) + +# # if hueplt is not None: +# # ScalarMap = plt.cm.ScalarMappable(norm=norm, cmap=kwargs.get("cmap", None)) +# # kwargs.update(colors=ScalarMap.to_rgba(hueplt.to_numpy().ravel())) +# # kwargs.update(colors=hueplt.to_numpy().ravel()) + +# # if sizeplt is not None: +# # kwargs.update(linewidths=sizeplt.to_numpy().ravel()) + +# # Remove pd.Intervals if contained in xplt.values and/or yplt.values. +# xplt_val, yplt_val, x_suffix, y_suffix, kwargs = _resolve_intervals_1dplot( +# xplt.to_numpy(), yplt.to_numpy(), kwargs +# ) +# _ensure_plottable(xplt_val, yplt_val) + +# primitive = ax.plot(xplt_val, yplt_val, *args, **kwargs) + +# _add_labels(add_labels, (xplt, yplt), (x_suffix, y_suffix), (True, False), ax) + +# return primitive + + +def _line_(xplt, yplt, *args, ax, add_labels=True, **kwargs): + plt = import_matplotlib_pyplot() + + zplt = kwargs.pop("zplt", None) + hueplt = kwargs.pop("hueplt", None) + sizeplt = kwargs.pop("sizeplt", None) - if _data["size"] is not None: - kwargs.update(s=_data["size"].values.ravel()) + cmap = kwargs.pop("cmap", None) + vmin = kwargs.pop("vmin", None) + vmax = kwargs.pop("vmax", None) + norm = kwargs.pop("norm", None) + + c = hueplt.to_numpy() if hueplt is not None else None + s = sizeplt.to_numpy() if sizeplt is not None else None + zplt_val = zplt.to_numpy() if zplt is not None else None + + # Remove pd.Intervals if contained in xplt.values and/or yplt.values. + xplt_val, yplt_val, x_suffix, y_suffix, kwargs = _resolve_intervals_1dplot( + xplt.to_numpy(), yplt.to_numpy(), kwargs + ) + z_suffix = "" # TODO: to _resolve_intervals? + _ensure_plottable(xplt_val, yplt_val) if Version(plt.matplotlib.__version__) < Version("3.5.0"): # Plot the data. 3d plots has the z value in upward direction # instead of y. To make jumping between 2d and 3d easy and intuitive # switch the order so that z is shown in the depthwise direction: - axis_order = ["x", "z", "y"] + # axis_order = dict(x="x", y="z", z="y") + axis_order = ["x", "y", "z"] + to_plot, to_labels, to_suffix, i = {}, {}, {}, 0 + for arr, arr_val, suffix in zip( + [xplt, zplt, yplt], + [xplt_val, zplt_val, yplt_val], + (x_suffix, z_suffix, y_suffix), + ): + if arr is not None: + to_plot[axis_order[i]] = arr_val + to_labels[axis_order[i]] = arr + to_suffix[axis_order[i]] = suffix + i += 1 + # to_plot = dict(x=xplt_val, y=zplt_val, z=yplt_val) + # to_labels = dict(x=xplt, y=zplt, z=yplt) else: # Switching axis order not needed in 3.5.0, can also simplify the code # that uses axis_order: # https://github.com/matplotlib/matplotlib/pull/19873 + # axis_order = dict(x="x", y="y", z="z") axis_order = ["x", "y", "z"] - - primitive = ax.scatter( - *[ - _data[v].values.ravel() - for v in axis_order - if _data.get(v, None) is not None - ], - **cmap_params_subset, + to_plot, to_labels, to_suffix, i = {}, {}, {}, 0 + for arr, arr_val, suffix in zip( + [xplt, yplt, zplt], + [xplt_val, yplt_val, zplt_val], + (x_suffix, z_suffix, y_suffix), + ): + if arr is not None: + to_plot[axis_order[i]] = arr_val + to_labels[axis_order[i]] = arr + to_suffix[axis_order[i]] = suffix + i += 1 + + primitive = _line( + ax, + **to_plot, + s=s, + c=c, + cmap=cmap, + norm=norm, + vmin=vmin, + vmax=vmax, **kwargs, ) # Set x, y, z labels: - i = 0 - set_label = [ax.set_xlabel, ax.set_ylabel, getattr(ax, "set_zlabel", None)] - for v in axis_order: - if _data.get(f"{v}label", None) is not None: - set_label[i](_data[f"{v}label"]) - i += 1 - - if add_legend: - - def to_label(data, key, x): - """Map prop values back to its original values.""" - if key in data: - # Use reindex to be less sensitive to float errors. - # Return as numpy array since legend_elements - # seems to require that: - return data[key].reindex(x, method="nearest").to_numpy() - else: - return x - - handles, labels = [], [] - for subtitle, prop, func in [ - ( - _data["hue_label"], - "colors", - functools.partial(to_label, _data, "hue_to_label"), - ), - ( - _data["size_label"], - "sizes", - functools.partial(to_label, _data, "size_to_label"), - ), - ]: - if subtitle: - # Get legend handles and labels that displays the - # values correctly. Order might be different because - # legend_elements uses np.unique instead of pd.unique, - # FacetGrid.add_legend might have troubles with this: - hdl, lbl = legend_elements(primitive, prop, num="auto", func=func) - hdl, lbl = _legend_add_subtitle(hdl, lbl, subtitle, ax.scatter) - handles += hdl - labels += lbl - legend = ax.legend(handles, labels, framealpha=0.5) - _adjust_legend_subtitles(legend) - - if add_colorbar and _data["hue_label"]: - if _data["hue_style"] == "discrete": - raise NotImplementedError("Cannot create a colorbar for non numerics.") - cbar_kwargs = {} if cbar_kwargs is None else cbar_kwargs - if "label" not in cbar_kwargs: - cbar_kwargs["label"] = _data["hue_label"] - _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params) + _add_labels( + add_labels, to_labels.values(), to_suffix.values(), (True, False, False), ax + ) return primitive -# MUST run before any 2d plotting functions are defined since -# _plot2d decorator adds them as methods here. -class _PlotMethods: +# This function signature should not change so that it can use +# matplotlib format strings +@_plot1d +def line(xplt, yplt, *args, ax, add_labels=True, **kwargs): """ - Enables use of xarray.plot functions as attributes on a DataArray. - For example, DataArray.plot.imshow + Line plot of DataArray index against values + Wraps :func:`matplotlib:matplotlib.collections.LineCollection` """ + return _line_(xplt, yplt, *args, ax=ax, add_labels=add_labels, **kwargs) - __slots__ = ("_da",) - def __init__(self, darray): - self._da = darray +@_plot1d +def step(xplt, yplt, *args, ax, add_labels=True, **kwargs): + """ + Step plot of DataArray index against values + Wraps :func:`matplotlib:matplotlib.collections.LineCollection` + """ + kwargs.pop("drawstyle", None) + where = kwargs.pop("where", "pre") + kwargs.update(drawstyle="steps-" + where) + return _line_(xplt, yplt, *args, ax=ax, add_labels=add_labels, **kwargs) - def __call__(self, **kwargs): - return plot(self._da, **kwargs) - # we can't use functools.wraps here since that also modifies the name / qualname - __doc__ = __call__.__doc__ = plot.__doc__ - __call__.__wrapped__ = plot # type: ignore[attr-defined] - __call__.__annotations__ = plot.__annotations__ +@_plot1d +def hist(xplt, yplt, *args, ax, add_labels=True, **kwargs): + """ + Histogram of DataArray. - @functools.wraps(hist) - def hist(self, ax=None, **kwargs): - return hist(self._da, ax=ax, **kwargs) + Wraps :py:func:`matplotlib:matplotlib.pyplot.hist`. - @functools.wraps(line) - def line(self, *args, **kwargs): - return line(self._da, *args, **kwargs) + Plots *N*-dimensional arrays by first flattening the array. + """ + # plt = import_matplotlib_pyplot() - @functools.wraps(step) - def step(self, *args, **kwargs): - return step(self._da, *args, **kwargs) + zplt = kwargs.pop("zplt", None) + kwargs.pop("hueplt", None) + kwargs.pop("sizeplt", None) - @functools.wraps(scatter) - def _scatter(self, *args, **kwargs): - return scatter(self._da, *args, **kwargs) + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + kwargs.pop("norm", None) + kwargs.pop("cmap", None) + no_nan = np.ravel(yplt.to_numpy()) + no_nan = no_nan[pd.notnull(no_nan)] -def override_signature(f): - def wrapper(func): - func.__wrapped__ = f + # counts, bins = np.histogram(no_nan) + # n, bins, primitive = ax.hist(bins[:-1], bins, weights=counts, **kwargs) + n, bins, primitive = ax.hist(no_nan, **kwargs) - return func + _add_labels(add_labels, [yplt, xplt, zplt], ("", "", ""), (True, False, False), ax) - return wrapper + return primitive + + +# This function signature should not change so that it can use +# matplotlib format strings +@_plot1d +def scatter(xplt, yplt, *args, ax, add_labels=True, **kwargs): + plt = import_matplotlib_pyplot() + + zplt = kwargs.pop("zplt", None) + hueplt = kwargs.pop("hueplt", None) + sizeplt = kwargs.pop("sizeplt", None) + + # Add a white border to make it easier seeing overlapping markers: + kwargs.update(edgecolors=kwargs.pop("edgecolors", "w")) + + if hueplt is not None: + kwargs.update(c=hueplt.to_numpy().ravel()) + + if sizeplt is not None: + kwargs.update(s=sizeplt.to_numpy().ravel()) + + if Version(plt.matplotlib.__version__) < Version("3.5.0"): + # Plot the data. 3d plots has the z value in upward direction + # instead of y. To make jumping between 2d and 3d easy and intuitive + # switch the order so that z is shown in the depthwise direction: + axis_order = ["x", "z", "y"] + else: + # Switching axis order not needed in 3.5.0, can also simplify the code + # that uses axis_order: + # https://github.com/matplotlib/matplotlib/pull/19873 + axis_order = ["x", "y", "z"] + + plts = dict(x=xplt, y=yplt, z=zplt) + primitive = ax.scatter( + *[ + plts[v].to_numpy().ravel() + for v in axis_order + if plts.get(v, None) is not None + ], + **kwargs, + ) + + # Set x, y, z labels: + plts_ = [] + for v in axis_order: + arr = plts.get(f"{v}", None) + if arr is not None: + plts_.append(arr) + _add_labels(add_labels, plts_, ("", "", ""), (True, False, False), ax) + + return primitive def _plot2d(plotfunc): diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index aef21f0be42..10594bdfdb8 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -5,7 +5,7 @@ import warnings from datetime import datetime from inspect import getfullargspec -from typing import Any, Iterable, Mapping +from typing import Any, Iterable, Mapping, Sequence import numpy as np import pandas as pd @@ -30,6 +30,10 @@ ROBUST_PERCENTILE = 2.0 +# copied from seaborn +_MARKERSIZE_RANGE = np.array([18.0, 72.0]) +_LINEWIDTH_RANGE = np.array([1.5, 6.0]) + def import_matplotlib_pyplot(): """import pyplot""" @@ -393,6 +397,7 @@ def _infer_xy_labels(darray, x, y, imshow=False, rgb=None): return x, y +# TODO: Can by used to more than x or y, rename? def _assert_valid_xy(darray, xy, name): """ make sure x and y passed to plotting functions are valid @@ -407,9 +412,9 @@ def _assert_valid_xy(darray, xy, name): valid_xy = (set(darray.dims) | set(darray.coords)) - multiindex_dims - if xy not in valid_xy: + if (xy is not None) and (xy not in valid_xy): valid_xy_str = "', '".join(sorted(valid_xy)) - raise ValueError(f"{name} must be one of None, '{valid_xy_str}'") + raise ValueError(f"{name} must be one of None, '{valid_xy_str}', got '{xy}'.") def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs): @@ -530,8 +535,8 @@ def _interval_to_double_bound_points(xarray, yarray): xarray1 = np.array([x.left for x in xarray]) xarray2 = np.array([x.right for x in xarray]) - xarray = list(itertools.chain.from_iterable(zip(xarray1, xarray2))) - yarray = list(itertools.chain.from_iterable(zip(yarray, yarray))) + xarray = np.array(list(itertools.chain.from_iterable(zip(xarray1, xarray2)))) + yarray = np.array(list(itertools.chain.from_iterable(zip(yarray, yarray)))) return xarray, yarray @@ -1005,7 +1010,10 @@ def _get_color_and_size(value): return self.cmap(self.norm(value)), _size elif prop == "sizes": - arr = self.get_sizes() + if isinstance(self, mpl.collections.LineCollection): + arr = self.get_linewidths() + else: + arr = self.get_sizes() _color = kwargs.pop("color", "k") def _get_color_and_size(value): @@ -1094,22 +1102,29 @@ def _get_color_and_size(value): for val, lab in zip(values, label_values): color, size = _get_color_and_size(val) - h = mlines.Line2D( - [0], [0], ls="", color=color, ms=size, marker=self.get_paths()[0], **kw - ) + + if isinstance(self, mpl.collections.PathCollection): + kw.update(linestyle="", marker=self.get_paths()[0], markersize=size) + elif isinstance(self, mpl.collections.LineCollection): + kw.update(linestyle=self.get_linestyle()[0], linewidth=size) + + h = mlines.Line2D([0], [0], color=color, **kw) + handles.append(h) labels.append(fmt(lab)) return handles, labels -def _legend_add_subtitle(handles, labels, text, func): +def _legend_add_subtitle(handles, labels, text, ax): """Add a subtitle to legend handles.""" + plt = import_matplotlib_pyplot() + if text and len(handles) > 1: # Create a blank handle that's not visible, the # invisibillity will be used to discern which are subtitles # or not: - blank_handle = func([], [], label=text) + blank_handle = plt.Line2D([], [], label=text) blank_handle.set_visible(False) # Subtitles are shown first: @@ -1126,8 +1141,13 @@ def _adjust_legend_subtitles(legend): # Legend title not in rcParams until 3.0 font_size = plt.rcParams.get("legend.title_fontsize", None) hpackers = legend.findobj(plt.matplotlib.offsetbox.VPacker)[0].get_children() + hpackers = [v for v in hpackers if isinstance(v, plt.matplotlib.offsetbox.HPacker)] for hpack in hpackers: - draw_area, text_area = hpack.get_children() + areas = hpack.get_children() + if len(areas) < 2: + continue + draw_area, text_area = areas + handles = draw_area.get_children() # Assume that all artists that are not visible are @@ -1141,3 +1161,582 @@ def _adjust_legend_subtitles(legend): # The sutbtitles should have the same font size # as normal legend titles: text.set_size(font_size) + + +# %% +class _Normalize(Sequence): + """ + Normalize numerical or categorical values to numerical values. + + The class includes helper methods that simplifies transforming to + and from normalized values. + + Parameters + ---------- + data : DataArray + DataArray to normalize. + width : Sequence of two numbers, optional + Normalize the data to theses min and max values. + The default is None. + """ + + __slots__ = ( + "_data", + "_data_is_numeric", + "_width", + "_unique", + "_unique_index", + "_unique_inverse", + "plt", + ) + + def __init__(self, data, width=None, _is_facetgrid=False): + self._data = data + self._width = width if not _is_facetgrid else None + self.plt = import_matplotlib_pyplot() + + pint_array_type = DuckArrayModule("pint").type + to_unique = data.to_numpy() if isinstance(self._type, pint_array_type) else data + unique, unique_inverse = np.unique(to_unique, return_inverse=True) + self._unique = unique + self._unique_index = np.arange(0, unique.size) + if data is not None: + self._unique_inverse = data.copy(data=unique_inverse.reshape(data.shape)) + self._data_is_numeric = _is_numeric(data) + else: + self._unique_inverse = unique_inverse + self._data_is_numeric = False + + def __repr__(self): + with np.printoptions(precision=4, suppress=True, threshold=5): + return ( + f"<_Normalize(data, width={self._width})>\n" + f"{self._unique} -> {self.values_unique}" + ) + + def __len__(self): + return len(self._unique) + + def __getitem__(self, key): + return self._unique[key] + + @property + def _type(self): + data = self.data + return data.data if data is not None else data + + @property + def data(self): + return self._data + + @property + def data_is_numeric(self) -> bool: + """ + Check if data is numeric. + + Examples + -------- + >>> a = xr.DataArray(["b", "a", "a", "b", "c"]) + >>> _Normalize(a).data_is_numeric + False + """ + return self._data_is_numeric + + def _calc_widths(self, y): + if self._width is None or y is None: + return y + + x0, x1 = self._width + + k = (y - np.min(y)) / (np.max(y) - np.min(y)) + widths = x0 + k * (x1 - x0) + + return widths + + def _indexes_centered(self, x): + """ + Offset indexes to make sure being in the center of self.levels. + ["a", "b", "c"] -> [1, 3, 5] + """ + if self.data is None: + return None + else: + return x * 2 + 1 + + @property + def values(self): + """ + Return a normalized number array for the unique levels. + + Examples + -------- + >>> a = xr.DataArray(["b", "a", "a", "b", "c"]) + >>> _Normalize(a).values + + array([3, 1, 1, 3, 5]) + Dimensions without coordinates: dim_0 + + >>> _Normalize(a, width=[18, 72]).values + + array([45., 18., 18., 45., 72.]) + Dimensions without coordinates: dim_0 + + >>> a = xr.DataArray([0.5, 0, 0, 0.5, 2, 3]) + >>> _Normalize(a).values + + array([0.5, 0. , 0. , 0.5, 2. , 3. ]) + Dimensions without coordinates: dim_0 + + >>> _Normalize(a, width=[18, 72]).values + + array([27., 18., 18., 27., 54., 72.]) + Dimensions without coordinates: dim_0 + """ + return self._calc_widths( + self.data + if self.data_is_numeric + else self._indexes_centered(self._unique_inverse) + ) + + def _integers(self): + """ + Return integers. + ["a", "b", "c"] -> [1, 3, 5] + """ + return self._indexes_centered(self._unique_index) + + @property + def values_unique(self): + """ + Return unique values. + + Examples + -------- + >>> a = xr.DataArray(["b", "a", "a", "b", "c"]) + >>> _Normalize(a).values_unique + array([1, 3, 5]) + >>> a = xr.DataArray([2, 1, 1, 2, 3]) + >>> _Normalize(a).values_unique + array([1, 2, 3]) + >>> _Normalize(a, width=[18, 72]).values_unique + array([18., 45., 72.]) + """ + return ( + self._integers() + if not self.data_is_numeric + else self._calc_widths(self._unique) + ) + + @property + def ticks(self): + """ + Return ticks for plt.colorbar if the data is not numeric. + + Examples + -------- + >>> a = xr.DataArray(["b", "a", "a", "b", "c"]) + >>> _Normalize(a).ticks + array([1, 3, 5]) + """ + return self._integers() if not self.data_is_numeric else None + + @property + def levels(self): + """ + Return discrete levels that will evenly bound self.values. + ["a", "b", "c"] -> [0, 2, 4, 6] + + Examples + -------- + >>> a = xr.DataArray(["b", "a", "a", "b", "c"]) + >>> _Normalize(a).levels + array([0, 2, 4, 6]) + """ + return np.append(self._unique_index, np.max(self._unique_index) + 1) * 2 + + @property + def _lookup(self) -> pd.Series: + return pd.Series(dict(zip(self.values_unique, self._unique))) + + def _lookup_arr(self, x) -> np.ndarray: + # Use reindex to be less sensitive to float errors. reindex only + # works with sorted index. + # Return as numpy array since legend_elements + # seems to require that: + return self._lookup.sort_index().reindex(x, method="nearest").to_numpy() + + @property + def format(self): + """ + Return a FuncFormatter that maps self.values elements back to + the original value as a string. Useful with plt.colorbar. + + Examples + -------- + >>> a = xr.DataArray([0.5, 0, 0, 0.5, 2, 3]) + >>> aa = _Normalize(a, width=[0, 1]) + >>> aa._lookup + 0.000000 0.0 + 0.166667 0.5 + 0.666667 2.0 + 1.000000 3.0 + dtype: float64 + >>> aa.format(1) + '3.0' + """ + return self.plt.FuncFormatter(lambda x, pos=None: f"{self._lookup_arr([x])[0]}") + + @property + def func(self): + """ + Return a lambda function that maps self.values elements back to + the original value as a numpy array. Useful with ax.legend_elements. + + Examples + -------- + >>> a = xr.DataArray([0.5, 0, 0, 0.5, 2, 3]) + >>> aa = _Normalize(a, width=[0, 1]) + >>> aa._lookup + 0.000000 0.0 + 0.166667 0.5 + 0.666667 2.0 + 1.000000 3.0 + dtype: float64 + >>> aa.func([0.16, 1]) + array([0.5, 3. ]) + """ + return lambda x, pos=None: self._lookup_arr(x) + + +def _determine_guide( + hueplt_norm, + sizeplt_norm, + add_colorbar=None, + add_legend=None, + plotfunc_name: str = None, +): + if plotfunc_name == "hist": + return False, False + + if (add_colorbar) and hueplt_norm.data is None: + raise KeyError("Cannot create a colorbar when hue is None.") + if add_colorbar is None: + if hueplt_norm.data is not None: + add_colorbar = True + else: + add_colorbar = False + + if (add_legend) and hueplt_norm.data is None and sizeplt_norm.data is None: + raise KeyError("Cannot create a legend when hue and markersize is None.") + if add_legend is None: + if ( + not add_colorbar + and (hueplt_norm.data is not None and hueplt_norm.data_is_numeric is False) + or sizeplt_norm.data is not None + ): + add_legend = True + else: + add_legend = False + + return add_colorbar, add_legend + + +def _add_legend( + hueplt_norm: _Normalize, + sizeplt_norm: _Normalize, + primitive, + ax, + legend_ax, + plotfunc: str, +): + + primitive = primitive if isinstance(primitive, list) else [primitive] + + handles, labels = [], [] + for huesizeplt, prop in [ + (hueplt_norm, "colors"), + (sizeplt_norm, "sizes"), + ]: + if huesizeplt.data is not None: + # Get legend handles and labels that displays the + # values correctly. Order might be different because + # legend_elements uses np.unique instead of pd.unique, + # FacetGrid.add_legend might have troubles with this: + hdl, lbl = [], [] + for p in primitive: + hdl_, lbl_ = legend_elements(p, prop, num="auto", func=huesizeplt.func) + hdl += hdl_ + lbl += lbl_ + + # Only save unique values: + u, ind = np.unique(lbl, return_index=True) + ind = np.argsort(ind) + lbl = u[ind].tolist() + hdl = np.array(hdl)[ind].tolist() + + # Add a subtitle: + hdl, lbl = _legend_add_subtitle( + hdl, lbl, label_from_attrs(huesizeplt.data), ax + ) + handles += hdl + labels += lbl + legend = legend_ax.legend(handles, labels, framealpha=0.5) + _adjust_legend_subtitles(legend) + + return legend + + +def _infer_meta_data(ds, x, y, hue, hue_style, add_guide, funcname): + dvars = set(ds.variables.keys()) + + error_msg = f" must be one of ({', '.join(dvars)})" + + if x not in dvars: + raise ValueError("x" + error_msg + f", got {x}") + + if y not in dvars: + raise ValueError("y" + error_msg + f", got {y}") + + if hue is not None and hue not in dvars: + raise ValueError("hue" + error_msg + f", got {hue}") + + if hue: + hue_is_numeric = _is_numeric(ds[hue].values) + + if hue_style is None: + hue_style = "continuous" if hue_is_numeric else "discrete" + + if not hue_is_numeric and (hue_style == "continuous"): + raise ValueError( + f"Cannot create a colorbar for a non numeric coordinate: {hue}" + ) + + if add_guide is None or add_guide is True: + add_colorbar = True if hue_style == "continuous" else False + add_legend = True if hue_style == "discrete" else False + else: + add_colorbar = False + add_legend = False + else: + if add_guide is True and funcname not in ("quiver", "streamplot"): + raise ValueError("Cannot set add_guide when hue is None.") + add_legend = False + add_colorbar = False + + if (add_guide or add_guide is None) and funcname == "quiver": + add_quiverkey = True + if hue: + add_colorbar = True + if not hue_style: + hue_style = "continuous" + elif hue_style != "continuous": + raise ValueError( + "hue_style must be 'continuous' or None for .plot.quiver or " + ".plot.streamplot" + ) + else: + add_quiverkey = False + + if (add_guide or add_guide is None) and funcname == "streamplot": + if hue: + add_colorbar = True + if not hue_style: + hue_style = "continuous" + elif hue_style != "continuous": + raise ValueError( + "hue_style must be 'continuous' or None for .plot.quiver or " + ".plot.streamplot" + ) + + if hue_style is not None and hue_style not in ["discrete", "continuous"]: + raise ValueError("hue_style must be either None, 'discrete' or 'continuous'.") + + if hue: + hue_label = label_from_attrs(ds[hue]) + hue = ds[hue] + else: + hue_label = None + hue = None + + return { + "add_colorbar": add_colorbar, + "add_legend": add_legend, + "add_quiverkey": add_quiverkey, + "hue_label": hue_label, + "hue_style": hue_style, + "xlabel": label_from_attrs(ds[x]), + "ylabel": label_from_attrs(ds[y]), + "hue": hue, + } + + +# copied from seaborn +def _parse_size(data, norm, width): + """ + Determine what type of data it is. Then normalize it to width. + + If the data is categorical, normalize it to numbers. + """ + plt = import_matplotlib_pyplot() + + if data is None: + return None + + data = data.values.ravel() + + if not _is_numeric(data): + # Data is categorical. + # Use pd.unique instead of np.unique because that keeps + # the order of the labels: + levels = pd.unique(data) + numbers = np.arange(1, 1 + len(levels)) + else: + levels = numbers = np.sort(np.unique(data)) + + min_width, max_width = width + # width_range = min_width, max_width + + if norm is None: + norm = plt.Normalize() + elif isinstance(norm, tuple): + norm = plt.Normalize(*norm) + elif not isinstance(norm, plt.Normalize): + err = "``size_norm`` must be None, tuple, or Normalize object." + raise ValueError(err) + + norm.clip = True + if not norm.scaled(): + norm(np.asarray(numbers)) + # limits = norm.vmin, norm.vmax + + scl = norm(numbers) + widths = np.asarray(min_width + scl * (max_width - min_width)) + if scl.mask.any(): + widths[scl.mask] = 0 + sizes = dict(zip(levels, widths)) + + return pd.Series(sizes) + return pd.Series(sizes) + + +def _line( + self, + x, + y, + s=None, + c=None, + linestyle=None, + cmap=None, + norm=None, + vmin=None, + vmax=None, + alpha=None, + linewidths=None, + *, + edgecolors=None, + plotnonfinite=False, + **kwargs, +): + """ + ax.scatter-like wrapper for LineCollection. + + This function helps the handling of datetimes since Linecollection doesn't + support it directly, just like PatchCollection doesn't either. + + """ + plt = import_matplotlib_pyplot() + rcParams = plt.matplotlib.rcParams + + # Handle z inputs: + z = kwargs.pop("z", None) + if z is not None: + from mpl_toolkits.mplot3d.art3d import Line3DCollection + + LineCollection_ = Line3DCollection + add_collection_ = self.add_collection3d + auto_scale = self.auto_scale_xyz + auto_scale_args = (x, y, z, self.has_data()) + else: + LineCollection_ = plt.matplotlib.collections.LineCollection + add_collection_ = self.add_collection + auto_scale = self._request_autoscale_view + auto_scale_args = tuple() + + # Process **kwargs to handle aliases, conflicts with explicit kwargs: + x, y = self._process_unit_info([("x", x), ("y", y)], kwargs) + + if s is None: + s = np.array([rcParams["lines.linewidth"]]) + # s = np.ma.ravel(s) + if len(s) not in (1, x.size) or ( + not np.issubdtype(s.dtype, np.floating) + and not np.issubdtype(s.dtype, np.integer) + ): + raise ValueError( + "s must be a scalar, " "or float array-like with the same size as x and y" + ) + + edgecolors or kwargs.get("edgecolor", None) + c, colors, edgecolors = self._parse_scatter_color_args( + c, + edgecolors, + kwargs, + x.size, + get_next_color_func=self._get_patches_for_fill.get_next_color, + ) + + # load default linestyle from rcParams + if linestyle is None: + linestyle = rcParams["lines.linestyle"] + + drawstyle = kwargs.pop("drawstyle", "default") + if drawstyle == "default": + # Draw linear lines: + xyz = list(v for v in (x, y, z) if v is not None) + else: + # Draw stepwise lines: + from matplotlib.cbook import STEP_LOOKUP_MAP + + step_func = STEP_LOOKUP_MAP[drawstyle] + xyz = step_func(*tuple(v for v in (x, y, z) if v is not None)) + + # Create steps by repeating all elements, then roll the last array by 1: + # Might be scary duplicating number of elements? + # xyz = list(np.repeat(v, 2) for v in (x, y, z) if v is not None) + # c = np.repeat(c, 2) # TODO: Off by one? + # s = np.repeat(s, 2) + # if drawstyle == "steps-pre": + # xyz[-1][:-1] = xyz[-1][1:] + # elif drawstyle == "steps-post": + # xyz[-1][1:] = xyz[-1][:-1] + # else: + # raise NotImplementedError( + # f"Allowed values are: 'default', 'steps-pre', 'steps-post', got {drawstyle}." + # ) + + # Broadcast arrays to correct format: + # https://stackoverflow.com/questions/42215777/matplotlib-line-color-in-3d + points = np.stack(np.broadcast_arrays(*xyz), axis=-1).reshape(-1, 1, len(xyz)) + segments = np.concatenate([points[:-1], points[1:]], axis=1) + + collection = LineCollection_( + segments, + linewidths=s, + linestyles="solid", + ) + # collection.set_transform(plt.matplotlib.transforms.IdentityTransform()) + collection.update(kwargs) + + if colors is None: + collection.set_array(c) + collection.set_cmap(cmap) + collection.set_norm(norm) + collection._scale_norm(norm, vmin, vmax) + + add_collection_(collection) + + # self._request_autoscale_view() + # self.autoscale_view() + auto_scale(*auto_scale_args) + + return collection diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 27e48c27ae2..19d0c76bffd 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -782,20 +782,30 @@ def test_step_with_where(self, where): def test_coord_with_interval_step(self): """Test step plot with intervals.""" bins = [-1, 0, 1, 2] - self.darray.groupby_bins("dim_0", bins).mean(...).plot.step() - assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) + lc = self.darray.groupby_bins("dim_0", bins).mean(...).plot.step() + expected = (len(bins) - 1) * 2 + actual = sum(v.shape[0] for v in lc.get_segments() if v.shape[0] > 1) + assert expected == actual def test_coord_with_interval_step_x(self): """Test step plot with intervals explicitly on x axis.""" bins = [-1, 0, 1, 2] - self.darray.groupby_bins("dim_0", bins).mean(...).plot.step(x="dim_0_bins") - assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) + lc = self.darray.groupby_bins("dim_0", bins).mean(...).plot.step(x="dim_0_bins") + expected = (len(bins) - 1) * 2 + actual = sum(v.shape[0] for v in lc.get_segments() if v.shape[0] > 1) + assert expected == actual def test_coord_with_interval_step_y(self): """Test step plot with intervals explicitly on y axis.""" bins = [-1, 0, 1, 2] - self.darray.groupby_bins("dim_0", bins).mean(...).plot.step(y="dim_0_bins") - assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) + arr = self.darray.groupby_bins("dim_0", bins).mean(...) + lc = arr.plot.step(y="dim_0_bins") + # TODO: Test and make sure data is plotted on the correct axis: + x = np.array([v[0, 0] for v in lc.get_segments() if v.shape[0] > 1]) + y = np.array([v[1, 1] for v in lc.get_segments() if v.shape[0] > 1]) + expected = len(bins) - 1 + actual = sum(v.shape[0] for v in lc.get_segments() if v.shape[0] > 1) + assert expected == actual class TestPlotHistogram(PlotTestCase): @@ -827,7 +837,7 @@ def test_can_pass_in_axis(self): def test_primitive_returned(self): h = self.darray.plot.hist() - assert isinstance(h[-1][0], mpl.patches.Rectangle) + assert isinstance(h[0], mpl.patches.Rectangle) @pytest.mark.slow def test_plot_nans(self): @@ -2292,10 +2302,18 @@ def setUp(self): self.darray = xr.tutorial.scatter_example_dataset() def test_legend_labels(self): - fg = self.darray.A.plot.line(col="x", row="w", hue="z") + fg = self.darray.A.plot.line(col="x", row="w", hue="z", linewidth="z") all_legend_labels = [t.get_text() for t in fg.figlegend.texts] # labels in legend should be ['0', '1', '2', '3'] - assert sorted(all_legend_labels) == ["0", "1", "2", "3"] + # assert sorted(all_legend_labels) == ["0", "1", "2", "3", "z [zunits]"] + actual = [ + "z [zunits]", + "$\\mathdefault{0}$", + "$\\mathdefault{1}$", + "$\\mathdefault{2}$", + "$\\mathdefault{3}$", + ] + assert all_legend_labels == actual @pytest.mark.filterwarnings("ignore:tight_layout cannot") @@ -2322,14 +2340,14 @@ def test_facetgrid_shape(self): g = self.darray.plot(row="col", col="row", hue="hue") assert g.axes.shape == (len(self.darray.col), len(self.darray.row)) - def test_unnamed_args(self): - g = self.darray.plot.line("o--", row="row", col="col", hue="hue") - lines = [ - q for q in g.axes.flat[0].get_children() if isinstance(q, mpl.lines.Line2D) - ] - # passing 'o--' as argument should set marker and linestyle - assert lines[0].get_marker() == "o" - assert lines[0].get_linestyle() == "--" + # def test_unnamed_args(self): + # g = self.darray.plot.line("o--", row="row", col="col", hue="hue") + # lines = [ + # q for q in g.axes.flat[0].get_children() if isinstance(q, mpl.lines.Line2D) + # ] + # # passing 'o--' as argument should set marker and linestyle + # assert lines[0].get_marker() == "o" + # assert lines[0].get_linestyle() == "--" def test_default_labels(self): g = self.darray.plot(row="row", col="col", hue="hue") @@ -2371,10 +2389,10 @@ def test_figsize_and_size(self): with pytest.raises(ValueError): self.darray.plot.line(row="row", col="col", x="x", size=3, figsize=4) - def test_wrong_num_of_dimensions(self): - with pytest.raises(ValueError): - self.darray.plot(row="row", hue="hue") - self.darray.plot.line(row="row", hue="hue") + # def test_wrong_num_of_dimensions(self): + # with pytest.raises(ValueError): + # self.darray.plot(row="row", hue="hue") + # # self.darray.plot.line(row="row", hue="hue") @requires_matplotlib @@ -2436,6 +2454,29 @@ def test_facetgrid(self): with pytest.raises(ValueError, match=r"Please provide scale"): self.ds.plot.quiver(x="x", y="y", u="u", v="v", row="row", col="col") + @pytest.mark.parametrize( + "add_guide, hue_style, legend, colorbar", + [ + (None, None, False, True), + (False, None, False, False), + (True, None, False, True), + (True, "continuous", False, True), + ], + ) + def test_add_guide(self, add_guide, hue_style, legend, colorbar): + + meta_data = _infer_meta_data( + self.ds, + x="x", + y="y", + hue="mag", + hue_style=hue_style, + add_guide=add_guide, + funcname="quiver", + ) + assert meta_data["add_legend"] is legend + assert meta_data["add_colorbar"] is colorbar + @requires_matplotlib class TestDatasetStreamplotPlots(PlotTestCase): @@ -2520,31 +2561,6 @@ def test_accessor(self): assert Dataset.plot is _Dataset_PlotMethods assert isinstance(self.ds.plot, _Dataset_PlotMethods) - @pytest.mark.parametrize( - "add_guide, hue_style, legend, colorbar", - [ - (None, None, False, True), - (False, None, False, False), - (True, None, False, True), - (True, "continuous", False, True), - (False, "discrete", False, False), - (True, "discrete", True, False), - ], - ) - def test_add_guide(self, add_guide, hue_style, legend, colorbar): - - meta_data = _infer_meta_data( - self.ds, - x="A", - y="B", - hue="hue", - hue_style=hue_style, - add_guide=add_guide, - funcname="scatter", - ) - assert meta_data["add_legend"] is legend - assert meta_data["add_colorbar"] is colorbar - def test_facetgrid_shape(self): g = self.ds.plot.scatter(x="A", y="B", row="row", col="col") assert g.axes.shape == (len(self.ds.row), len(self.ds.col)) @@ -2576,18 +2592,17 @@ def test_figsize_and_size(self): self.ds.plot.scatter(x="A", y="B", row="row", size=3, figsize=4) @pytest.mark.parametrize( - "x, y, hue_style, add_guide", + "x, y, hue, add_legend, add_colorbar, error_type", [ - ("A", "B", "something", True), - ("A", "B", "discrete", True), - ("A", "B", None, True), - ("A", "The Spanish Inquisition", None, None), - ("The Spanish Inquisition", "B", None, True), + ("A", "The Spanish Inquisition", None, None, None, KeyError), + ("The Spanish Inquisition", "B", None, None, True, ValueError), ], ) - def test_bad_args(self, x, y, hue_style, add_guide): - with pytest.raises(ValueError): - self.ds.plot.scatter(x, y, hue_style=hue_style, add_guide=add_guide) + def test_bad_args(self, x, y, hue, add_legend, add_colorbar, error_type): + with pytest.raises(error_type): + self.ds.plot.scatter( + x=x, y=y, hue=hue, add_legend=add_legend, add_colorbar=add_colorbar + ) @pytest.mark.xfail(reason="datetime,timedelta hue variable not supported.") @pytest.mark.parametrize("hue_style", ["discrete", "continuous"]) @@ -2602,53 +2617,60 @@ def test_datetime_hue(self, hue_style): def test_facetgrid_hue_style(self): # Can't move this to pytest.mark.parametrize because py37-bare-minimum # doesn't have matplotlib. - for hue_style, map_type in ( - ("discrete", list), - ("continuous", mpl.collections.PathCollection), - ): + for hue_style in ("discrete", "continuous"): g = self.ds.plot.scatter( x="A", y="B", row="row", col="col", hue="hue", hue_style=hue_style ) - # for 'discrete' a list is appended to _mappables - # for 'continuous', should be single PathCollection - assert isinstance(g._mappables[-1], map_type) + # 'discrete' and 'continuous', should be single PathCollection + assert isinstance(g._mappables[-1], mpl.collections.PathCollection) @pytest.mark.parametrize( "x, y, hue, markersize", [("A", "B", "x", "col"), ("x", "row", "A", "B")] ) def test_scatter(self, x, y, hue, markersize): - self.ds.plot.scatter(x, y, hue=hue, markersize=markersize) + self.ds.plot.scatter(x=x, y=y, hue=hue, markersize=markersize) - with pytest.raises(ValueError, match=r"u, v"): - self.ds.plot.scatter(x, y, u="col", v="row") + # with pytest.raises(ValueError, match=r"u, v"): + # self.ds.plot.scatter(x, y, u="col", v="row") def test_non_numeric_legend(self): ds2 = self.ds.copy() ds2["hue"] = ["a", "b", "c", "d"] - lines = ds2.plot.scatter(x="A", y="B", hue="hue") + pc = ds2.plot.scatter(x="A", y="B", hue="hue") # should make a discrete legend - assert lines[0].axes.legend_ is not None - # and raise an error if explicitly not allowed to do so - with pytest.raises(ValueError): - ds2.plot.scatter(x="A", y="B", hue="hue", hue_style="continuous") + assert pc.axes.legend_ is not None + # # and raise an error if explicitly not allowed to do so + # with pytest.raises(ValueError): + # ds2.plot.scatter(x="A", y="B", hue="hue", hue_style="continuous") def test_legend_labels(self): # regression test for #4126: incorrect legend labels ds2 = self.ds.copy() ds2["hue"] = ["a", "a", "b", "b"] - lines = ds2.plot.scatter(x="A", y="B", hue="hue") - assert [t.get_text() for t in lines[0].axes.get_legend().texts] == ["a", "b"] + pc = ds2.plot.scatter(x="A", y="B", hue="hue") + actual = [t.get_text() for t in pc.axes.get_legend().texts] + expected = [ + "col [colunits]", + "$\\mathdefault{0}$", + "$\\mathdefault{1}$", + "$\\mathdefault{2}$", + "$\\mathdefault{3}$", + ] + assert actual == expected def test_legend_labels_facetgrid(self): ds2 = self.ds.copy() ds2["hue"] = ["d", "a", "c", "b"] - g = ds2.plot.scatter(x="A", y="B", hue="hue", col="col") - legend_labels = tuple(t.get_text() for t in g.figlegend.texts) - attached_labels = [ - tuple(m.get_label() for m in mappables_per_ax) - for mappables_per_ax in g._mappables - ] - assert list(set(attached_labels)) == [legend_labels] + # g = ds2.plot.scatter(x="A", y="B", hue="hue", col="col") # cateorgical colorbars work now, so hue isn't shown in legend. + g = ds2.plot.scatter(x="A", y="B", hue="hue", markersize="x", col="col") + actual = tuple(t.get_text() for t in g.figlegend.texts) + expected = ( + "x [xunits]", + "$\\mathdefault{0}$", + "$\\mathdefault{1}$", + "$\\mathdefault{2}$", + ) + assert actual == expected def test_add_legend_by_default(self): sc = self.ds.plot.scatter(x="A", y="B", hue="hue") @@ -2685,7 +2707,7 @@ def test_datetime_units(self): def test_datetime_plot1d(self): # Test that matplotlib-native datetime works: p = self.darray.plot.line() - ax = p[0].axes + ax = p.axes # Make sure only mpl converters are used, use type() so only # mpl.dates.AutoDateLocator passes and no other subclasses: @@ -2989,7 +3011,7 @@ def test_datarray_scatter(x, y, z, hue, markersize, row, col, add_legend, add_co darray = xr.DataArray(ds[y], coords=coords) with figure_context(): - darray.plot._scatter( + darray.plot.scatter( x=x, z=z, hue=hue,