Skip to content

Commit 9ddc03b

Browse files
committed
Validation, tests for rgb imshow
1 parent 0226d4c commit 9ddc03b

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

xarray/plot/plot.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,8 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
453453
# Convert byte-arrays to float for correct display in matplotlib
454454
if darray.dtype == np.dtype('uint8'):
455455
darray = darray / 256.0
456-
# Manually stretch colors for robust cmap
456+
# Manually stretch colors for robust cmap. We have to do this
457+
# first so faceted plots are comparable between facets.
457458
if robust:
458459
flat = darray.values.ravel(order='K')
459460
flat = flat[~np.isnan(flat)]
@@ -503,6 +504,12 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
503504
xlab, ylab = _infer_xy_labels(
504505
darray=darray, x=x, y=y, imshow=imshow_rgb, rgb=rgb)
505506

507+
if rgb is not None and plotfunc.__name__ != 'imshow':
508+
raise ValueError('The "rgb" keyword is only valid for imshow()')
509+
elif rgb is not None and not imshow_rgb:
510+
raise ValueError('The "rgb" keyword is only valid for imshow()'
511+
'with a three-dimensional array (per facet)')
512+
506513
# better to pass the ndarrays directly to plotting functions
507514
xval = darray[xlab].values
508515
yval = darray[ylab].values

xarray/tests/test_plot.py

+22
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,11 @@ def test_can_plot_axis_size_one(self):
672672
if self.plotfunc.__name__ not in ('contour', 'contourf'):
673673
self.plotfunc(DataArray(np.ones((1, 1))))
674674

675+
def test_disallows_rgb_arg(self):
676+
with pytest.raises(ValueError):
677+
# Always invalid for most plots. Invalid for imshow with 2D data.
678+
self.plotfunc(DataArray(np.ones((2, 2))), rgb='not None')
679+
675680
def test_viridis_cmap(self):
676681
cmap_name = self.plotmethod(cmap='viridis').get_cmap().name
677682
self.assertEqual('viridis', cmap_name)
@@ -1071,6 +1076,13 @@ def test_plot_rgb_image(self):
10711076
).plot.imshow()
10721077
self.assertEqual(0, len(find_possible_colorbars()))
10731078

1079+
def test_plot_rgb_image_explicit(self):
1080+
DataArray(
1081+
easy_array((10, 15, 3), start=0),
1082+
dims=['y', 'x', 'band'],
1083+
).plot.imshow(y='y', x='x', rgb='band')
1084+
self.assertEqual(0, len(find_possible_colorbars()))
1085+
10741086
def test_plot_rgb_faceted(self):
10751087
DataArray(
10761088
easy_array((2, 2, 10, 15, 3), start=0),
@@ -1093,6 +1105,16 @@ def test_warns_ambigious_dim(self):
10931105
arr.plot.imshow(rgb='band')
10941106
arr.plot.imshow(x='x', y='y')
10951107

1108+
def test_rgb_errors_too_many_dims(self):
1109+
arr = DataArray(easy_array((3, 3, 3, 3)), dims=['y', 'x', 'z', 'band'])
1110+
with pytest.raises(ValueError):
1111+
arr.plot.imshow(rgb='band')
1112+
1113+
def test_rgb_errors_bad_dim_sizes(self):
1114+
arr = DataArray(easy_array((5, 5, 5)), dims=['y', 'x', 'band'])
1115+
with pytest.raises(ValueError):
1116+
arr.plot.imshow(rgb='band')
1117+
10961118

10971119
class TestFacetGrid(PlotTestCase):
10981120

0 commit comments

Comments
 (0)