From 44d03bbb5986983e3577ac04059194f4ba77ced4 Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Fri, 29 Nov 2019 22:11:51 -0500 Subject: [PATCH] fix a PX input bug when using data frame indices --- .../python/plotly/plotly/express/_core.py | 7 +++++-- .../tests/test_core/test_px/test_px_input.py | 20 +++++++++++++++++-- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 923e6ea3dfc..459168ea02f 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -954,7 +954,7 @@ def build_dataframe(args, attrables, array_attrables): ) ) col_name = str(argument) - df_output[col_name] = df_input[argument] + df_output[col_name] = df_input[argument].values # ----------------- argument is a column / array / list.... ------- else: is_index = isinstance(argument, pd.RangeIndex) @@ -989,7 +989,10 @@ def build_dataframe(args, attrables, array_attrables): "length of previous arguments %s is %d" % (field, len(argument), str(list(df_output.columns)), length) ) - df_output[str(col_name)] = argument + if hasattr(argument, "values"): + df_output[str(col_name)] = argument.values + else: + df_output[str(col_name)] = np.array(argument) # Finally, update argument with column name now that column exists if field_name not in array_attrables: diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index 89e3a921c9e..8062dd78fc3 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -2,8 +2,6 @@ import numpy as np import pandas as pd import pytest -import plotly.graph_objects as go -import plotly from plotly.express._core import build_dataframe from pandas.util.testing import assert_frame_equal @@ -234,6 +232,24 @@ def test_build_df_with_index(): assert_frame_equal(tips.reset_index()[out["data_frame"].columns], out["data_frame"]) +def test_non_matching_index(): + df = pd.DataFrame(dict(y=[1, 2, 3]), index=["a", "b", "c"]) + + expected = pd.DataFrame(dict(x=["a", "b", "c"], y=[1, 2, 3])) + + args = dict(data_frame=df, x=df.index, y="y") + out = build_dataframe(args, all_attrables, array_attrables) + assert_frame_equal(expected, out["data_frame"]) + + args = dict(data_frame=None, x=df.index, y=df.y) + out = build_dataframe(args, all_attrables, array_attrables) + assert_frame_equal(expected, out["data_frame"]) + + args = dict(data_frame=None, x=["a", "b", "c"], y=df.y) + out = build_dataframe(args, all_attrables, array_attrables) + assert_frame_equal(expected, out["data_frame"]) + + def test_splom_case(): iris = px.data.iris() fig = px.scatter_matrix(iris)