diff --git a/flox/core.py b/flox/core.py index 42bc2f1a2..8edd7c22b 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1507,17 +1507,15 @@ def dask_groupby_agg( group_chunks: tuple[tuple[int | float, ...]] if method in ["map-reduce", "cohorts"]: - combine: Callable[..., IntermediateDict] - if do_simple_combine: - combine = partial(_simple_combine, reindex=reindex) - combine_name = "simple-combine" - else: - combine = partial(_grouped_combine, engine=engine, sort=sort) - combine_name = "grouped-combine" + combine: Callable[..., IntermediateDict] = ( + partial(_simple_combine, reindex=reindex) + if do_simple_combine + else partial(_grouped_combine, engine=engine, sort=sort) + ) tree_reduce = partial( dask.array.reductions._tree_reduce, - name=f"{name}-reduce-{method}-{combine_name}", + name=f"{name}-reduce-{method}", dtype=array.dtype, axis=axis, keepdims=True, diff --git a/tests/test_core.py b/tests/test_core.py index 880284b75..1ee24550f 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -4,12 +4,14 @@ import warnings from functools import partial, reduce from typing import TYPE_CHECKING, Callable +from unittest.mock import MagicMock, patch import numpy as np import pandas as pd import pytest from numpy_groupies.aggregate_numpy import aggregate +import flox from flox import xrutils from flox.aggregations import Aggregation, _initialize_aggregation from flox.core import ( @@ -303,6 +305,7 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): if chunks == -1: params.extend([("blockwise", None)]) + combine_error = RuntimeError("This combine should not have been called.") for method, reindex in params: call = partial( groupby_reduce, array, *by, method=method, reindex=reindex, **flox_kwargs @@ -312,13 +315,22 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): with pytest.raises(NotImplementedError): call() continue - actual, *groups = call() - if method != "blockwise": + + if method == "blockwise": + # no combine necessary + mocks = { + "_simple_combine": MagicMock(side_effect=combine_error), + "_grouped_combine": MagicMock(side_effect=combine_error), + } + else: if "arg" not in func: # make sure we use simple combine - assert any("simple-combine" in key for key in actual.dask.layers.keys()) + mocks = {"_grouped_combine": MagicMock(side_effect=combine_error)} else: - assert any("grouped-combine" in key for key in actual.dask.layers.keys()) + mocks = {"_simple_combine": MagicMock(side_effect=combine_error)} + + with patch.multiple(flox.core, **mocks): + actual, *groups = call() for actual_group, expect in zip(groups, expected_groups): assert_equal(actual_group, expect, tolerance) if "arg" in func: