diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index b3bcd096d3..27376ab037 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -2422,7 +2422,6 @@ def get_groups_and_orders(args, grouper): # figure out orders and what the single group name would be if there were one single_group_name = [] unique_cache = dict() - grp_to_idx = dict() for i, col in enumerate(grouper): if col == one_group: @@ -2440,27 +2439,28 @@ def get_groups_and_orders(args, grouper): else: orders[col] = list(OrderedDict.fromkeys(list(orders[col]) + uniques)) - grp_to_idx = {k: i for i, k in enumerate(orders)} - if len(single_group_name) == len(grouper): # we have a single group, so we can skip all group-by operations! groups = {tuple(single_group_name): df} else: - required_grouper = list(orders.keys()) + required_grouper = [group for group in orders if group in grouper] grouped = dict(df.group_by(required_grouper, drop_null_keys=True).__iter__()) - sorted_group_names = list(grouped.keys()) - for i, col in reversed(list(enumerate(required_grouper))): - sorted_group_names = sorted( - sorted_group_names, - key=lambda g: orders[col].index(g[i]) if g[i] in orders[col] else -1, - ) + sorted_group_names = sorted( + grouped.keys(), + key=lambda values: [ + orders[group].index(value) if value in orders[group] else -1 + for group, value in zip(required_grouper, values) + ], + ) # calculate the full group_names by inserting "" in the tuple index for one_group groups full_sorted_group_names = [ tuple( [ - "" if col == one_group else sub_group_names[grp_to_idx[col]] + "" + if col == one_group + else sub_group_names[required_grouper.index(col)] for col in grouper ] ) diff --git a/packages/python/plotly/plotly/tests/test_optional/test_px/test_px.py b/packages/python/plotly/plotly/tests/test_optional/test_px/test_px.py index 8d091df3ae..5ae751f666 100644 --- a/packages/python/plotly/plotly/tests/test_optional/test_px/test_px.py +++ b/packages/python/plotly/plotly/tests/test_optional/test_px/test_px.py @@ -289,6 +289,27 @@ def test_orthogonal_orderings(backend, days, times): assert_orderings(backend, days, days, times, times) +def test_category_order_with_category_as_x(backend): + # https://github.com/plotly/plotly.py/issues/4875 + tips = nw.from_native(px.data.tips(return_type=backend)) + fig = px.bar( + tips, + x="day", + y="total_bill", + color="smoker", + barmode="group", + facet_col="sex", + category_orders={ + "day": ["Thur", "Fri", "Sat", "Sun"], + "smoker": ["Yes", "No"], + "sex": ["Male", "Female"], + }, + ) + assert fig["layout"]["xaxis"]["categoryarray"] == ("Thur", "Fri", "Sat", "Sun") + for trace in fig["data"]: + assert set(trace["x"]) == {"Thur", "Fri", "Sat", "Sun"} + + def test_permissive_defaults(): msg = "'PxDefaults' object has no attribute 'should_not_work'" with pytest.raises(AttributeError, match=msg):