Skip to content

Commit 35f751f

Browse files
committed
✨ fix + test for #124
1 parent 251ec82 commit 35f751f

File tree

2 files changed

+40
-9
lines changed

2 files changed

+40
-9
lines changed

plotly_resampler/figure_resampler/figure_resampler_interface.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,20 @@ def _construct_hf_data_dict(
794794
"hovertext": dc.hovertext,
795795
}
796796

797+
@staticmethod
798+
def _add_trace_to_add_traces_kwargs(kwargs: dict) -> dict:
799+
"""Convert the `add_trace` kwargs to the `add_traces` kwargs."""
800+
row = kwargs.pop("row", None)
801+
row = [row] if row is not None else None
802+
803+
cols = kwargs.pop("col", None)
804+
cols = [cols] if cols is not None else None
805+
806+
secondary_ys = kwargs.pop("secondary_y", None)
807+
secondary_ys = [secondary_ys] if secondary_ys is not None else None
808+
809+
return {**kwargs, "rows": row, "cols": cols, "secondary_ys": secondary_ys}
810+
797811
def add_trace(
798812
self,
799813
trace: Union[BaseTraceType, dict],
@@ -955,17 +969,22 @@ def add_trace(
955969
# Hence, you first downsample the trace.
956970
trace = self._check_update_trace_data(trace)
957971
assert trace is not None
958-
return super(self._figure_class, self).add_trace(trace, **trace_kwargs)
972+
return super(AbstractFigureAggregator, self).add_traces(
973+
[trace], **self._add_trace_to_add_traces_kwargs(trace_kwargs)
974+
)
959975
else:
960976
self._print(f"[i] NOT resampling {trace['name']} - len={n_samples}")
961977
# TODO: can be made more generic
962978
trace.x = dc.x
963979
trace.y = dc.y
964980
trace.text = dc.text
965981
trace.hovertext = dc.hovertext
966-
return super(self._figure_class, self).add_trace(trace, **trace_kwargs)
967-
968-
return super(self._figure_class, self).add_trace(trace, **trace_kwargs)
982+
return super(AbstractFigureAggregator, self).add_traces(
983+
[trace], **self._add_trace_to_add_traces_kwargs(trace_kwargs)
984+
)
985+
return super(AbstractFigureAggregator, self).add_traces(
986+
[trace], **self._add_trace_to_add_traces_kwargs(trace_kwargs)
987+
)
969988

970989
def add_traces(
971990
self,
@@ -1152,8 +1171,8 @@ def replace(self, figure: go.Figure, convert_existing_traces: bool = True):
11521171
)
11531172

11541173
def construct_update_data(
1155-
self,
1156-
relayout_data: dict
1174+
self,
1175+
relayout_data: dict,
11571176
) -> Union[List[dict], dash.no_update]:
11581177
"""Construct the to-be-updated front-end data, based on the layout change.
11591178

tests/test_figure_resampler.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,18 @@ def test_add_trace_not_resampling(float_series):
102102
)
103103

104104

105+
def test_max_n_samples(float_series):
106+
s = float_series[:5000]
107+
108+
fig = FigureResampler()
109+
fig.add_trace(
110+
go.Scattergl(name="test"), hf_x=s.index, hf_y=s, max_n_samples=len(s) + 1
111+
)
112+
# make sure that there is not hf_data
113+
assert len(fig.hf_data) == 0
114+
assert len(fig.data[0]["x"]) == len(s)
115+
116+
105117
def test_add_scatter_trace_no_data():
106118
fig = FigureResampler(default_n_shown_samples=1000)
107119

@@ -1045,7 +1057,7 @@ def test_fr_update_layout_axes_range(driver):
10451057
y=np.arange(nb_datapoints)
10461058
)
10471059

1048-
def check_data(fr: FigureResampler, min_v=0, max_v=nb_datapoints-1):
1060+
def check_data(fr: FigureResampler, min_v=0, max_v=nb_datapoints - 1):
10491061
# closure for n_shown and nb_datapoints
10501062
assert len(fr.data[0]["y"]) == min(n_shown, nb_datapoints)
10511063
assert len(fr.data[0]["x"]) == min(n_shown, nb_datapoints)
@@ -1121,7 +1133,7 @@ def check_data(fr: FigureResampler, min_v=0, max_v=nb_datapoints-1):
11211133
finally:
11221134
proc.terminate()
11231135
f_pr.stop_server()
1124-
1136+
11251137

11261138
def test_fr_update_layout_axes_range_no_update(driver):
11271139
nb_datapoints = 2_000
@@ -1133,7 +1145,7 @@ def test_fr_update_layout_axes_range_no_update(driver):
11331145
y=np.arange(nb_datapoints)
11341146
)
11351147

1136-
def check_data(fr: FigureResampler, min_v=0, max_v=nb_datapoints-1):
1148+
def check_data(fr: FigureResampler, min_v=0, max_v=nb_datapoints - 1):
11371149
# closure for n_shown and nb_datapoints
11381150
assert len(fr.data[0]["y"]) == min(n_shown, nb_datapoints)
11391151
assert len(fr.data[0]["x"]) == min(n_shown, nb_datapoints)

0 commit comments

Comments
 (0)