Skip to content

Consolidate validation of expected_groups #193

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 37 additions & 12 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
if TYPE_CHECKING:
import dask.array.Array as DaskArray

T_ExpectedGroups = Union[Sequence, np.ndarray, pd.Index]
T_Expect = Union[Sequence, np.ndarray, pd.Index, None]
T_ExpectTuple = tuple[T_Expect, ...]
T_ExpectedGroups = Union[T_Expect, T_ExpectTuple]
T_ExpectedGroupsOpt = Union[T_ExpectedGroups, None]
T_Func = Union[str, Callable]
T_Funcs = Union[T_Func, Sequence[T_Func]]
Expand Down Expand Up @@ -1476,7 +1478,7 @@ def _assert_by_is_aligned(shape, by):


def _convert_expected_groups_to_index(
expected_groups: T_ExpectedGroups, isbin: Sequence[bool], sort: bool
expected_groups: T_ExpectTuple, isbin: Sequence[bool], sort: bool
) -> tuple[pd.Index | None, ...]:
out: list[pd.Index | None] = []
for ex, isbin_ in zip(expected_groups, isbin):
Expand Down Expand Up @@ -1543,6 +1545,36 @@ def _factorize_multiple(by, expected_groups, any_by_dask, reindex):
return (group_idx,), final_groups, grp_shape


def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroupsOpt) -> T_ExpectTuple:

if expected_groups is None:
return (None,) * nby

if nby == 1 and not isinstance(expected_groups, tuple):
return (np.asarray(expected_groups),)

if nby > 1 and not isinstance(expected_groups, tuple): # TODO: test for list
raise ValueError(
"When grouping by multiple variables, expected_groups must be a tuple "
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@keewis how is this for a better error message ;) ?

"of either arrays or objects convertible to an array (like lists). "
"For example `expected_groups=(np.array([1, 2, 3]), ['a', 'b', 'c'])`."
f"Received a {type(expected_groups).__name__} instead. "
"When grouping by a single variable, you can pass an array or something "
"convertible to an array for convenience: `expected_groups=['a', 'b', 'c']`."
)

if TYPE_CHECKING:
assert isinstance(expected_groups, tuple)

if len(expected_groups) != nby:
raise ValueError(
f"Must have same number of `expected_groups` (received {len(expected_groups)}) "
f" and variables to group by (received {nby})."
)

return expected_groups


def groupby_reduce(
array: np.ndarray | DaskArray,
*by: np.ndarray | DaskArray,
Expand Down Expand Up @@ -1679,24 +1711,17 @@ def groupby_reduce(
isbins = isbin
else:
isbins = (isbin,) * nby
if expected_groups is None:
expected_groups = (None,) * nby

_assert_by_is_aligned(array.shape, bys)

expected_groups = _validate_expected_groups(nby, expected_groups)

for idx, (expect, is_dask) in enumerate(zip(expected_groups, by_is_dask)):
if is_dask and (reindex or nby > 1) and expect is None:
raise ValueError(
f"`expected_groups` for array {idx} in `by` cannot be None since it is a dask.array."
)

if nby == 1 and not isinstance(expected_groups, tuple):
expected_groups = (np.asarray(expected_groups),)
elif len(expected_groups) != nby:
raise ValueError(
f"Must have same number of `expected_groups` (received {len(expected_groups)}) "
f" and variables to group by (received {nby})."
)

# We convert to pd.Index since that lets us know if we are binning or not
# (pd.IntervalIndex or not)
expected_groups = _convert_expected_groups_to_index(expected_groups, isbins, sort)
Expand Down
11 changes: 3 additions & 8 deletions flox/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .core import (
_convert_expected_groups_to_index,
_get_expected_groups,
_validate_expected_groups,
groupby_reduce,
rechunk_for_blockwise as rechunk_array_for_blockwise,
rechunk_for_cohorts as rechunk_array_for_cohorts,
Expand Down Expand Up @@ -216,16 +217,10 @@ def xarray_reduce(
else:
isbins = (isbin,) * nby

if expected_groups is None:
expected_groups = (None,) * nby
if isinstance(expected_groups, (np.ndarray, list)): # TODO: test for list
if nby == 1:
expected_groups = (expected_groups,)
else:
raise ValueError("Needs better message.")
expected_groups = _validate_expected_groups(nby, expected_groups)

if not sort:
raise NotImplementedError
raise NotImplementedError("sort must be True for xarray_reduce")

# eventually drop the variables we are grouping by
maybe_drop = [b for b in by if isinstance(b, Hashable)]
Expand Down
20 changes: 19 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,8 +1019,26 @@ def test_multiple_groupers(chunk, by1, by2, expected_groups) -> None:
assert_equal(expected, actual)


@pytest.mark.parametrize(
"expected_groups",
(
[None, None, None],
(None,),
),
)
def test_validate_expected_groups(expected_groups):
with pytest.raises(ValueError):
groupby_reduce(
np.ones((10,)),
np.ones((10,)),
np.ones((10,)),
expected_groups=expected_groups,
func="mean",
)


@requires_dask
def test_multiple_groupers_errors() -> None:
def test_validate_expected_groups_not_none_dask() -> None:
with pytest.raises(ValueError):
groupby_reduce(
dask.array.ones((5, 2)),
Expand Down
14 changes: 12 additions & 2 deletions tests/test_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,22 @@ def test_xarray_reduce_multiple_groupers_2(pass_expected_groups, chunk, engine):


@requires_dask
def test_dask_groupers_error():
@pytest.mark.parametrize(
"expected_groups",
(None, (None, None), [[1, 2], [1, 2]]),
)
def test_validate_expected_groups(expected_groups):
da = xr.DataArray(
[1.0, 2.0], dims="x", coords={"labels": ("x", [1, 2]), "labels2": ("x", [1, 2])}
)
with pytest.raises(ValueError):
xarray_reduce(da.chunk({"x": 2, "z": 1}), "labels", "labels2", func="count")
xarray_reduce(
da.chunk({"x": 1}),
"labels",
"labels2",
func="count",
expected_groups=expected_groups,
)


@requires_dask
Expand Down