Skip to content

Commit 289f95a

Browse files
Zac-HDshoyer
authored andcommitted
Support RGB[A] arrays in plot.imshow() (#1796)
* Allow RGB plots from DataArray.plot.imshow * Allow RGB[A] dim for imshow to be in any order Includes new `rgb` keyword to tell imshow about that dimension, and much error handling in inference. * Use true RGB color for Rasterio gallery page * Add whats-new entry
1 parent 049cbdd commit 289f95a

File tree

6 files changed

+158
-13
lines changed

6 files changed

+158
-13
lines changed

doc/gallery/plot_rasterio.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,10 @@
4444
da.coords['lon'] = (('y', 'x'), lon)
4545
da.coords['lat'] = (('y', 'x'), lat)
4646

47-
# Compute a greyscale out of the rgb image
48-
greyscale = da.mean(dim='band')
49-
5047
# Plot on a map
5148
ax = plt.subplot(projection=ccrs.PlateCarree())
52-
greyscale.plot(ax=ax, x='lon', y='lat', transform=ccrs.PlateCarree(),
53-
cmap='Greys_r', add_colorbar=False)
49+
da.plot.imshow(ax=ax, x='lon', y='lat', rgb='band',
50+
transform=ccrs.PlateCarree())
5451
ax.coastlines('10m', color='r')
5552
plt.show()
5653

doc/whats-new.rst

+2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ Enhancements
3333
By `Joe Hamman <https://github.com/jhamman>`_.
3434
- Support for using `Zarr`_ as storage layer for xarray.
3535
By `Ryan Abernathey <https://github.com/rabernat>`_.
36+
- :func:`xarray.plot.imshow` now handles RGB and RGBA images.
37+
By `Zac Hatfield-Dodds <https://github.com/Zac-HD>`_.
3638
- Experimental support for parsing ENVI metadata to coordinates and attributes
3739
in :py:func:`xarray.open_rasterio`.
3840
By `Matti Eskelinen <https://github.com/maaleske>`_.

xarray/plot/facetgrid.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,9 @@ def map_dataarray(self, func, x, y, **kwargs):
239239
func_kwargs.update({'add_colorbar': False, 'add_labels': False})
240240

241241
# Get x, y labels for the first subplot
242-
x, y = _infer_xy_labels(darray=self.data.loc[self.name_dicts.flat[0]],
243-
x=x, y=y)
242+
x, y = _infer_xy_labels(
243+
darray=self.data.loc[self.name_dicts.flat[0]], x=x, y=y,
244+
imshow=func.__name__ == 'imshow', rgb=kwargs.get('rgb', None))
244245

245246
for d, ax in zip(self.name_dicts.flat, self.axes.flat):
246247
# None is the sentinel value

xarray/plot/plot.py

+43-4
Original file line numberDiff line numberDiff line change
@@ -443,10 +443,17 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
443443
# Decide on a default for the colorbar before facetgrids
444444
if add_colorbar is None:
445445
add_colorbar = plotfunc.__name__ != 'contour'
446+
imshow_rgb = (
447+
plotfunc.__name__ == 'imshow' and
448+
darray.ndim == (3 + (row is not None) + (col is not None)))
449+
if imshow_rgb:
450+
# Don't add a colorbar when showing an image with explicit colors
451+
add_colorbar = False
446452

447453
# Handle facetgrids first
448454
if row or col:
449455
allargs = locals().copy()
456+
allargs.pop('imshow_rgb')
450457
allargs.update(allargs.pop('kwargs'))
451458

452459
# Need the decorated plotting function
@@ -470,12 +477,19 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
470477
"Use colors keyword instead.",
471478
DeprecationWarning, stacklevel=3)
472479

473-
xlab, ylab = _infer_xy_labels(darray=darray, x=x, y=y)
480+
rgb = kwargs.pop('rgb', None)
481+
xlab, ylab = _infer_xy_labels(
482+
darray=darray, x=x, y=y, imshow=imshow_rgb, rgb=rgb)
483+
484+
if rgb is not None and plotfunc.__name__ != 'imshow':
485+
raise ValueError('The "rgb" keyword is only valid for imshow()')
486+
elif rgb is not None and not imshow_rgb:
487+
raise ValueError('The "rgb" keyword is only valid for imshow()'
488+
'with a three-dimensional array (per facet)')
474489

475490
# better to pass the ndarrays directly to plotting functions
476491
xval = darray[xlab].values
477492
yval = darray[ylab].values
478-
zval = darray.to_masked_array(copy=False)
479493

480494
# check if we need to broadcast one dimension
481495
if xval.ndim < yval.ndim:
@@ -486,8 +500,19 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
486500

487501
# May need to transpose for correct x, y labels
488502
# xlab may be the name of a coord, we have to check for dim names
489-
if darray[xlab].dims[-1] == darray.dims[0]:
490-
zval = zval.T
503+
if imshow_rgb:
504+
# For RGB[A] images, matplotlib requires the color dimension
505+
# to be last. In Xarray the order should be unimportant, so
506+
# we transpose to (y, x, color) to make this work.
507+
yx_dims = (ylab, xlab)
508+
dims = yx_dims + tuple(d for d in darray.dims if d not in yx_dims)
509+
if dims != darray.dims:
510+
darray = darray.transpose(*dims)
511+
elif darray[xlab].dims[-1] == darray.dims[0]:
512+
darray = darray.transpose()
513+
514+
# Pass the data as a masked ndarray too
515+
zval = darray.to_masked_array(copy=False)
491516

492517
_ensure_plottable(xval, yval)
493518

@@ -595,6 +620,11 @@ def imshow(x, y, z, ax, **kwargs):
595620
596621
Wraps :func:`matplotlib:matplotlib.pyplot.imshow`
597622
623+
While other plot methods require the DataArray to be strictly
624+
two-dimensional, ``imshow`` also accepts a 3D array where some
625+
dimension can be interpreted as RGB or RGBA color channels and
626+
allows this dimension to be specified via the kwarg ``rgb=``.
627+
598628
.. note::
599629
This function needs uniformly spaced coordinates to
600630
properly label the axes. Call DataArray.plot() to check.
@@ -632,6 +662,15 @@ def imshow(x, y, z, ax, **kwargs):
632662
# Allow user to override these defaults
633663
defaults.update(kwargs)
634664

665+
if z.ndim == 3:
666+
# matplotlib imshow uses black for missing data, but Xarray makes
667+
# missing data transparent. We therefore add an alpha channel if
668+
# there isn't one, and set it to transparent where data is masked.
669+
if z.shape[-1] == 3:
670+
z = np.ma.concatenate((z, np.ma.ones(z.shape[:2] + (1,))), 2)
671+
z = z.copy()
672+
z[np.any(z.mask, axis=-1), -1] = 0
673+
635674
primitive = ax.imshow(z, **defaults)
636675

637676
return primitive

xarray/plot/utils.py

+55-2
Original file line numberDiff line numberDiff line change
@@ -258,12 +258,65 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
258258
levels=levels, norm=norm)
259259

260260

261-
def _infer_xy_labels(darray, x, y):
261+
def _infer_xy_labels_3d(darray, x, y, rgb):
262+
"""
263+
Determine x and y labels for showing RGB images.
264+
265+
Attempts to infer which dimension is RGB/RGBA by size and order of dims.
266+
267+
"""
268+
assert rgb is None or rgb != x
269+
assert rgb is None or rgb != y
270+
# Start by detecting and reporting invalid combinations of arguments
271+
assert darray.ndim == 3
272+
not_none = [a for a in (x, y, rgb) if a is not None]
273+
if len(set(not_none)) < len(not_none):
274+
raise ValueError(
275+
'Dimension names must be None or unique strings, but imshow was '
276+
'passed x=%r, y=%r, and rgb=%r.' % (x, y, rgb))
277+
for label in not_none:
278+
if label not in darray.dims:
279+
raise ValueError('%r is not a dimension' % (label,))
280+
281+
# Then calculate rgb dimension if certain and check validity
282+
could_be_color = [label for label in darray.dims
283+
if darray[label].size in (3, 4) and label not in (x, y)]
284+
if rgb is None and not could_be_color:
285+
raise ValueError(
286+
'A 3-dimensional array was passed to imshow(), but there is no '
287+
'dimension that could be color. At least one dimension must be '
288+
'of size 3 (RGB) or 4 (RGBA), and not given as x or y.')
289+
if rgb is None and len(could_be_color) == 1:
290+
rgb = could_be_color[0]
291+
if rgb is not None and darray[rgb].size not in (3, 4):
292+
raise ValueError('Cannot interpret dim %r of size %s as RGB or RGBA.'
293+
% (rgb, darray[rgb].size))
294+
295+
# If rgb dimension is still unknown, there must be two or three dimensions
296+
# in could_be_color. We therefore warn, and use a heuristic to break ties.
297+
if rgb is None:
298+
assert len(could_be_color) in (2, 3)
299+
rgb = could_be_color[-1]
300+
warnings.warn(
301+
'Several dimensions of this array could be colors. Xarray '
302+
'will use the last possible dimension (%r) to match '
303+
'matplotlib.pyplot.imshow. You can pass names of x, y, '
304+
'and/or rgb dimensions to override this guess.' % rgb)
305+
assert rgb is not None
306+
307+
# Finally, we pick out the red slice and delegate to the 2D version:
308+
return _infer_xy_labels(darray.isel(**{rgb: 0}).squeeze(), x, y)
309+
310+
311+
def _infer_xy_labels(darray, x, y, imshow=False, rgb=None):
262312
"""
263313
Determine x and y labels. For use in _plot2d
264314
265-
darray must be a 2 dimensional data array.
315+
darray must be a 2 dimensional data array, or 3d for imshow only.
266316
"""
317+
assert x is None or x != y
318+
if imshow and darray.ndim == 3:
319+
return _infer_xy_labels_3d(darray, x, y, rgb)
267320

268321
if x is None and y is None:
269322
if darray.ndim != 2:

xarray/tests/test_plot.py

+53
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,8 @@ def test_1d_raises_valueerror(self):
619619

620620
def test_3d_raises_valueerror(self):
621621
a = DataArray(easy_array((2, 3, 4)))
622+
if self.plotfunc.__name__ == 'imshow':
623+
pytest.skip()
622624
with raises_regex(ValueError, r'DataArray must be 2d'):
623625
self.plotfunc(a)
624626

@@ -670,6 +672,11 @@ def test_can_plot_axis_size_one(self):
670672
if self.plotfunc.__name__ not in ('contour', 'contourf'):
671673
self.plotfunc(DataArray(np.ones((1, 1))))
672674

675+
def test_disallows_rgb_arg(self):
676+
with pytest.raises(ValueError):
677+
# Always invalid for most plots. Invalid for imshow with 2D data.
678+
self.plotfunc(DataArray(np.ones((2, 2))), rgb='not None')
679+
673680
def test_viridis_cmap(self):
674681
cmap_name = self.plotmethod(cmap='viridis').get_cmap().name
675682
self.assertEqual('viridis', cmap_name)
@@ -1062,6 +1069,52 @@ def test_2d_coord_names(self):
10621069
with raises_regex(ValueError, 'requires 1D coordinates'):
10631070
self.plotmethod(x='x2d', y='y2d')
10641071

1072+
def test_plot_rgb_image(self):
1073+
DataArray(
1074+
easy_array((10, 15, 3), start=0),
1075+
dims=['y', 'x', 'band'],
1076+
).plot.imshow()
1077+
self.assertEqual(0, len(find_possible_colorbars()))
1078+
1079+
def test_plot_rgb_image_explicit(self):
1080+
DataArray(
1081+
easy_array((10, 15, 3), start=0),
1082+
dims=['y', 'x', 'band'],
1083+
).plot.imshow(y='y', x='x', rgb='band')
1084+
self.assertEqual(0, len(find_possible_colorbars()))
1085+
1086+
def test_plot_rgb_faceted(self):
1087+
DataArray(
1088+
easy_array((2, 2, 10, 15, 3), start=0),
1089+
dims=['a', 'b', 'y', 'x', 'band'],
1090+
).plot.imshow(row='a', col='b')
1091+
self.assertEqual(0, len(find_possible_colorbars()))
1092+
1093+
def test_plot_rgba_image_transposed(self):
1094+
# We can handle the color axis being in any position
1095+
DataArray(
1096+
easy_array((4, 10, 15), start=0),
1097+
dims=['band', 'y', 'x'],
1098+
).plot.imshow()
1099+
1100+
def test_warns_ambigious_dim(self):
1101+
arr = DataArray(easy_array((3, 3, 3)), dims=['y', 'x', 'band'])
1102+
with pytest.warns(UserWarning):
1103+
arr.plot.imshow()
1104+
# but doesn't warn if dimensions specified
1105+
arr.plot.imshow(rgb='band')
1106+
arr.plot.imshow(x='x', y='y')
1107+
1108+
def test_rgb_errors_too_many_dims(self):
1109+
arr = DataArray(easy_array((3, 3, 3, 3)), dims=['y', 'x', 'z', 'band'])
1110+
with pytest.raises(ValueError):
1111+
arr.plot.imshow(rgb='band')
1112+
1113+
def test_rgb_errors_bad_dim_sizes(self):
1114+
arr = DataArray(easy_array((5, 5, 5)), dims=['y', 'x', 'band'])
1115+
with pytest.raises(ValueError):
1116+
arr.plot.imshow(rgb='band')
1117+
10651118

10661119
class TestFacetGrid(PlotTestCase):
10671120

0 commit comments

Comments
 (0)