Skip to content

Sunburst/treemap path #2006

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 34 commits into from
Jan 22, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
4ac1efb
proof of concept
emmanuelle Dec 13, 2019
c619e28
first version
emmanuelle Dec 16, 2019
10668b6
tests
emmanuelle Dec 16, 2019
edfcced
black
emmanuelle Dec 16, 2019
1f3b8da
added test with missing values
emmanuelle Dec 18, 2019
8cb9d99
examples for sunburst tutorial
emmanuelle Dec 18, 2019
cd500a5
added type check and corresponding test
emmanuelle Dec 18, 2019
c233220
corrected bug
emmanuelle Dec 18, 2019
edefabf
treemap branchvalues
emmanuelle Dec 18, 2019
41c8d30
Merge branch 'master' into sunburst-path
emmanuelle Jan 17, 2020
2952fe6
path is now from root to leaves
emmanuelle Jan 17, 2020
c6b7243
removed EPS hack
emmanuelle Jan 18, 2020
be3b622
working version for continuous color
emmanuelle Jan 20, 2020
7f2920b
new tests and more readable code, also added hover support
emmanuelle Jan 20, 2020
8519302
updated docs
emmanuelle Jan 20, 2020
437bbd7
removed named agg which is valid only starting from pandas 0.25
emmanuelle Jan 20, 2020
fb9d992
version hopefully compatible with older pandas
emmanuelle Jan 20, 2020
a57b027
still debugging
emmanuelle Jan 21, 2020
bf8da4b
do not use lambdas
emmanuelle Jan 21, 2020
9e23890
removed redundant else
emmanuelle Jan 21, 2020
f67602f
discrete color
emmanuelle Jan 22, 2020
6b6a105
always add a count column when no values column is passed
emmanuelle Jan 22, 2020
9996731
removed if which is not required any more
emmanuelle Jan 22, 2020
f3e7e27
nicer labels with /
emmanuelle Jan 22, 2020
8cd227a
simplified code
emmanuelle Jan 22, 2020
8b66c90
better id labels
emmanuelle Jan 22, 2020
19b81ac
discrete colors
emmanuelle Jan 22, 2020
ba6ec19
raise ValueError for non-leaves with None
emmanuelle Jan 22, 2020
c0cbce0
other check
emmanuelle Jan 22, 2020
57503b4
discrete color other comes first
emmanuelle Jan 22, 2020
0ab2afd
fixed tests
emmanuelle Jan 22, 2020
0d86998
hover
emmanuelle Jan 22, 2020
d63d4bd
fixed pandas API pb
emmanuelle Jan 22, 2020
9b217f8
pandas stuff
emmanuelle Jan 22, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions packages/python/plotly/plotly/express/_chart_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,6 +1252,7 @@ def sunburst(
names=None,
values=None,
parents=None,
path=None,
ids=None,
color=None,
color_continuous_scale=None,
Expand All @@ -1278,6 +1279,13 @@ def sunburst(
layout_patch = {"sunburstcolorway": color_discrete_sequence}
else:
layout_patch = {}
if path is not None and (ids is not None or parents is not None):
raise ValueError(
"Either `path` should be provided, or `ids` and `parents`."
"These parameters are mutually exclusive and cannot be passed together."
)
if path is not None and branchvalues is None:
branchvalues = "total"
return make_figure(
args=locals(),
constructor=go.Sunburst,
Expand All @@ -1295,6 +1303,7 @@ def treemap(
values=None,
parents=None,
ids=None,
path=None,
color=None,
color_continuous_scale=None,
range_color=None,
Expand All @@ -1320,6 +1329,12 @@ def treemap(
layout_patch = {"treemapcolorway": color_discrete_sequence}
else:
layout_patch = {}
if path is not None and (ids is not None or parents is not None):
raise ValueError(
"Either `path` should be provided, or `ids` and `parents`."
"These parameters are mutually exclusive and cannot be passed together."
)

return make_figure(
args=locals(),
constructor=go.Treemap,
Expand Down
67 changes: 66 additions & 1 deletion packages/python/plotly/plotly/express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,9 @@ def build_dataframe(args, attrables, array_attrables):
else:
df_output[df_input.columns] = df_input[df_input.columns]

# This should be improved + tested - HACK
if "path" in args and args["path"] is not None:
df_output[args["path"]] = df_input[args["path"]]
# Loop over possible arguments
for field_name in attrables:
# Massaging variables
Expand Down Expand Up @@ -1007,6 +1010,66 @@ def build_dataframe(args, attrables, array_attrables):
return args


def process_dataframe_hierarchy(args):
"""
Build dataframe for sunburst or treemap when the path argument is provided.
"""
df = args["data_frame"]
path = args["path"]
# Other columns (for color, hover_data, custom_data etc.)
cols = list(set(df.columns).difference(path))
df_all_trees = pd.DataFrame(columns=["labels", "parent", "id"] + cols)
# Set column type here (useful for continuous vs discrete colorscale)
for col in cols:
df_all_trees[col] = df_all_trees[col].astype(df[col].dtype)
for i, level in enumerate(path):
df_tree = pd.DataFrame(columns=df_all_trees.columns)
dfg = df.groupby(path[i:]).sum(numerical_only=True)
dfg = dfg.reset_index()
df_tree["labels"] = dfg[level].copy().astype(str)
df_tree["parent"] = ""
df_tree["id"] = dfg[level].copy().astype(str)
if i < len(path) - 1:
j = i + 1
while j < len(path):
df_tree["parent"] += dfg[path[j]].copy().astype(str)
df_tree["id"] += dfg[path[j]].copy().astype(str)
j += 1
else:
df_tree["parent"] = "total"

if i == 0 and cols:
df_tree[cols] = dfg[cols]
elif cols:
for col in cols:
df_tree[col] = "n/a"
if args["values"]:
df_tree[args["values"]] = dfg[args["values"]]
df_all_trees = df_all_trees.append(df_tree, ignore_index=True)

# Root node
total_dict = {
"labels": "total",
"id": "total",
"parent": "",
}
for col in cols:
if not col == args["values"]:
total_dict[col] = "n/a"
if col == args["values"]:
total_dict[col] = df[col].sum()
total = pd.Series(total_dict)

df_all_trees = df_all_trees.append(total, ignore_index=True)
# Now modify arguments
args["data_frame"] = df_all_trees
args["path"] = None
args["ids"] = "id"
args["names"] = "labels"
args["parents"] = "parent"
return args


def infer_config(args, constructor, trace_patch):
# Declare all supported attributes, across all plot types
attrables = (
Expand All @@ -1017,7 +1080,7 @@ def infer_config(args, constructor, trace_patch):
+ ["error_y", "error_y_minus", "error_z", "error_z_minus"]
+ ["lat", "lon", "locations", "animation_group"]
)
array_attrables = ["dimensions", "custom_data", "hover_data"]
array_attrables = ["dimensions", "custom_data", "hover_data", "path"]
group_attrables = ["animation_frame", "facet_row", "facet_col", "line_group"]
all_attrables = attrables + group_attrables + ["color"]
group_attrs = ["symbol", "line_dash"]
Expand All @@ -1026,6 +1089,8 @@ def infer_config(args, constructor, trace_patch):
all_attrables += [group_attr]

args = build_dataframe(args, all_attrables, array_attrables)
if constructor in [go.Treemap, go.Sunburst] and args["path"] is not None:
args = process_dataframe_hierarchy(args)

attrs = [k for k in attrables if k in args]
grouped_attrs = []
Expand Down
1 change: 1 addition & 0 deletions packages/python/plotly/plotly/express/_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
colref_desc,
"Values from this column or array_like are used to set ids of sectors",
],
path=[colref_type, colref_desc],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs work

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hehe yes :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops yes. The PR is not completely ready but I'll still appreciate your thoughts about the API and how it's implemented, then I will polish if this is the right direction, unless there is some refactoring to do first.

lat=[
colref_type,
colref_desc,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import plotly.graph_objects as go
from numpy.testing import assert_array_equal
import numpy as np
import pandas as pd


def _compare_figures(go_trace, px_fig):
Expand Down Expand Up @@ -111,6 +112,42 @@ def test_sunburst_treemap_colorscales():
assert list(fig.layout[colorway]) == color_seq


def test_sunburst_treemap_with_path():
vendors = ["A", "B", "C", "D", "E", "F", "G", "H"]
sectors = [
"Tech",
"Tech",
"Finance",
"Finance",
"Tech",
"Tech",
"Finance",
"Finance",
]
regions = ["North", "North", "North", "North", "South", "South", "South", "South"]
values = [1, 3, 2, 4, 2, 2, 1, 4]
df = pd.DataFrame(
dict(vendors=vendors, sectors=sectors, regions=regions, values=values)
)
# No values
fig = px.sunburst(df, path=["vendors", "sectors", "regions"])
assert fig.data[0].branchvalues == "total"
# Values passed
fig = px.sunburst(df, path=["vendors", "sectors", "regions"], values="values")
assert fig.data[0].branchvalues == "total"
assert fig.data[0].values[-1] == np.sum(values)
# Values passed
fig = px.sunburst(df, path=["vendors", "sectors", "regions"], values="values")
assert fig.data[0].branchvalues == "total"
assert fig.data[0].values[-1] == np.sum(values)
# Continuous colorscale
fig = px.sunburst(
df, path=["vendors", "sectors", "regions"], values="values", color="values"
)
assert "coloraxis" in fig.data[0].marker
assert np.all(np.array(fig.data[0].marker.colors) == np.array(fig.data[0].values))


def test_pie_funnelarea_colorscale():
labels = ["A", "B", "C", "D"]
values = [3, 2, 1, 4]
Expand Down