Skip to content

Commit 5348d47

Browse files
committed
Force reindex to be bool always
Closes #155 Turns out we weren't using the more efficient simple_combine with map_reduce in all cases because do_simple_combine was None when reindex was None. Now the default for map-reduce is 1. reindex=True when (expected_groups is not None) or (expected_groups is None and by_is_dask is False)
1 parent 6897240 commit 5348d47

File tree

2 files changed

+33
-17
lines changed

2 files changed

+33
-17
lines changed

flox/core.py

+20-11
Original file line numberDiff line numberDiff line change
@@ -1388,27 +1388,35 @@ def dask_groupby_agg(
13881388
return (result, groups)
13891389

13901390

1391-
def _validate_reindex(reindex: bool | None, func, method: T_Method, expected_groups) -> bool | None:
1391+
def _validate_reindex(
1392+
reindex: bool | None, func, method: T_Method, expected_groups, by_is_dask: bool
1393+
) -> bool:
13921394
if reindex is True:
13931395
if _is_arg_reduction(func):
13941396
raise NotImplementedError
13951397
if method == "blockwise":
13961398
raise NotImplementedError
13971399

1398-
if method == "blockwise" or _is_arg_reduction(func):
1399-
reindex = False
1400+
if reindex is None:
1401+
if method == "blockwise" or _is_arg_reduction(func):
1402+
reindex = False
14001403

1401-
if reindex is None and expected_groups is not None:
1402-
reindex = True
1404+
elif expected_groups is not None:
1405+
reindex = True
1406+
1407+
elif method in ["split-reduce", "cohorts"]:
1408+
reindex = True
1409+
1410+
elif method == "map-reduce":
1411+
if expected_groups is None and by_is_dask:
1412+
reindex = False
1413+
else:
1414+
reindex = True
14031415

14041416
if method in ["split-reduce", "cohorts"] and reindex is False:
14051417
raise NotImplementedError
14061418

1407-
if method in ["split-reduce", "cohorts"] and reindex is None:
1408-
reindex = True
1409-
1410-
# TODO: Should reindex be a bool-only at this point? Would've been nice but
1411-
# None's are relied on after this function as well.
1419+
assert isinstance(reindex, bool)
14121420
return reindex
14131421

14141422

@@ -1597,7 +1605,6 @@ def groupby_reduce(
15971605
"argreductions not supported for engine='flox' yet."
15981606
"Try engine='numpy' or engine='numba' instead."
15991607
)
1600-
reindex = _validate_reindex(reindex, func, method, expected_groups)
16011608

16021609
bys = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by)
16031610
nby = len(bys)
@@ -1606,6 +1613,8 @@ def groupby_reduce(
16061613
if method in ["split-reduce", "cohorts"] and by_is_dask:
16071614
raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.")
16081615

1616+
reindex = _validate_reindex(reindex, func, method, expected_groups, by_is_dask)
1617+
16091618
if not is_duck_array(array):
16101619
array = np.asarray(array)
16111620
is_bool_array = np.issubdtype(array.dtype, bool)

tests/test_core.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -223,12 +223,19 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
223223
for method in ["map-reduce", "cohorts", "split-reduce"]:
224224
if "arg" in func and method != "map-reduce":
225225
continue
226-
actual, *groups = groupby_reduce(array, *by, method=method, **flox_kwargs)
227-
for actual_group, expect in zip(groups, expected_groups):
228-
assert_equal(actual_group, expect, tolerance)
229-
if "arg" in func:
230-
assert actual.dtype.kind == "i"
231-
assert_equal(actual, expected, tolerance)
226+
if method == "map-reduce":
227+
reindexes = [True, False, None]
228+
else:
229+
reindexes = [None]
230+
for reindex in reindexes:
231+
actual, *groups = groupby_reduce(
232+
array, *by, method=method, reindex=reindex, **flox_kwargs
233+
)
234+
for actual_group, expect in zip(groups, expected_groups):
235+
assert_equal(actual_group, expect, tolerance)
236+
if "arg" in func:
237+
assert actual.dtype.kind == "i"
238+
assert_equal(actual, expected, tolerance)
232239

233240

234241
@requires_dask

0 commit comments

Comments
 (0)