diff --git a/CHANGELOG.md b/CHANGELOG.md index 6ba94433ee5..091912b8726 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ All notable changes to this project will be documented in this file. This project adheres to [Semantic Versioning](http://semver.org/). +### Updated + +- Updated plotly.py to use base64 encoding of arrays in plotly JSON to improve performance. + ## [5.24.1] - 2024-09-12 ### Updated diff --git a/packages/python/plotly/_plotly_utils/utils.py b/packages/python/plotly/_plotly_utils/utils.py index e8a32e0c8ae..88f1a4000db 100644 --- a/packages/python/plotly/_plotly_utils/utils.py +++ b/packages/python/plotly/_plotly_utils/utils.py @@ -1,3 +1,4 @@ +import base64 import decimal import json as _json import sys @@ -5,7 +6,110 @@ from functools import reduce from _plotly_utils.optional_imports import get_module -from _plotly_utils.basevalidators import ImageUriValidator +from _plotly_utils.basevalidators import ( + ImageUriValidator, + copy_to_readonly_numpy_array, + is_homogeneous_array, +) + + +int8min = -128 +int8max = 127 +int16min = -32768 +int16max = 32767 +int32min = -2147483648 +int32max = 2147483647 + +uint8max = 255 +uint16max = 65535 +uint32max = 4294967295 + +plotlyjsShortTypes = { + "int8": "i1", + "uint8": "u1", + "int16": "i2", + "uint16": "u2", + "int32": "i4", + "uint32": "u4", + "float32": "f4", + "float64": "f8", +} + + +def to_typed_array_spec(v): + """ + Convert numpy array to plotly.js typed array spec + If not possible return the original value + """ + v = copy_to_readonly_numpy_array(v) + + np = get_module("numpy", should_load=False) + if not np or not isinstance(v, np.ndarray): + return v + + dtype = str(v.dtype) + + # convert default Big Ints until we could support them in plotly.js + if dtype == "int64": + max = v.max() + min = v.min() + if max <= int8max and min >= int8min: + v = v.astype("int8") + elif max <= int16max and min >= int16min: + v = v.astype("int16") + elif max <= int32max and min >= int32min: + v = v.astype("int32") + else: + return v + + elif dtype == "uint64": + max = v.max() + min = v.min() + if max <= uint8max and min >= 0: + v = v.astype("uint8") + elif max <= uint16max and min >= 0: + v = v.astype("uint16") + elif max <= uint32max and min >= 0: + v = v.astype("uint32") + else: + return v + + dtype = str(v.dtype) + + if dtype in plotlyjsShortTypes: + arrObj = { + "dtype": plotlyjsShortTypes[dtype], + "bdata": base64.b64encode(v).decode("ascii"), + } + + if v.ndim > 1: + arrObj["shape"] = str(v.shape)[1:-1] + + return arrObj + + return v + + +def is_skipped_key(key): + """ + Return whether the key is skipped for conversion to the typed array spec + """ + skipped_keys = ["geojson", "layer", "layers", "range"] + return any(skipped_key == key for skipped_key in skipped_keys) + + +def convert_to_base64(obj): + if isinstance(obj, dict): + for key, value in obj.items(): + if is_skipped_key(key): + continue + elif is_homogeneous_array(value): + obj[key] = to_typed_array_spec(value) + else: + convert_to_base64(value) + elif isinstance(obj, list) or isinstance(obj, tuple): + for value in obj: + convert_to_base64(value) def cumsum(x): diff --git a/packages/python/plotly/plotly/basedatatypes.py b/packages/python/plotly/plotly/basedatatypes.py index 21b4cb1f312..0fe26c91473 100644 --- a/packages/python/plotly/plotly/basedatatypes.py +++ b/packages/python/plotly/plotly/basedatatypes.py @@ -15,6 +15,7 @@ display_string_positions, chomp_empty_strings, find_closest_string, + convert_to_base64, ) from _plotly_utils.exceptions import PlotlyKeyError from .optional_imports import get_module @@ -3310,6 +3311,9 @@ def to_dict(self): if frames: res["frames"] = frames + # Add base64 conversion before sending to the front-end + convert_to_base64(res) + return res def to_plotly_json(self): diff --git a/packages/python/plotly/plotly/tests/test_optional/test_graph_objs/test_skipped_b64_keys.py b/packages/python/plotly/plotly/tests/test_optional/test_graph_objs/test_skipped_b64_keys.py new file mode 100644 index 00000000000..ee85785644b --- /dev/null +++ b/packages/python/plotly/plotly/tests/test_optional/test_graph_objs/test_skipped_b64_keys.py @@ -0,0 +1,84 @@ +import json +from unittest import TestCase +import numpy as np +from plotly.tests.test_optional.optional_utils import NumpyTestUtilsMixin +import plotly.graph_objs as go + + +class TestShouldNotUseBase64InUnsupportedKeys(NumpyTestUtilsMixin, TestCase): + def test_np_geojson(self): + normal_coordinates = [ + [ + [-87, 35], + [-87, 30], + [-85, 30], + [-85, 35], + ] + ] + + numpy_coordinates = np.array(normal_coordinates) + + data = [ + { + "type": "choropleth", + "locations": ["AL"], + "featureidkey": "properties.id", + "z": np.array([10]), + "geojson": { + "type": "Feature", + "properties": {"id": "AL"}, + "geometry": {"type": "Polygon", "coordinates": numpy_coordinates}, + }, + } + ] + + fig = go.Figure(data=data) + + assert ( + json.loads(fig.to_json())["data"][0]["geojson"]["geometry"]["coordinates"] + == normal_coordinates + ) + + def test_np_layers(self): + layout = { + "mapbox": { + "layers": [ + { + "sourcetype": "geojson", + "type": "line", + "line": {"dash": np.array([2.5, 1])}, + "source": { + "type": "FeatureCollection", + "features": [ + { + "type": "Feature", + "geometry": { + "type": "LineString", + "coordinates": np.array( + [[0.25, 52], [0.75, 50]] + ), + }, + } + ], + }, + }, + ], + "center": {"lon": 0.5, "lat": 51}, + }, + } + data = [{"type": "scattermapbox"}] + + fig = go.Figure(data=data, layout=layout) + + assert (fig.layout["mapbox"]["layers"][0]["line"]["dash"] == (2.5, 1)).all() + + assert json.loads(fig.to_json())["layout"]["mapbox"]["layers"][0]["source"][ + "features" + ][0]["geometry"]["coordinates"] == [[0.25, 52], [0.75, 50]] + + def test_np_range(self): + layout = {"xaxis": {"range": np.array([0, 1])}} + + fig = go.Figure(data=[{"type": "scatter"}], layout=layout) + + assert json.loads(fig.to_json())["layout"]["xaxis"]["range"] == [0, 1] diff --git a/packages/python/plotly/plotly/tests/test_optional/test_px/test_px_functions.py b/packages/python/plotly/plotly/tests/test_optional/test_px/test_px_functions.py index ec27441d6c1..e34dd0d20bd 100644 --- a/packages/python/plotly/plotly/tests/test_optional/test_px/test_px_functions.py +++ b/packages/python/plotly/plotly/tests/test_optional/test_px/test_px_functions.py @@ -25,7 +25,7 @@ def _compare_figures(go_trace, px_fig): def test_pie_like_px(): # Pie labels = ["Oxygen", "Hydrogen", "Carbon_Dioxide", "Nitrogen"] - values = [4500, 2500, 1053, 500] + values = np.array([4500, 2500, 1053, 500]) fig = px.pie(names=labels, values=values) trace = go.Pie(labels=labels, values=values) @@ -33,7 +33,7 @@ def test_pie_like_px(): labels = ["Eve", "Cain", "Seth", "Enos", "Noam", "Abel", "Awan", "Enoch", "Azura"] parents = ["", "Eve", "Eve", "Seth", "Seth", "Eve", "Eve", "Awan", "Eve"] - values = [10, 14, 12, 10, 2, 6, 6, 4, 4] + values = np.array([10, 14, 12, 10, 2, 6, 6, 4, 4]) # Sunburst fig = px.sunburst(names=labels, parents=parents, values=values) trace = go.Sunburst(labels=labels, parents=parents, values=values) @@ -45,7 +45,7 @@ def test_pie_like_px(): # Funnel x = ["A", "B", "C"] - y = [3, 2, 1] + y = np.array([3, 2, 1]) fig = px.funnel(y=y, x=x) trace = go.Funnel(y=y, x=x) _compare_figures(trace, fig) diff --git a/packages/python/plotly/plotly/tests/test_optional/test_utils/test_utils.py b/packages/python/plotly/plotly/tests/test_optional/test_utils/test_utils.py index cf32e1bdff8..9fa18966406 100644 --- a/packages/python/plotly/plotly/tests/test_optional/test_utils/test_utils.py +++ b/packages/python/plotly/plotly/tests/test_optional/test_utils/test_utils.py @@ -372,38 +372,6 @@ def test_invalid_encode_exception(self): with self.assertRaises(TypeError): _json.dumps({"a": {1}}, cls=utils.PlotlyJSONEncoder) - def test_fast_track_finite_arrays(self): - # if NaN or Infinity is found in the json dump - # of a figure, it is decoded and re-encoded to replace these values - # with null. This test checks that NaN and Infinity values are - # indeed converted to null, and that the encoding of figures - # without inf or nan is faster (because we can avoid decoding - # and reencoding). - z = np.random.randn(100, 100) - x = np.arange(100.0) - fig_1 = go.Figure(go.Heatmap(z=z, x=x)) - t1 = time() - json_str_1 = _json.dumps(fig_1, cls=utils.PlotlyJSONEncoder) - t2 = time() - x[0] = np.nan - x[1] = np.inf - fig_2 = go.Figure(go.Heatmap(z=z, x=x)) - t3 = time() - json_str_2 = _json.dumps(fig_2, cls=utils.PlotlyJSONEncoder) - t4 = time() - assert t2 - t1 < t4 - t3 - assert "null" in json_str_2 - assert "NaN" not in json_str_2 - assert "Infinity" not in json_str_2 - x = np.arange(100.0) - fig_3 = go.Figure(go.Heatmap(z=z, x=x)) - fig_3.update_layout(title_text="Infinity") - t5 = time() - json_str_3 = _json.dumps(fig_3, cls=utils.PlotlyJSONEncoder) - t6 = time() - assert t2 - t1 < t6 - t5 - assert "Infinity" in json_str_3 - class TestNumpyIntegerBaseType(TestCase): def test_numpy_integer_import(self):