diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index e94a79d3954..cc938318780 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1012,6 +1012,7 @@ def build_dataframe(args, attrables, array_attrables): def _check_dataframe_all_leaves(df): df_sorted = df.sort_values(by=list(df.columns)) null_mask = df_sorted.isnull() + df_sorted = df_sorted.astype(str) null_indices = np.nonzero(null_mask.any(axis=1).values)[0] for null_row_index in null_indices: row = null_mask.iloc[null_row_index] @@ -1043,8 +1044,9 @@ def process_dataframe_hierarchy(args): if args["color"] and args["color"] in path: series_to_copy = df[args["color"]] - args["color"] = str(args["color"]) + "additional_col_for_px" - df[args["color"]] = series_to_copy + new_col_name = args["color"] + "additional_col_for_color" + path = [new_col_name if x == args["color"] else x for x in path] + df[new_col_name] = series_to_copy if args["hover_data"]: for col_name in args["hover_data"]: if col_name == args["color"]: @@ -1147,6 +1149,11 @@ def aggfunc_continuous(x): args["ids"] = "id" args["names"] = "labels" args["parents"] = "parent" + if args["color"]: + if not args["hover_data"]: + args["hover_data"] = [args["color"]] + else: + args["hover_data"].append(args["color"]) return args diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py index 0393931af11..0fc38c94d4d 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py @@ -209,14 +209,22 @@ def test_sunburst_treemap_with_path_color(): # Hover info df["hover"] = [el.lower() for el in vendors] fig = px.sunburst(df, path=path, color="calls", hover_data=["hover"]) - custom = fig.data[0].customdata.ravel() - assert np.all(custom[:8] == df["hover"]) - assert np.all(custom[8:] == "(?)") + custom = fig.data[0].customdata + assert np.all(custom[:8, 0] == df["hover"]) + assert np.all(custom[8:, 0] == "(?)") + assert np.all(custom[:8, 1] == df["calls"]) # Discrete color fig = px.sunburst(df, path=path, color="vendors") assert len(np.unique(fig.data[0].marker.colors)) == 9 + # Numerical column in path + df["regions"] = df["regions"].map({"North": 1, "South": 2}) + path = ["total", "regions", "sectors", "vendors"] + fig = px.sunburst(df, path=path, values="values", color="calls") + colors = fig.data[0].marker.colors + assert np.all(np.array(colors[:8]) == np.array(calls)) + def test_sunburst_treemap_with_path_non_rectangular(): vendors = ["A", "B", "C", "D", None, "E", "F", "G", "H", None]