Skip to content

Commit 540300d

Browse files
committed
Allow RGB[A] dim for imshow to be in any order
Includes new `rgb` keyword to tell imshow about that dimension, and much error handling in inference.
1 parent ad00933 commit 540300d

File tree

4 files changed

+95
-16
lines changed

4 files changed

+95
-16
lines changed

xarray/plot/facetgrid.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,9 @@ def map_dataarray(self, func, x, y, **kwargs):
239239
func_kwargs.update({'add_colorbar': False, 'add_labels': False})
240240

241241
# Get x, y labels for the first subplot
242-
x, y = _infer_xy_labels(darray=self.data.loc[self.name_dicts.flat[0]],
243-
x=x, y=y, imshow=func.__name__ == 'imshow')
242+
x, y = _infer_xy_labels(
243+
darray=self.data.loc[self.name_dicts.flat[0]], x=x, y=y,
244+
imshow=func.__name__ == 'imshow', rgb=kwargs.get('rgb', None))
244245

245246
for d, ax in zip(self.name_dicts.flat, self.axes.flat):
246247
# None is the sentinel value

xarray/plot/plot.py

+19-7
Original file line numberDiff line numberDiff line change
@@ -468,13 +468,13 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
468468
"Use colors keyword instead.",
469469
DeprecationWarning, stacklevel=3)
470470

471-
xlab, ylab = _infer_xy_labels(darray=darray, x=x, y=y,
472-
imshow=plotfunc.__name__ == 'imshow')
471+
rgb = kwargs.pop('rgb', None)
472+
xlab, ylab = _infer_xy_labels(
473+
darray=darray, x=x, y=y, imshow=imshow_rgb, rgb=rgb)
473474

474475
# better to pass the ndarrays directly to plotting functions
475476
xval = darray[xlab].values
476477
yval = darray[ylab].values
477-
zval = darray.to_masked_array(copy=False)
478478

479479
# check if we need to broadcast one dimension
480480
if xval.ndim < yval.ndim:
@@ -485,8 +485,19 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
485485

486486
# May need to transpose for correct x, y labels
487487
# xlab may be the name of a coord, we have to check for dim names
488-
if darray[xlab].dims[-1] == darray.dims[0]:
489-
zval = zval.T
488+
if imshow_rgb:
489+
# For RGB[A] images, matplotlib requires the color dimension
490+
# to be last. In Xarray the order should be unimportant, so
491+
# we transpose to (y, x, color) to make this work.
492+
yx_dims = (ylab, xlab)
493+
dims = yx_dims + tuple(d for d in darray.dims if d not in yx_dims)
494+
if dims != darray.dims:
495+
darray = darray.transpose(*dims)
496+
elif darray[xlab].dims[-1] == darray.dims[0]:
497+
darray = darray.transpose()
498+
499+
# Pass the data as a masked ndarray too
500+
zval = darray.to_masked_array(copy=False)
490501

491502
_ensure_plottable(xval, yval)
492503

@@ -591,8 +602,9 @@ def imshow(x, y, z, ax, **kwargs):
591602
Wraps :func:`matplotlib:matplotlib.pyplot.imshow`
592603
593604
While other plot methods require the DataArray to be strictly
594-
two-dimensional, ``imshow`` also accepts a 3D array where the third
595-
dimension can be interpreted as RGB or RGBA color channels.
605+
two-dimensional, ``imshow`` also accepts a 3D array where some
606+
dimension can be interpreted as RGB or RGBA color channels and
607+
allows this dimension to be specified via the kwarg ``rgb=``.
596608
In this case, ``robust=True`` will saturate the image in the
597609
usual way, consistenly between all bands and facets.
598610

xarray/plot/utils.py

+58-7
Original file line numberDiff line numberDiff line change
@@ -254,21 +254,72 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
254254
levels=levels, norm=norm)
255255

256256

257-
def _infer_xy_labels(darray, x, y, imshow=False):
257+
def _infer_xy_labels_3d(darray, x, y, rgb):
258+
"""
259+
Determine x and y labels for showing RGB images.
260+
261+
Attempts to infer which dimension is RGB/RGBA by size and order of dims.
262+
263+
"""
264+
# Start by detecting and reporting invalid combinations of arguments
265+
assert darray.ndim == 3
266+
not_none = [a for a in (x, y, rgb) if a is not None]
267+
if len(set(not_none)) < len(not_none):
268+
raise ValueError('Dimensions passed as x, y, and rgb must be unique.')
269+
for label in not_none:
270+
if label not in darray.dims:
271+
raise ValueError('%r is not a dimension' % (label,))
272+
273+
# Then calculate rgb dimension if certain and check validity
274+
could_be_color = [label for label in darray.dims
275+
if darray[label].size in (3, 4) and label not in (x, y)]
276+
if rgb is None and not could_be_color:
277+
raise ValueError(
278+
'A 3-dimensional array was passed to imshow(), but there is no '
279+
'dimension that could be color. At least one dimension must be '
280+
'of size 3 (RGB) or 4 (RGBA), and not given as x or y.')
281+
if rgb is None and len(could_be_color) == 1:
282+
rgb = could_be_color[0]
283+
if rgb is not None and darray[rgb].size not in (3, 4):
284+
raise ValueError('Cannot interpret dim %r of size %s as RGB or RGBA.'
285+
% (rgb, darray[rgb].size))
286+
287+
# If rgb dimension is still unknown, there must be two or three dimensions
288+
# in could_be_color. We therefore warn, and use a heuristic to break ties.
289+
if rgb is None:
290+
assert len(could_be_color) in (2, 3)
291+
if darray.dims[-1] in could_be_color:
292+
rgb = darray.dims[-1]
293+
warnings.warn(
294+
'Several dimensions of this array could be colors. Xarray '
295+
'will use the last dimension (%r) to match '
296+
'matplotlib.pyplot.imshow. You can pass names of x, y, '
297+
'and/or rgb dimensions to override this guess.' % rgb)
298+
else:
299+
rgb = darray.dims[0]
300+
warnings.warn(
301+
'%r has been selected as the color dimension, but %r would '
302+
'also be valid. Pass names of x, y, and/or rgb dimensions to '
303+
'override this guess.' % darray.dims[:2])
304+
assert rgb is not None
305+
306+
# Finally, we pick out the red slice and delegate to the 2D version:
307+
return _infer_xy_labels(darray.isel(**{rgb: 0}).squeeze(), x, y)
308+
309+
310+
def _infer_xy_labels(darray, x, y, imshow=False, rgb=None):
258311
"""
259312
Determine x and y labels. For use in _plot2d
260313
261314
darray must be a 2 dimensional data array, or 3d for imshow only.
262315
"""
316+
if imshow and darray.ndim == 3:
317+
return _infer_xy_labels_3d(darray, x, y, rgb)
263318

264319
if x is None and y is None:
265320
if darray.ndim != 2:
266-
if not imshow:
267-
raise ValueError('DataArray must be 2d')
268-
elif darray.ndim != 3 or darray.shape[2] not in (3, 4):
269-
raise ValueError('DataArray for imshow must be 2d, MxNx3 for '
270-
'RGB image, or MxNx4 for RGBA image.')
271-
y, x, *_ = darray.dims
321+
raise ValueError('DataArray must be 2d')
322+
y, x = darray.dims
272323
elif x is None:
273324
if y not in darray.dims:
274325
raise ValueError('y must be a dimension name if x is not supplied')

xarray/tests/test_plot.py

+15
Original file line numberDiff line numberDiff line change
@@ -1040,6 +1040,21 @@ def test_plot_rgb_faceted(self):
10401040
).plot.imshow(row='a', col='b')
10411041
self.assertEqual(0, len(find_possible_colorbars()))
10421042

1043+
def test_plot_rgba_image_transposed(self):
1044+
# We can handle the color axis being in any position
1045+
DataArray(
1046+
easy_array((4, 10, 15), start=0),
1047+
dims=['band', 'y', 'x'],
1048+
).plot.imshow()
1049+
1050+
def test_warns_ambigious_dim(self):
1051+
arr = DataArray(easy_array((3, 3, 3)), dims=['y', 'x', 'band'])
1052+
with pytest.warns(UserWarning):
1053+
arr.plot.imshow()
1054+
# but doesn't warn if dimensions specified
1055+
arr.plot.imshow(rgb='band')
1056+
arr.plot.imshow(x='x', y='y')
1057+
10431058

10441059
class TestFacetGrid(PlotTestCase):
10451060

0 commit comments

Comments
 (0)