diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 25133df8a29..3df670aa8c4 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -97,6 +97,9 @@ Bug fixes By `Justus Magin `_. - Fixed errors emitted by ``mypy --strict`` in modules that import xarray. (:issue:`3695`) by `Guido Imperiale `_. +- Fix plotting of binned coordinates on the y axis in :py:meth:`DataArray.plot` + (line) and :py:meth:`DataArray.plot.step` plots (:issue:`#3571`, + :pull:`3685`) by `Julien Seguinot _`. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index d38c9765352..b4802f6194b 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -17,13 +17,11 @@ _ensure_plottable, _infer_interval_breaks, _infer_xy_labels, - _interval_to_double_bound_points, - _interval_to_mid_points, _process_cmap_cbar_kwargs, _rescale_imshow_rgb, + _resolve_intervals_1dplot, _resolve_intervals_2dplot, _update_axes, - _valid_other_type, get_axis, import_matplotlib_pyplot, label_from_attrs, @@ -296,29 +294,10 @@ def line( ax = get_axis(figsize, size, aspect, ax) xplt, yplt, hueplt, xlabel, ylabel, hue_label = _infer_line_data(darray, x, y, hue) - # Remove pd.Intervals if contained in xplt.values. - if _valid_other_type(xplt.values, [pd.Interval]): - # Is it a step plot? (see matplotlib.Axes.step) - if kwargs.get("linestyle", "").startswith("steps-"): - xplt_val, yplt_val = _interval_to_double_bound_points( - xplt.values, yplt.values - ) - # Remove steps-* to be sure that matplotlib is not confused - kwargs["linestyle"] = ( - kwargs["linestyle"] - .replace("steps-pre", "") - .replace("steps-post", "") - .replace("steps-mid", "") - ) - if kwargs["linestyle"] == "": - del kwargs["linestyle"] - else: - xplt_val = _interval_to_mid_points(xplt.values) - yplt_val = yplt.values - xlabel += "_center" - else: - xplt_val = xplt.values - yplt_val = yplt.values + # Remove pd.Intervals if contained in xplt.values and/or yplt.values. + xplt_val, yplt_val, xlabel, ylabel, kwargs = _resolve_intervals_1dplot( + xplt.values, yplt.values, xlabel, ylabel, kwargs + ) _ensure_plottable(xplt_val, yplt_val) @@ -367,7 +346,7 @@ def step(darray, *args, where="pre", linestyle=None, ls=None, **kwargs): every *x* position, i.e. the interval ``[x[i], x[i+1])`` has the value ``y[i]``. - 'mid': Steps occur half-way between the *x* positions. - Note that this parameter is ignored if the x coordinate consists of + Note that this parameter is ignored if one coordinate consists of :py:func:`pandas.Interval` values, e.g. as a result of :py:func:`xarray.Dataset.groupby_bins`. In this case, the actual boundaries of the interval are used. diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 3b739197fea..c900dfeff3e 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -453,6 +453,42 @@ def _interval_to_double_bound_points(xarray, yarray): return xarray, yarray +def _resolve_intervals_1dplot(xval, yval, xlabel, ylabel, kwargs): + """ + Helper function to replace the values of x and/or y coordinate arrays + containing pd.Interval with their mid-points or - for step plots - double + points which double the length. + """ + + # Is it a step plot? (see matplotlib.Axes.step) + if kwargs.get("linestyle", "").startswith("steps-"): + + # Convert intervals to double points + if _valid_other_type(np.array([xval, yval]), [pd.Interval]): + raise TypeError("Can't step plot intervals against intervals.") + if _valid_other_type(xval, [pd.Interval]): + xval, yval = _interval_to_double_bound_points(xval, yval) + if _valid_other_type(yval, [pd.Interval]): + yval, xval = _interval_to_double_bound_points(yval, xval) + + # Remove steps-* to be sure that matplotlib is not confused + del kwargs["linestyle"] + + # Is it another kind of plot? + else: + + # Convert intervals to mid points and adjust labels + if _valid_other_type(xval, [pd.Interval]): + xval = _interval_to_mid_points(xval) + xlabel += "_center" + if _valid_other_type(yval, [pd.Interval]): + yval = _interval_to_mid_points(yval) + ylabel += "_center" + + # return converted arguments + return xval, yval, xlabel, ylabel, kwargs + + def _resolve_intervals_2dplot(val, func_name): """ Helper function to replace the values of a coordinate array containing diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index a5402d88f3e..71cb119f0d6 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -426,9 +426,25 @@ def test_convenient_facetgrid_4d(self): d.plot(x="x", y="y", col="columns", ax=plt.gca()) def test_coord_with_interval(self): + """Test line plot with intervals.""" bins = [-1, 0, 1, 2] self.darray.groupby_bins("dim_0", bins).mean(...).plot() + def test_coord_with_interval_x(self): + """Test line plot with intervals explicitly on x axis.""" + bins = [-1, 0, 1, 2] + self.darray.groupby_bins("dim_0", bins).mean(...).plot(x="dim_0_bins") + + def test_coord_with_interval_y(self): + """Test line plot with intervals explicitly on y axis.""" + bins = [-1, 0, 1, 2] + self.darray.groupby_bins("dim_0", bins).mean(...).plot(y="dim_0_bins") + + def test_coord_with_interval_xy(self): + """Test line plot with intervals on both x and y axes.""" + bins = [-1, 0, 1, 2] + self.darray.groupby_bins("dim_0", bins).mean(...).dim_0_bins.plot() + class TestPlot1D(PlotTestCase): @pytest.fixture(autouse=True) @@ -511,10 +527,23 @@ def test_step(self): self.darray[0, 0].plot.step() def test_coord_with_interval_step(self): + """Test step plot with intervals.""" bins = [-1, 0, 1, 2] self.darray.groupby_bins("dim_0", bins).mean(...).plot.step() assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) + def test_coord_with_interval_step_x(self): + """Test step plot with intervals explicitly on x axis.""" + bins = [-1, 0, 1, 2] + self.darray.groupby_bins("dim_0", bins).mean(...).plot.step(x="dim_0_bins") + assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) + + def test_coord_with_interval_step_y(self): + """Test step plot with intervals explicitly on y axis.""" + bins = [-1, 0, 1, 2] + self.darray.groupby_bins("dim_0", bins).mean(...).plot.step(y="dim_0_bins") + assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) + class TestPlotHistogram(PlotTestCase): @pytest.fixture(autouse=True)