Skip to content

Commit 9c9b98e

Browse files
authored
move array_to_imagestr function to be part of public API (#2879)
* move array_to_imagestr function to be part of public API * renamed function
1 parent fa9500b commit 9c9b98e

File tree

4 files changed

+78
-76
lines changed

4 files changed

+78
-76
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from io import BytesIO
2+
import base64
3+
from .png import Writer, from_array
4+
5+
try:
6+
from PIL import Image
7+
8+
pil_imported = True
9+
except ImportError:
10+
pil_imported = False
11+
12+
13+
def image_array_to_data_uri(img, backend="pil", compression=4, ext="png"):
14+
"""Converts a numpy array of uint8 into a base64 png or jpg string.
15+
16+
Parameters
17+
----------
18+
img: ndarray of uint8
19+
array image
20+
backend: str
21+
'auto', 'pil' or 'pypng'. If 'auto', Pillow is used if installed,
22+
otherwise pypng.
23+
compression: int, between 0 and 9
24+
compression level to be passed to the backend
25+
ext: str, 'png' or 'jpg'
26+
compression format used to generate b64 string
27+
"""
28+
# PIL and pypng error messages are quite obscure so we catch invalid compression values
29+
if compression < 0 or compression > 9:
30+
raise ValueError("compression level must be between 0 and 9.")
31+
alpha = False
32+
if img.ndim == 2:
33+
mode = "L"
34+
elif img.ndim == 3 and img.shape[-1] == 3:
35+
mode = "RGB"
36+
elif img.ndim == 3 and img.shape[-1] == 4:
37+
mode = "RGBA"
38+
alpha = True
39+
else:
40+
raise ValueError("Invalid image shape")
41+
if backend == "auto":
42+
backend = "pil" if pil_imported else "pypng"
43+
if ext != "png" and backend != "pil":
44+
raise ValueError("jpg binary strings are only available with PIL backend")
45+
46+
if backend == "pypng":
47+
ndim = img.ndim
48+
sh = img.shape
49+
if ndim == 3:
50+
img = img.reshape((sh[0], sh[1] * sh[2]))
51+
w = Writer(
52+
sh[1], sh[0], greyscale=(ndim == 2), alpha=alpha, compression=compression
53+
)
54+
img_png = from_array(img, mode=mode)
55+
prefix = "data:image/png;base64,"
56+
with BytesIO() as stream:
57+
w.write(stream, img_png.rows)
58+
base64_string = prefix + base64.b64encode(stream.getvalue()).decode("utf-8")
59+
else: # pil
60+
if not pil_imported:
61+
raise ImportError(
62+
"pillow needs to be installed to use `backend='pil'. Please"
63+
"install pillow or use `backend='pypng'."
64+
)
65+
pil_img = Image.fromarray(img)
66+
if ext == "jpg" or ext == "jpeg":
67+
prefix = "data:image/jpeg;base64,"
68+
ext = "jpeg"
69+
else:
70+
prefix = "data:image/png;base64,"
71+
ext = "png"
72+
with BytesIO() as stream:
73+
pil_img.save(stream, format=ext, compress_level=compression)
74+
base64_string = prefix + base64.b64encode(stream.getvalue()).decode("utf-8")
75+
return base64_string

packages/python/plotly/plotly/express/_imshow.py

+2-75
Original file line numberDiff line numberDiff line change
@@ -1,94 +1,21 @@
11
import plotly.graph_objs as go
22
from _plotly_utils.basevalidators import ColorscaleValidator
33
from ._core import apply_default_cascade
4-
from io import BytesIO
5-
import base64
64
from .imshow_utils import rescale_intensity, _integer_ranges, _integer_types
75
import pandas as pd
8-
from .png import Writer, from_array
96
import numpy as np
7+
from plotly.utils import image_array_to_data_uri
108

119
try:
1210
import xarray
1311

1412
xarray_imported = True
1513
except ImportError:
1614
xarray_imported = False
17-
try:
18-
from PIL import Image
19-
20-
pil_imported = True
21-
except ImportError:
22-
pil_imported = False
2315

2416
_float_types = []
2517

2618

27-
def _array_to_b64str(img, backend="pil", compression=4, ext="png"):
28-
"""Converts a numpy array of uint8 into a base64 png string.
29-
30-
Parameters
31-
----------
32-
img: ndarray of uint8
33-
array image
34-
backend: str
35-
'auto', 'pil' or 'pypng'. If 'auto', Pillow is used if installed,
36-
otherwise pypng.
37-
compression: int, between 0 and 9
38-
compression level to be passed to the backend
39-
ext: str, 'png' or 'jpg'
40-
compression format used to generate b64 string
41-
"""
42-
# PIL and pypng error messages are quite obscure so we catch invalid compression values
43-
if compression < 0 or compression > 9:
44-
raise ValueError("compression level must be between 0 and 9.")
45-
alpha = False
46-
if img.ndim == 2:
47-
mode = "L"
48-
elif img.ndim == 3 and img.shape[-1] == 3:
49-
mode = "RGB"
50-
elif img.ndim == 3 and img.shape[-1] == 4:
51-
mode = "RGBA"
52-
alpha = True
53-
else:
54-
raise ValueError("Invalid image shape")
55-
if backend == "auto":
56-
backend = "pil" if pil_imported else "pypng"
57-
if ext != "png" and backend != "pil":
58-
raise ValueError("jpg binary strings are only available with PIL backend")
59-
60-
if backend == "pypng":
61-
ndim = img.ndim
62-
sh = img.shape
63-
if ndim == 3:
64-
img = img.reshape((sh[0], sh[1] * sh[2]))
65-
w = Writer(
66-
sh[1], sh[0], greyscale=(ndim == 2), alpha=alpha, compression=compression
67-
)
68-
img_png = from_array(img, mode=mode)
69-
prefix = "data:image/png;base64,"
70-
with BytesIO() as stream:
71-
w.write(stream, img_png.rows)
72-
base64_string = prefix + base64.b64encode(stream.getvalue()).decode("utf-8")
73-
else: # pil
74-
if not pil_imported:
75-
raise ImportError(
76-
"pillow needs to be installed to use `backend='pil'. Please"
77-
"install pillow or use `backend='pypng'."
78-
)
79-
pil_img = Image.fromarray(img)
80-
if ext == "jpg" or ext == "jpeg":
81-
prefix = "data:image/jpeg;base64,"
82-
ext = "jpeg"
83-
else:
84-
prefix = "data:image/png;base64,"
85-
ext = "png"
86-
with BytesIO() as stream:
87-
pil_img.save(stream, format=ext, compress_level=compression)
88-
base64_string = prefix + base64.b64encode(stream.getvalue()).decode("utf-8")
89-
return base64_string
90-
91-
9219
def _vectorize_zvalue(z, mode="max"):
9320
alpha = 255 if mode == "max" else 0
9421
if z is None:
@@ -422,7 +349,7 @@ def imshow(
422349
for ch in range(img.shape[-1])
423350
]
424351
)
425-
img_str = _array_to_b64str(
352+
img_str = image_array_to_data_uri(
426353
img_rescaled,
427354
backend=binary_backend,
428355
compression=binary_compression_level,

packages/python/plotly/plotly/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from pprint import PrettyPrinter
55

66
from _plotly_utils.utils import *
7-
7+
from _plotly_utils.data_utils import *
88

99
# Pretty printing
1010
def _list_repr_elided(v, threshold=200, edgeitems=3, indent=0, width=80):

0 commit comments

Comments
 (0)