Skip to content

Commit 47e0b38

Browse files
authored
Force reindex to be bool always (#176)
* Force reindex to be bool always Closes #155 Turns out we weren't using the more efficient simple_combine with map_reduce in all cases because do_simple_combine was None when reindex was None. Now the default for map-reduce is reindex=True when (expected_groups is not None) or (expected_groups is None and by_is_dask is False)
1 parent 6897240 commit 47e0b38

File tree

2 files changed

+72
-20
lines changed

2 files changed

+72
-20
lines changed

flox/core.py

+20-11
Original file line numberDiff line numberDiff line change
@@ -1388,27 +1388,35 @@ def dask_groupby_agg(
13881388
return (result, groups)
13891389

13901390

1391-
def _validate_reindex(reindex: bool | None, func, method: T_Method, expected_groups) -> bool | None:
1391+
def _validate_reindex(
1392+
reindex: bool | None, func, method: T_Method, expected_groups, by_is_dask: bool
1393+
) -> bool:
13921394
if reindex is True:
13931395
if _is_arg_reduction(func):
13941396
raise NotImplementedError
13951397
if method == "blockwise":
13961398
raise NotImplementedError
13971399

1398-
if method == "blockwise" or _is_arg_reduction(func):
1399-
reindex = False
1400+
if reindex is None:
1401+
if method == "blockwise" or _is_arg_reduction(func):
1402+
reindex = False
14001403

1401-
if reindex is None and expected_groups is not None:
1402-
reindex = True
1404+
elif expected_groups is not None:
1405+
reindex = True
1406+
1407+
elif method in ["split-reduce", "cohorts"]:
1408+
reindex = True
1409+
1410+
elif method == "map-reduce":
1411+
if expected_groups is None and by_is_dask:
1412+
reindex = False
1413+
else:
1414+
reindex = True
14031415

14041416
if method in ["split-reduce", "cohorts"] and reindex is False:
14051417
raise NotImplementedError
14061418

1407-
if method in ["split-reduce", "cohorts"] and reindex is None:
1408-
reindex = True
1409-
1410-
# TODO: Should reindex be a bool-only at this point? Would've been nice but
1411-
# None's are relied on after this function as well.
1419+
assert isinstance(reindex, bool)
14121420
return reindex
14131421

14141422

@@ -1597,7 +1605,6 @@ def groupby_reduce(
15971605
"argreductions not supported for engine='flox' yet."
15981606
"Try engine='numpy' or engine='numba' instead."
15991607
)
1600-
reindex = _validate_reindex(reindex, func, method, expected_groups)
16011608

16021609
bys = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by)
16031610
nby = len(bys)
@@ -1606,6 +1613,8 @@ def groupby_reduce(
16061613
if method in ["split-reduce", "cohorts"] and by_is_dask:
16071614
raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.")
16081615

1616+
reindex = _validate_reindex(reindex, func, method, expected_groups, by_is_dask)
1617+
16091618
if not is_duck_array(array):
16101619
array = np.asarray(array)
16111620
is_bool_array = np.issubdtype(array.dtype, bool)

tests/test_core.py

+52-9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from functools import reduce
3+
from functools import partial, reduce
44
from typing import TYPE_CHECKING
55

66
import numpy as np
@@ -13,6 +13,7 @@
1313
_convert_expected_groups_to_index,
1414
_get_optimal_chunks_for_groups,
1515
_normalize_indexes,
16+
_validate_reindex,
1617
factorize_,
1718
find_group_cohorts,
1819
groupby_reduce,
@@ -221,14 +222,26 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
221222
if not has_dask:
222223
continue
223224
for method in ["map-reduce", "cohorts", "split-reduce"]:
224-
if "arg" in func and method != "map-reduce":
225-
continue
226-
actual, *groups = groupby_reduce(array, *by, method=method, **flox_kwargs)
227-
for actual_group, expect in zip(groups, expected_groups):
228-
assert_equal(actual_group, expect, tolerance)
229-
if "arg" in func:
230-
assert actual.dtype.kind == "i"
231-
assert_equal(actual, expected, tolerance)
225+
if method == "map-reduce":
226+
reindexes = [True, False, None]
227+
else:
228+
reindexes = [None]
229+
for reindex in reindexes:
230+
call = partial(
231+
groupby_reduce, array, *by, method=method, reindex=reindex, **flox_kwargs
232+
)
233+
if "arg" in func:
234+
if method != "map-reduce" or reindex is True:
235+
with pytest.raises(NotImplementedError):
236+
call()
237+
continue
238+
239+
actual, *groups = call()
240+
for actual_group, expect in zip(groups, expected_groups):
241+
assert_equal(actual_group, expect, tolerance)
242+
if "arg" in func:
243+
assert actual.dtype.kind == "i"
244+
assert_equal(actual, expected, tolerance)
232245

233246

234247
@requires_dask
@@ -1125,3 +1138,33 @@ def test_subset_block_2d(flatblocks, expectidx):
11251138
subset = subset_to_blocks(array, flatblocks)
11261139
assert len(subset.dask.layers) == 2
11271140
assert_equal(subset, array.compute()[expectidx])
1141+
1142+
1143+
@pytest.mark.parametrize("method", ["map-reduce", "cohorts"])
1144+
@pytest.mark.parametrize(
1145+
"expected, reindex, func, expected_groups, by_is_dask",
1146+
[
1147+
# argmax only False
1148+
[False, None, "argmax", None, False],
1149+
# True when by is numpy but expected is None
1150+
[True, None, "sum", None, False],
1151+
# False when by is dask but expected is None
1152+
[False, None, "sum", None, True],
1153+
# if expected_groups then always True
1154+
[True, None, "sum", [1, 2, 3], False],
1155+
[True, None, "sum", ([1], [2]), False],
1156+
[True, None, "sum", ([1], [2]), True],
1157+
[True, None, "sum", ([1], None), False],
1158+
[True, None, "sum", ([1], None), True],
1159+
],
1160+
)
1161+
def test_validate_reindex(expected, reindex, func, method, expected_groups, by_is_dask):
1162+
if by_is_dask and method == "cohorts":
1163+
# This should error elsewhere
1164+
pytest.skip()
1165+
call = partial(_validate_reindex, reindex, func, method, expected_groups, by_is_dask)
1166+
if "arg" in func and method == "cohorts":
1167+
with pytest.raises(NotImplementedError):
1168+
call()
1169+
else:
1170+
assert call() == expected

0 commit comments

Comments
 (0)