Skip to content

Sunburst/treemap path #2006

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 34 commits into from
Jan 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
4ac1efb
proof of concept
emmanuelle Dec 13, 2019
c619e28
first version
emmanuelle Dec 16, 2019
10668b6
tests
emmanuelle Dec 16, 2019
edfcced
black
emmanuelle Dec 16, 2019
1f3b8da
added test with missing values
emmanuelle Dec 18, 2019
8cb9d99
examples for sunburst tutorial
emmanuelle Dec 18, 2019
cd500a5
added type check and corresponding test
emmanuelle Dec 18, 2019
c233220
corrected bug
emmanuelle Dec 18, 2019
edefabf
treemap branchvalues
emmanuelle Dec 18, 2019
41c8d30
Merge branch 'master' into sunburst-path
emmanuelle Jan 17, 2020
2952fe6
path is now from root to leaves
emmanuelle Jan 17, 2020
c6b7243
removed EPS hack
emmanuelle Jan 18, 2020
be3b622
working version for continuous color
emmanuelle Jan 20, 2020
7f2920b
new tests and more readable code, also added hover support
emmanuelle Jan 20, 2020
8519302
updated docs
emmanuelle Jan 20, 2020
437bbd7
removed named agg which is valid only starting from pandas 0.25
emmanuelle Jan 20, 2020
fb9d992
version hopefully compatible with older pandas
emmanuelle Jan 20, 2020
a57b027
still debugging
emmanuelle Jan 21, 2020
bf8da4b
do not use lambdas
emmanuelle Jan 21, 2020
9e23890
removed redundant else
emmanuelle Jan 21, 2020
f67602f
discrete color
emmanuelle Jan 22, 2020
6b6a105
always add a count column when no values column is passed
emmanuelle Jan 22, 2020
9996731
removed if which is not required any more
emmanuelle Jan 22, 2020
f3e7e27
nicer labels with /
emmanuelle Jan 22, 2020
8cd227a
simplified code
emmanuelle Jan 22, 2020
8b66c90
better id labels
emmanuelle Jan 22, 2020
19b81ac
discrete colors
emmanuelle Jan 22, 2020
ba6ec19
raise ValueError for non-leaves with None
emmanuelle Jan 22, 2020
c0cbce0
other check
emmanuelle Jan 22, 2020
57503b4
discrete color other comes first
emmanuelle Jan 22, 2020
0ab2afd
fixed tests
emmanuelle Jan 22, 2020
0d86998
hover
emmanuelle Jan 22, 2020
d63d4bd
fixed pandas API pb
emmanuelle Jan 22, 2020
9b217f8
pandas stuff
emmanuelle Jan 22, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions doc/python/sunburst-charts.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,53 @@ fig =px.sunburst(
fig.show()
```

### Sunburst of a rectangular DataFrame with plotly.express

Hierarchical data are often stored as a rectangular dataframe, with different columns corresponding to different levels of the hierarchy. `px.sunburst` can take a `path` parameter corresponding to a list of columns. Note that `id` and `parent` should not be provided if `path` is given.

```python
import plotly.express as px
df = px.data.tips()
fig = px.sunburst(df, path=['day', 'time', 'sex'], values='total_bill')
fig.show()
```

### Sunburst of a rectangular DataFrame with continuous color argument in px.sunburst

If a `color` argument is passed, the color of a node is computed as the average of the color values of its children, weighted by their values.

```python
import plotly.express as px
import numpy as np
df = px.data.gapminder().query("year == 2007")
fig = px.sunburst(df, path=['continent', 'country'], values='pop',
color='lifeExp', hover_data=['iso_alpha'],
color_continuous_scale='RdBu',
color_continuous_midpoint=np.average(df['lifeExp'], weights=df['pop']))
fig.show()
```

### Rectangular data with missing values

If the dataset is not fully rectangular, missing values should be supplied as `None`. Note that the parents of `None` entries must be a leaf, i.e. it cannot have other children than `None` (otherwise a `ValueError` is raised).

```python
import plotly.express as px
import pandas as pd
vendors = ["A", "B", "C", "D", None, "E", "F", "G", "H", None]
sectors = ["Tech", "Tech", "Finance", "Finance", "Other",
"Tech", "Tech", "Finance", "Finance", "Other"]
regions = ["North", "North", "North", "North", "North",
"South", "South", "South", "South", "South"]
sales = [1, 3, 2, 4, 1, 2, 2, 1, 4, 1]
df = pd.DataFrame(
dict(vendors=vendors, sectors=sectors, regions=regions, sales=sales)
)
print(df)
fig = px.sunburst(df, path=['regions', 'sectors', 'vendors'], values='sales')
fig.show()
```

### Basic Sunburst Plot with go.Sunburst

If Plotly Express does not provide a good starting point, it is also possible to use the more generic `go.Sunburst` function from `plotly.graph_objects`.
Expand Down
46 changes: 46 additions & 0 deletions doc/python/treemaps.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,52 @@ fig = px.treemap(
fig.show()
```

### Treemap of a rectangular DataFrame with plotly.express

Hierarchical data are often stored as a rectangular dataframe, with different columns corresponding to different levels of the hierarchy. `px.treemap` can take a `path` parameter corresponding to a list of columns. Note that `id` and `parent` should not be provided if `path` is given.

```python
import plotly.express as px
df = px.data.tips()
fig = px.treemap(df, path=['day', 'time', 'sex'], values='total_bill')
fig.show()
```

### Treemap of a rectangular DataFrame with continuous color argument in px.treemap

If a `color` argument is passed, the color of a node is computed as the average of the color values of its children, weighted by their values.

```python
import plotly.express as px
import numpy as np
df = px.data.gapminder().query("year == 2007")
fig = px.treemap(df, path=['continent', 'country'], values='pop',
color='lifeExp', hover_data=['iso_alpha'],
color_continuous_scale='RdBu',
color_continuous_midpoint=np.average(df['lifeExp'], weights=df['pop']))
fig.show()
```

### Rectangular data with missing values

If the dataset is not fully rectangular, missing values should be supplied as `None`.

```python
import plotly.express as px
import pandas as pd
vendors = ["A", "B", "C", "D", None, "E", "F", "G", "H", None]
sectors = ["Tech", "Tech", "Finance", "Finance", "Other",
"Tech", "Tech", "Finance", "Finance", "Other"]
regions = ["North", "North", "North", "North", "North",
"South", "South", "South", "South", "South"]
sales = [1, 3, 2, 4, 1, 2, 2, 1, 4, 1]
df = pd.DataFrame(
dict(vendors=vendors, sectors=sectors, regions=regions, sales=sales)
)
print(df)
fig = px.treemap(df, path=['regions', 'sectors', 'vendors'], values='sales')
fig.show()
```
### Basic Treemap with go.Treemap

If Plotly Express does not provide a good starting point, it is also possible to use the more generic `go.Treemap` function from `plotly.graph_objects`.
Expand Down
16 changes: 16 additions & 0 deletions packages/python/plotly/plotly/express/_chart_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1269,6 +1269,7 @@ def sunburst(
names=None,
values=None,
parents=None,
path=None,
ids=None,
color=None,
color_continuous_scale=None,
Expand All @@ -1295,6 +1296,13 @@ def sunburst(
layout_patch = {"sunburstcolorway": color_discrete_sequence}
else:
layout_patch = {}
if path is not None and (ids is not None or parents is not None):
raise ValueError(
"Either `path` should be provided, or `ids` and `parents`."
"These parameters are mutually exclusive and cannot be passed together."
)
if path is not None and branchvalues is None:
branchvalues = "total"
return make_figure(
args=locals(),
constructor=go.Sunburst,
Expand All @@ -1312,6 +1320,7 @@ def treemap(
values=None,
parents=None,
ids=None,
path=None,
color=None,
color_continuous_scale=None,
range_color=None,
Expand All @@ -1337,6 +1346,13 @@ def treemap(
layout_patch = {"treemapcolorway": color_discrete_sequence}
else:
layout_patch = {}
if path is not None and (ids is not None or parents is not None):
raise ValueError(
"Either `path` should be provided, or `ids` and `parents`."
"These parameters are mutually exclusive and cannot be passed together."
)
if path is not None and branchvalues is None:
branchvalues = "total"
return make_figure(
args=locals(),
constructor=go.Treemap,
Expand Down
147 changes: 145 additions & 2 deletions packages/python/plotly/plotly/express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,147 @@ def build_dataframe(args, attrables, array_attrables):
return args


def _check_dataframe_all_leaves(df):
df_sorted = df.sort_values(by=list(df.columns))
null_mask = df_sorted.isnull()
null_indices = np.nonzero(null_mask.any(axis=1).values)[0]
for null_row_index in null_indices:
row = null_mask.iloc[null_row_index]
indices = np.nonzero(row.values)[0]
if not row[indices[0] :].all():
raise ValueError(
"None entries cannot have not-None children",
df_sorted.iloc[null_row_index],
)
df_sorted[null_mask] = ""
row_strings = list(df_sorted.apply(lambda x: "".join(x), axis=1))
for i, row in enumerate(row_strings[:-1]):
if row_strings[i + 1] in row and (i + 1) in null_indices:
raise ValueError(
"Non-leaves rows are not permitted in the dataframe \n",
df_sorted.iloc[i + 1],
"is not a leaf.",
)


def process_dataframe_hierarchy(args):
"""
Build dataframe for sunburst or treemap when the path argument is provided.
"""
df = args["data_frame"]
path = args["path"][::-1]
_check_dataframe_all_leaves(df[path[::-1]])
discrete_color = False

if args["color"] and args["color"] in path:
series_to_copy = df[args["color"]]
args["color"] = str(args["color"]) + "additional_col_for_px"
df[args["color"]] = series_to_copy
if args["hover_data"]:
for col_name in args["hover_data"]:
if col_name == args["color"]:
series_to_copy = df[col_name]
new_col_name = str(args["color"]) + "additional_col_for_hover"
df[new_col_name] = series_to_copy
args["color"] = new_col_name
elif col_name in path:
series_to_copy = df[col_name]
new_col_name = col_name + "additional_col_for_hover"
path = [new_col_name if x == col_name else x for x in path]
df[new_col_name] = series_to_copy
# ------------ Define aggregation functions --------------------------------
def aggfunc_discrete(x):
uniques = x.unique()
if len(uniques) == 1:
return uniques[0]
else:
return "(?)"

agg_f = {}
aggfunc_color = None
if args["values"]:
try:
df[args["values"]] = pd.to_numeric(df[args["values"]])
except ValueError:
raise ValueError(
"Column `%s` of `df` could not be converted to a numerical data type."
% args["values"]
)

if args["color"]:
if args["color"] == args["values"]:
aggfunc_color = "sum"
count_colname = args["values"]
else:
# we need a count column for the first groupby and the weighted mean of color
# trick to be sure the col name is unused: take the sum of existing names
count_colname = (
"count"
if "count" not in df.columns
else "".join([str(el) for el in list(df.columns)])
)
# we can modify df because it's a copy of the px argument
df[count_colname] = 1
args["values"] = count_colname
agg_f[count_colname] = "sum"

if args["color"]:
if df[args["color"]].dtype.kind not in "bifc":
aggfunc_color = aggfunc_discrete
discrete_color = True
elif not aggfunc_color:

def aggfunc_continuous(x):
return np.average(x, weights=df.loc[x.index, count_colname])

aggfunc_color = aggfunc_continuous
agg_f[args["color"]] = aggfunc_color

# Other columns (for color, hover_data, custom_data etc.)
cols = list(set(df.columns).difference(path))
for col in cols: # for hover_data, custom_data etc.
if col not in agg_f:
agg_f[col] = aggfunc_discrete
# ----------------------------------------------------------------------------

df_all_trees = pd.DataFrame(columns=["labels", "parent", "id"] + cols)
# Set column type here (useful for continuous vs discrete colorscale)
for col in cols:
df_all_trees[col] = df_all_trees[col].astype(df[col].dtype)
for i, level in enumerate(path):
df_tree = pd.DataFrame(columns=df_all_trees.columns)
dfg = df.groupby(path[i:]).agg(agg_f)
dfg = dfg.reset_index()
# Path label massaging
df_tree["labels"] = dfg[level].copy().astype(str)
df_tree["parent"] = ""
df_tree["id"] = dfg[level].copy().astype(str)
if i < len(path) - 1:
j = i + 1
while j < len(path):
df_tree["parent"] = (
dfg[path[j]].copy().astype(str) + "/" + df_tree["parent"]
)
df_tree["id"] = dfg[path[j]].copy().astype(str) + "/" + df_tree["id"]
j += 1

df_tree["parent"] = df_tree["parent"].str.rstrip("/")
if cols:
df_tree[cols] = dfg[cols]
df_all_trees = df_all_trees.append(df_tree, ignore_index=True)

if args["color"] and discrete_color:
df_all_trees = df_all_trees.sort_values(by=args["color"])

# Now modify arguments
args["data_frame"] = df_all_trees
args["path"] = None
args["ids"] = "id"
args["names"] = "labels"
args["parents"] = "parent"
return args


def infer_config(args, constructor, trace_patch):
# Declare all supported attributes, across all plot types
attrables = (
Expand All @@ -1015,9 +1156,9 @@ def infer_config(args, constructor, trace_patch):
+ ["names", "values", "parents", "ids"]
+ ["error_x", "error_x_minus"]
+ ["error_y", "error_y_minus", "error_z", "error_z_minus"]
+ ["lat", "lon", "locations", "animation_group"]
+ ["lat", "lon", "locations", "animation_group", "path"]
)
array_attrables = ["dimensions", "custom_data", "hover_data"]
array_attrables = ["dimensions", "custom_data", "hover_data", "path"]
group_attrables = ["animation_frame", "facet_row", "facet_col", "line_group"]
all_attrables = attrables + group_attrables + ["color"]
group_attrs = ["symbol", "line_dash"]
Expand All @@ -1026,6 +1167,8 @@ def infer_config(args, constructor, trace_patch):
all_attrables += [group_attr]

args = build_dataframe(args, all_attrables, array_attrables)
if constructor in [go.Treemap, go.Sunburst] and args["path"] is not None:
args = process_dataframe_hierarchy(args)

attrs = [k for k in attrables if k in args]
grouped_attrs = []
Expand Down
6 changes: 6 additions & 0 deletions packages/python/plotly/plotly/express/_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@
colref_desc,
"Values from this column or array_like are used to set ids of sectors",
],
path=[
colref_list_type,
colref_list_desc,
"List of columns names or columns of a rectangular dataframe defining the hierarchy of sectors, from root to leaves.",
"An error is raised if path AND ids or parents is passed",
],
lat=[
colref_type,
colref_desc,
Expand Down
Loading