From 74726e01c18500f9b31f5b8528ec4736f8205de5 Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Mon, 14 Jun 2021 15:32:56 -0400 Subject: [PATCH 1/7] PX val_map now respects category_orders --- packages/python/plotly/plotly/express/_core.py | 18 +++++++++--------- .../tests/test_optional/test_px/test_px.py | 17 +++++++---------- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index fe362c7e1a1..cca68c8dd9a 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1838,6 +1838,7 @@ def get_orderings(args, grouper, grouped): if col not in orders: orders[col] = list(uniques) else: + orders[col] = list(orders[col]) for val in uniques: if val not in orders[col]: orders[col].append(val) @@ -1849,7 +1850,6 @@ def get_orderings(args, grouper, grouped): group_names, key=lambda g: orders[col].index(g[i]) if g[i] in orders[col] else -1, ) - return orders, group_names, group_values @@ -1877,17 +1877,19 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): col_labels = [] row_labels = [] - for m in grouped_mappings: - if m.grouper: + if m.grouper not in sorted_group_values: + m.val_map[""] = m.sequence[0] + else: + sorted_values = orders[m.grouper] if m.facet == "col": prefix = get_label(args, args["facet_col"]) + "=" - col_labels = [prefix + str(s) for s in sorted_group_values[m.grouper]] + col_labels = [prefix + str(s) for s in sorted_values] if m.facet == "row": prefix = get_label(args, args["facet_row"]) + "=" - row_labels = [prefix + str(s) for s in sorted_group_values[m.grouper]] - for val in sorted_group_values[m.grouper]: - if val not in m.val_map: + row_labels = [prefix + str(s) for s in sorted_values] + for val in sorted_values: + if val not in m.val_map: # always False if it's an IdentityMap m.val_map[val] = m.sequence[len(m.val_map) % len(m.sequence)] subplot_type = _subplot_type_for_trace_type(constructor().type) @@ -1943,8 +1945,6 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): for i, m in enumerate(grouped_mappings): val = group_name[i] - if val not in m.val_map: - m.val_map[val] = m.sequence[len(m.val_map) % len(m.sequence)] try: m.updater(trace, m.val_map[val]) # covers most cases except ValueError: diff --git a/packages/python/plotly/plotly/tests/test_optional/test_px/test_px.py b/packages/python/plotly/plotly/tests/test_optional/test_px/test_px.py index 03f10722794..81e3bcd8a29 100644 --- a/packages/python/plotly/plotly/tests/test_optional/test_px/test_px.py +++ b/packages/python/plotly/plotly/tests/test_optional/test_px/test_px.py @@ -209,8 +209,8 @@ def test_px_defaults(): def assert_orderings(days_order, days_check, times_order, times_check): - symbol_sequence = ["circle", "diamond", "square", "cross"] - color_sequence = ["red", "blue"] + symbol_sequence = ["circle", "diamond", "square", "cross", "circle", "diamond"] + color_sequence = ["red", "blue", "red", "blue", "red", "blue", "red", "blue"] fig = px.scatter( px.data.tips(), x="total_bill", @@ -229,7 +229,7 @@ def assert_orderings(days_order, days_check, times_order, times_check): assert days_check[col] in trace.hovertemplate for row in range(len(times_check)): - for trace in fig.select_traces(row=2 - row): + for trace in fig.select_traces(row=len(times_check) - row): assert times_check[row] in trace.hovertemplate for trace in fig.data: @@ -241,13 +241,10 @@ def assert_orderings(days_order, days_check, times_order, times_check): assert trace.marker.color == color_sequence[i] -def test_noisy_orthogonal_orderings(): - assert_orderings( - ["x", "Sun", "Sat", "y", "Fri", "z"], # add extra noise, missing Thur - ["Sun", "Sat", "Fri", "Thur"], # Thur is at the back - ["a", "Lunch", "b"], # add extra noise, missing Dinner - ["Lunch", "Dinner"], # Dinner is at the back - ) +@pytest.mark.parametrize("days", permutations(["Sun", "Sat", "Fri", "x"])) +@pytest.mark.parametrize("times", permutations(["Lunch", "x"])) +def test_orthogonal_and_missing_orderings(days, times): + assert_orderings(days, list(days) + ["Thur"], times, list(times) + ["Dinner"]) @pytest.mark.parametrize("days", permutations(["Sun", "Sat", "Fri", "Thur"])) From 2b8b1c8d8919e059552f1ded0a5162e3153d89d8 Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Tue, 15 Jun 2021 11:56:24 -0400 Subject: [PATCH 2/7] force range to be breadth of category_orders --- packages/python/plotly/plotly/express/_core.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index cca68c8dd9a..8e0521047b0 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -552,8 +552,10 @@ def set_cartesian_axis_opts(args, axis, letter, orders): axis["categoryarray"] = ( orders[args[letter]] if isinstance(axis, go.layout.XAxis) - else list(reversed(orders[args[letter]])) + else list(reversed(orders[args[letter]])) # top down for Y axis ) + if "range" not in axis: + axis["range"] = [-0.5, len(orders[args[letter]]) - 0.5] def configure_cartesian_marginal_axes(args, fig, orders): From 90efcfca08d2cffa3219a2155e4da1aeef606d95 Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Tue, 15 Jun 2021 20:48:59 -0400 Subject: [PATCH 3/7] directly compute nrows/ncols --- .../python/plotly/plotly/express/_core.py | 24 +++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 8e0521047b0..8dacb09e5e5 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1779,7 +1779,7 @@ def infer_config(args, constructor, trace_patch, layout_patch): else args["geojson"].__geo_interface__ ) - # Compute marginal attribute + # Compute marginal attribute: copy to appropriate marginal_* if "marginal" in args: position = "marginal_x" if args["orientation"] == "v" else "marginal_y" other_position = "marginal_x" if args["orientation"] == "h" else "marginal_y" @@ -1879,6 +1879,7 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): col_labels = [] row_labels = [] + nrows = ncols = 1 for m in grouped_mappings: if m.grouper not in sorted_group_values: m.val_map[""] = m.sequence[0] @@ -1887,9 +1888,11 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): if m.facet == "col": prefix = get_label(args, args["facet_col"]) + "=" col_labels = [prefix + str(s) for s in sorted_values] + ncols = len(col_labels) if m.facet == "row": prefix = get_label(args, args["facet_row"]) + "=" row_labels = [prefix + str(s) for s in sorted_values] + nrows = len(row_labels) for val in sorted_values: if val not in m.val_map: # always False if it's an IdentityMap m.val_map[val] = m.sequence[len(m.val_map) % len(m.sequence)] @@ -1899,8 +1902,8 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): trace_names_by_frame = {} frames = OrderedDict() trendline_rows = [] - nrows = ncols = 1 trace_name_labels = None + facet_col_wrap = args.get("facet_col_wrap", 0) for group_name in sorted_group_names: group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0]) mapping_labels = OrderedDict() @@ -1981,14 +1984,13 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): row = m.val_map[val] else: if ( - bool(args.get("marginal_x", False)) - and trace_spec.marginal != "x" + bool(args.get("marginal_x", False)) # there is a marginal + and trace_spec.marginal != "x" # and we're not it ): row = 2 else: row = 1 - facet_col_wrap = args.get("facet_col_wrap", 0) # Find col for trace, handling facet_col and marginal_y if m.facet == "col": col = m.val_map[val] @@ -2001,11 +2003,9 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): else: col = 1 - nrows = max(nrows, row) if row > 1: trace._subplot_row = row - ncols = max(ncols, col) if col > 1: trace._subplot_col = col if ( @@ -2064,6 +2064,16 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): ): layout_patch["legend"]["itemsizing"] = "constant" + if facet_col_wrap: + nrows = 1 + ncols // facet_col_wrap + ncols = ncols if ncols < facet_col_wrap else facet_col_wrap + + if args.get("marginal_x"): + nrows += 1 + + if args.get("marginal_y"): + ncols += 1 + fig = init_figure( args, subplot_type, frame_list, nrows, ncols, col_labels, row_labels ) From 75137acd672aa0d8131b24013dd8e8d75b0d34cd Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Tue, 15 Jun 2021 21:53:44 -0400 Subject: [PATCH 4/7] directly compute nrows/ncols --- packages/python/plotly/plotly/express/_core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 8dacb09e5e5..9b344de867a 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -2065,8 +2065,8 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): layout_patch["legend"]["itemsizing"] = "constant" if facet_col_wrap: - nrows = 1 + ncols // facet_col_wrap - ncols = ncols if ncols < facet_col_wrap else facet_col_wrap + nrows = math.ceil(ncols / facet_col_wrap) + ncols = min(ncols, facet_col_wrap) if args.get("marginal_x"): nrows += 1 From fed63d637fec49eba815b304507eb87a0abf11af Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Tue, 15 Jun 2021 22:00:11 -0400 Subject: [PATCH 5/7] standardize arg check --- packages/python/plotly/plotly/express/_core.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 9b344de867a..dd20cc02cca 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1984,7 +1984,7 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): row = m.val_map[val] else: if ( - bool(args.get("marginal_x", False)) # there is a marginal + args.get("marginal_x", None) is not None # there is a marginal and trace_spec.marginal != "x" # and we're not it ): row = 2 @@ -2068,10 +2068,10 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): nrows = math.ceil(ncols / facet_col_wrap) ncols = min(ncols, facet_col_wrap) - if args.get("marginal_x"): + if args.get("marginal_x", None) is not None: nrows += 1 - if args.get("marginal_y"): + if args.get("marginal_y", None) is not None: ncols += 1 fig = init_figure( @@ -2118,7 +2118,7 @@ def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_la # Build column_widths/row_heights if subplot_type == "xy": - if bool(args.get("marginal_x", False)): + if args.get("marginal_x", None) is not None: if args["marginal_x"] == "histogram" or ("color" in args and args["color"]): main_size = 0.74 else: @@ -2131,7 +2131,7 @@ def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_la else: vertical_spacing = args.get("facet_row_spacing", None) or 0.03 - if bool(args.get("marginal_y", False)): + if args.get("marginal_y", None) is not None: if args["marginal_y"] == "histogram" or ("color" in args and args["color"]): main_size = 0.74 else: From 6a1ec9d6c18670364aa6cfa68f8f51a821ea253e Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Tue, 15 Jun 2021 22:05:51 -0400 Subject: [PATCH 6/7] standardize arg check --- .../python/plotly/plotly/express/_core.py | 52 +++++++++---------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index dd20cc02cca..52bc6970ad0 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -400,10 +400,10 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): if hover_is_dict and not attr_value[col]: continue if col in [ - args.get("x", None), - args.get("y", None), - args.get("z", None), - args.get("base", None), + args.get("x"), + args.get("y"), + args.get("z"), + args.get("base"), ]: continue try: @@ -1286,8 +1286,8 @@ def build_dataframe(args, constructor): # now we handle special cases like wide-mode or x-xor-y specification # by rearranging args to tee things up for process_args_into_dataframe to work - no_x = args.get("x", None) is None - no_y = args.get("y", None) is None + no_x = args.get("x") is None + no_y = args.get("y") is None wide_x = False if no_x else _is_col_list(df_input, args["x"]) wide_y = False if no_y else _is_col_list(df_input, args["y"]) @@ -1314,9 +1314,9 @@ def build_dataframe(args, constructor): if var_name in [None, "value", "index"] or var_name in df_input: var_name = "variable" if constructor == go.Funnel: - wide_orientation = args.get("orientation", None) or "h" + wide_orientation = args.get("orientation") or "h" else: - wide_orientation = args.get("orientation", None) or "v" + wide_orientation = args.get("orientation") or "v" args["orientation"] = wide_orientation args["wide_cross"] = None elif wide_x != wide_y: @@ -1347,7 +1347,7 @@ def build_dataframe(args, constructor): if constructor in [go.Scatter, go.Bar, go.Funnel] + hist2d_types: if not wide_mode and (no_x != no_y): for ax in ["x", "y"]: - if args.get(ax, None) is None: + if args.get(ax) is None: args[ax] = df_input.index if df_provided else Range() if constructor == go.Bar: missing_bar_dim = ax @@ -1371,7 +1371,7 @@ def build_dataframe(args, constructor): ) no_color = False - if type(args.get("color", None)) == str and args["color"] == NO_COLOR: + if type(args.get("color")) == str and args["color"] == NO_COLOR: no_color = True args["color"] = None # now that things have been prepped, we do the systematic rewriting of `args` @@ -1787,17 +1787,17 @@ def infer_config(args, constructor, trace_patch, layout_patch): args[other_position] = None # If both marginals and faceting are specified, faceting wins - if args.get("facet_col", None) is not None and args.get("marginal_y", None): + if args.get("facet_col") is not None and args.get("marginal_y") is not None: args["marginal_y"] = None - if args.get("facet_row", None) is not None and args.get("marginal_x", None): + if args.get("facet_row") is not None and args.get("marginal_x") is not None: args["marginal_x"] = None # facet_col_wrap only works if no marginals or row faceting is used if ( - args.get("marginal_x", None) is not None - or args.get("marginal_y", None) is not None - or args.get("facet_row", None) is not None + args.get("marginal_x") is not None + or args.get("marginal_y") is not None + or args.get("facet_row") is not None ): args["facet_col_wrap"] = 0 @@ -1984,7 +1984,7 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): row = m.val_map[val] else: if ( - args.get("marginal_x", None) is not None # there is a marginal + args.get("marginal_x") is not None # there is a marginal and trace_spec.marginal != "x" # and we're not it ): row = 2 @@ -2068,10 +2068,10 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): nrows = math.ceil(ncols / facet_col_wrap) ncols = min(ncols, facet_col_wrap) - if args.get("marginal_x", None) is not None: + if args.get("marginal_x") is not None: nrows += 1 - if args.get("marginal_y", None) is not None: + if args.get("marginal_y") is not None: ncols += 1 fig = init_figure( @@ -2118,7 +2118,7 @@ def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_la # Build column_widths/row_heights if subplot_type == "xy": - if args.get("marginal_x", None) is not None: + if args.get("marginal_x") is not None: if args["marginal_x"] == "histogram" or ("color" in args and args["color"]): main_size = 0.74 else: @@ -2127,11 +2127,11 @@ def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_la row_heights = [main_size] * (nrows - 1) + [1 - main_size] vertical_spacing = 0.01 elif facet_col_wrap: - vertical_spacing = args.get("facet_row_spacing", None) or 0.07 + vertical_spacing = args.get("facet_row_spacing") or 0.07 else: - vertical_spacing = args.get("facet_row_spacing", None) or 0.03 + vertical_spacing = args.get("facet_row_spacing") or 0.03 - if args.get("marginal_y", None) is not None: + if args.get("marginal_y") is not None: if args["marginal_y"] == "histogram" or ("color" in args and args["color"]): main_size = 0.74 else: @@ -2140,7 +2140,7 @@ def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_la column_widths = [main_size] * (ncols - 1) + [1 - main_size] horizontal_spacing = 0.005 else: - horizontal_spacing = args.get("facet_col_spacing", None) or 0.02 + horizontal_spacing = args.get("facet_col_spacing") or 0.02 else: # Other subplot types: # 'scene', 'geo', 'polar', 'ternary', 'mapbox', 'domain', None @@ -2148,10 +2148,10 @@ def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_la # We can customize subplot spacing per type once we enable faceting # for all plot types if facet_col_wrap: - vertical_spacing = args.get("facet_row_spacing", None) or 0.07 + vertical_spacing = args.get("facet_row_spacing") or 0.07 else: - vertical_spacing = args.get("facet_row_spacing", None) or 0.03 - horizontal_spacing = args.get("facet_col_spacing", None) or 0.02 + vertical_spacing = args.get("facet_row_spacing") or 0.03 + horizontal_spacing = args.get("facet_col_spacing") or 0.02 if facet_col_wrap: subplot_labels = [None] * nrows * ncols From b076a38a5d186f1fed679be239831d5f898ffb5f Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Wed, 16 Jun 2021 08:55:52 -0400 Subject: [PATCH 7/7] simplify get_orderings --- .../python/plotly/plotly/express/_core.py | 58 +++++++++---------- 1 file changed, 27 insertions(+), 31 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 52bc6970ad0..d9cc19cc061 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1816,43 +1816,41 @@ def infer_config(args, constructor, trace_patch, layout_patch): def get_orderings(args, grouper, grouped): """ - `orders` is the user-supplied ordering (with the remaining data-frame-supplied - ordering appended if the column is used for grouping). It includes anything the user - gave, for any variable, including values not present in the dataset. It is used - downstream to set e.g. `categoryarray` for cartesian axes - - `group_names` is the set of groups, ordered by the order above - - `group_values` is a subset of `orders` in both keys and values. It contains a key - for every grouped mapping and its values are the sorted *data* values for these - mappings. + `orders` is the user-supplied ordering with the remaining data-frame-supplied + ordering appended if the column is used for grouping. It includes anything the user + gave, for any variable, including values not present in the dataset. It's a dict + where the keys are e.g. "x" or "color" + + `sorted_group_names` is the set of groups, ordered by the order above. It's a list + of tuples like [("value1", ""), ("value2", "")] where each tuple contains the name + of a single dimension-group """ + orders = {} if "category_orders" not in args else args["category_orders"].copy() - group_names = [] - group_values = {} + for col in grouper: + if col != one_group: + uniques = args["data_frame"][col].unique() + if col not in orders: + orders[col] = list(uniques) + else: + orders[col] = list(orders[col]) + for val in uniques: + if val not in orders[col]: + orders[col].append(val) + + sorted_group_names = [] for group_name in grouped.groups: if len(grouper) == 1: group_name = (group_name,) - group_names.append(group_name) - for col in grouper: - if col != one_group: - uniques = args["data_frame"][col].unique() - if col not in orders: - orders[col] = list(uniques) - else: - orders[col] = list(orders[col]) - for val in uniques: - if val not in orders[col]: - orders[col].append(val) - group_values[col] = sorted(uniques, key=orders[col].index) + sorted_group_names.append(group_name) for i, col in reversed(list(enumerate(grouper))): if col != one_group: - group_names = sorted( - group_names, + sorted_group_names = sorted( + sorted_group_names, key=lambda g: orders[col].index(g[i]) if g[i] in orders[col] else -1, ) - return orders, group_names, group_values + return orders, sorted_group_names def make_figure(args, constructor, trace_patch=None, layout_patch=None): @@ -1873,15 +1871,13 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group] grouped = args["data_frame"].groupby(grouper, sort=False) - orders, sorted_group_names, sorted_group_values = get_orderings( - args, grouper, grouped - ) + orders, sorted_group_names = get_orderings(args, grouper, grouped) col_labels = [] row_labels = [] nrows = ncols = 1 for m in grouped_mappings: - if m.grouper not in sorted_group_values: + if m.grouper not in orders: m.val_map[""] = m.sequence[0] else: sorted_values = orders[m.grouper]