Skip to content

Commit 9bd29d6

Browse files
authored
Merge pull request #127 from predict-idlab/fix_max_samples
✨ fix + test for #124
2 parents 2d10991 + eb5d8f3 commit 9bd29d6

File tree

2 files changed

+43
-9
lines changed

2 files changed

+43
-9
lines changed

plotly_resampler/figure_resampler/figure_resampler_interface.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,23 @@ def _construct_hf_data_dict(
794794
"hovertext": dc.hovertext,
795795
}
796796

797+
@staticmethod
798+
def _add_trace_to_add_traces_kwargs(kwargs: dict) -> dict:
799+
"""Convert the `add_trace` kwargs to the `add_traces` kwargs."""
800+
# The keywords that need to be converted to a list
801+
convert_keywords = ["row", "col", "secondary_y"]
802+
803+
updated_kwargs = {} # The updated kwargs (from `add_trace` to `add_traces`)
804+
for keyword in convert_keywords:
805+
value = kwargs.pop(keyword, None)
806+
if value is not None:
807+
updated_kwargs[f"{keyword}s"] = [value]
808+
else:
809+
updated_kwargs[f"{keyword}s"] = None
810+
811+
return {**kwargs, **updated_kwargs}
812+
813+
797814
def add_trace(
798815
self,
799816
trace: Union[BaseTraceType, dict],
@@ -955,17 +972,22 @@ def add_trace(
955972
# Hence, you first downsample the trace.
956973
trace = self._check_update_trace_data(trace)
957974
assert trace is not None
958-
return super(self._figure_class, self).add_trace(trace, **trace_kwargs)
975+
return super(AbstractFigureAggregator, self).add_traces(
976+
[trace], **self._add_trace_to_add_traces_kwargs(trace_kwargs)
977+
)
959978
else:
960979
self._print(f"[i] NOT resampling {trace['name']} - len={n_samples}")
961980
# TODO: can be made more generic
962981
trace.x = dc.x
963982
trace.y = dc.y
964983
trace.text = dc.text
965984
trace.hovertext = dc.hovertext
966-
return super(self._figure_class, self).add_trace(trace, **trace_kwargs)
967-
968-
return super(self._figure_class, self).add_trace(trace, **trace_kwargs)
985+
return super(AbstractFigureAggregator, self).add_traces(
986+
[trace], **self._add_trace_to_add_traces_kwargs(trace_kwargs)
987+
)
988+
return super(AbstractFigureAggregator, self).add_traces(
989+
[trace], **self._add_trace_to_add_traces_kwargs(trace_kwargs)
990+
)
969991

970992
def add_traces(
971993
self,
@@ -1152,8 +1174,8 @@ def replace(self, figure: go.Figure, convert_existing_traces: bool = True):
11521174
)
11531175

11541176
def construct_update_data(
1155-
self,
1156-
relayout_data: dict
1177+
self,
1178+
relayout_data: dict,
11571179
) -> Union[List[dict], dash.no_update]:
11581180
"""Construct the to-be-updated front-end data, based on the layout change.
11591181

tests/test_figure_resampler.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,18 @@ def test_add_trace_not_resampling(float_series):
102102
)
103103

104104

105+
def test_max_n_samples(float_series):
106+
s = float_series[:5000]
107+
108+
fig = FigureResampler()
109+
fig.add_trace(
110+
go.Scattergl(name="test"), hf_x=s.index, hf_y=s, max_n_samples=len(s) + 1
111+
)
112+
# make sure that there is not hf_data
113+
assert len(fig.hf_data) == 0
114+
assert len(fig.data[0]["x"]) == len(s)
115+
116+
105117
def test_add_scatter_trace_no_data():
106118
fig = FigureResampler(default_n_shown_samples=1000)
107119

@@ -1045,7 +1057,7 @@ def test_fr_update_layout_axes_range(driver):
10451057
y=np.arange(nb_datapoints)
10461058
)
10471059

1048-
def check_data(fr: FigureResampler, min_v=0, max_v=nb_datapoints-1):
1060+
def check_data(fr: FigureResampler, min_v=0, max_v=nb_datapoints - 1):
10491061
# closure for n_shown and nb_datapoints
10501062
assert len(fr.data[0]["y"]) == min(n_shown, nb_datapoints)
10511063
assert len(fr.data[0]["x"]) == min(n_shown, nb_datapoints)
@@ -1121,7 +1133,7 @@ def check_data(fr: FigureResampler, min_v=0, max_v=nb_datapoints-1):
11211133
finally:
11221134
proc.terminate()
11231135
f_pr.stop_server()
1124-
1136+
11251137

11261138
def test_fr_update_layout_axes_range_no_update(driver):
11271139
nb_datapoints = 2_000
@@ -1133,7 +1145,7 @@ def test_fr_update_layout_axes_range_no_update(driver):
11331145
y=np.arange(nb_datapoints)
11341146
)
11351147

1136-
def check_data(fr: FigureResampler, min_v=0, max_v=nb_datapoints-1):
1148+
def check_data(fr: FigureResampler, min_v=0, max_v=nb_datapoints - 1):
11371149
# closure for n_shown and nb_datapoints
11381150
assert len(fr.data[0]["y"]) == min(n_shown, nb_datapoints)
11391151
assert len(fr.data[0]["x"]) == min(n_shown, nb_datapoints)

0 commit comments

Comments
 (0)