Skip to content

Commit e778c6b

Browse files
fix bug in trendline in the case of missing values (plotly#2357)
* fix bug in trendline in the case of missing values * paint it black * added statsmodels to dependencies for CI * version for py2 * Update packages/python/plotly/plotly/express/_core.py Co-Authored-By: Nicolas Kruchten <[email protected]> * extended test to lowess, and more precise check of attribute length * Update packages/python/plotly/plotly/tests/test_core/test_px/test_trendline.py Co-authored-by: Nicolas Kruchten <[email protected]>
1 parent ad0dd30 commit e778c6b

File tree

4 files changed

+26
-4
lines changed

4 files changed

+26
-4
lines changed

.circleci/create_conda_optional_env.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ if [ ! -d $HOME/miniconda/envs/circle_optional ]; then
1616
# Create environment
1717
# PYTHON_VERSION=2.7 or 3.5
1818
$HOME/miniconda/bin/conda create -n circle_optional --yes python=$PYTHON_VERSION \
19-
requests nbformat six retrying psutil pandas decorator pytest mock nose poppler xarray scikit-image ipython jupyter ipykernel ipywidgets
19+
requests nbformat six retrying psutil pandas decorator pytest mock nose poppler xarray scikit-image ipython jupyter ipykernel ipywidgets statsmodels
2020

2121
# Install orca into environment
2222
$HOME/miniconda/bin/conda install --yes -n circle_optional -c plotly plotly-orca==1.3.1

packages/python/plotly/plotly/express/_core.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -241,18 +241,25 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
241241
sorted_trace_data = trace_data.sort_values(by=args["x"])
242242
y = sorted_trace_data[args["y"]]
243243
x = sorted_trace_data[args["x"]]
244-
trace_patch["x"] = x
245244

246245
if x.dtype.type == np.datetime64:
247246
x = x.astype(int) / 10 ** 9 # convert to unix epoch seconds
248247

249248
if attr_value == "lowess":
250-
trendline = sm.nonparametric.lowess(y, x)
249+
# missing ='drop' is the default value for lowess but not for OLS (None)
250+
# we force it here in case statsmodels change their defaults
251+
trendline = sm.nonparametric.lowess(y, x, missing="drop")
252+
trace_patch["x"] = trendline[:, 0]
251253
trace_patch["y"] = trendline[:, 1]
252254
hover_header = "<b>LOWESS trendline</b><br><br>"
253255
elif attr_value == "ols":
254-
fit_results = sm.OLS(y.values, sm.add_constant(x.values)).fit()
256+
fit_results = sm.OLS(
257+
y.values, sm.add_constant(x.values), missing="drop"
258+
).fit()
255259
trace_patch["y"] = fit_results.predict()
260+
trace_patch["x"] = x[
261+
np.logical_not(np.logical_or(np.isnan(y), np.isnan(x)))
262+
]
256263
hover_header = "<b>OLS trendline</b><br>"
257264
hover_header += "%s = %g * %s + %g<br>" % (
258265
args["y"],
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import plotly.express as px
2+
import numpy as np
3+
4+
5+
def test_trendline_nan_values():
6+
df = px.data.gapminder().query("continent == 'Oceania'")
7+
start_date = 1970
8+
df["pop"][df["year"] < start_date] = np.nan
9+
modes = ["ols", "lowess"]
10+
for mode in modes:
11+
fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode)
12+
for trendline in fig["data"][1::2]:
13+
assert trendline.x[0] >= start_date
14+
assert len(trendline.x) == len(trendline.y)

packages/python/plotly/tox.ini

+1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ deps=
5959
pytest==3.5.1
6060
pandas==0.24.2
6161
xarray==0.10.9
62+
statsmodels==0.10.2
6263
backports.tempfile==1.0
6364
optional: --editable=file:///{toxinidir}/../plotly-geo
6465
optional: numpy==1.16.5

0 commit comments

Comments
 (0)