Skip to content

Commit 96a9101

Browse files
authored
Correct bug when color and values correspond to the same color in treemap or sunburst (#2591)
* solve aggregation bug when color and values correspond to the same color * Update packages/python/plotly/plotly/express/_core.py * put back deleted code * added test * code improvement * changelog entry + improve name of new column
1 parent f8f202a commit 96a9101

File tree

4 files changed

+32
-11
lines changed

4 files changed

+32
-11
lines changed

Diff for: CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ This project adheres to [Semantic Versioning](http://semver.org/).
99
- Fixed special cases with `px.sunburst` and `px.treemap` with `path` input ([#2524](https://github.com/plotly/plotly.py/pull/2524))
1010
- Fixed bug in `hover_data` argument of `px` functions, when the column name is changed with labels and `hover_data` is a dictionary setting up a specific format for the hover data ([#2544](https://github.com/plotly/plotly.py/pull/2544)).
1111
- Made the Plotly Express `trendline` argument more robust and made it work with datetime `x` values ([#2554](https://github.com/plotly/plotly.py/pull/2554))
12+
- Fixed bug in `px.sunburst` and `px.treemap`: when the `color` and `values`
13+
arguments correspond to the same column, a different aggregation function has
14+
to be used for the two arguments ([#2591](https://github.com/plotly/plotly.py/pull/2591))
1215
- Plotly Express wide mode now accepts mixed integer and float columns ([#2598](https://github.com/plotly/plotly.py/pull/2598))
1316
- Plotly Express `range_(x|y)` should not impact the unlinked range of marginal subplots ([#2600](https://github.com/plotly/plotly.py/pull/2600))
1417
- `px.line` now sets `line_group=<variable>` in wide mode by default ([#2599](https://github.com/plotly/plotly.py/pull/2599))

Diff for: packages/python/plotly/plotly/express/_core.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -474,11 +474,10 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
474474
# We need to invert the mapping here
475475
k_args = invert_label(args, k)
476476
if k_args in args["hover_data"]:
477-
if args["hover_data"][k_args][0]:
478-
if isinstance(args["hover_data"][k_args][0], str):
479-
mapping_labels_copy[k] = v.replace(
480-
"}", "%s}" % args["hover_data"][k_args][0]
481-
)
477+
formatter = args["hover_data"][k_args][0]
478+
if formatter:
479+
if isinstance(formatter, str):
480+
mapping_labels_copy[k] = v.replace("}", "%s}" % formatter)
482481
else:
483482
_ = mapping_labels_copy.pop(k)
484483
hover_lines = [k + "=" + v for k, v in mapping_labels_copy.items()]
@@ -1507,7 +1506,9 @@ def aggfunc_discrete(x):
15071506

15081507
if args["color"]:
15091508
if args["color"] == args["values"]:
1510-
aggfunc_color = "sum"
1509+
new_value_col_name = args["values"] + "_sum"
1510+
df[new_value_col_name] = df[args["values"]]
1511+
args["values"] = new_value_col_name
15111512
count_colname = args["values"]
15121513
else:
15131514
# we need a count column for the first groupby and the weighted mean of color
@@ -1526,7 +1527,7 @@ def aggfunc_discrete(x):
15261527
if not _is_continuous(df, args["color"]):
15271528
aggfunc_color = aggfunc_discrete
15281529
discrete_color = True
1529-
elif not aggfunc_color:
1530+
else:
15301531

15311532
def aggfunc_continuous(x):
15321533
return np.average(x, weights=df.loc[x.index, count_colname])
@@ -1584,6 +1585,9 @@ def aggfunc_continuous(x):
15841585
if args["color"]:
15851586
if not args["hover_data"]:
15861587
args["hover_data"] = [args["color"]]
1588+
elif isinstance(args["hover_data"], dict):
1589+
if not args["hover_data"].get(args["color"]):
1590+
args["hover_data"][args["color"]] = (True, None)
15871591
else:
15881592
args["hover_data"].append(args["color"])
15891593
return args

Diff for: packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,6 @@ def test_sunburst_treemap_with_path():
149149
fig = px.sunburst(df, path=path, values="values")
150150
assert fig.data[0].branchvalues == "total"
151151
assert fig.data[0].values[-1] == np.sum(values)
152-
# Continuous colorscale
153-
fig = px.sunburst(df, path=path, values="values", color="values")
154-
assert "coloraxis" in fig.data[0].marker
155-
assert np.all(np.array(fig.data[0].marker.colors) == np.array(fig.data[0].values))
156152
# Error when values cannot be converted to numerical data type
157153
df["values"] = ["1 000", "3 000", "2", "4", "2", "2", "1 000", "4 000"]
158154
msg = "Column `values` of `df` could not be converted to a numerical data type."
@@ -162,6 +158,12 @@ def test_sunburst_treemap_with_path():
162158
path = [df.total, "regions", df.sectors, "vendors"]
163159
fig = px.sunburst(df, path=path)
164160
assert fig.data[0].branchvalues == "total"
161+
# Continuous colorscale
162+
df["values"] = 1
163+
fig = px.sunburst(df, path=path, values="values", color="values")
164+
assert "coloraxis" in fig.data[0].marker
165+
assert np.all(np.array(fig.data[0].marker.colors) == 1)
166+
assert fig.data[0].values[-1] == 8
165167

166168

167169
def test_sunburst_treemap_with_path_and_hover():

Diff for: packages/python/plotly/plotly/tests/test_core/test_px/test_px_hover.py

+12
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,15 @@ def test_fail_wrong_column():
151151
"Ambiguous input: values for 'c' appear both in hover_data and data_frame"
152152
in str(err_msg.value)
153153
)
154+
155+
156+
def test_sunburst_hoverdict_color():
157+
df = px.data.gapminder().query("year == 2007")
158+
fig = px.sunburst(
159+
df,
160+
path=["continent", "country"],
161+
values="pop",
162+
color="lifeExp",
163+
hover_data={"pop": ":,"},
164+
)
165+
assert "color" in fig.data[0].hovertemplate

0 commit comments

Comments
 (0)