From 8d71cc58c5da81aa6ddd6fb96b4447a270fbdfa9 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 14 Feb 2021 12:13:43 +0100 Subject: [PATCH 01/46] add dataarray scatter --- xarray/plot/plot.py | 368 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 368 insertions(+) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 75fa786ecc5..ad5605f1871 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -26,8 +26,175 @@ get_axis, import_matplotlib_pyplot, label_from_attrs, + _is_numeric, ) +from ..core.alignment import broadcast + +# copied from seaborn +_MARKERSIZE_RANGE = np.array([18.0, 72.0]) + + +def _infer_meta_data(darray, x, y, hue, hue_style, size, add_guide): + def _hue_calc(darray, hue, hue_style, add_guide): + """Something.""" + hue_is_numeric = _is_numeric(darray[hue].values) + + if hue_style is None: + hue_style = "continuous" if hue_is_numeric else "discrete" + elif hue_style not in ["discrete", "continuous"]: + raise ValueError( + "hue_style must be either None, 'discrete' or 'continuous'." + ) + + if not hue_is_numeric and (hue_style == "continuous"): + raise ValueError( + f"Cannot create a colorbar for a non numeric coordinate: {hue}" + ) + + hue_label = label_from_attrs(darray[hue]) + hue = darray[hue] + + # Handle colorbar and legend: + if add_guide is None or add_guide is True: + add_colorbar = True if hue_style == "continuous" else False + add_legend = True if hue_style == "discrete" else False + else: + add_colorbar = False + add_legend = False + + return hue, hue_style, hue_label, add_colorbar, add_legend + + if x is not None and y is not None: + raise ValueError("Cannot specify both x and y kwargs for line plots.") + + if x is not None: + xplt = darray[x] + yplt = darray + elif y is not None: + xplt = darray + yplt = darray[y] + xlabel = label_from_attrs(xplt) + ylabel = label_from_attrs(yplt) + + if hue: + hue, hue_style, hue_label, add_colorbar, add_legend = _hue_calc( + darray, hue, hue_style, add_guide + ) + else: + # Try finding a hue: + _, hue = _infer_xy_labels(darray=yplt, x=xplt.name, y=hue) + + if hue: + hue, hue_style, hue_label, add_colorbar, add_legend = _hue_calc( + darray, hue, hue_style, add_guide + ) + else: + if add_guide is True: + raise ValueError("Cannot set add_guide when hue is None.") + add_legend = False + add_colorbar = False + + hue_label = None + hue = None + + if size: + size_label = label_from_attrs(darray[size]) + size = darray[size] + else: + size_label = None + size = None + + return dict( + add_colorbar=add_colorbar, + add_legend=add_legend, + hue_label=hue_label, + hue_style=hue_style, + xlabel=xlabel, + ylabel=ylabel, + hue=hue, + size=size, + size_label=size_label, + ) + + +# copied from seaborn +def _parse_size(data, norm, width): + + import matplotlib as mpl + + if data is None: + return None + + data = data.values.flatten() + + if not _is_numeric(data): + levels = np.unique(data) + numbers = np.arange(1, 1 + len(levels))[::-1] + else: + levels = numbers = np.sort(np.unique(data)) + + min_width, max_width = width + # width_range = min_width, max_width + + if norm is None: + norm = mpl.colors.Normalize() + elif isinstance(norm, tuple): + norm = mpl.colors.Normalize(*norm) + elif not isinstance(norm, mpl.colors.Normalize): + err = "``size_norm`` must be None, tuple, or Normalize object." + raise ValueError(err) + + norm.clip = True + if not norm.scaled(): + norm(np.asarray(numbers)) + # limits = norm.vmin, norm.vmax + + scl = norm(numbers) + widths = np.asarray(min_width + scl * (max_width - min_width)) + if scl.mask.any(): + widths[scl.mask] = 0 + sizes = dict(zip(levels, widths)) + + return pd.Series(sizes) + + +def _infer_scatter_data( + darray, x, y, hue, size, size_norm, size_mapping=None, size_range=(1, 10) +): + if x is not None and y is not None: + raise ValueError("Cannot specify both x and y kwargs for scatter plots.") + + # Broadcast together all the chosen variables: + if x is not None: + to_broadcast = dict(x=darray[x], y=darray) + elif y is not None: + to_broadcast = dict(x=darray, y=darray[y]) + to_broadcast.update( + { + key: darray[value] + for key, value in dict(hue=hue, size=size).items() + if value in darray.dims + } + ) + broadcasted = dict(zip(to_broadcast.keys(), broadcast(*(to_broadcast.values())))) + + data = dict(x=broadcasted["x"], y=broadcasted["y"], sizes=None) + data.update(hue=broadcasted.get("hue", None)) + + if size: + size = broadcasted["size"] + + if size_mapping is None: + size_mapping = _parse_size(size, size_norm, size_range) + + data["sizes"] = size.copy( + data=np.reshape(size_mapping.loc[size.values.ravel()].values, size.shape) + ) + data["sizes_to_labels"] = pd.Series(size_mapping.index, index=size_mapping) + + return data + def _infer_line_data(darray, x, y, hue): @@ -427,6 +594,204 @@ def hist( return primitive +def scatter( + darray, + *args, + row=None, + col=None, + figsize=None, + aspect=None, + size=None, + ax=None, + hue=None, + hue_style=None, + x=None, + y=None, + xincrease=None, + yincrease=None, + xscale=None, + yscale=None, + xticks=None, + yticks=None, + xlim=None, + ylim=None, + add_legend=True, + _labels=True, + **kwargs, +): + # Handle facetgrids first + if row or col: + allargs = locals().copy() + allargs.update(allargs.pop("kwargs")) + allargs.pop("darray") + return _easy_facetgrid(darray, scatter, kind="dataarray", **allargs) + + # _is_facetgrid = kwargs.pop("_is_facetgrid", False) + _sizes = kwargs.pop("markersize", kwargs.pop("linewidth", None)) + add_guide = kwargs.pop("add_guide", None) + size_norm = kwargs.pop("size_norm", None) + size_mapping = kwargs.pop("size_mapping", None) # set by facetgrid + cbar_ax = kwargs.pop("cbar_ax", None) + + figsize = kwargs.pop("figsize", None) + ax = get_axis(figsize, size, aspect, ax) + + _data = _infer_meta_data(darray, x, y, hue, hue_style, _sizes, add_guide) + + # need to infer size_mapping with full dataset + _data.update( + _infer_scatter_data( + darray, + x, + y, + _data["hue_label"], + _sizes, + size_norm, + size_mapping, + _MARKERSIZE_RANGE, + ) + ) + + # Plot the data: + if _data["hue_style"] == "discrete": + # Plot discrete data. ax.scatter only supports numerical values + # in colors and sizes. Use a for loop to work around this issue. + + primitive = [] + # use pd.unique instead of np.unique because that keeps the order of the labels, + # which is important to keep them in sync with the ones used in + # FacetGrid.add_legend + for label in pd.unique(_data["hue"].values.ravel()): + mask = _data["hue"] == label + if _data["sizes"] is not None: + kwargs.update(s=_data["sizes"].where(mask, drop=True).values.ravel()) + + primitive.append( + ax.scatter( + _data["x"].where(mask, drop=True).values.ravel(), + _data["y"].where(mask, drop=True).values.ravel(), + label=label, + **kwargs, + ) + ) + elif _data["hue_label"] is None or _data["hue_style"] == "continuous": + # ax.scatter suppoerts numerical values in colors and sizes. + # So no need for for loops. + + cmap_params_subset = {} + if _data["hue"] is not None: + kwargs.update(c=_data["hue"].values.ravel()) + + cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( + scatter, darray[_data["hue_label"]].values, **locals() + ) + + # subset that can be passed to scatter, hist2d + cmap_params_subset = { + vv: cmap_params[vv] for vv in ["vmin", "vmax", "norm", "cmap"] + } + + if _data["sizes"] is not None: + kwargs.update(s=_data["sizes"].values.ravel()) + + primitive = ax.scatter( + _data["x"].values.ravel(), + _data["y"].values.ravel(), + **cmap_params_subset, + **kwargs, + ) + + # Set x and y labels: + if _data["xlabel"]: + ax.set_xlabel(_data["xlabel"]) + if _data["ylabel"]: + ax.set_ylabel(_data["ylabel"]) + + def _legend_elements_from_list(primitives, prop, **kwargs): + """ + Get unique legend elements from a list of pathcollections. + + Getting multiple pathcollections happens when adding multiple + scatters to the same plot. + """ + import warnings + + handles = np.array([], dtype=object) + labels = np.array([], dtype=str) + + for i, pc in enumerate(primitives): + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + # Get legend elements, suppress empty data warnings + # because it will be handled later: + hdl, lbl = pc.legend_elements(prop=prop, **kwargs) + handles = np.append(handles, hdl) + labels = np.append(labels, lbl) + + # Duplicate legend entries is not necessary, therefore return + # unique labels: + unique_indices = np.sort(np.unique(labels, return_index=True)[1]) + handles = handles[unique_indices] + labels = labels[unique_indices] + + return [handles, labels] + + def _add_legend(primitives, prop, func, ax, title, loc): + """Add legend to axes.""" + # Get handles and labels to use in the legend: + handles, labels = _legend_elements_from_list( + primitives, prop, num="auto", alpha=0.6, func=func, + ) + + # title has to be a required check as otherwise the legend may + # display values that are not related to the data, such as the + # markersize value: + if title and len(handles) > 1: + # The normal case where a prop has been defined and + # legend_elements finds results: + legend = ax.legend(handles, labels, title=title, loc=loc) + ax.add_artist(legend) + elif title and len(primitives) > 1: + # For caases when + legend = ax.legend(handles=primitives, title=title, loc=loc) + ax.add_artist(legend) + + if _data["hue_style"] == "discrete": + primitives = primitive + else: + primitives = [primitive] + + if _data["add_legend"] and _data["hue_label"]: + _add_legend( + primitives, + prop="colors", + func=lambda x: x, + ax=ax, + title=_data["hue_label"], + loc="upper right", + ) + + if _data["add_legend"] and _data["size_label"]: + _add_legend( + primitives, + prop="sizes", + func=lambda x: _data["sizes_to_labels"][x] + if "sizes_to_labels" in _data + else x, + ax=ax, + title=_data["size_label"], + loc="upper left", + ) + + if _data["add_colorbar"] and _data["hue_label"]: + cbar_kwargs = {} if cbar_kwargs is None else cbar_kwargs + if "label" not in cbar_kwargs: + cbar_kwargs["label"] = _data["hue_label"] + _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params) + + return primitive + + # MUST run before any 2d plotting functions are defined since # _plot2d decorator adds them as methods here. class _PlotMethods: @@ -460,6 +825,9 @@ def line(self, *args, **kwargs): def step(self, *args, **kwargs): return step(self._da, *args, **kwargs) + @functools.wraps(scatter) + def scatter(self, *args, **kwargs): + return scatter(self._da, *args, **kwargs) def override_signature(f): def wrapper(func): From 2dff1ce6624d31f95d07e18e90bb9b5a249166ad Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 14 Feb 2021 12:16:42 +0100 Subject: [PATCH 02/46] Update plot.py --- xarray/plot/plot.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index ad5605f1871..3e4a77bedc5 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -829,6 +829,7 @@ def step(self, *args, **kwargs): def scatter(self, *args, **kwargs): return scatter(self._da, *args, **kwargs) + def override_signature(f): def wrapper(func): func.__wrapped__ = f From 6837a382449e467b7605a345c75cc0e50a5777b6 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 14 Feb 2021 12:31:30 +0100 Subject: [PATCH 03/46] Update plot.py --- xarray/plot/plot.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 3e4a77bedc5..198c514ede8 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -11,6 +11,7 @@ import numpy as np import pandas as pd +from ..core.alignment import broadcast from .facetgrid import _easy_facetgrid from .utils import ( _add_colorbar, @@ -18,6 +19,7 @@ _ensure_plottable, _infer_interval_breaks, _infer_xy_labels, + _is_numeric, _process_cmap_cbar_kwargs, _rescale_imshow_rgb, _resolve_intervals_1dplot, @@ -26,11 +28,8 @@ get_axis, import_matplotlib_pyplot, label_from_attrs, - _is_numeric, ) -from ..core.alignment import broadcast - # copied from seaborn _MARKERSIZE_RANGE = np.array([18.0, 72.0]) From 1364e6389a604cc32ca57aa1098c8dfe2f4030e4 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 14 Feb 2021 12:35:31 +0100 Subject: [PATCH 04/46] Update plot.py --- xarray/plot/plot.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 198c514ede8..0d31fe4ddb0 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -739,7 +739,11 @@ def _add_legend(primitives, prop, func, ax, title, loc): """Add legend to axes.""" # Get handles and labels to use in the legend: handles, labels = _legend_elements_from_list( - primitives, prop, num="auto", alpha=0.6, func=func, + primitives, + prop, + num="auto", + alpha=0.6, + func=func, ) # title has to be a required check as otherwise the legend may From 66850f2ebc9dca648310b72ae84bc1e3245a32f9 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 14 Feb 2021 13:16:10 +0100 Subject: [PATCH 05/46] copy doc --- xarray/plot/plot.py | 82 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 0d31fe4ddb0..a71b045ac81 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -618,6 +618,88 @@ def scatter( _labels=True, **kwargs, ): + """ + Scatter plot a DataArray along some coordinates. + + Parameters + ---------- + + darray : DataArray + Dataarray to plot. + x, y : str + Variable names for x, y axis. + hue: str, optional + Variable by which to color scattered points + hue_style: str, optional + Can be either 'discrete' (legend) or 'continuous' (color bar). + markersize: str, optional + scatter only. Variable by which to vary size of scattered points. + size_norm: optional + Either None or 'Norm' instance to normalize the 'markersize' variable. + add_guide: bool, optional + Add a guide that depends on hue_style + - for "discrete", build a legend. + This is the default for non-numeric `hue` variables. + - for "continuous", build a colorbar + row : str, optional + If passed, make row faceted plots on this dimension name + col : str, optional + If passed, make column faceted plots on this dimension name + col_wrap : int, optional + Use together with ``col`` to wrap faceted plots + ax : matplotlib axes object, optional + If None, uses the current axis. Not applicable when using facets. + subplot_kws : dict, optional + Dictionary of keyword arguments for matplotlib subplots. Only applies + to FacetGrid plotting. + aspect : scalar, optional + Aspect ratio of plot, so that ``aspect * size`` gives the width in + inches. Only used if a ``size`` is provided. + size : scalar, optional + If provided, create a new figure for the plot with the given size. + Height (in inches) of each plot. See also: ``aspect``. + norm : ``matplotlib.colors.Normalize`` instance, optional + If the ``norm`` has vmin or vmax specified, the corresponding kwarg + must be None. + vmin, vmax : float, optional + Values to anchor the colormap, otherwise they are inferred from the + data and other keyword arguments. When a diverging dataset is inferred, + setting one of these values will fix the other by symmetry around + ``center``. Setting both values prevents use of a diverging colormap. + If discrete levels are provided as an explicit list, both of these + values are ignored. + cmap : str or colormap, optional + The mapping from data values to color space. Either a + matplotlib colormap name or object. If not provided, this will + be either ``viridis`` (if the function infers a sequential + dataset) or ``RdBu_r`` (if the function infers a diverging + dataset). When `Seaborn` is installed, ``cmap`` may also be a + `seaborn` color palette. If ``cmap`` is seaborn color palette + and the plot type is not ``contour`` or ``contourf``, ``levels`` + must also be specified. + colors : color-like or list of color-like, optional + A single color or a list of colors. If the plot type is not ``contour`` + or ``contourf``, the ``levels`` argument is required. + center : float, optional + The value at which to center the colormap. Passing this value implies + use of a diverging colormap. Setting it to ``False`` prevents use of a + diverging colormap. + robust : bool, optional + If True and ``vmin`` or ``vmax`` are absent, the colormap range is + computed with 2nd and 98th percentiles instead of the extreme values. + extend : {"neither", "both", "min", "max"}, optional + How to draw arrows extending the colorbar beyond its limits. If not + provided, extend is inferred from vmin, vmax and the data limits. + levels : int or list-like object, optional + Split the colormap (cmap) into discrete color intervals. If an integer + is provided, "nice" levels are chosen based on the data range: this can + imply that the final number of levels is not exactly the expected one. + Setting ``vmin`` and/or ``vmax`` with ``levels=N`` is equivalent to + setting ``levels=np.linspace(vmin, vmax, N)``. + **kwargs : optional + Additional keyword arguments to matplotlib + """ + # Handle facetgrids first if row or col: allargs = locals().copy() From 487770afd1861538fd605a42f1437bb8ad531abc Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 19 Feb 2021 22:00:33 +0100 Subject: [PATCH 06/46] Update plot.py --- xarray/plot/plot.py | 73 ++++++++++++++++++++------------------------- 1 file changed, 33 insertions(+), 40 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index a71b045ac81..234e96963de 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -34,8 +34,8 @@ _MARKERSIZE_RANGE = np.array([18.0, 72.0]) -def _infer_meta_data(darray, x, y, hue, hue_style, size, add_guide): - def _hue_calc(darray, hue, hue_style, add_guide): +def _infer_meta_data(darray, x, y, hue, hue_style, size): + def _hue_calc(darray, hue, hue_style): """Something.""" hue_is_numeric = _is_numeric(darray[hue].values) @@ -54,15 +54,7 @@ def _hue_calc(darray, hue, hue_style, add_guide): hue_label = label_from_attrs(darray[hue]) hue = darray[hue] - # Handle colorbar and legend: - if add_guide is None or add_guide is True: - add_colorbar = True if hue_style == "continuous" else False - add_legend = True if hue_style == "discrete" else False - else: - add_colorbar = False - add_legend = False - - return hue, hue_style, hue_label, add_colorbar, add_legend + return hue, hue_style, hue_label if x is not None and y is not None: raise ValueError("Cannot specify both x and y kwargs for line plots.") @@ -77,23 +69,14 @@ def _hue_calc(darray, hue, hue_style, add_guide): ylabel = label_from_attrs(yplt) if hue: - hue, hue_style, hue_label, add_colorbar, add_legend = _hue_calc( - darray, hue, hue_style, add_guide - ) + hue, hue_style, hue_label = _hue_calc(darray, hue, hue_style) else: # Try finding a hue: _, hue = _infer_xy_labels(darray=yplt, x=xplt.name, y=hue) if hue: - hue, hue_style, hue_label, add_colorbar, add_legend = _hue_calc( - darray, hue, hue_style, add_guide - ) + hue, hue_style, hue_label = _hue_calc(darray, hue, hue_style) else: - if add_guide is True: - raise ValueError("Cannot set add_guide when hue is None.") - add_legend = False - add_colorbar = False - hue_label = None hue = None @@ -105,8 +88,6 @@ def _hue_calc(darray, hue, hue_style, add_guide): size = None return dict( - add_colorbar=add_colorbar, - add_legend=add_legend, hue_label=hue_label, hue_style=hue_style, xlabel=xlabel, @@ -614,7 +595,8 @@ def scatter( yticks=None, xlim=None, ylim=None, - add_legend=True, + add_legend=None, + add_colorbar=None, _labels=True, **kwargs, ): @@ -623,7 +605,6 @@ def scatter( Parameters ---------- - darray : DataArray Dataarray to plot. x, y : str @@ -699,7 +680,6 @@ def scatter( **kwargs : optional Additional keyword arguments to matplotlib """ - # Handle facetgrids first if row or col: allargs = locals().copy() @@ -709,15 +689,30 @@ def scatter( # _is_facetgrid = kwargs.pop("_is_facetgrid", False) _sizes = kwargs.pop("markersize", kwargs.pop("linewidth", None)) - add_guide = kwargs.pop("add_guide", None) size_norm = kwargs.pop("size_norm", None) size_mapping = kwargs.pop("size_mapping", None) # set by facetgrid cbar_ax = kwargs.pop("cbar_ax", None) + cbar_kwargs = kwargs.pop("cbar_kwargs", None) figsize = kwargs.pop("figsize", None) ax = get_axis(figsize, size, aspect, ax) - _data = _infer_meta_data(darray, x, y, hue, hue_style, _sizes, add_guide) + _data = _infer_meta_data(darray, x, y, hue, hue_style, _sizes) + + add_guide = kwargs.pop("add_guide", None) + if add_legend: + pass + elif add_guide is None or add_guide is True: + add_legend = True if hue_style == "discrete" else False + elif add_legend is None: + add_legend = False + + if add_colorbar: + pass + elif add_guide is None or add_guide is True: + add_colorbar = True if hue_style == "continuous" else False + else: + add_colorbar = False # need to infer size_mapping with full dataset _data.update( @@ -817,15 +812,11 @@ def _legend_elements_from_list(primitives, prop, **kwargs): return [handles, labels] - def _add_legend(primitives, prop, func, ax, title, loc): + def _add_legend(primitives, prop, func, ax, title, **kwargs): """Add legend to axes.""" # Get handles and labels to use in the legend: handles, labels = _legend_elements_from_list( - primitives, - prop, - num="auto", - alpha=0.6, - func=func, + primitives, prop, num="auto", alpha=0.6, func=func, ) # title has to be a required check as otherwise the legend may @@ -834,11 +825,11 @@ def _add_legend(primitives, prop, func, ax, title, loc): if title and len(handles) > 1: # The normal case where a prop has been defined and # legend_elements finds results: - legend = ax.legend(handles, labels, title=title, loc=loc) + legend = ax.legend(handles, labels, title=title, **kwargs) ax.add_artist(legend) elif title and len(primitives) > 1: # For caases when - legend = ax.legend(handles=primitives, title=title, loc=loc) + legend = ax.legend(handles=primitives, title=title, **kwargs) ax.add_artist(legend) if _data["hue_style"] == "discrete": @@ -846,7 +837,7 @@ def _add_legend(primitives, prop, func, ax, title, loc): else: primitives = [primitive] - if _data["add_legend"] and _data["hue_label"]: + if add_legend and _data["hue_label"]: _add_legend( primitives, prop="colors", @@ -856,7 +847,7 @@ def _add_legend(primitives, prop, func, ax, title, loc): loc="upper right", ) - if _data["add_legend"] and _data["size_label"]: + if add_legend and _data["size_label"]: _add_legend( primitives, prop="sizes", @@ -868,7 +859,9 @@ def _add_legend(primitives, prop, func, ax, title, loc): loc="upper left", ) - if _data["add_colorbar"] and _data["hue_label"]: + if add_colorbar and _data["hue_label"]: + if _data["hue_style"] == "discrete": + raise NotImplementedError("Cannot create a colorbar for non numerics.") cbar_kwargs = {} if cbar_kwargs is None else cbar_kwargs if "label" not in cbar_kwargs: cbar_kwargs["label"] = _data["hue_label"] From 88ce1c81b5cd10f19bc53d6374b873087d14b1d9 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 20 Feb 2021 20:16:47 +0100 Subject: [PATCH 07/46] merge hue and size legends --- xarray/plot/plot.py | 111 +++++++++++++++++++++++++++----------------- 1 file changed, 69 insertions(+), 42 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 234e96963de..879deb5ff01 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -693,6 +693,7 @@ def scatter( size_mapping = kwargs.pop("size_mapping", None) # set by facetgrid cbar_ax = kwargs.pop("cbar_ax", None) cbar_kwargs = kwargs.pop("cbar_kwargs", None) + cmap_params = kwargs.pop("cmap_params", None) figsize = kwargs.pop("figsize", None) ax = get_axis(figsize, size, aspect, ax) @@ -703,14 +704,14 @@ def scatter( if add_legend: pass elif add_guide is None or add_guide is True: - add_legend = True if hue_style == "discrete" else False + add_legend = True if _data["hue_style"] == "discrete" else False elif add_legend is None: add_legend = False if add_colorbar: pass elif add_guide is None or add_guide is True: - add_colorbar = True if hue_style == "continuous" else False + add_colorbar = True if _data["hue_style"] == "continuous" else False else: add_colorbar = False @@ -750,6 +751,8 @@ def scatter( **kwargs, ) ) + primitives = primitive + elif _data["hue_label"] is None or _data["hue_style"] == "continuous": # ax.scatter suppoerts numerical values in colors and sizes. # So no need for for loops. @@ -777,6 +780,8 @@ def scatter( **kwargs, ) + primitives = [primitive] + # Set x and y labels: if _data["xlabel"]: ax.set_xlabel(_data["xlabel"]) @@ -812,52 +817,74 @@ def _legend_elements_from_list(primitives, prop, **kwargs): return [handles, labels] - def _add_legend(primitives, prop, func, ax, title, **kwargs): - """Add legend to axes.""" - # Get handles and labels to use in the legend: - handles, labels = _legend_elements_from_list( - primitives, prop, num="auto", alpha=0.6, func=func, - ) + def _legend_add_subtitle(handles, labels, text, func): + """Add a subtitle to legend handles.""" + if text and len(handles) > 1: + # Create a blank handle that's not visible, the + # invisibillity will be used to discern which are subtitles + # or not: + blank_handle = func([], [], label=text) + blank_handle.set_visible(False) - # title has to be a required check as otherwise the legend may - # display values that are not related to the data, such as the - # markersize value: - if title and len(handles) > 1: + # Subtitles are shown first: + handles = np.insert(handles, 0, blank_handle) + labels = np.insert(labels, 0, text) + + return handles, labels + + def _adjust_legend_subtitles(legend): + """Make invisible-handle "subtitles" entries look more like titles.""" + plt = import_matplotlib_pyplot() + + # Legend title not in rcParams until 3.0 + font_size = plt.rcParams.get("legend.title_fontsize", None) + hpackers = legend.findobj(plt.matplotlib.offsetbox.VPacker)[0].get_children() + for hpack in hpackers: + draw_area, text_area = hpack.get_children() + handles = draw_area.get_children() + + # Assume that all artists that are not visible are + # subtitles: + if not all(artist.get_visible() for artist in handles): + # Remove the dummy marker which will bring the text + # more to the center: + draw_area.set_width(0) + for text in text_area.get_children(): + if font_size is not None: + # The sutbtitles should have the same font size + # as normal legend titles: + text.set_size(font_size) + + def _add_legend(primitives, handles, labels, ax, **kwargs): + if len(handles) > 1: # The normal case where a prop has been defined and # legend_elements finds results: - legend = ax.legend(handles, labels, title=title, **kwargs) - ax.add_artist(legend) - elif title and len(primitives) > 1: - # For caases when - legend = ax.legend(handles=primitives, title=title, **kwargs) - ax.add_artist(legend) + legend = ax.legend(handles, labels, framealpha=0.5, **kwargs) + elif len(primitives) > 1: + # When no handles have been found use the primitives instead. + # Example: hue and sizes have the same string: + legend = ax.legend(handles=primitives, framealpha=0.5, **kwargs) - if _data["hue_style"] == "discrete": - primitives = primitive - else: - primitives = [primitive] + return legend - if add_legend and _data["hue_label"]: - _add_legend( - primitives, - prop="colors", - func=lambda x: x, - ax=ax, - title=_data["hue_label"], - loc="upper right", - ) + if add_legend: + handles, labels = np.array([]), np.array([]) + if _data["hue_label"] and (_data["hue_label"] == _data["size_label"]): + _add_legend(primitives, handles, labels, ax, title=_data["hue_label"]) + else: + for hue_lbl, prop, func in [ + (_data["hue_label"], "colors", lambda x: x), + (_data["size_label"], "sizes", lambda x: _data["sizes_to_labels"][x]), + ]: + if hue_lbl: + hdl, lbl = _legend_elements_from_list( + primitives, prop, num="auto", func=func, + ) + hdl, lbl = _legend_add_subtitle(hdl, lbl, hue_lbl, ax.scatter) + handles, labels = np.append(handles, hdl), np.append(labels, lbl) - if add_legend and _data["size_label"]: - _add_legend( - primitives, - prop="sizes", - func=lambda x: _data["sizes_to_labels"][x] - if "sizes_to_labels" in _data - else x, - ax=ax, - title=_data["size_label"], - loc="upper left", - ) + legend = _add_legend(primitives, handles, labels, ax) + _adjust_legend_subtitles(legend) if add_colorbar and _data["hue_label"]: if _data["hue_style"] == "discrete": From 5215c7a6c17f78dd10307f4b57f22552ff4b2c5a Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 20 Feb 2021 20:31:48 +0100 Subject: [PATCH 08/46] Update plot.py --- xarray/plot/plot.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 879deb5ff01..299e5b78799 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -35,8 +35,8 @@ def _infer_meta_data(darray, x, y, hue, hue_style, size): - def _hue_calc(darray, hue, hue_style): - """Something.""" + def _determine_hue(darray, hue, hue_style): + """Find and determine what type of hue it is.""" hue_is_numeric = _is_numeric(darray[hue].values) if hue_style is None: @@ -69,13 +69,13 @@ def _hue_calc(darray, hue, hue_style): ylabel = label_from_attrs(yplt) if hue: - hue, hue_style, hue_label = _hue_calc(darray, hue, hue_style) + hue, hue_style, hue_label = _determine_hue(darray, hue, hue_style) else: # Try finding a hue: _, hue = _infer_xy_labels(darray=yplt, x=xplt.name, y=hue) if hue: - hue, hue_style, hue_label = _hue_calc(darray, hue, hue_style) + hue, hue_style, hue_label = _determine_hue(darray, hue, hue_style) else: hue_label = None hue = None @@ -100,8 +100,8 @@ def _hue_calc(darray, hue, hue_style): # copied from seaborn def _parse_size(data, norm, width): - - import matplotlib as mpl + """Parse sizes.""" + plt = import_matplotlib_pyplot if data is None: return None @@ -118,10 +118,10 @@ def _parse_size(data, norm, width): # width_range = min_width, max_width if norm is None: - norm = mpl.colors.Normalize() + norm = plt.colors.Normalize() elif isinstance(norm, tuple): - norm = mpl.colors.Normalize(*norm) - elif not isinstance(norm, mpl.colors.Normalize): + norm = plt.colors.Normalize(*norm) + elif not isinstance(norm, plt.colors.Normalize): err = "``size_norm`` must be None, tuple, or Normalize object." raise ValueError(err) From bada616b06c7a90f3a5b3f0174f09534ddd75cad Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 20 Feb 2021 22:08:41 +0100 Subject: [PATCH 09/46] add test --- xarray/plot/plot.py | 6 +++--- xarray/tests/test_plot.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 299e5b78799..a5032264765 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -721,7 +721,7 @@ def scatter( darray, x, y, - _data["hue_label"], + _data["hue"].name, _sizes, size_norm, size_mapping, @@ -753,7 +753,7 @@ def scatter( ) primitives = primitive - elif _data["hue_label"] is None or _data["hue_style"] == "continuous": + elif _data["hue_style"] == "continuous": # ax.scatter suppoerts numerical values in colors and sizes. # So no need for for loops. @@ -762,7 +762,7 @@ def scatter( kwargs.update(c=_data["hue"].values.ravel()) cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( - scatter, darray[_data["hue_label"]].values, **locals() + scatter, _data["hue"].values, **locals() ) # subset that can be passed to scatter, hist2d diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 705b2d5e2e7..6a4b1a4429c 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2590,3 +2590,36 @@ def test_get_axis_cartopy(): with figure_context(): ax = get_axis(**kwargs) assert isinstance(ax, ctpy.mpl.geoaxes.GeoAxesSubplot) + + +@requires_matplotlib +@pytest.mark.parametrize( + "x, y, hue, markersize, add_legend, add_colorbar", + [ + ("A", "B", None, None, None, None), + ("B", "A", "w", None, True, None), + ("A", "B", "y", "z", True, True), + ], +) +def test_datarray_scatter(x, y, hue, markersize, add_legend, add_colorbar): + """Test datarray scatter. Merge with TestPlot1D eventually.""" + ds = xr.tutorial.scatter_example_dataset() + + extra_coords = [v for v in [x, hue, markersize] if v is not None] + + # Base coords: + coords = dict(ds.coords) + + # Add extra coords to the DataArray: + coords.update({v: ds[v] for v in extra_coords}) + + darray = xr.DataArray(ds[y], coords=coords) + + with figure_context(): + darray.plot.scatter( + x=x, + hue=hue, + markersize=markersize, + add_legend=add_legend, + add_colorbar=add_colorbar, + ) From ccc1b3575666678f9e075830bd5bbc2069b9acb1 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 20 Feb 2021 22:26:27 +0100 Subject: [PATCH 10/46] Update plot.py --- xarray/plot/plot.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index a5032264765..989cc9e21fc 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -101,7 +101,7 @@ def _determine_hue(darray, hue, hue_style): # copied from seaborn def _parse_size(data, norm, width): """Parse sizes.""" - plt = import_matplotlib_pyplot + plt = import_matplotlib_pyplot() if data is None: return None @@ -118,10 +118,10 @@ def _parse_size(data, norm, width): # width_range = min_width, max_width if norm is None: - norm = plt.colors.Normalize() + norm = plt.Normalize() elif isinstance(norm, tuple): - norm = plt.colors.Normalize(*norm) - elif not isinstance(norm, plt.colors.Normalize): + norm = plt.Normalize(*norm) + elif not isinstance(norm, plt.Normalize): err = "``size_norm`` must be None, tuple, or Normalize object." raise ValueError(err) From fccc1d7f271afddadf943d62126e4f99ae3fcdb6 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 21 Feb 2021 14:43:32 +0100 Subject: [PATCH 11/46] add 3d support --- xarray/plot/plot.py | 282 +++++++++++++++++++++++++------------------- 1 file changed, 160 insertions(+), 122 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 989cc9e21fc..19cf00c5c6c 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -34,67 +34,56 @@ _MARKERSIZE_RANGE = np.array([18.0, 72.0]) -def _infer_meta_data(darray, x, y, hue, hue_style, size): - def _determine_hue(darray, hue, hue_style): - """Find and determine what type of hue it is.""" - hue_is_numeric = _is_numeric(darray[hue].values) - - if hue_style is None: - hue_style = "continuous" if hue_is_numeric else "discrete" - elif hue_style not in ["discrete", "continuous"]: +def _infer_meta_data(darray, x, z, hue, hue_style, size): + def _determine_array(darray, name, array_style): + """Find and determine what type of array it is.""" + array = darray[name] + array_is_numeric = _is_numeric(array.values) + + if array_style is None: + array_style = "continuous" if array_is_numeric else "discrete" + elif array_style not in ["discrete", "continuous"]: raise ValueError( "hue_style must be either None, 'discrete' or 'continuous'." ) - if not hue_is_numeric and (hue_style == "continuous"): + if not array_is_numeric and (array_style == "continuous"): raise ValueError( - f"Cannot create a colorbar for a non numeric coordinate: {hue}" + f"Cannot create a colorbar for a non numeric coordinate: {name}" ) - hue_label = label_from_attrs(darray[hue]) - hue = darray[hue] + array_label = label_from_attrs(array) - return hue, hue_style, hue_label + return array, array_style, array_label - if x is not None and y is not None: - raise ValueError("Cannot specify both x and y kwargs for line plots.") + # if x is not None and y is not None: + # raise ValueError("Cannot specify both x and y kwargs for line plots.") - if x is not None: - xplt = darray[x] - yplt = darray - elif y is not None: - xplt = darray - yplt = darray[y] - xlabel = label_from_attrs(xplt) - ylabel = label_from_attrs(yplt) + label = dict(y=label_from_attrs(darray)) + label.update( + {k: v if v in darray.coords else None for k, v in [("x", x), ("z", z)]} + ) if hue: - hue, hue_style, hue_label = _determine_hue(darray, hue, hue_style) + hue, hue_style, hue_label = _determine_array(darray, hue, hue_style) else: - # Try finding a hue: - _, hue = _infer_xy_labels(darray=yplt, x=xplt.name, y=hue) - - if hue: - hue, hue_style, hue_label = _determine_hue(darray, hue, hue_style) - else: - hue_label = None - hue = None + hue, hue_style, hue_label = None, None, None if size: - size_label = label_from_attrs(darray[size]) - size = darray[size] + size, size_style, size_label = _determine_array(darray, size, None) else: - size_label = None - size = None + size, size_style, size_label = None, None, None return dict( + xlabel=label["x"], + ylabel=label["y"], + zlabel=label["z"], + hue=hue, hue_label=hue_label, hue_style=hue_style, - xlabel=xlabel, - ylabel=ylabel, - hue=hue, size=size, size_label=size_label, + size_style=size_style, ) @@ -140,40 +129,43 @@ def _parse_size(data, norm, width): def _infer_scatter_data( - darray, x, y, hue, size, size_norm, size_mapping=None, size_range=(1, 10) + darray, x, z, hue, size, size_norm, size_mapping=None, size_range=(1, 10) ): - if x is not None and y is not None: - raise ValueError("Cannot specify both x and y kwargs for scatter plots.") + # if x is not None and y is not None: + # raise ValueError("Cannot specify both x and y kwargs for scatter plots.") # Broadcast together all the chosen variables: - if x is not None: - to_broadcast = dict(x=darray[x], y=darray) - elif y is not None: - to_broadcast = dict(x=darray, y=darray[y]) + to_broadcast = dict(y=darray) + to_broadcast.update( + {k: darray[v] for k, v in dict(x=x, z=z).items() if v is not None} + ) to_broadcast.update( { key: darray[value] - for key, value in dict(hue=hue, size=size).items() + for key, value in dict(hue=hue, sizes=size).items() if value in darray.dims } ) broadcasted = dict(zip(to_broadcast.keys(), broadcast(*(to_broadcast.values())))) - - data = dict(x=broadcasted["x"], y=broadcasted["y"], sizes=None) - data.update(hue=broadcasted.get("hue", None)) + broadcasted.update( + hue=broadcasted.pop("hue", None), sizes=broadcasted.pop("sizes", None) + ) if size: - size = broadcasted["size"] - if size_mapping is None: - size_mapping = _parse_size(size, size_norm, size_range) + size_mapping = _parse_size(broadcasted["sizes"], size_norm, size_range) - data["sizes"] = size.copy( - data=np.reshape(size_mapping.loc[size.values.ravel()].values, size.shape) + broadcasted["sizes"] = broadcasted["sizes"].copy( + data=np.reshape( + size_mapping.loc[broadcasted["sizes"].values.ravel()].values, + broadcasted["sizes"].shape, + ) + ) + broadcasted["sizes_to_labels"] = pd.Series( + size_mapping.index, index=size_mapping ) - data["sizes_to_labels"] = pd.Series(size_mapping.index, index=size_mapping) - return data + return broadcasted def _infer_line_data(darray, x, y, hue): @@ -586,7 +578,7 @@ def scatter( hue=None, hue_style=None, x=None, - y=None, + z=None, xincrease=None, yincrease=None, xscale=None, @@ -696,19 +688,25 @@ def scatter( cmap_params = kwargs.pop("cmap_params", None) figsize = kwargs.pop("figsize", None) - ax = get_axis(figsize, size, aspect, ax) + if z is None: + ax = get_axis(figsize, size, aspect, ax) + else: + try: + ax.set_zlabel + except AttributeError as e: + raise AttributeError("3D projection not set on axes.") from e - _data = _infer_meta_data(darray, x, y, hue, hue_style, _sizes) + _data = _infer_meta_data(darray, x, z, hue, hue_style, _sizes) add_guide = kwargs.pop("add_guide", None) - if add_legend: + if add_legend is not None: pass elif add_guide is None or add_guide is True: add_legend = True if _data["hue_style"] == "discrete" else False elif add_legend is None: add_legend = False - if add_colorbar: + if add_colorbar is not None: pass elif add_guide is None or add_guide is True: add_colorbar = True if _data["hue_style"] == "continuous" else False @@ -718,42 +716,13 @@ def scatter( # need to infer size_mapping with full dataset _data.update( _infer_scatter_data( - darray, - x, - y, - _data["hue"].name, - _sizes, - size_norm, - size_mapping, - _MARKERSIZE_RANGE, + darray, x, z, hue, _sizes, size_norm, size_mapping, _MARKERSIZE_RANGE, ) ) # Plot the data: - if _data["hue_style"] == "discrete": - # Plot discrete data. ax.scatter only supports numerical values - # in colors and sizes. Use a for loop to work around this issue. - - primitive = [] - # use pd.unique instead of np.unique because that keeps the order of the labels, - # which is important to keep them in sync with the ones used in - # FacetGrid.add_legend - for label in pd.unique(_data["hue"].values.ravel()): - mask = _data["hue"] == label - if _data["sizes"] is not None: - kwargs.update(s=_data["sizes"].where(mask, drop=True).values.ravel()) - - primitive.append( - ax.scatter( - _data["x"].where(mask, drop=True).values.ravel(), - _data["y"].where(mask, drop=True).values.ravel(), - label=label, - **kwargs, - ) - ) - primitives = primitive - - elif _data["hue_style"] == "continuous": + axis_order = ["x", "z", "y"] + if _data["hue_style"] is None or _data["hue_style"] == "continuous": # ax.scatter suppoerts numerical values in colors and sizes. # So no need for for loops. @@ -774,19 +743,50 @@ def scatter( kwargs.update(s=_data["sizes"].values.ravel()) primitive = ax.scatter( - _data["x"].values.ravel(), - _data["y"].values.ravel(), + *[ + _data[v].values.ravel() + for v in axis_order + if _data.get(v, None) is not None + ], **cmap_params_subset, **kwargs, ) primitives = [primitive] - # Set x and y labels: - if _data["xlabel"]: - ax.set_xlabel(_data["xlabel"]) - if _data["ylabel"]: - ax.set_ylabel(_data["ylabel"]) + elif _data["hue_style"] == "discrete": + # Plot discrete data. ax.scatter only supports numerical values + # in colors and sizes. Use a for loop to work around this issue. + + primitive = [] + # use pd.unique instead of np.unique because that keeps the order of the labels, + # which is important to keep them in sync with the ones used in + # FacetGrid.add_legend + for label in pd.unique(_data["hue"].values.ravel()): + mask = _data["hue"] == label + if _data["sizes"] is not None: + kwargs.update(s=_data["sizes"].where(mask, drop=True).values.ravel()) + + primitive.append( + ax.scatter( + *[ + _data[v].where(mask, drop=True).values.ravel() + for v in axis_order + if _data.get(v, None) is not None + ], + label=label, + **kwargs, + ) + ) + primitives = primitive + + # Set x, y, z labels: + i = 0 + set_label = [ax.set_xlabel, ax.set_ylabel, getattr(ax, "set_zlabel", [])] + for v in axis_order: + if _data.get(f"{v}label", None) is not None: + set_label[i](_data[f"{v}label"]) + i += 1 def _legend_elements_from_list(primitives, prop, **kwargs): """ @@ -856,36 +856,74 @@ def _adjust_legend_subtitles(legend): text.set_size(font_size) def _add_legend(primitives, handles, labels, ax, **kwargs): + # Title is used as backup: + title = kwargs.pop("title", None) if len(handles) > 1: # The normal case where a prop has been defined and # legend_elements finds results: - legend = ax.legend(handles, labels, framealpha=0.5, **kwargs) + return ax.legend(handles, labels, framealpha=0.5, **kwargs) elif len(primitives) > 1: # When no handles have been found use the primitives instead. - # Example: hue and sizes have the same string: - legend = ax.legend(handles=primitives, framealpha=0.5, **kwargs) - - return legend - - if add_legend: - handles, labels = np.array([]), np.array([]) - if _data["hue_label"] and (_data["hue_label"] == _data["size_label"]): - _add_legend(primitives, handles, labels, ax, title=_data["hue_label"]) + # Example: + # * hue and sizes having the same string + # * non-numerics in colors or sizes + return ax.legend(handles=primitives, framealpha=0.5, title=title, **kwargs) else: - for hue_lbl, prop, func in [ - (_data["hue_label"], "colors", lambda x: x), - (_data["size_label"], "sizes", lambda x: _data["sizes_to_labels"][x]), - ]: - if hue_lbl: - hdl, lbl = _legend_elements_from_list( - primitives, prop, num="auto", func=func, - ) - hdl, lbl = _legend_add_subtitle(hdl, lbl, hue_lbl, ax.scatter) - handles, labels = np.append(handles, hdl), np.append(labels, lbl) + return None - legend = _add_legend(primitives, handles, labels, ax) + if add_legend: + handles, labels = [], [] + for hue_lbl, prop, func in [ + (_data["hue_label"], "colors", lambda x: x), + ( + _data["size_label"], + "sizes", + lambda x: _data["sizes_to_labels"][x] + if "sizes_to_labels" in _data + else x, + ), + ]: + # if hue_lbl: + hdl, lbl = _legend_elements_from_list( + primitives, prop, num="auto", func=func, + ) + hdl, lbl = _legend_add_subtitle(hdl, lbl, hue_lbl, ax.scatter) + # handles, labels = ax.get_legend_handles_labels() + handles.append(hdl) + labels.append(lbl) + legend = _add_legend( + primitives, + np.concatenate(handles), + np.concatenate(labels), + ax, + title=_data["hue_label"], + ) + if legend is not None: _adjust_legend_subtitles(legend) + # else: + # hdl, lbl = [], [] + # handles.append(hdl) + # labels.append(lbl) + + # if _data["hue_label"] is not None and _data["size_label"] is not None: + # legend = _add_legend(primitives, np.append(*handles), np.append(*labels), ax) + # if legend is not None: + # _adjust_legend_subtitles(legend) + # else: + # _add_legend(primitives, [], [], ax, title=_data["hue_label"]) + + # if (_data["hue_label"] == _data["size_label"]) or ( + # all(len(h) < 2 for h in handles) and len(primitives) > 1 + # ): + # _add_legend(primitives, [], [], ax, title=_data["hue_label"]) + # elif len(primitives) < 2 and len(handles) < 2: + # pass + # else: + # legend = _add_legend(primitives, np.append(*handles), np.append(*labels), ax) + # if legend is not None: + # _adjust_legend_subtitles(legend) + if add_colorbar and _data["hue_label"]: if _data["hue_style"] == "discrete": raise NotImplementedError("Cannot create a colorbar for non numerics.") From 059ffc080296d7d3f7b341a58e77bcd21d851631 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 21 Feb 2021 16:31:12 +0100 Subject: [PATCH 12/46] fix labels --- xarray/plot/plot.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 19cf00c5c6c..aae46e7b97b 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -61,7 +61,10 @@ def _determine_array(darray, name, array_style): label = dict(y=label_from_attrs(darray)) label.update( - {k: v if v in darray.coords else None for k, v in [("x", x), ("z", z)]} + { + k: label_from_attrs(darray[v]) if v in darray.coords else None + for k, v in [("x", x), ("z", z)] + } ) if hue: From f281634716db98e839275d9afcdd15fcaa89fc7d Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 22 Feb 2021 21:08:08 +0100 Subject: [PATCH 13/46] Update plot.py --- xarray/plot/plot.py | 166 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 166 insertions(+) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index aae46e7b97b..89ea49d78d0 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -34,6 +34,172 @@ _MARKERSIZE_RANGE = np.array([18.0, 72.0]) +# Copied from matplotlib, tweaked so func can return strings. +# https://github.com/matplotlib/matplotlib/issues/19555 +def legend_elements( + self, prop="colors", num="auto", fmt=None, func=lambda x: x, **kwargs +): + """ + Create legend handles and labels for a PathCollection. + + Each legend handle is a `.Line2D` representing the Path that was drawn, + and each label is a string what each Path represents. + + This is useful for obtaining a legend for a `~.Axes.scatter` plot; + e.g.:: + + scatter = plt.scatter([1, 2, 3], [4, 5, 6], c=[7, 2, 3]) + plt.legend(*scatter.legend_elements()) + + creates three legend elements, one for each color with the numerical + values passed to *c* as the labels. + + Also see the :ref:`automatedlegendcreation` example. + + + Parameters + ---------- + prop : {"colors", "sizes"}, default: "colors" + If "colors", the legend handles will show the different colors of + the collection. If "sizes", the legend will show the different + sizes. To set both, use *kwargs* to directly edit the `.Line2D` + properties. + num : int, None, "auto" (default), array-like, or `~.ticker.Locator` + Target number of elements to create. + If None, use all unique elements of the mappable array. If an + integer, target to use *num* elements in the normed range. + If *"auto"*, try to determine which option better suits the nature + of the data. + The number of created elements may slightly deviate from *num* due + to a `~.ticker.Locator` being used to find useful locations. + If a list or array, use exactly those elements for the legend. + Finally, a `~.ticker.Locator` can be provided. + fmt : str, `~matplotlib.ticker.Formatter`, or None (default) + The format or formatter to use for the labels. If a string must be + a valid input for a `~.StrMethodFormatter`. If None (the default), + use a `~.ScalarFormatter`. + func : function, default: ``lambda x: x`` + Function to calculate the labels. Often the size (or color) + argument to `~.Axes.scatter` will have been pre-processed by the + user using a function ``s = f(x)`` to make the markers visible; + e.g. ``size = np.log10(x)``. Providing the inverse of this + function here allows that pre-processing to be inverted, so that + the legend labels have the correct values; e.g. ``func = lambda + x: 10**x``. + **kwargs + Allowed keyword arguments are *color* and *size*. E.g. it may be + useful to set the color of the markers if *prop="sizes"* is used; + similarly to set the size of the markers if *prop="colors"* is + used. Any further parameters are passed onto the `.Line2D` + instance. This may be useful to e.g. specify a different + *markeredgecolor* or *alpha* for the legend handles. + + Returns + ------- + handles : list of `.Line2D` + Visual representation of each element of the legend. + labels : list of str + The string labels for elements of the legend. + """ + import matplotlib as mpl + + mlines = mpl.lines + + handles = [] + labels = [] + hasarray = self.get_array() is not None + if fmt is None: + fmt = mpl.ticker.ScalarFormatter(useOffset=False, useMathText=True) + elif isinstance(fmt, str): + fmt = mpl.ticker.StrMethodFormatter(fmt) + fmt.create_dummy_axis() + + if prop == "colors": + if not hasarray: + warnings.warn( + "Collection without array used. Make sure to " + "specify the values to be colormapped via the " + "`c` argument." + ) + return handles, labels + u = np.unique(self.get_array()) + size = kwargs.pop("size", mpl.rcParams["lines.markersize"]) + elif prop == "sizes": + u = np.unique(self.get_sizes()) + color = kwargs.pop("color", "k") + else: + raise ValueError( + "Valid values for `prop` are 'colors' or " + f"'sizes'. You supplied '{prop}' instead." + ) + + func_value = np.asarray(func(u)) + func_is_numeric = np.issubdtype(func_value.dtype, np.number) + if func_is_numeric: + fmt.set_bounds(min(func_value), max(func_value)) + + if num == "auto": + num = 9 + if len(u) <= num: + num = None + if num is None: + values = u + label_values = func(values) + elif not func_is_numeric: + # Values are not numerical so instead of interpolating + # just choose evenly distributed indexes instead: + def which_idxs(m, n): + out = np.rint(np.linspace(1, n, min(m, n)) - 1) + return out.astype(int) + + label_values = func(u) + cond = which_idxs(num, len(label_values)) + values = u[cond] + label_values = label_values[cond] + else: + if prop == "colors": + arr = self.get_array() + elif prop == "sizes": + arr = self.get_sizes() + if isinstance(num, mpl.ticker.Locator): + loc = num + elif np.iterable(num): + loc = mpl.ticker.FixedLocator(num) + else: + num = int(num) + loc = mpl.ticker.MaxNLocator( + nbins=num, min_n_ticks=num - 1, steps=[1, 2, 2.5, 3, 5, 6, 8, 10] + ) + label_values = loc.tick_values(func(arr).min(), func(arr).max()) + cond = (label_values >= func(arr).min()) & (label_values <= func(arr).max()) + label_values = label_values[cond] + yarr = np.linspace(arr.min(), arr.max(), 256) + xarr = func(yarr) + ix = np.argsort(xarr) + values = np.interp(label_values, xarr[ix], yarr[ix]) + + kw = dict(markeredgewidth=self.get_linewidths()[0], alpha=self.get_alpha()) + kw.update(kwargs) + + for val, lab in zip(values, label_values): + if prop == "colors": + color = self.cmap(self.norm(val)) + elif prop == "sizes": + size = np.sqrt(val) + if np.isclose(size, 0.0): + continue + h = mlines.Line2D( + [0], [0], ls="", color=color, ms=size, marker=self.get_paths()[0], **kw + ) + handles.append(h) + if hasattr(fmt, "set_locs"): + fmt.set_locs(label_values) + l = fmt(lab) + labels.append(l) + + return handles, labels + + def _infer_meta_data(darray, x, z, hue, hue_style, size): def _determine_array(darray, name, array_style): """Find and determine what type of array it is.""" From 99eb5c1b52e858f8903eaa9d5f7a8984bf1e5ac5 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 22 Feb 2021 22:54:19 +0100 Subject: [PATCH 14/46] remove ax.scatter loop Remove the scatter for loop when using discrete values. --- xarray/plot/plot.py | 202 ++++++++++---------------------------------- 1 file changed, 46 insertions(+), 156 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 89ea49d78d0..6b58d44ba87 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -854,16 +854,14 @@ def scatter( size_mapping = kwargs.pop("size_mapping", None) # set by facetgrid cbar_ax = kwargs.pop("cbar_ax", None) cbar_kwargs = kwargs.pop("cbar_kwargs", None) + cmap = kwargs.pop("cmap", None) cmap_params = kwargs.pop("cmap_params", None) figsize = kwargs.pop("figsize", None) - if z is None: - ax = get_axis(figsize, size, aspect, ax) - else: - try: - ax.set_zlabel - except AttributeError as e: - raise AttributeError("3D projection not set on axes.") from e + subplot_kws = dict() + if z is not None and ax is None: + subplot_kws.update(projection="3d") + ax = get_axis(figsize, size, aspect, ax, **subplot_kws) _data = _infer_meta_data(darray, x, z, hue, hue_style, _sizes) @@ -891,63 +889,33 @@ def scatter( # Plot the data: axis_order = ["x", "z", "y"] - if _data["hue_style"] is None or _data["hue_style"] == "continuous": - # ax.scatter suppoerts numerical values in colors and sizes. - # So no need for for loops. + cmap_params_subset = {} + if _data["hue"] is not None: + kwargs.update(c=_data["colors"].values.ravel()) - cmap_params_subset = {} - if _data["hue"] is not None: - kwargs.update(c=_data["hue"].values.ravel()) + if cmap is None and _data["hue_style"] == "discrete": + cmap = "tab10" + cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( + scatter, _data["colors"].values, **locals() + ) - cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( - scatter, _data["hue"].values, **locals() - ) + # subset that can be passed to scatter, hist2d + cmap_params_subset = { + vv: cmap_params[vv] for vv in ["vmin", "vmax", "norm", "cmap"] + } - # subset that can be passed to scatter, hist2d - cmap_params_subset = { - vv: cmap_params[vv] for vv in ["vmin", "vmax", "norm", "cmap"] - } - - if _data["sizes"] is not None: - kwargs.update(s=_data["sizes"].values.ravel()) - - primitive = ax.scatter( - *[ - _data[v].values.ravel() - for v in axis_order - if _data.get(v, None) is not None - ], - **cmap_params_subset, - **kwargs, - ) + if _data["sizes"] is not None: + kwargs.update(s=_data["sizes"].values.ravel()) - primitives = [primitive] - - elif _data["hue_style"] == "discrete": - # Plot discrete data. ax.scatter only supports numerical values - # in colors and sizes. Use a for loop to work around this issue. - - primitive = [] - # use pd.unique instead of np.unique because that keeps the order of the labels, - # which is important to keep them in sync with the ones used in - # FacetGrid.add_legend - for label in pd.unique(_data["hue"].values.ravel()): - mask = _data["hue"] == label - if _data["sizes"] is not None: - kwargs.update(s=_data["sizes"].where(mask, drop=True).values.ravel()) - - primitive.append( - ax.scatter( - *[ - _data[v].where(mask, drop=True).values.ravel() - for v in axis_order - if _data.get(v, None) is not None - ], - label=label, - **kwargs, - ) - ) - primitives = primitive + primitive = ax.scatter( + *[ + _data[v].values.ravel() + for v in axis_order + if _data.get(v, None) is not None + ], + **cmap_params_subset, + **kwargs, + ) # Set x, y, z labels: i = 0 @@ -957,35 +925,6 @@ def scatter( set_label[i](_data[f"{v}label"]) i += 1 - def _legend_elements_from_list(primitives, prop, **kwargs): - """ - Get unique legend elements from a list of pathcollections. - - Getting multiple pathcollections happens when adding multiple - scatters to the same plot. - """ - import warnings - - handles = np.array([], dtype=object) - labels = np.array([], dtype=str) - - for i, pc in enumerate(primitives): - with warnings.catch_warnings(record=True): - warnings.simplefilter("always") - # Get legend elements, suppress empty data warnings - # because it will be handled later: - hdl, lbl = pc.legend_elements(prop=prop, **kwargs) - handles = np.append(handles, hdl) - labels = np.append(labels, lbl) - - # Duplicate legend entries is not necessary, therefore return - # unique labels: - unique_indices = np.sort(np.unique(labels, return_index=True)[1]) - handles = handles[unique_indices] - labels = labels[unique_indices] - - return [handles, labels] - def _legend_add_subtitle(handles, labels, text, func): """Add a subtitle to legend handles.""" if text and len(handles) > 1: @@ -996,8 +935,8 @@ def _legend_add_subtitle(handles, labels, text, func): blank_handle.set_visible(False) # Subtitles are shown first: - handles = np.insert(handles, 0, blank_handle) - labels = np.insert(labels, 0, text) + handles = [blank_handle] + handles + labels = [text] + labels return handles, labels @@ -1024,74 +963,25 @@ def _adjust_legend_subtitles(legend): # as normal legend titles: text.set_size(font_size) - def _add_legend(primitives, handles, labels, ax, **kwargs): - # Title is used as backup: - title = kwargs.pop("title", None) - if len(handles) > 1: - # The normal case where a prop has been defined and - # legend_elements finds results: - return ax.legend(handles, labels, framealpha=0.5, **kwargs) - elif len(primitives) > 1: - # When no handles have been found use the primitives instead. - # Example: - # * hue and sizes having the same string - # * non-numerics in colors or sizes - return ax.legend(handles=primitives, framealpha=0.5, title=title, **kwargs) - else: - return None - if add_legend: + + def to_label(d, key): + return lambda x: d[key][x] if key in d else x + handles, labels = [], [] - for hue_lbl, prop, func in [ - (_data["hue_label"], "colors", lambda x: x), - ( - _data["size_label"], - "sizes", - lambda x: _data["sizes_to_labels"][x] - if "sizes_to_labels" in _data - else x, - ), + for subtitle, prop, func in [ + (_data["hue_label"], "colors", to_label(_data, "colors_to_labels")), + (_data["size_label"], "sizes", to_label(_data, "sizes_to_labels")), ]: - # if hue_lbl: - hdl, lbl = _legend_elements_from_list( - primitives, prop, num="auto", func=func, - ) - hdl, lbl = _legend_add_subtitle(hdl, lbl, hue_lbl, ax.scatter) - # handles, labels = ax.get_legend_handles_labels() - handles.append(hdl) - labels.append(lbl) - legend = _add_legend( - primitives, - np.concatenate(handles), - np.concatenate(labels), - ax, - title=_data["hue_label"], - ) - if legend is not None: - _adjust_legend_subtitles(legend) - - # else: - # hdl, lbl = [], [] - # handles.append(hdl) - # labels.append(lbl) - - # if _data["hue_label"] is not None and _data["size_label"] is not None: - # legend = _add_legend(primitives, np.append(*handles), np.append(*labels), ax) - # if legend is not None: - # _adjust_legend_subtitles(legend) - # else: - # _add_legend(primitives, [], [], ax, title=_data["hue_label"]) - - # if (_data["hue_label"] == _data["size_label"]) or ( - # all(len(h) < 2 for h in handles) and len(primitives) > 1 - # ): - # _add_legend(primitives, [], [], ax, title=_data["hue_label"]) - # elif len(primitives) < 2 and len(handles) < 2: - # pass - # else: - # legend = _add_legend(primitives, np.append(*handles), np.append(*labels), ax) - # if legend is not None: - # _adjust_legend_subtitles(legend) + if subtitle: + hdl, lbl = legend_elements( + primitive, prop, num="auto", func=func, fmt="{x}" + ) + hdl, lbl = _legend_add_subtitle(hdl, lbl, subtitle, ax.scatter) + handles += hdl + labels += lbl + legend = ax.legend(handles, labels, framealpha=0.5) + _adjust_legend_subtitles(legend) if add_colorbar and _data["hue_label"]: if _data["hue_style"] == "discrete": From 8fd4765fffe03bee10a10b191321bee2633cbeae Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 22 Feb 2021 22:56:14 +0100 Subject: [PATCH 15/46] Update plot.py --- xarray/plot/plot.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 6b58d44ba87..5fe8d4917be 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -102,6 +102,7 @@ def legend_elements( The string labels for elements of the legend. """ import matplotlib as mpl + import warnings mlines = mpl.lines From 6eceabe9a9a0926cf67ff7945eb3a44638a94988 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 23 Feb 2021 08:02:27 +0100 Subject: [PATCH 16/46] Update plot.py --- xarray/plot/plot.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 5fe8d4917be..23ad6757889 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -195,8 +195,7 @@ def which_idxs(m, n): handles.append(h) if hasattr(fmt, "set_locs"): fmt.set_locs(label_values) - l = fmt(lab) - labels.append(l) + labels.append(fmt(lab)) return handles, labels From 1a5c2da2c7c13ddf6f3ddd6152fb262385f17dd6 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 23 Feb 2021 20:20:01 +0100 Subject: [PATCH 17/46] use label mapper only when necessary --- xarray/plot/plot.py | 36 +++++++++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 23ad6757889..1a6d8188b8d 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -966,17 +966,39 @@ def _adjust_legend_subtitles(legend): if add_legend: def to_label(d, key): - return lambda x: d[key][x] if key in d else x + """Map prop values back to it's original values.""" + + def _to_label(x): + if key in d: + # Use reindex to be less sensitive to float errors: + return d[key].reindex(x, method="nearest") + else: + return x + + return _to_label handles, labels = [], [] - for subtitle, prop, func in [ - (_data["hue_label"], "colors", to_label(_data, "colors_to_labels")), - (_data["size_label"], "sizes", to_label(_data, "sizes_to_labels")), + for subtitle, prop, func, style in [ + ( + _data["hue_label"], + "colors", + to_label(_data, "colors_to_labels"), + _data["hue_style"], + ), + ( + _data["size_label"], + "sizes", + to_label(_data, "sizes_to_labels"), + _data["size_style"], + ), ]: if subtitle: - hdl, lbl = legend_elements( - primitive, prop, num="auto", func=func, fmt="{x}" - ) + if style == "discrete": + hdl, lbl = legend_elements( + primitive, prop, num="auto", func=func, fmt="{x}" + ) + else: + hdl, lbl = legend_elements(primitive, prop, num="auto") hdl, lbl = _legend_add_subtitle(hdl, lbl, subtitle, ax.scatter) handles += hdl labels += lbl From 34d1e74ddb5401d5d1ab50ed1bd42922fd356c65 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 23 Feb 2021 20:25:33 +0100 Subject: [PATCH 18/46] Update plot.py --- xarray/plot/plot.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 1a6d8188b8d..9fa9b7252b0 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -101,9 +101,10 @@ def legend_elements( labels : list of str The string labels for elements of the legend. """ - import matplotlib as mpl import warnings + import matplotlib as mpl + mlines = mpl.lines handles = [] From fb0d7ca35f0c008e97293ede87ba9930b9965423 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 23 Feb 2021 21:02:28 +0100 Subject: [PATCH 19/46] Update plot.py --- xarray/plot/plot.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 9fa9b7252b0..fe9f6a6715d 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -301,9 +301,6 @@ def _parse_size(data, norm, width): def _infer_scatter_data( darray, x, z, hue, size, size_norm, size_mapping=None, size_range=(1, 10) ): - # if x is not None and y is not None: - # raise ValueError("Cannot specify both x and y kwargs for scatter plots.") - # Broadcast together all the chosen variables: to_broadcast = dict(y=darray) to_broadcast.update( @@ -321,6 +318,20 @@ def _infer_scatter_data( hue=broadcasted.pop("hue", None), sizes=broadcasted.pop("sizes", None) ) + if hue: + # if hue_mapping is None: + hue_mapping = _parse_size(broadcasted["hue"], None, [0, 1]) + + broadcasted["colors"] = broadcasted["hue"].copy( + data=np.reshape( + hue_mapping.loc[broadcasted["hue"].values.ravel()].values, + broadcasted["hue"].shape, + ) + ) + broadcasted["colors_to_labels"] = pd.Series( + hue_mapping.index, index=hue_mapping + ) + if size: if size_mapping is None: size_mapping = _parse_size(broadcasted["sizes"], size_norm, size_range) From a65a6d9c9c27b23ce4d92962dd62e20b1b5d8ec8 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 23 Feb 2021 22:41:49 +0100 Subject: [PATCH 20/46] func should return np.arrays --- xarray/plot/plot.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index fe9f6a6715d..7fb2c851f8e 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -978,12 +978,14 @@ def _adjust_legend_subtitles(legend): if add_legend: def to_label(d, key): - """Map prop values back to it's original values.""" + """Map prop values back to its original values.""" def _to_label(x): if key in d: - # Use reindex to be less sensitive to float errors: - return d[key].reindex(x, method="nearest") + # Use reindex to be less sensitive to float errors. + # Return as numpy array since legend_elements + # seems to require that: + return d[key].reindex(x, method="nearest").to_numpy() else: return x @@ -1005,12 +1007,15 @@ def _to_label(x): ), ]: if subtitle: - if style == "discrete": - hdl, lbl = legend_elements( - primitive, prop, num="auto", func=func, fmt="{x}" - ) - else: - hdl, lbl = legend_elements(primitive, prop, num="auto") + # Floats are handled nicely by the defaults but strings + # needs special format to bypass numerical operations: + fmt = "{x}" if style == "discrete" else None + + # Get legend handles and labels that displays the + # values correctly: + hdl, lbl = legend_elements( + primitive, prop, num="auto", func=func, fmt=fmt + ) hdl, lbl = _legend_add_subtitle(hdl, lbl, subtitle, ax.scatter) handles += hdl labels += lbl From b510bdcd878cab2da0c97977c82c33f4d388b69b Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 24 Feb 2021 19:07:21 +0100 Subject: [PATCH 21/46] better defaults in legend_elements --- xarray/plot/plot.py | 38 +++++++++++++------------------------- 1 file changed, 13 insertions(+), 25 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 7fb2c851f8e..fba9520d68d 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -110,11 +110,6 @@ def legend_elements( handles = [] labels = [] hasarray = self.get_array() is not None - if fmt is None: - fmt = mpl.ticker.ScalarFormatter(useOffset=False, useMathText=True) - elif isinstance(fmt, str): - fmt = mpl.ticker.StrMethodFormatter(fmt) - fmt.create_dummy_axis() if prop == "colors": if not hasarray: @@ -137,6 +132,15 @@ def legend_elements( func_value = np.asarray(func(u)) func_is_numeric = np.issubdtype(func_value.dtype, np.number) + + if fmt is None: + if func_is_numeric: + fmt = mpl.ticker.ScalarFormatter(useOffset=False, useMathText=True) + else: + fmt = mpl.ticker.StrMethodFormatter("{x}") + elif isinstance(fmt, str): + fmt = mpl.ticker.StrMethodFormatter(fmt) + fmt.create_dummy_axis() if func_is_numeric: fmt.set_bounds(min(func_value), max(func_value)) @@ -992,30 +996,14 @@ def _to_label(x): return _to_label handles, labels = [], [] - for subtitle, prop, func, style in [ - ( - _data["hue_label"], - "colors", - to_label(_data, "colors_to_labels"), - _data["hue_style"], - ), - ( - _data["size_label"], - "sizes", - to_label(_data, "sizes_to_labels"), - _data["size_style"], - ), + for subtitle, prop, func in [ + (_data["hue_label"], "colors", to_label(_data, "colors_to_labels")), + (_data["size_label"], "sizes", to_label(_data, "sizes_to_labels")), ]: if subtitle: - # Floats are handled nicely by the defaults but strings - # needs special format to bypass numerical operations: - fmt = "{x}" if style == "discrete" else None - # Get legend handles and labels that displays the # values correctly: - hdl, lbl = legend_elements( - primitive, prop, num="auto", func=func, fmt=fmt - ) + hdl, lbl = legend_elements(primitive, prop, num="auto", func=func) hdl, lbl = _legend_add_subtitle(hdl, lbl, subtitle, ax.scatter) handles += hdl labels += lbl From 6814c6b0916086c48a8cb36d7fd0255814b76ad7 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 27 Feb 2021 09:04:35 +0100 Subject: [PATCH 22/46] use pd.unique to retain order --- xarray/plot/plot.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index fba9520d68d..0cce6f719dd 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -263,7 +263,11 @@ def _determine_array(darray, name, array_style): # copied from seaborn def _parse_size(data, norm, width): - """Parse sizes.""" + """ + Determine what type of data it is. Then normalize is it to width. + + If the data is categorical, normalize it to numbers. + """ plt = import_matplotlib_pyplot() if data is None: @@ -272,8 +276,11 @@ def _parse_size(data, norm, width): data = data.values.flatten() if not _is_numeric(data): - levels = np.unique(data) - numbers = np.arange(1, 1 + len(levels))[::-1] + # Data is categorical. + # Use pd.unique instead of np.unique because that keeps + # the order of the labels: + levels = pd.unique(data) + numbers = np.arange(1, 1 + len(levels)) else: levels = numbers = np.sort(np.unique(data)) From 7a9e4d5f233f23e760a5dee4c4d4ce5512faa6f7 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 21 Mar 2021 09:40:27 +0100 Subject: [PATCH 23/46] update legend_:elements --- xarray/plot/plot.py | 133 +++++++++++++++++++++++++------------------- 1 file changed, 75 insertions(+), 58 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 0cce6f719dd..433bc167b59 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -109,97 +109,114 @@ def legend_elements( handles = [] labels = [] - hasarray = self.get_array() is not None if prop == "colors": - if not hasarray: + arr = self.get_array() + if arr is None: warnings.warn( "Collection without array used. Make sure to " "specify the values to be colormapped via the " "`c` argument." ) return handles, labels - u = np.unique(self.get_array()) - size = kwargs.pop("size", mpl.rcParams["lines.markersize"]) + _size = kwargs.pop("size", mpl.rcParams["lines.markersize"]) + + def _get_color_and_size(value): + return self.cmap(self.norm(value)), _size + elif prop == "sizes": - u = np.unique(self.get_sizes()) - color = kwargs.pop("color", "k") + arr = self.get_sizes() + _color = kwargs.pop("color", "k") + + def _get_color_and_size(value): + return _color, np.sqrt(value) + else: raise ValueError( "Valid values for `prop` are 'colors' or " f"'sizes'. You supplied '{prop}' instead." ) - func_value = np.asarray(func(u)) - func_is_numeric = np.issubdtype(func_value.dtype, np.number) + # Get the unique values and their labels: + values = np.unique(arr) + label_values = np.asarray(func(values)) + label_values_are_numeric = np.issubdtype(label_values.dtype, np.number) - if fmt is None: - if func_is_numeric: - fmt = mpl.ticker.ScalarFormatter(useOffset=False, useMathText=True) - else: - fmt = mpl.ticker.StrMethodFormatter("{x}") + # Handle the label format: + if fmt is None and label_values_are_numeric: + fmt = mpl.ticker.ScalarFormatter(useOffset=False, useMathText=True) + elif fmt is None and not label_values_are_numeric: + fmt = mpl.ticker.StrMethodFormatter("{x}") elif isinstance(fmt, str): fmt = mpl.ticker.StrMethodFormatter(fmt) fmt.create_dummy_axis() - if func_is_numeric: - fmt.set_bounds(min(func_value), max(func_value)) if num == "auto": num = 9 - if len(u) <= num: + if len(values) <= num: num = None - if num is None: - values = u - label_values = func(values) - elif not func_is_numeric: - # Values are not numerical so instead of interpolating - # just choose evenly distributed indexes instead: - def which_idxs(m, n): - out = np.rint(np.linspace(1, n, min(m, n)) - 1) - return out.astype(int) - - label_values = func(u) - cond = which_idxs(num, len(label_values)) - values = u[cond] - label_values = label_values[cond] - else: - if prop == "colors": - arr = self.get_array() - elif prop == "sizes": - arr = self.get_sizes() - if isinstance(num, mpl.ticker.Locator): - loc = num - elif np.iterable(num): - loc = mpl.ticker.FixedLocator(num) - else: - num = int(num) - loc = mpl.ticker.MaxNLocator( - nbins=num, min_n_ticks=num - 1, steps=[1, 2, 2.5, 3, 5, 6, 8, 10] + + if label_values_are_numeric: + label_values_min = label_values.min() + label_values_max = label_values.max() + fmt.set_bounds(label_values_min, label_values_max) + + if num is not None: + # Labels are numerical but larger than the target + # number of elements, reduce to target using matplotlibs + # ticker classes: + if isinstance(num, mpl.ticker.Locator): + loc = num + elif np.iterable(num): + loc = mpl.ticker.FixedLocator(num) + else: + num = int(num) + loc = mpl.ticker.MaxNLocator( + nbins=num, min_n_ticks=num - 1, steps=[1, 2, 2.5, 3, 5, 6, 8, 10] + ) + + # Get nicely spaced label_values: + label_values = loc.tick_values(label_values_min, label_values_max) + + # Remove extrapolated label_values: + cond = (label_values >= label_values_min) & ( + label_values <= label_values_max ) - label_values = loc.tick_values(func(arr).min(), func(arr).max()) - cond = (label_values >= func(arr).min()) & (label_values <= func(arr).max()) - label_values = label_values[cond] - yarr = np.linspace(arr.min(), arr.max(), 256) - xarr = func(yarr) - ix = np.argsort(xarr) - values = np.interp(label_values, xarr[ix], yarr[ix]) + label_values = label_values[cond] + + # Get the corresponding values by creating a linear interpolant + # with small step size: + values_interp = np.linspace(values.min(), values.max(), 256) + label_values_interp = func(values_interp) + ix = np.argsort(label_values_interp) + values = np.interp(label_values, label_values_interp[ix], values_interp[ix]) + elif num is not None and not label_values_are_numeric: + # Labels are not numerical so modifying label_values is not + # possible, instead filter the array with nicely distributed + # indexes: + if type(num) == int: + loc = mpl.ticker.LinearLocator(num) + else: + raise ValueError("`num` only supports integers for non-numeric labels.") + + ind = loc.tick_values(0, len(label_values) - 1).astype(int) + label_values = label_values[ind] + values = values[ind] + + # Some formatters requires set_locs: + if hasattr(fmt, "set_locs"): + fmt.set_locs(label_values) + # Default settings for handles, add or override with kwargs: kw = dict(markeredgewidth=self.get_linewidths()[0], alpha=self.get_alpha()) kw.update(kwargs) for val, lab in zip(values, label_values): - if prop == "colors": - color = self.cmap(self.norm(val)) - elif prop == "sizes": - size = np.sqrt(val) - if np.isclose(size, 0.0): - continue + color, size = _get_color_and_size(val) h = mlines.Line2D( [0], [0], ls="", color=color, ms=size, marker=self.get_paths()[0], **kw ) handles.append(h) - if hasattr(fmt, "set_locs"): - fmt.set_locs(label_values) labels.append(fmt(lab)) return handles, labels From e91aabaa991150caed943a63972ec700cd7a113d Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 21 Mar 2021 09:47:08 +0100 Subject: [PATCH 24/46] Update plot.py --- xarray/plot/plot.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 433bc167b59..473dd9cb500 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -923,7 +923,14 @@ def scatter( # need to infer size_mapping with full dataset _data.update( _infer_scatter_data( - darray, x, z, hue, _sizes, size_norm, size_mapping, _MARKERSIZE_RANGE, + darray, + x, + z, + hue, + _sizes, + size_norm, + size_mapping, + _MARKERSIZE_RANGE, ) ) From 70918755172d43a6bc8006c81c5102b4b20c4279 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 21 Mar 2021 13:33:23 +0100 Subject: [PATCH 25/46] move discrete color to OPTIONS --- xarray/core/options.py | 5 +++++ xarray/plot/plot.py | 5 +++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/xarray/core/options.py b/xarray/core/options.py index 129698903c4..46445b75325 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -1,6 +1,7 @@ import warnings ARITHMETIC_JOIN = "arithmetic_join" +CMAP_DISCRETE = "cmap_discrete" CMAP_DIVERGENT = "cmap_divergent" CMAP_SEQUENTIAL = "cmap_sequential" DISPLAY_MAX_ROWS = "display_max_rows" @@ -14,6 +15,7 @@ OPTIONS = { ARITHMETIC_JOIN: "inner", + CMAP_DISCRETE: "tab10", CMAP_DIVERGENT: "RdBu_r", CMAP_SEQUENTIAL: "viridis", DISPLAY_MAX_ROWS: 12, @@ -95,6 +97,9 @@ class set_options: - ``warn_for_unclosed_files``: whether or not to issue a warning when unclosed files are deallocated (default False). This is mostly useful for debugging. + - ``cmap_discrete``: colormap to use for discrete data plots. + Default: ``tab10``. If string, must be matplotlib built-in colormap. + Can also be a Colormap object (e.g. mpl.cm.magma) - ``cmap_sequential``: colormap to use for nondivergent data plots. Default: ``viridis``. If string, must be matplotlib built-in colormap. Can also be a Colormap object (e.g. mpl.cm.magma) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 473dd9cb500..23f1da0b666 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -12,6 +12,7 @@ import pandas as pd from ..core.alignment import broadcast +from ..core.options import OPTIONS from .facetgrid import _easy_facetgrid from .utils import ( _add_colorbar, @@ -882,13 +883,13 @@ def scatter( Additional keyword arguments to matplotlib """ # Handle facetgrids first + _is_facetgrid = kwargs.pop("_is_facetgrid", False) if row or col: allargs = locals().copy() allargs.update(allargs.pop("kwargs")) allargs.pop("darray") return _easy_facetgrid(darray, scatter, kind="dataarray", **allargs) - # _is_facetgrid = kwargs.pop("_is_facetgrid", False) _sizes = kwargs.pop("markersize", kwargs.pop("linewidth", None)) size_norm = kwargs.pop("size_norm", None) size_mapping = kwargs.pop("size_mapping", None) # set by facetgrid @@ -941,7 +942,7 @@ def scatter( kwargs.update(c=_data["colors"].values.ravel()) if cmap is None and _data["hue_style"] == "discrete": - cmap = "tab10" + cmap = OPTIONS["cmap_discrete"] cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( scatter, _data["colors"].values, **locals() ) From f304a02b8fed99babfedd55decbfa3e170e8b7e1 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 21 Mar 2021 15:52:10 +0100 Subject: [PATCH 26/46] support facetgrid --- xarray/plot/plot.py | 45 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 23f1da0b666..830e5b7200d 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -797,8 +797,21 @@ def scatter( yticks=None, xlim=None, ylim=None, + add_guide=None, add_legend=None, add_colorbar=None, + cbar_kwargs=None, + cbar_ax=None, + vmin=None, + vmax=None, + norm=None, + infer_intervals=None, + center=None, + levels=None, + robust=None, + colors=None, + extend=None, + cmap=None, _labels=True, **kwargs, ): @@ -883,19 +896,29 @@ def scatter( Additional keyword arguments to matplotlib """ # Handle facetgrids first - _is_facetgrid = kwargs.pop("_is_facetgrid", False) if row or col: allargs = locals().copy() allargs.update(allargs.pop("kwargs")) allargs.pop("darray") - return _easy_facetgrid(darray, scatter, kind="dataarray", **allargs) + subplot_kws = dict(projection="3d") if z is not None else {} + return _easy_facetgrid( + darray, scatter, kind="dataarray", subplot_kws=subplot_kws, **allargs + ) + + # Further + _is_facetgrid = kwargs.pop("_is_facetgrid", False) + if _is_facetgrid: + # Why do I need to pop these here? + kwargs.pop("y", None) + kwargs.pop("args", None) + kwargs.pop("add_labels", None) _sizes = kwargs.pop("markersize", kwargs.pop("linewidth", None)) size_norm = kwargs.pop("size_norm", None) size_mapping = kwargs.pop("size_mapping", None) # set by facetgrid - cbar_ax = kwargs.pop("cbar_ax", None) - cbar_kwargs = kwargs.pop("cbar_kwargs", None) - cmap = kwargs.pop("cmap", None) + # cbar_ax = kwargs.pop("cbar_ax", None) + # cbar_kwargs = kwargs.pop("cbar_kwargs", None) + # cmap = kwargs.pop("cmap", None) cmap_params = kwargs.pop("cmap_params", None) figsize = kwargs.pop("figsize", None) @@ -906,7 +929,7 @@ def scatter( _data = _infer_meta_data(darray, x, z, hue, hue_style, _sizes) - add_guide = kwargs.pop("add_guide", None) + # add_guide = kwargs.pop("add_guide", None) if add_legend is not None: pass elif add_guide is None or add_guide is True: @@ -935,7 +958,9 @@ def scatter( ) ) - # Plot the data: + # Plot the data. 3d plots has the z value in upward direction + # instead of y. To make jumping between 2d and 3d easy and intuitive + # switch the order so that z is shown in the depthwise direction: axis_order = ["x", "z", "y"] cmap_params_subset = {} if _data["hue"] is not None: @@ -967,7 +992,7 @@ def scatter( # Set x, y, z labels: i = 0 - set_label = [ax.set_xlabel, ax.set_ylabel, getattr(ax, "set_zlabel", [])] + set_label = [ax.set_xlabel, ax.set_ylabel, getattr(ax, "set_zlabel", None)] for v in axis_order: if _data.get(f"{v}label", None) is not None: set_label[i](_data[f"{v}label"]) @@ -1034,7 +1059,9 @@ def _to_label(x): ]: if subtitle: # Get legend handles and labels that displays the - # values correctly: + # values correctly. Order might be different because + # legend_elements uses np.unique instead of pd.unique, + # FacetGrid.add_legend might have troubles with this: hdl, lbl = legend_elements(primitive, prop, num="auto", func=func) hdl, lbl = _legend_add_subtitle(hdl, lbl, subtitle, ax.scatter) handles += hdl From 808115d55b0d5cb012ab7e836d4221690a1e2c29 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 21 Mar 2021 15:53:22 +0100 Subject: [PATCH 27/46] Update plot.py --- xarray/plot/plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 830e5b7200d..f26714959b5 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -900,7 +900,7 @@ def scatter( allargs = locals().copy() allargs.update(allargs.pop("kwargs")) allargs.pop("darray") - subplot_kws = dict(projection="3d") if z is not None else {} + subplot_kws = dict(projection="3d") if z is not None else None return _easy_facetgrid( darray, scatter, kind="dataarray", subplot_kws=subplot_kws, **allargs ) From ade494a0b98bf6747a3a9f9985d65b30f4be1609 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 21 Mar 2021 19:26:11 +0100 Subject: [PATCH 28/46] add tests for z and facetgrids --- xarray/plot/plot.py | 3 +-- xarray/tests/test_plot.py | 15 ++++++++++----- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index f26714959b5..9ff470cdefd 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -797,7 +797,6 @@ def scatter( yticks=None, xlim=None, ylim=None, - add_guide=None, add_legend=None, add_colorbar=None, cbar_kwargs=None, @@ -929,7 +928,7 @@ def scatter( _data = _infer_meta_data(darray, x, z, hue, hue_style, _sizes) - # add_guide = kwargs.pop("add_guide", None) + add_guide = kwargs.pop("add_guide", None) if add_legend is not None: pass elif add_guide is None or add_guide is True: diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 6a4b1a4429c..039c711002a 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2594,14 +2594,18 @@ def test_get_axis_cartopy(): @requires_matplotlib @pytest.mark.parametrize( - "x, y, hue, markersize, add_legend, add_colorbar", + "x, y, z, hue, markersize, row, col, add_legend, add_colorbar", [ - ("A", "B", None, None, None, None), - ("B", "A", "w", None, True, None), - ("A", "B", "y", "z", True, True), + ("A", "B", None, None, None, None, None, None, None), + ("B", "A", None, "w", None, None, None, True, None), + ("A", "B", None, "y", "x", None, None, True, True), + ("A", "B", "z", None, None, None, None, None, None), + ("B", "A", "z", "w", None, None, None, True, None), + ("A", "B", "z", "y", "x", None, None, True, True), + ("A", "B", "z", "y", "x", "w", None, True, True), ], ) -def test_datarray_scatter(x, y, hue, markersize, add_legend, add_colorbar): +def test_datarray_scatter(x, y, z, hue, markersize, row, col, add_legend, add_colorbar): """Test datarray scatter. Merge with TestPlot1D eventually.""" ds = xr.tutorial.scatter_example_dataset() @@ -2618,6 +2622,7 @@ def test_datarray_scatter(x, y, hue, markersize, add_legend, add_colorbar): with figure_context(): darray.plot.scatter( x=x, + z=z, hue=hue, markersize=markersize, add_legend=add_legend, From 246338a362742a3fd42baff132be49c351fdf7dd Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 23 Mar 2021 18:02:15 +0100 Subject: [PATCH 29/46] test increasing min dependency --- ci/requirements/py37-min-all-deps.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/requirements/py37-min-all-deps.yml b/ci/requirements/py37-min-all-deps.yml index 7d04f431935..6eb4e938de8 100644 --- a/ci/requirements/py37-min-all-deps.yml +++ b/ci/requirements/py37-min-all-deps.yml @@ -23,7 +23,7 @@ dependencies: - hypothesis - iris=2.4 - lxml=4.5 # Optional dep of pydap - - matplotlib-base=3.1 + - matplotlib-base=3.2 - nc-time-axis=1.2 # netcdf follows a 1.major.minor[.patch] convention (see https://github.com/Unidata/netcdf4-python/issues/1090) # bumping the netCDF4 version is currently blocked by #4491 From 4e6c41d1c0c26cc509812bbdf6740fb2ec67e2db Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 23 Mar 2021 20:55:40 +0100 Subject: [PATCH 30/46] reduce code duplication --- xarray/plot/plot.py | 108 +++++++++++++++----------------------------- 1 file changed, 37 insertions(+), 71 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 9ff470cdefd..3cd557df4bd 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -233,50 +233,32 @@ def _determine_array(darray, name, array_style): array_style = "continuous" if array_is_numeric else "discrete" elif array_style not in ["discrete", "continuous"]: raise ValueError( - "hue_style must be either None, 'discrete' or 'continuous'." - ) - - if not array_is_numeric and (array_style == "continuous"): - raise ValueError( - f"Cannot create a colorbar for a non numeric coordinate: {name}" + f"The style '{array_style}' is not valid, " + "valid options are None, 'discrete' or 'continuous'." ) array_label = label_from_attrs(array) return array, array_style, array_label - # if x is not None and y is not None: - # raise ValueError("Cannot specify both x and y kwargs for line plots.") - - label = dict(y=label_from_attrs(darray)) - label.update( + # Add nice looking labels: + out = dict(ylabel=label_from_attrs(darray)) + out.update( { k: label_from_attrs(darray[v]) if v in darray.coords else None - for k, v in [("x", x), ("z", z)] + for k, v in [("xlabel", x), ("zlabel", z)] } ) - if hue: - hue, hue_style, hue_label = _determine_array(darray, hue, hue_style) - else: - hue, hue_style, hue_label = None, None, None + # Add styles and labels for the dataarrays: + for type_, a, style, in [("hue", hue, hue_style), ("size", size, None)]: + tp, stl, lbl = f"{type_}", f"{type_}_style", f"{type_}_label" + if a: + out[tp], out[stl], out[lbl] = _determine_array(darray, a, style) + else: + out[tp], out[stl], out[lbl] = None, None, None - if size: - size, size_style, size_label = _determine_array(darray, size, None) - else: - size, size_style, size_label = None, None, None - - return dict( - xlabel=label["x"], - ylabel=label["y"], - zlabel=label["z"], - hue=hue, - hue_label=hue_label, - hue_style=hue_style, - size=size, - size_label=size_label, - size_style=size_style, - ) + return out # copied from seaborn @@ -338,42 +320,29 @@ def _infer_scatter_data( to_broadcast.update( { key: darray[value] - for key, value in dict(hue=hue, sizes=size).items() + for key, value in dict(hue=hue, size=size).items() if value in darray.dims } ) broadcasted = dict(zip(to_broadcast.keys(), broadcast(*(to_broadcast.values())))) - broadcasted.update( - hue=broadcasted.pop("hue", None), sizes=broadcasted.pop("sizes", None) - ) - if hue: - # if hue_mapping is None: - hue_mapping = _parse_size(broadcasted["hue"], None, [0, 1]) - - broadcasted["colors"] = broadcasted["hue"].copy( - data=np.reshape( - hue_mapping.loc[broadcasted["hue"].values.ravel()].values, - broadcasted["hue"].shape, - ) - ) - broadcasted["colors_to_labels"] = pd.Series( - hue_mapping.index, index=hue_mapping - ) - - if size: - if size_mapping is None: - size_mapping = _parse_size(broadcasted["sizes"], size_norm, size_range) - - broadcasted["sizes"] = broadcasted["sizes"].copy( - data=np.reshape( - size_mapping.loc[broadcasted["sizes"].values.ravel()].values, - broadcasted["sizes"].shape, + # Normalize hue and size and create lookup tables: + for type_, mapping, norm, width in [ + ("hue", None, None, [0, 1]), + ("size", size_mapping, size_norm, size_range), + ]: + broadcasted_type = broadcasted.get(type_, None) + if broadcasted_type is not None: + if mapping is None: + mapping = _parse_size(broadcasted_type, norm, width) + + broadcasted[type_] = broadcasted_type.copy( + data=np.reshape( + mapping.loc[broadcasted_type.values.ravel()].values, + broadcasted_type.shape, + ) ) - ) - broadcasted["sizes_to_labels"] = pd.Series( - size_mapping.index, index=size_mapping - ) + broadcasted[f"{type_}_to_label"] = pd.Series(mapping.index, index=mapping) return broadcasted @@ -915,9 +884,6 @@ def scatter( _sizes = kwargs.pop("markersize", kwargs.pop("linewidth", None)) size_norm = kwargs.pop("size_norm", None) size_mapping = kwargs.pop("size_mapping", None) # set by facetgrid - # cbar_ax = kwargs.pop("cbar_ax", None) - # cbar_kwargs = kwargs.pop("cbar_kwargs", None) - # cmap = kwargs.pop("cmap", None) cmap_params = kwargs.pop("cmap_params", None) figsize = kwargs.pop("figsize", None) @@ -963,12 +929,12 @@ def scatter( axis_order = ["x", "z", "y"] cmap_params_subset = {} if _data["hue"] is not None: - kwargs.update(c=_data["colors"].values.ravel()) + kwargs.update(c=_data["hue"].values.ravel()) if cmap is None and _data["hue_style"] == "discrete": cmap = OPTIONS["cmap_discrete"] cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( - scatter, _data["colors"].values, **locals() + scatter, _data["hue"].values, **locals() ) # subset that can be passed to scatter, hist2d @@ -976,8 +942,8 @@ def scatter( vv: cmap_params[vv] for vv in ["vmin", "vmax", "norm", "cmap"] } - if _data["sizes"] is not None: - kwargs.update(s=_data["sizes"].values.ravel()) + if _data["size"] is not None: + kwargs.update(s=_data["size"].values.ravel()) primitive = ax.scatter( *[ @@ -1053,8 +1019,8 @@ def _to_label(x): handles, labels = [], [] for subtitle, prop, func in [ - (_data["hue_label"], "colors", to_label(_data, "colors_to_labels")), - (_data["size_label"], "sizes", to_label(_data, "sizes_to_labels")), + (_data["hue_label"], "colors", to_label(_data, "hue_to_label")), + (_data["size_label"], "sizes", to_label(_data, "size_to_label")), ]: if subtitle: # Get legend handles and labels that displays the From 0cddfca425382752a23607f53c6246a228d0c07a Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 23 Mar 2021 21:03:29 +0100 Subject: [PATCH 31/46] Update plot.py --- xarray/plot/plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 3cd557df4bd..d8feef808f1 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -251,7 +251,7 @@ def _determine_array(darray, name, array_style): ) # Add styles and labels for the dataarrays: - for type_, a, style, in [("hue", hue, hue_style), ("size", size, None)]: + for type_, a, style in [("hue", hue, hue_style), ("size", size, None)]: tp, stl, lbl = f"{type_}", f"{type_}_style", f"{type_}_label" if a: out[tp], out[stl], out[lbl] = _determine_array(darray, a, style) From 6ee18da43425e8b3b1153e882846f9ccbee9ba05 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 26 Mar 2021 18:15:56 +0100 Subject: [PATCH 32/46] get cmap from _process_cmap_cbar_kwargs --- xarray/core/options.py | 5 ----- xarray/plot/plot.py | 3 --- 2 files changed, 8 deletions(-) diff --git a/xarray/core/options.py b/xarray/core/options.py index 46445b75325..129698903c4 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -1,7 +1,6 @@ import warnings ARITHMETIC_JOIN = "arithmetic_join" -CMAP_DISCRETE = "cmap_discrete" CMAP_DIVERGENT = "cmap_divergent" CMAP_SEQUENTIAL = "cmap_sequential" DISPLAY_MAX_ROWS = "display_max_rows" @@ -15,7 +14,6 @@ OPTIONS = { ARITHMETIC_JOIN: "inner", - CMAP_DISCRETE: "tab10", CMAP_DIVERGENT: "RdBu_r", CMAP_SEQUENTIAL: "viridis", DISPLAY_MAX_ROWS: 12, @@ -97,9 +95,6 @@ class set_options: - ``warn_for_unclosed_files``: whether or not to issue a warning when unclosed files are deallocated (default False). This is mostly useful for debugging. - - ``cmap_discrete``: colormap to use for discrete data plots. - Default: ``tab10``. If string, must be matplotlib built-in colormap. - Can also be a Colormap object (e.g. mpl.cm.magma) - ``cmap_sequential``: colormap to use for nondivergent data plots. Default: ``viridis``. If string, must be matplotlib built-in colormap. Can also be a Colormap object (e.g. mpl.cm.magma) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index d8feef808f1..72e5467f2ba 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -930,9 +930,6 @@ def scatter( cmap_params_subset = {} if _data["hue"] is not None: kwargs.update(c=_data["hue"].values.ravel()) - - if cmap is None and _data["hue_style"] == "discrete": - cmap = OPTIONS["cmap_discrete"] cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( scatter, _data["hue"].values, **locals() ) From 2a100403bedb4b68f6daf194f342f9bbbc06b382 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 26 Mar 2021 19:20:43 +0100 Subject: [PATCH 33/46] move funcs to utils --- xarray/plot/plot.py | 238 ++----------------------------------------- xarray/plot/utils.py | 228 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 235 insertions(+), 231 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 72e5467f2ba..f97867706f7 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -12,15 +12,16 @@ import pandas as pd from ..core.alignment import broadcast -from ..core.options import OPTIONS from .facetgrid import _easy_facetgrid from .utils import ( _add_colorbar, + _adjust_legend_subtitles, _assert_valid_xy, _ensure_plottable, _infer_interval_breaks, _infer_xy_labels, _is_numeric, + _legend_add_subtitle, _process_cmap_cbar_kwargs, _rescale_imshow_rgb, _resolve_intervals_1dplot, @@ -29,200 +30,13 @@ get_axis, import_matplotlib_pyplot, label_from_attrs, + legend_elements, ) # copied from seaborn _MARKERSIZE_RANGE = np.array([18.0, 72.0]) -# Copied from matplotlib, tweaked so func can return strings. -# https://github.com/matplotlib/matplotlib/issues/19555 -def legend_elements( - self, prop="colors", num="auto", fmt=None, func=lambda x: x, **kwargs -): - """ - Create legend handles and labels for a PathCollection. - - Each legend handle is a `.Line2D` representing the Path that was drawn, - and each label is a string what each Path represents. - - This is useful for obtaining a legend for a `~.Axes.scatter` plot; - e.g.:: - - scatter = plt.scatter([1, 2, 3], [4, 5, 6], c=[7, 2, 3]) - plt.legend(*scatter.legend_elements()) - - creates three legend elements, one for each color with the numerical - values passed to *c* as the labels. - - Also see the :ref:`automatedlegendcreation` example. - - - Parameters - ---------- - prop : {"colors", "sizes"}, default: "colors" - If "colors", the legend handles will show the different colors of - the collection. If "sizes", the legend will show the different - sizes. To set both, use *kwargs* to directly edit the `.Line2D` - properties. - num : int, None, "auto" (default), array-like, or `~.ticker.Locator` - Target number of elements to create. - If None, use all unique elements of the mappable array. If an - integer, target to use *num* elements in the normed range. - If *"auto"*, try to determine which option better suits the nature - of the data. - The number of created elements may slightly deviate from *num* due - to a `~.ticker.Locator` being used to find useful locations. - If a list or array, use exactly those elements for the legend. - Finally, a `~.ticker.Locator` can be provided. - fmt : str, `~matplotlib.ticker.Formatter`, or None (default) - The format or formatter to use for the labels. If a string must be - a valid input for a `~.StrMethodFormatter`. If None (the default), - use a `~.ScalarFormatter`. - func : function, default: ``lambda x: x`` - Function to calculate the labels. Often the size (or color) - argument to `~.Axes.scatter` will have been pre-processed by the - user using a function ``s = f(x)`` to make the markers visible; - e.g. ``size = np.log10(x)``. Providing the inverse of this - function here allows that pre-processing to be inverted, so that - the legend labels have the correct values; e.g. ``func = lambda - x: 10**x``. - **kwargs - Allowed keyword arguments are *color* and *size*. E.g. it may be - useful to set the color of the markers if *prop="sizes"* is used; - similarly to set the size of the markers if *prop="colors"* is - used. Any further parameters are passed onto the `.Line2D` - instance. This may be useful to e.g. specify a different - *markeredgecolor* or *alpha* for the legend handles. - - Returns - ------- - handles : list of `.Line2D` - Visual representation of each element of the legend. - labels : list of str - The string labels for elements of the legend. - """ - import warnings - - import matplotlib as mpl - - mlines = mpl.lines - - handles = [] - labels = [] - - if prop == "colors": - arr = self.get_array() - if arr is None: - warnings.warn( - "Collection without array used. Make sure to " - "specify the values to be colormapped via the " - "`c` argument." - ) - return handles, labels - _size = kwargs.pop("size", mpl.rcParams["lines.markersize"]) - - def _get_color_and_size(value): - return self.cmap(self.norm(value)), _size - - elif prop == "sizes": - arr = self.get_sizes() - _color = kwargs.pop("color", "k") - - def _get_color_and_size(value): - return _color, np.sqrt(value) - - else: - raise ValueError( - "Valid values for `prop` are 'colors' or " - f"'sizes'. You supplied '{prop}' instead." - ) - - # Get the unique values and their labels: - values = np.unique(arr) - label_values = np.asarray(func(values)) - label_values_are_numeric = np.issubdtype(label_values.dtype, np.number) - - # Handle the label format: - if fmt is None and label_values_are_numeric: - fmt = mpl.ticker.ScalarFormatter(useOffset=False, useMathText=True) - elif fmt is None and not label_values_are_numeric: - fmt = mpl.ticker.StrMethodFormatter("{x}") - elif isinstance(fmt, str): - fmt = mpl.ticker.StrMethodFormatter(fmt) - fmt.create_dummy_axis() - - if num == "auto": - num = 9 - if len(values) <= num: - num = None - - if label_values_are_numeric: - label_values_min = label_values.min() - label_values_max = label_values.max() - fmt.set_bounds(label_values_min, label_values_max) - - if num is not None: - # Labels are numerical but larger than the target - # number of elements, reduce to target using matplotlibs - # ticker classes: - if isinstance(num, mpl.ticker.Locator): - loc = num - elif np.iterable(num): - loc = mpl.ticker.FixedLocator(num) - else: - num = int(num) - loc = mpl.ticker.MaxNLocator( - nbins=num, min_n_ticks=num - 1, steps=[1, 2, 2.5, 3, 5, 6, 8, 10] - ) - - # Get nicely spaced label_values: - label_values = loc.tick_values(label_values_min, label_values_max) - - # Remove extrapolated label_values: - cond = (label_values >= label_values_min) & ( - label_values <= label_values_max - ) - label_values = label_values[cond] - - # Get the corresponding values by creating a linear interpolant - # with small step size: - values_interp = np.linspace(values.min(), values.max(), 256) - label_values_interp = func(values_interp) - ix = np.argsort(label_values_interp) - values = np.interp(label_values, label_values_interp[ix], values_interp[ix]) - elif num is not None and not label_values_are_numeric: - # Labels are not numerical so modifying label_values is not - # possible, instead filter the array with nicely distributed - # indexes: - if type(num) == int: - loc = mpl.ticker.LinearLocator(num) - else: - raise ValueError("`num` only supports integers for non-numeric labels.") - - ind = loc.tick_values(0, len(label_values) - 1).astype(int) - label_values = label_values[ind] - values = values[ind] - - # Some formatters requires set_locs: - if hasattr(fmt, "set_locs"): - fmt.set_locs(label_values) - - # Default settings for handles, add or override with kwargs: - kw = dict(markeredgewidth=self.get_linewidths()[0], alpha=self.get_alpha()) - kw.update(kwargs) - - for val, lab in zip(values, label_values): - color, size = _get_color_and_size(val) - h = mlines.Line2D( - [0], [0], ls="", color=color, ms=size, marker=self.get_paths()[0], **kw - ) - handles.append(h) - labels.append(fmt(lab)) - - return handles, labels - - def _infer_meta_data(darray, x, z, hue, hue_style, size): def _determine_array(darray, name, array_style): """Find and determine what type of array it is.""" @@ -923,10 +737,6 @@ def scatter( ) ) - # Plot the data. 3d plots has the z value in upward direction - # instead of y. To make jumping between 2d and 3d easy and intuitive - # switch the order so that z is shown in the depthwise direction: - axis_order = ["x", "z", "y"] cmap_params_subset = {} if _data["hue"] is not None: kwargs.update(c=_data["hue"].values.ravel()) @@ -942,6 +752,10 @@ def scatter( if _data["size"] is not None: kwargs.update(s=_data["size"].values.ravel()) + # Plot the data. 3d plots has the z value in upward direction + # instead of y. To make jumping between 2d and 3d easy and intuitive + # switch the order so that z is shown in the depthwise direction: + axis_order = ["x", "z", "y"] primitive = ax.scatter( *[ _data[v].values.ravel() @@ -960,44 +774,6 @@ def scatter( set_label[i](_data[f"{v}label"]) i += 1 - def _legend_add_subtitle(handles, labels, text, func): - """Add a subtitle to legend handles.""" - if text and len(handles) > 1: - # Create a blank handle that's not visible, the - # invisibillity will be used to discern which are subtitles - # or not: - blank_handle = func([], [], label=text) - blank_handle.set_visible(False) - - # Subtitles are shown first: - handles = [blank_handle] + handles - labels = [text] + labels - - return handles, labels - - def _adjust_legend_subtitles(legend): - """Make invisible-handle "subtitles" entries look more like titles.""" - plt = import_matplotlib_pyplot() - - # Legend title not in rcParams until 3.0 - font_size = plt.rcParams.get("legend.title_fontsize", None) - hpackers = legend.findobj(plt.matplotlib.offsetbox.VPacker)[0].get_children() - for hpack in hpackers: - draw_area, text_area = hpack.get_children() - handles = draw_area.get_children() - - # Assume that all artists that are not visible are - # subtitles: - if not all(artist.get_visible() for artist in handles): - # Remove the dummy marker which will bring the text - # more to the center: - draw_area.set_width(0) - for text in text_area.get_children(): - if font_size is not None: - # The sutbtitles should have the same font size - # as normal legend titles: - text.set_size(font_size) - if add_legend: def to_label(d, key): diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index a83bc28e273..8d341d12d15 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -853,3 +853,231 @@ def _get_nice_quiver_magnitude(u, v): mean = np.mean(np.hypot(u.values, v.values)) magnitude = ticker.tick_values(0, mean)[-2] return magnitude + + +# Copied from matplotlib, tweaked so func can return strings. +# https://github.com/matplotlib/matplotlib/issues/19555 +def legend_elements( + self, prop="colors", num="auto", fmt=None, func=lambda x: x, **kwargs +): + """ + Create legend handles and labels for a PathCollection. + + Each legend handle is a `.Line2D` representing the Path that was drawn, + and each label is a string what each Path represents. + + This is useful for obtaining a legend for a `~.Axes.scatter` plot; + e.g.:: + + scatter = plt.scatter([1, 2, 3], [4, 5, 6], c=[7, 2, 3]) + plt.legend(*scatter.legend_elements()) + + creates three legend elements, one for each color with the numerical + values passed to *c* as the labels. + + Also see the :ref:`automatedlegendcreation` example. + + + Parameters + ---------- + prop : {"colors", "sizes"}, default: "colors" + If "colors", the legend handles will show the different colors of + the collection. If "sizes", the legend will show the different + sizes. To set both, use *kwargs* to directly edit the `.Line2D` + properties. + num : int, None, "auto" (default), array-like, or `~.ticker.Locator` + Target number of elements to create. + If None, use all unique elements of the mappable array. If an + integer, target to use *num* elements in the normed range. + If *"auto"*, try to determine which option better suits the nature + of the data. + The number of created elements may slightly deviate from *num* due + to a `~.ticker.Locator` being used to find useful locations. + If a list or array, use exactly those elements for the legend. + Finally, a `~.ticker.Locator` can be provided. + fmt : str, `~matplotlib.ticker.Formatter`, or None (default) + The format or formatter to use for the labels. If a string must be + a valid input for a `~.StrMethodFormatter`. If None (the default), + use a `~.ScalarFormatter`. + func : function, default: ``lambda x: x`` + Function to calculate the labels. Often the size (or color) + argument to `~.Axes.scatter` will have been pre-processed by the + user using a function ``s = f(x)`` to make the markers visible; + e.g. ``size = np.log10(x)``. Providing the inverse of this + function here allows that pre-processing to be inverted, so that + the legend labels have the correct values; e.g. ``func = lambda + x: 10**x``. + **kwargs + Allowed keyword arguments are *color* and *size*. E.g. it may be + useful to set the color of the markers if *prop="sizes"* is used; + similarly to set the size of the markers if *prop="colors"* is + used. Any further parameters are passed onto the `.Line2D` + instance. This may be useful to e.g. specify a different + *markeredgecolor* or *alpha* for the legend handles. + + Returns + ------- + handles : list of `.Line2D` + Visual representation of each element of the legend. + labels : list of str + The string labels for elements of the legend. + """ + import warnings + + import matplotlib as mpl + + mlines = mpl.lines + + handles = [] + labels = [] + + if prop == "colors": + arr = self.get_array() + if arr is None: + warnings.warn( + "Collection without array used. Make sure to " + "specify the values to be colormapped via the " + "`c` argument." + ) + return handles, labels + _size = kwargs.pop("size", mpl.rcParams["lines.markersize"]) + + def _get_color_and_size(value): + return self.cmap(self.norm(value)), _size + + elif prop == "sizes": + arr = self.get_sizes() + _color = kwargs.pop("color", "k") + + def _get_color_and_size(value): + return _color, np.sqrt(value) + + else: + raise ValueError( + "Valid values for `prop` are 'colors' or " + f"'sizes'. You supplied '{prop}' instead." + ) + + # Get the unique values and their labels: + values = np.unique(arr) + label_values = np.asarray(func(values)) + label_values_are_numeric = np.issubdtype(label_values.dtype, np.number) + + # Handle the label format: + if fmt is None and label_values_are_numeric: + fmt = mpl.ticker.ScalarFormatter(useOffset=False, useMathText=True) + elif fmt is None and not label_values_are_numeric: + fmt = mpl.ticker.StrMethodFormatter("{x}") + elif isinstance(fmt, str): + fmt = mpl.ticker.StrMethodFormatter(fmt) + fmt.create_dummy_axis() + + if num == "auto": + num = 9 + if len(values) <= num: + num = None + + if label_values_are_numeric: + label_values_min = label_values.min() + label_values_max = label_values.max() + fmt.set_bounds(label_values_min, label_values_max) + + if num is not None: + # Labels are numerical but larger than the target + # number of elements, reduce to target using matplotlibs + # ticker classes: + if isinstance(num, mpl.ticker.Locator): + loc = num + elif np.iterable(num): + loc = mpl.ticker.FixedLocator(num) + else: + num = int(num) + loc = mpl.ticker.MaxNLocator( + nbins=num, min_n_ticks=num - 1, steps=[1, 2, 2.5, 3, 5, 6, 8, 10] + ) + + # Get nicely spaced label_values: + label_values = loc.tick_values(label_values_min, label_values_max) + + # Remove extrapolated label_values: + cond = (label_values >= label_values_min) & ( + label_values <= label_values_max + ) + label_values = label_values[cond] + + # Get the corresponding values by creating a linear interpolant + # with small step size: + values_interp = np.linspace(values.min(), values.max(), 256) + label_values_interp = func(values_interp) + ix = np.argsort(label_values_interp) + values = np.interp(label_values, label_values_interp[ix], values_interp[ix]) + elif num is not None and not label_values_are_numeric: + # Labels are not numerical so modifying label_values is not + # possible, instead filter the array with nicely distributed + # indexes: + if type(num) == int: + loc = mpl.ticker.LinearLocator(num) + else: + raise ValueError("`num` only supports integers for non-numeric labels.") + + ind = loc.tick_values(0, len(label_values) - 1).astype(int) + label_values = label_values[ind] + values = values[ind] + + # Some formatters requires set_locs: + if hasattr(fmt, "set_locs"): + fmt.set_locs(label_values) + + # Default settings for handles, add or override with kwargs: + kw = dict(markeredgewidth=self.get_linewidths()[0], alpha=self.get_alpha()) + kw.update(kwargs) + + for val, lab in zip(values, label_values): + color, size = _get_color_and_size(val) + h = mlines.Line2D( + [0], [0], ls="", color=color, ms=size, marker=self.get_paths()[0], **kw + ) + handles.append(h) + labels.append(fmt(lab)) + + return handles, labels + + +def _legend_add_subtitle(handles, labels, text, func): + """Add a subtitle to legend handles.""" + if text and len(handles) > 1: + # Create a blank handle that's not visible, the + # invisibillity will be used to discern which are subtitles + # or not: + blank_handle = func([], [], label=text) + blank_handle.set_visible(False) + + # Subtitles are shown first: + handles = [blank_handle] + handles + labels = [text] + labels + + return handles, labels + + +def _adjust_legend_subtitles(legend): + """Make invisible-handle "subtitles" entries look more like titles.""" + plt = import_matplotlib_pyplot() + + # Legend title not in rcParams until 3.0 + font_size = plt.rcParams.get("legend.title_fontsize", None) + hpackers = legend.findobj(plt.matplotlib.offsetbox.VPacker)[0].get_children() + for hpack in hpackers: + draw_area, text_area = hpack.get_children() + handles = draw_area.get_children() + + # Assume that all artists that are not visible are + # subtitles: + if not all(artist.get_visible() for artist in handles): + # Remove the dummy marker which will bring the text + # more to the center: + draw_area.set_width(0) + for text in text_area.get_children(): + if font_size is not None: + # The sutbtitles should have the same font size + # as normal legend titles: + text.set_size(font_size) From 7017d95a673a2aecc1d7efd09e7620b81bc0f8cf Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 26 Mar 2021 19:47:20 +0100 Subject: [PATCH 34/46] simplify with functools.partial --- xarray/plot/plot.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index f97867706f7..76dd5679c95 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -776,24 +776,28 @@ def scatter( if add_legend: - def to_label(d, key): + def to_label(x, data, key): """Map prop values back to its original values.""" - - def _to_label(x): - if key in d: - # Use reindex to be less sensitive to float errors. - # Return as numpy array since legend_elements - # seems to require that: - return d[key].reindex(x, method="nearest").to_numpy() - else: - return x - - return _to_label + if key in data: + # Use reindex to be less sensitive to float errors. + # Return as numpy array since legend_elements + # seems to require that: + return data[key].reindex(x, method="nearest").to_numpy() + else: + return x handles, labels = [], [] for subtitle, prop, func in [ - (_data["hue_label"], "colors", to_label(_data, "hue_to_label")), - (_data["size_label"], "sizes", to_label(_data, "size_to_label")), + ( + _data["hue_label"], + "colors", + functools.partial(to_label, data=_data, key="hue_to_label"), + ), + ( + _data["size_label"], + "sizes", + functools.partial(to_label, data=_data, key="size_to_label"), + ), ]: if subtitle: # Get legend handles and labels that displays the From 343b30ca29678f8ecab6d938a53e6ad90a217915 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 26 Mar 2021 20:16:10 +0100 Subject: [PATCH 35/46] simplify some more --- xarray/plot/plot.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 76dd5679c95..f368bfedf61 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -132,11 +132,7 @@ def _infer_scatter_data( {k: darray[v] for k, v in dict(x=x, z=z).items() if v is not None} ) to_broadcast.update( - { - key: darray[value] - for key, value in dict(hue=hue, size=size).items() - if value in darray.dims - } + {k: darray[v] for k, v in dict(hue=hue, size=size).items() if v in darray.dims} ) broadcasted = dict(zip(to_broadcast.keys(), broadcast(*(to_broadcast.values())))) @@ -776,7 +772,7 @@ def scatter( if add_legend: - def to_label(x, data, key): + def to_label(data, key, x): """Map prop values back to its original values.""" if key in data: # Use reindex to be less sensitive to float errors. @@ -791,12 +787,12 @@ def to_label(x, data, key): ( _data["hue_label"], "colors", - functools.partial(to_label, data=_data, key="hue_to_label"), + functools.partial(to_label, _data, "hue_to_label"), ), ( _data["size_label"], "sizes", - functools.partial(to_label, data=_data, key="size_to_label"), + functools.partial(to_label, _data, "size_to_label"), ), ]: if subtitle: From 131aecafd6bf691ae1a0130e9e5e5686daec566d Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 27 Mar 2021 11:41:37 +0100 Subject: [PATCH 36/46] activate 3d plotting by importing The Axes3d import should be removed when min req matplolib >= 3.2. --- ci/requirements/py37-min-all-deps.yml | 2 +- xarray/plot/plot.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/ci/requirements/py37-min-all-deps.yml b/ci/requirements/py37-min-all-deps.yml index 6eb4e938de8..7d04f431935 100644 --- a/ci/requirements/py37-min-all-deps.yml +++ b/ci/requirements/py37-min-all-deps.yml @@ -23,7 +23,7 @@ dependencies: - hypothesis - iris=2.4 - lxml=4.5 # Optional dep of pydap - - matplotlib-base=3.2 + - matplotlib-base=3.1 - nc-time-axis=1.2 # netcdf follows a 1.major.minor[.patch] convention (see https://github.com/Unidata/netcdf4-python/issues/1090) # bumping the netCDF4 version is currently blocked by #4491 diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index f368bfedf61..8265555dd86 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -78,7 +78,7 @@ def _determine_array(darray, name, array_style): # copied from seaborn def _parse_size(data, norm, width): """ - Determine what type of data it is. Then normalize is it to width. + Determine what type of data it is. Then normalize it to width. If the data is categorical, normalize it to numbers. """ @@ -87,7 +87,7 @@ def _parse_size(data, norm, width): if data is None: return None - data = data.values.flatten() + data = data.values.ravel() if not _is_numeric(data): # Data is categorical. @@ -699,6 +699,10 @@ def scatter( figsize = kwargs.pop("figsize", None) subplot_kws = dict() if z is not None and ax is None: + # TODO: Importing Axes3D is not necessary in matplotlib >= 3.2. + # Remove when minimum requirement of matplotlib is 3.2: + from mpl_toolkits.mplot3d import Axes3D # NOQA + subplot_kws.update(projection="3d") ax = get_axis(figsize, size, aspect, ax, **subplot_kws) From 965d610a49f198f069fec091aa36c6d3030916a0 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 27 Mar 2021 11:57:28 +0100 Subject: [PATCH 37/46] Update plot.py --- xarray/plot/plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 8265555dd86..da0ac4118e2 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -701,7 +701,7 @@ def scatter( if z is not None and ax is None: # TODO: Importing Axes3D is not necessary in matplotlib >= 3.2. # Remove when minimum requirement of matplotlib is 3.2: - from mpl_toolkits.mplot3d import Axes3D # NOQA + from mpl_toolkits.mplot3d import Axes3D # type: ignore # noqa subplot_kws.update(projection="3d") ax = get_axis(figsize, size, aspect, ax, **subplot_kws) From 1d2a61c35e1011c588a2812b12bf59373230ac30 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 1 Apr 2021 19:10:11 +0200 Subject: [PATCH 38/46] Update test_plot.py --- xarray/tests/test_plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index d230a3efa1d..6804fff9c18 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2654,6 +2654,7 @@ def test_get_axis_cartopy(): with figure_context(): ax = get_axis(**kwargs) assert isinstance(ax, ctpy.mpl.geoaxes.GeoAxesSubplot) + assert isinstance(ax, cartopy.mpl.geoaxes.GeoAxesSubplot) @requires_matplotlib @@ -2692,4 +2693,3 @@ def test_datarray_scatter(x, y, z, hue, markersize, row, col, add_legend, add_co add_legend=add_legend, add_colorbar=add_colorbar, ) - assert isinstance(ax, cartopy.mpl.geoaxes.GeoAxesSubplot) From 4fff2749365fb9b100f6f3198a23d40aaa5e571a Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 1 Apr 2021 19:16:09 +0200 Subject: [PATCH 39/46] Update test_plot.py --- xarray/tests/test_plot.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 6804fff9c18..255ef2e5a30 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2653,7 +2653,6 @@ def test_get_axis_cartopy(): kwargs = {"projection": cartopy.crs.PlateCarree()} with figure_context(): ax = get_axis(**kwargs) - assert isinstance(ax, ctpy.mpl.geoaxes.GeoAxesSubplot) assert isinstance(ax, cartopy.mpl.geoaxes.GeoAxesSubplot) From 1626174be9b95ae4de09a4f16bc7d01e7755b6fb Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 20 Apr 2021 21:27:56 +0200 Subject: [PATCH 40/46] suggestions from review --- xarray/plot/plot.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 5463d168283..ebf7e26becc 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -37,7 +37,7 @@ _MARKERSIZE_RANGE = np.array([18.0, 72.0]) -def _infer_meta_data(darray, x, z, hue, hue_style, size): +def _infer_scatter_metadata(darray, x, z, hue, hue_style, size): def _determine_array(darray, name, array_style): """Find and determine what type of array it is.""" array = darray[name] @@ -708,7 +708,7 @@ def scatter( subplot_kws.update(projection="3d") ax = get_axis(figsize, size, aspect, ax, **subplot_kws) - _data = _infer_meta_data(darray, x, z, hue, hue_style, _sizes) + _data = _infer_scatter_metadata(darray, x, z, hue, hue_style, _sizes) add_guide = kwargs.pop("add_guide", None) if add_legend is not None: From f8de23ab7ea05a18662fe540263c421080ec104f Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 18 May 2021 21:14:19 +0200 Subject: [PATCH 41/46] view_init(30, 30) --- xarray/plot/plot.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 5314741f12d..3781b7f6a8b 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -706,7 +706,12 @@ def scatter( from mpl_toolkits.mplot3d import Axes3D # type: ignore # noqa subplot_kws.update(projection="3d") - ax = get_axis(figsize, size, aspect, ax, **subplot_kws) + ax = get_axis(figsize, size, aspect, ax, **subplot_kws) + # Using 30, 30 minimizes rotation of the plot. Making it easier to + # build on your intuition from 2D plots: + ax.view_init(azim=30, elev=30) + else: + ax = get_axis(figsize, size, aspect, ax, **subplot_kws) _data = _infer_scatter_metadata(darray, x, z, hue, hue_style, _sizes) From fda0310e967f9a1bd97ddf46ddc90255490e33b9 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 28 May 2021 07:41:14 +0200 Subject: [PATCH 42/46] Add support for new stuff in matplotlib master --- xarray/plot/plot.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 3781b7f6a8b..1beb369059d 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -6,6 +6,7 @@ DataArray.plot._____ Dataset.plot._____ """ +from distutils.version import LooseVersion import functools import numpy as np @@ -675,6 +676,8 @@ def scatter( **kwargs : optional Additional keyword arguments to matplotlib """ + plt = import_matplotlib_pyplot() + # Handle facetgrids first if row or col: allargs = locals().copy() @@ -709,7 +712,11 @@ def scatter( ax = get_axis(figsize, size, aspect, ax, **subplot_kws) # Using 30, 30 minimizes rotation of the plot. Making it easier to # build on your intuition from 2D plots: - ax.view_init(azim=30, elev=30) + if LooseVersion(plt.matplotlib.__version__) < "3.5.0": + ax.view_init(azim=30, elev=30) + else: + # https://github.com/matplotlib/matplotlib/pull/19873 + ax.view_init(azim=30, elev=30, vertical_axis="y") else: ax = get_axis(figsize, size, aspect, ax, **subplot_kws) @@ -759,10 +766,17 @@ def scatter( if _data["size"] is not None: kwargs.update(s=_data["size"].values.ravel()) - # Plot the data. 3d plots has the z value in upward direction - # instead of y. To make jumping between 2d and 3d easy and intuitive - # switch the order so that z is shown in the depthwise direction: - axis_order = ["x", "z", "y"] + if LooseVersion(plt.matplotlib.__version__) < "3.5.0": + # Plot the data. 3d plots has the z value in upward direction + # instead of y. To make jumping between 2d and 3d easy and intuitive + # switch the order so that z is shown in the depthwise direction: + axis_order = ["x", "z", "y"] + else: + # Switching axis order not needed in 3.5.0, can also simplify the code + # that uses axis_order: + # https://github.com/matplotlib/matplotlib/pull/19873 + axis_order = ["x", "y", "z"] + primitive = ax.scatter( *[ _data[v].values.ravel() From 6eb190b06ae1417b6a86480aad9c253f65219209 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 28 May 2021 07:43:44 +0200 Subject: [PATCH 43/46] Update plot.py --- xarray/plot/plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 1beb369059d..ef5bd8904ea 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -6,8 +6,8 @@ DataArray.plot._____ Dataset.plot._____ """ -from distutils.version import LooseVersion import functools +from distutils.version import LooseVersion import numpy as np import pandas as pd From 63b064fd3a8123f6deb507c27003b735837b0e00 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 24 Jun 2021 18:33:06 +0200 Subject: [PATCH 44/46] fix merge errors --- xarray/tests/test_plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 77cf6826a02..b34f31f2648 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2781,7 +2781,7 @@ def test_get_axis_cartopy(): ax = get_axis(**kwargs) assert isinstance(ax, cartopy.mpl.geoaxes.GeoAxesSubplot) - +@requires_matplotlib def test_maybe_gca(): with figure_context(): From 530c58087025f7fa2f589cb5480a0f00aa8def76 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 24 Jun 2021 18:33:56 +0200 Subject: [PATCH 45/46] Update test_plot.py --- xarray/tests/test_plot.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index b34f31f2648..d24b12c39d3 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2781,6 +2781,7 @@ def test_get_axis_cartopy(): ax = get_axis(**kwargs) assert isinstance(ax, cartopy.mpl.geoaxes.GeoAxesSubplot) + @requires_matplotlib def test_maybe_gca(): From 85e4ecdb5fcb282b72853174152d8ae457516250 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 30 Jun 2021 17:24:14 +0200 Subject: [PATCH 46/46] privatize scatter plot for now. --- xarray/plot/plot.py | 2 +- xarray/tests/test_plot.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 718d9c6a82c..e5032c3729a 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -883,7 +883,7 @@ def step(self, *args, **kwargs): return step(self._da, *args, **kwargs) @functools.wraps(scatter) - def scatter(self, *args, **kwargs): + def _scatter(self, *args, **kwargs): return scatter(self._da, *args, **kwargs) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index d24b12c39d3..c7f363bbab2 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2838,7 +2838,7 @@ def test_datarray_scatter(x, y, z, hue, markersize, row, col, add_legend, add_co darray = xr.DataArray(ds[y], coords=coords) with figure_context(): - darray.plot.scatter( + darray.plot._scatter( x=x, z=z, hue=hue,