Skip to content

Commit 96967d7

Browse files
authored
Imshow (#1855)
1 parent 643f58e commit 96967d7

File tree

3 files changed

+235
-0
lines changed

3 files changed

+235
-0
lines changed

Diff for: packages/python/plotly/plotly/express/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
density_heatmap,
4242
)
4343

44+
from ._imshow import imshow
45+
4446
from ._core import ( # noqa: F401
4547
set_mapbox_access_token,
4648
defaults,
@@ -75,6 +77,7 @@
7577
"strip",
7678
"histogram",
7779
"choropleth",
80+
"imshow",
7881
"data",
7982
"colors",
8083
"set_mapbox_access_token",

Diff for: packages/python/plotly/plotly/express/_imshow.py

+134
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import plotly.graph_objs as go
2+
import numpy as np # is it fine to depend on np here?
3+
4+
_float_types = []
5+
6+
# Adapted from skimage.util.dtype
7+
_integer_types = (
8+
np.byte,
9+
np.ubyte, # 8 bits
10+
np.short,
11+
np.ushort, # 16 bits
12+
np.intc,
13+
np.uintc, # 16 or 32 or 64 bits
14+
np.int_,
15+
np.uint, # 32 or 64 bits
16+
np.longlong,
17+
np.ulonglong,
18+
) # 64 bits
19+
_integer_ranges = {t: (np.iinfo(t).min, np.iinfo(t).max) for t in _integer_types}
20+
21+
22+
def _vectorize_zvalue(z):
23+
if z is None:
24+
return z
25+
elif np.isscalar(z):
26+
return [z] * 3 + [1]
27+
elif len(z) == 1:
28+
return list(z) * 3 + [1]
29+
elif len(z) == 3:
30+
return list(z) + [1]
31+
elif len(z) == 4:
32+
return z
33+
else:
34+
raise ValueError(
35+
"zmax can be a scalar, or an iterable of length 1, 3 or 4. "
36+
"A value of %s was passed for zmax." % str(z)
37+
)
38+
39+
40+
def _infer_zmax_from_type(img):
41+
dt = img.dtype.type
42+
rtol = 1.05
43+
if dt in _integer_types:
44+
return _integer_ranges[dt][1]
45+
else:
46+
im_max = img[np.isfinite(img)].max()
47+
if im_max <= 1 * rtol:
48+
return 1
49+
elif im_max <= 255 * rtol:
50+
return 255
51+
elif im_max <= 65535 * rtol:
52+
return 65535
53+
else:
54+
return 2 ** 32
55+
56+
57+
def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None):
58+
"""
59+
Display an image, i.e. data on a 2D regular raster.
60+
61+
Parameters
62+
----------
63+
64+
img: array-like image
65+
The image data. Supported array shapes are
66+
67+
- (M, N): an image with scalar data. The data is visualized
68+
using a colormap.
69+
- (M, N, 3): an image with RGB values.
70+
- (M, N, 4): an image with RGBA values, i.e. including transparency.
71+
72+
zmin, zmax : scalar or iterable, optional
73+
zmin and zmax define the scalar range that the colormap covers. By default,
74+
zmin and zmax correspond to the min and max values of the datatype for integer
75+
datatypes (ie [0-255] for uint8 images, [0, 65535] for uint16 images, etc.). For
76+
a multichannel image of floats, the max of the image is computed and zmax is the
77+
smallest power of 256 (1, 255, 65535) greater than this max value,
78+
with a 5% tolerance. For a single-channel image, the max of the image is used.
79+
80+
origin : str, 'upper' or 'lower' (default 'upper')
81+
position of the [0, 0] pixel of the image array, in the upper left or lower left
82+
corner. The convention 'upper' is typically used for matrices and images.
83+
84+
colorscale : str
85+
colormap used to map scalar data to colors (for a 2D image). This parameter is not used for
86+
RGB or RGBA images.
87+
88+
Returns
89+
-------
90+
fig : graph_objects.Figure containing the displayed image
91+
92+
See also
93+
--------
94+
95+
plotly.graph_objects.Image : image trace
96+
plotly.graph_objects.Heatmap : heatmap trace
97+
98+
Notes
99+
-----
100+
101+
In order to update and customize the returned figure, use
102+
`go.Figure.update_traces` or `go.Figure.update_layout`.
103+
"""
104+
img = np.asanyarray(img)
105+
# Cast bools to uint8 (also one byte)
106+
if img.dtype == np.bool:
107+
img = 255 * img.astype(np.uint8)
108+
109+
# For 2d data, use Heatmap trace
110+
if img.ndim == 2:
111+
if colorscale is None:
112+
colorscale = "gray"
113+
trace = go.Heatmap(z=img, zmin=zmin, zmax=zmax, colorscale=colorscale)
114+
autorange = True if origin == "lower" else "reversed"
115+
layout = dict(
116+
xaxis=dict(scaleanchor="y", constrain="domain"),
117+
yaxis=dict(autorange=autorange, constrain="domain"),
118+
)
119+
# For 2D+RGB data, use Image trace
120+
elif img.ndim == 3 and img.shape[-1] in [3, 4]:
121+
if zmax is None and img.dtype is not np.uint8:
122+
zmax = _infer_zmax_from_type(img)
123+
zmin, zmax = _vectorize_zvalue(zmin), _vectorize_zvalue(zmax)
124+
trace = go.Image(z=img, zmin=zmin, zmax=zmax)
125+
layout = {}
126+
if origin == "lower":
127+
layout["yaxis"] = dict(autorange=True)
128+
else:
129+
raise ValueError(
130+
"px.imshow only accepts 2D grayscale, RGB or RGBA images. "
131+
"An image of shape %s was provided" % str(img.shape)
132+
)
133+
fig = go.Figure(data=trace, layout=layout)
134+
return fig
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import plotly.express as px
2+
import numpy as np
3+
import pytest
4+
5+
img_rgb = np.array([[[255, 0, 0], [0, 255, 0], [0, 0, 255]]], dtype=np.uint8)
6+
img_gray = np.arange(100).reshape((10, 10))
7+
8+
9+
def test_rgb_uint8():
10+
fig = px.imshow(img_rgb)
11+
assert fig.data[0]["zmax"] == (255, 255, 255, 1)
12+
13+
14+
def test_vmax():
15+
for zmax in [
16+
100,
17+
[100],
18+
(100,),
19+
[100, 100, 100],
20+
(100, 100, 100),
21+
(100, 100, 100, 1),
22+
]:
23+
fig = px.imshow(img_rgb, zmax=zmax)
24+
assert fig.data[0]["zmax"] == (100, 100, 100, 1)
25+
26+
27+
def test_automatic_zmax_from_dtype():
28+
dtypes_dict = {
29+
np.uint8: 2 ** 8 - 1,
30+
np.uint16: 2 ** 16 - 1,
31+
np.float: 1,
32+
np.bool: 255,
33+
}
34+
for key, val in dtypes_dict.items():
35+
img = np.array([0, 1], dtype=key)
36+
img = np.dstack((img,) * 3)
37+
fig = px.imshow(img)
38+
assert fig.data[0]["zmax"] == (val, val, val, 1)
39+
40+
41+
def test_origin():
42+
for img in [img_rgb, img_gray]:
43+
fig = px.imshow(img, origin="lower")
44+
assert fig.layout.yaxis.autorange == True
45+
fig = px.imshow(img_rgb)
46+
assert fig.layout.yaxis.autorange is None
47+
fig = px.imshow(img_gray)
48+
assert fig.layout.yaxis.autorange == "reversed"
49+
50+
51+
def test_colorscale():
52+
fig = px.imshow(img_gray)
53+
assert fig.data[0].colorscale[0] == (0.0, "rgb(0, 0, 0)")
54+
fig = px.imshow(img_gray, colorscale="Viridis")
55+
assert fig.data[0].colorscale[0] == (0.0, "#440154")
56+
57+
58+
def test_wrong_dimensions():
59+
imgs = [1, np.ones((5,) * 3), np.ones((5,) * 4)]
60+
for img in imgs:
61+
with pytest.raises(ValueError) as err_msg:
62+
fig = px.imshow(img)
63+
64+
65+
def test_nan_inf_data():
66+
imgs = [np.ones((20, 20)), 255 * np.ones((20, 20), dtype=np.uint8)]
67+
zmaxs = [1, 255]
68+
for zmax, img in zip(zmaxs, imgs):
69+
img[0] = 0
70+
img[10:12] = np.nan
71+
# the case of 2d/heatmap is handled gracefully by the JS trace but I don't know how to check it
72+
fig = px.imshow(np.dstack((img,) * 3))
73+
assert fig.data[0]["zmax"] == (zmax, zmax, zmax, 1)
74+
75+
76+
def test_zmax_floats():
77+
# RGB
78+
imgs = [
79+
np.ones((5, 5, 3)),
80+
1.02 * np.ones((5, 5, 3)),
81+
2 * np.ones((5, 5, 3)),
82+
1000 * np.ones((5, 5, 3)),
83+
]
84+
zmaxs = [1, 1, 255, 65535]
85+
for zmax, img in zip(zmaxs, imgs):
86+
fig = px.imshow(img)
87+
assert fig.data[0]["zmax"] == (zmax, zmax, zmax, 1)
88+
# single-channel
89+
imgs = [
90+
np.ones((5, 5)),
91+
1.02 * np.ones((5, 5)),
92+
2 * np.ones((5, 5)),
93+
1000 * np.ones((5, 5)),
94+
]
95+
for zmax, img in zip(zmaxs, imgs):
96+
fig = px.imshow(img)
97+
print(fig.data[0]["zmax"], zmax)
98+
assert fig.data[0]["zmax"] == None

0 commit comments

Comments
 (0)