@@ -443,10 +443,17 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
443
443
# Decide on a default for the colorbar before facetgrids
444
444
if add_colorbar is None :
445
445
add_colorbar = plotfunc .__name__ != 'contour'
446
+ imshow_rgb = (
447
+ plotfunc .__name__ == 'imshow' and
448
+ darray .ndim == (3 + (row is not None ) + (col is not None )))
449
+ if imshow_rgb :
450
+ # Don't add a colorbar when showing an image with explicit colors
451
+ add_colorbar = False
446
452
447
453
# Handle facetgrids first
448
454
if row or col :
449
455
allargs = locals ().copy ()
456
+ allargs .pop ('imshow_rgb' )
450
457
allargs .update (allargs .pop ('kwargs' ))
451
458
452
459
# Need the decorated plotting function
@@ -470,12 +477,19 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
470
477
"Use colors keyword instead." ,
471
478
DeprecationWarning , stacklevel = 3 )
472
479
473
- xlab , ylab = _infer_xy_labels (darray = darray , x = x , y = y )
480
+ rgb = kwargs .pop ('rgb' , None )
481
+ xlab , ylab = _infer_xy_labels (
482
+ darray = darray , x = x , y = y , imshow = imshow_rgb , rgb = rgb )
483
+
484
+ if rgb is not None and plotfunc .__name__ != 'imshow' :
485
+ raise ValueError ('The "rgb" keyword is only valid for imshow()' )
486
+ elif rgb is not None and not imshow_rgb :
487
+ raise ValueError ('The "rgb" keyword is only valid for imshow()'
488
+ 'with a three-dimensional array (per facet)' )
474
489
475
490
# better to pass the ndarrays directly to plotting functions
476
491
xval = darray [xlab ].values
477
492
yval = darray [ylab ].values
478
- zval = darray .to_masked_array (copy = False )
479
493
480
494
# check if we need to broadcast one dimension
481
495
if xval .ndim < yval .ndim :
@@ -486,8 +500,19 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
486
500
487
501
# May need to transpose for correct x, y labels
488
502
# xlab may be the name of a coord, we have to check for dim names
489
- if darray [xlab ].dims [- 1 ] == darray .dims [0 ]:
490
- zval = zval .T
503
+ if imshow_rgb :
504
+ # For RGB[A] images, matplotlib requires the color dimension
505
+ # to be last. In Xarray the order should be unimportant, so
506
+ # we transpose to (y, x, color) to make this work.
507
+ yx_dims = (ylab , xlab )
508
+ dims = yx_dims + tuple (d for d in darray .dims if d not in yx_dims )
509
+ if dims != darray .dims :
510
+ darray = darray .transpose (* dims )
511
+ elif darray [xlab ].dims [- 1 ] == darray .dims [0 ]:
512
+ darray = darray .transpose ()
513
+
514
+ # Pass the data as a masked ndarray too
515
+ zval = darray .to_masked_array (copy = False )
491
516
492
517
_ensure_plottable (xval , yval )
493
518
@@ -595,6 +620,11 @@ def imshow(x, y, z, ax, **kwargs):
595
620
596
621
Wraps :func:`matplotlib:matplotlib.pyplot.imshow`
597
622
623
+ While other plot methods require the DataArray to be strictly
624
+ two-dimensional, ``imshow`` also accepts a 3D array where some
625
+ dimension can be interpreted as RGB or RGBA color channels and
626
+ allows this dimension to be specified via the kwarg ``rgb=``.
627
+
598
628
.. note::
599
629
This function needs uniformly spaced coordinates to
600
630
properly label the axes. Call DataArray.plot() to check.
@@ -632,6 +662,15 @@ def imshow(x, y, z, ax, **kwargs):
632
662
# Allow user to override these defaults
633
663
defaults .update (kwargs )
634
664
665
+ if z .ndim == 3 :
666
+ # matplotlib imshow uses black for missing data, but Xarray makes
667
+ # missing data transparent. We therefore add an alpha channel if
668
+ # there isn't one, and set it to transparent where data is masked.
669
+ if z .shape [- 1 ] == 3 :
670
+ z = np .ma .concatenate ((z , np .ma .ones (z .shape [:2 ] + (1 ,))), 2 )
671
+ z = z .copy ()
672
+ z [np .any (z .mask , axis = - 1 ), - 1 ] = 0
673
+
635
674
primitive = ax .imshow (z , ** defaults )
636
675
637
676
return primitive
0 commit comments