Skip to content

Commit 3d6b49a

Browse files
authored
Merge pull request #131 from predict-idlab/orjson_werkzeug
💪🏼 making orjson serialization more robust, see #118
2 parents 38ea31c + c262153 commit 3d6b49a

File tree

2 files changed

+57
-6
lines changed

2 files changed

+57
-6
lines changed

plotly_resampler/figure_resampler/figure_resampler_interface.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,10 @@ def _check_update_trace_data(
291291
s_res: pd.Series = downsampler.aggregate(
292292
hf_series, hf_trace_data["max_n_samples"]
293293
)
294-
trace["x"] = s_res.index
295-
trace["y"] = s_res.values
294+
# Also parse the data types to an orjson compatible format
295+
# Note this can be removed once orjson supports f16
296+
trace["x"] = self._parse_dtype_orjson(s_res.index)
297+
trace["y"] = self._parse_dtype_orjson(s_res.values)
296298
# todo -> first draft & not MP safe
297299

298300
agg_prefix, agg_suffix = ' <i style="color:#fc9944">~', "</i>"
@@ -700,10 +702,6 @@ def _parse_get_trace_props(
700702
except ValueError:
701703
hf_y = hf_y.astype("str")
702704

703-
# orjson encoding doesn't like to encode with uint8 & uint16 dtype
704-
if str(hf_y.dtype) in ["uint8", "uint16"]:
705-
hf_y = hf_y.astype("uint32")
706-
707705
assert len(hf_x) == len(hf_y), "x and y have different length!"
708706
else:
709707
self._print(f"trace {trace['type']} is not a high-frequency trace")
@@ -1283,6 +1281,17 @@ def construct_update_data(
12831281
layout_traces_list.append(trace_reduced)
12841282
return layout_traces_list
12851283

1284+
@staticmethod
1285+
def _parse_dtype_orjson(series: np.ndarray) -> np.ndarray:
1286+
"""Verify the orjson compatibility of the series and convert it if needed."""
1287+
# NOTE:
1288+
# * float16 and float128 aren't supported with latest orjson versions (3.8.1)
1289+
# * this method assumes that the it will not get a float128 series
1290+
# -> this method can be removed if orjson supports float16
1291+
if series.dtype in [np.float16]:
1292+
return series.astype(np.float32)
1293+
return series
1294+
12861295
@staticmethod
12871296
def _re_matches(regex: re.Pattern, strings: Iterable[str]) -> List[str]:
12881297
"""Returns all the items in ``strings`` which regex.match(es) ``regex``."""

tests/test_figure_resampler.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,48 @@ def test_add_trace_not_resampling(float_series):
101101
hf_hovertext="hovertext",
102102
)
103103

104+
def test_various_dtypes(float_series):
105+
# List of dtypes supported by orjson >= 3.8
106+
valid_dtype_list = [
107+
np.bool_,
108+
# ---- uints
109+
np.uint8,
110+
np.uint16,
111+
np.uint32,
112+
np.uint64,
113+
# -------- ints
114+
np.int8,
115+
np.int16,
116+
np.int32,
117+
np.int64,
118+
# -------- floats
119+
np.float16, # currently not supported by orjson
120+
np.float32,
121+
np.float64,
122+
]
123+
for dtype in valid_dtype_list:
124+
fig = FigureResampler(go.Figure(), default_n_shown_samples=1000)
125+
# nb. datapoints > default_n_shown_samples
126+
fig.add_trace(
127+
go.Scatter(name="float_series"),
128+
hf_x=float_series.index,
129+
hf_y=float_series.astype(dtype),
130+
)
131+
fig.full_figure_for_development()
132+
133+
# List of dtypes not supported by orjson >= 3.8
134+
invalid_dtype_list = [ np.float16 ]
135+
for invalid_dtype in invalid_dtype_list:
136+
fig = FigureResampler(go.Figure(), default_n_shown_samples=1000)
137+
# nb. datapoints < default_n_shown_samples
138+
with pytest.raises(TypeError):
139+
# if this test fails -> orjson supports f16 => remove casting frome code
140+
fig.add_trace(
141+
go.Scatter(name="float_series"),
142+
hf_x=float_series.index[:500],
143+
hf_y=float_series.astype(invalid_dtype)[:500],
144+
)
145+
fig.full_figure_for_development()
104146

105147
def test_max_n_samples(float_series):
106148
s = float_series[:5000]

0 commit comments

Comments
 (0)