Skip to content
forked from pydata/xarray

Commit d430ae0

Browse files
committed
proper fix.
1 parent 7fd69be commit d430ae0

File tree

2 files changed

+18
-7
lines changed

2 files changed

+18
-7
lines changed

xarray/plot/plot.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -672,16 +672,22 @@ def newplotfunc(
672672

673673
# check if we need to broadcast one dimension
674674
if xval.ndim < yval.ndim:
675+
dims = darray[ylab].dims
675676
if xval.shape[0] == yval.shape[0]:
676677
xval = np.broadcast_to(xval[:, np.newaxis], yval.shape)
677678
else:
678679
xval = np.broadcast_to(xval[np.newaxis, :], yval.shape)
679680

680681
elif yval.ndim < xval.ndim:
682+
dims = darray[xlab].dims
681683
if yval.shape[0] == xval.shape[0]:
682684
yval = np.broadcast_to(yval[:, np.newaxis], xval.shape)
683685
else:
684686
yval = np.broadcast_to(yval[np.newaxis, :], xval.shape)
687+
elif xval.ndim == 2:
688+
dims = darray[xlab].dims
689+
else:
690+
dims = (darray[ylab].dims[0], darray[xlab].dims[0])
685691

686692
# May need to transpose for correct x, y labels
687693
# xlab may be the name of a coord, we have to check for dim names
@@ -691,10 +697,9 @@ def newplotfunc(
691697
# we transpose to (y, x, color) to make this work.
692698
yx_dims = (ylab, xlab)
693699
dims = yx_dims + tuple(d for d in darray.dims if d not in yx_dims)
694-
if dims != darray.dims:
695-
darray = darray.transpose(*dims, transpose_coords=True)
696-
elif darray[xlab].dims[-1] == darray.dims[0]:
697-
darray = darray.transpose(transpose_coords=True)
700+
701+
if dims != darray.dims:
702+
darray = darray.transpose(*dims, transpose_coords=True)
698703

699704
# Pass the data as a masked ndarray too
700705
zval = darray.to_masked_array(copy=False)

xarray/tests/test_plot.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -2168,7 +2168,13 @@ def test_plot_transposed_nondim_coord(plotfunc):
21682168
getattr(da.plot, plotfunc)(x="zt", y="x")
21692169

21702170

2171-
def test_plot_transposes_properly():
2171+
@requires_matplotlib
2172+
@pytest.mark.parametrize("plotfunc", ["pcolormesh", "imshow"])
2173+
def test_plot_transposes_properly(plotfunc):
2174+
# test that we aren't mistakenly transposing when the 2 dimensions have equal sizes.
21722175
da = xr.DataArray([np.sin(2 * np.pi / 10 * np.arange(10))] * 10, dims=("y", "x"))
2173-
hdl = da.plot(x="x", y="y")
2174-
assert np.all(hdl.get_array() == da.to_masked_array().ravel())
2176+
hdl = getattr(da.plot, plotfunc)(x="x", y="y")
2177+
# get_array doesn't work for contour, contourf. It returns the colormap intervals.
2178+
# pcolormesh returns 1D array but imshow returns a 2D array so it is necessary
2179+
# to ravel() on the LHS
2180+
assert np.all(hdl.get_array().ravel() == da.to_masked_array().ravel())

0 commit comments

Comments
 (0)