Skip to content

Commit 93ec054

Browse files
Merge pull request #1934 from plotly/px_input_bugifx
fix a PX input bug when using data frame indices
2 parents f5af9cb + 44d03bb commit 93ec054

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

packages/python/plotly/plotly/express/_core.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -954,7 +954,7 @@ def build_dataframe(args, attrables, array_attrables):
954954
)
955955
)
956956
col_name = str(argument)
957-
df_output[col_name] = df_input[argument]
957+
df_output[col_name] = df_input[argument].values
958958
# ----------------- argument is a column / array / list.... -------
959959
else:
960960
is_index = isinstance(argument, pd.RangeIndex)
@@ -989,7 +989,10 @@ def build_dataframe(args, attrables, array_attrables):
989989
"length of previous arguments %s is %d"
990990
% (field, len(argument), str(list(df_output.columns)), length)
991991
)
992-
df_output[str(col_name)] = argument
992+
if hasattr(argument, "values"):
993+
df_output[str(col_name)] = argument.values
994+
else:
995+
df_output[str(col_name)] = np.array(argument)
993996

994997
# Finally, update argument with column name now that column exists
995998
if field_name not in array_attrables:

packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
import numpy as np
33
import pandas as pd
44
import pytest
5-
import plotly.graph_objects as go
6-
import plotly
75
from plotly.express._core import build_dataframe
86
from pandas.util.testing import assert_frame_equal
97

@@ -234,6 +232,24 @@ def test_build_df_with_index():
234232
assert_frame_equal(tips.reset_index()[out["data_frame"].columns], out["data_frame"])
235233

236234

235+
def test_non_matching_index():
236+
df = pd.DataFrame(dict(y=[1, 2, 3]), index=["a", "b", "c"])
237+
238+
expected = pd.DataFrame(dict(x=["a", "b", "c"], y=[1, 2, 3]))
239+
240+
args = dict(data_frame=df, x=df.index, y="y")
241+
out = build_dataframe(args, all_attrables, array_attrables)
242+
assert_frame_equal(expected, out["data_frame"])
243+
244+
args = dict(data_frame=None, x=df.index, y=df.y)
245+
out = build_dataframe(args, all_attrables, array_attrables)
246+
assert_frame_equal(expected, out["data_frame"])
247+
248+
args = dict(data_frame=None, x=["a", "b", "c"], y=df.y)
249+
out = build_dataframe(args, all_attrables, array_attrables)
250+
assert_frame_equal(expected, out["data_frame"])
251+
252+
237253
def test_splom_case():
238254
iris = px.data.iris()
239255
fig = px.scatter_matrix(iris)

0 commit comments

Comments
 (0)