From 6640d8f426299e15f2ba17bd9a25efd469550cbc Mon Sep 17 00:00:00 2001 From: jvdd <boebievdd@gmail.com> Date: Thu, 9 Jun 2022 10:47:33 +0200 Subject: [PATCH 1/2] :recycle: check for all same groups --- packages/python/plotly/plotly/express/_core.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 5dce0b75391..03b053567ec 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1904,7 +1904,7 @@ def infer_config(args, constructor, trace_patch, layout_patch): return trace_specs, grouped_mappings, sizeref, show_colorbar -def get_orderings(args, grouper, grouped): +def get_orderings(args, grouper, grouped, all_same_group): """ `orders` is the user-supplied ordering with the remaining data-frame-supplied ordering appended if the column is used for grouping. It includes anything the user @@ -1917,7 +1917,7 @@ def get_orderings(args, grouper, grouped): """ orders = {} if "category_orders" not in args else args["category_orders"].copy() - if _all_one_group(grouper): + if all_same_group: sorted_group_names = [("",) * len(grouper)] return orders, sorted_group_names @@ -1944,10 +1944,12 @@ def get_orderings(args, grouper, grouped): return orders, sorted_group_names -def _all_one_group(grouper): - for g in grouper: +def _all_same_group(args, grouper): + for g in set(grouper): if g != one_group: - return False + arr = args["data_frame"][g].values + if not (arr[0] == arr).all(axis=0): + return False return True @@ -1968,10 +1970,11 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): ) grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group] grouped = None - if not _all_one_group(grouper): + all_same_group = _all_same_group(args, grouper) + if not all_same_group: grouped = args["data_frame"].groupby(grouper, sort=False) - orders, sorted_group_names = get_orderings(args, grouper, grouped) + orders, sorted_group_names = get_orderings(args, grouper, grouped, all_same_group) col_labels = [] row_labels = [] From 73b3c7c25d0d07499900035f94db4412a31be088 Mon Sep 17 00:00:00 2001 From: jvdd <boebievdd@gmail.com> Date: Thu, 9 Jun 2022 16:40:24 +0200 Subject: [PATCH 2/2] :bug: make orders and sorted_group_names backwards compatible --- packages/python/plotly/plotly/express/_core.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 03b053567ec..8638c00fc47 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1916,10 +1916,17 @@ def get_orderings(args, grouper, grouped, all_same_group): of a single dimension-group """ orders = {} if "category_orders" not in args else args["category_orders"].copy() + sorted_group_names = [] if all_same_group: - sorted_group_names = [("",) * len(grouper)] - return orders, sorted_group_names + for col in grouper: + if col != one_group: + single_val = args["data_frame"][col].iloc[0] + sorted_group_names.append(single_val) + orders[col] = [single_val] + else: + sorted_group_names.append("") + return orders, [tuple(sorted_group_names)] for col in grouper: if col != one_group: @@ -1929,7 +1936,6 @@ def get_orderings(args, grouper, grouped, all_same_group): else: orders[col] = list(OrderedDict.fromkeys(list(orders[col]) + uniques)) - sorted_group_names = [] for group_name in grouped.groups: if len(grouper) == 1: group_name = (group_name,)