Skip to content

Commit 3ea081b

Browse files
authored
Avoid broadcasting by variables against each other (#186)
* Avoid broadcasting by variables against each other Save some time spent factorizing unnecessary values. TODO: - update cohorts detection * Bring back chunking. * Cleanup * Update flox/core.py * Avoid broadcast in factorize_multiple * Fix test * rename to any_by_dask * bugfix + add tests * Rename `by_is_dask` in test * Rework tests. * Fix benchmarks * cleanup * Test error. * Support autodetection of groups for numpy labels. When grouping by multiple variables. * Fix tests * Fix more tests * Fix dtypes for windows
1 parent 51fb6e9 commit 3ea081b

File tree

4 files changed

+147
-64
lines changed

4 files changed

+147
-64
lines changed

asv_bench/benchmarks/cohorts.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def setup(self, *args, **kwargs):
9393
ret = flox.core._factorize_multiple(
9494
by,
9595
expected_groups=(pd.Index(np.arange(1, 13)), pd.Index(np.arange(1, 25))),
96-
by_is_dask=False,
96+
any_by_dask=False,
9797
reindex=False,
9898
)
9999
# Add one so the rechunk code is simpler and makes sense

flox/core.py

+43-17
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ def find_group_cohorts(labels, chunks, merge: bool = True):
176176
axis = range(-labels.ndim, 0)
177177
# Easier to create a dask array and use the .blocks property
178178
array = dask.array.ones(tuple(sum(c) for c in chunks), chunks=chunks)
179+
labels = np.broadcast_to(labels, array.shape[-labels.ndim :])
179180

180181
# Iterate over each block and create a new block of same shape with "chunk number"
181182
shape = tuple(array.blocks.shape[ax] for ax in axis)
@@ -479,7 +480,7 @@ def factorize_(
479480
idx, groups = pd.factorize(flat, sort=sort)
480481

481482
found_groups.append(np.array(groups))
482-
factorized.append(idx)
483+
factorized.append(idx.reshape(groupvar.shape))
483484

484485
grp_shape = tuple(len(grp) for grp in found_groups)
485486
ngroups = math.prod(grp_shape)
@@ -489,20 +490,18 @@ def factorize_(
489490
# Restore these after the raveling
490491
nan_by_mask = reduce(np.logical_or, [(f == -1) for f in factorized])
491492
group_idx[nan_by_mask] = -1
492-
group_idx = group_idx.reshape(by[0].shape)
493493
else:
494494
group_idx = factorized[0]
495495

496496
if fastpath:
497-
return group_idx.reshape(by[0].shape), found_groups, grp_shape
497+
return group_idx, found_groups, grp_shape
498498

499499
if np.isscalar(axis) and groupvar.ndim > 1:
500500
# Not reducing along all dimensions of by
501501
# this is OK because for 3D by and axis=(1,2),
502502
# we collapse to a 2D by and axis=-1
503503
offset_group = True
504504
group_idx, size = offset_labels(group_idx.reshape(by[0].shape), ngroups)
505-
group_idx = group_idx.reshape(-1)
506505
else:
507506
size = ngroups
508507
offset_group = False
@@ -647,6 +646,8 @@ def chunk_reduce(
647646
else:
648647
nax = by.ndim
649648

649+
assert by.ndim <= array.ndim
650+
650651
final_array_shape = array.shape[:-nax] + (1,) * (nax - 1)
651652
final_groups_shape = (1,) * (nax - 1)
652653

@@ -667,9 +668,17 @@ def chunk_reduce(
667668
)
668669
groups = groups[0]
669670

671+
if isinstance(axis, Sequence):
672+
needs_broadcast = any(
673+
group_idx.shape[ax] != array.shape[ax] and group_idx.shape[ax] == 1
674+
for ax in range(-len(axis), 0)
675+
)
676+
if needs_broadcast:
677+
group_idx = np.broadcast_to(group_idx, array.shape[-by.ndim :])
670678
# always reshape to 1D along group dimensions
671679
newshape = array.shape[: array.ndim - by.ndim] + (math.prod(array.shape[-by.ndim :]),)
672680
array = array.reshape(newshape)
681+
group_idx = group_idx.reshape(-1)
673682

674683
assert group_idx.ndim == 1
675684
empty = np.all(props.nanmask)
@@ -1219,7 +1228,9 @@ def dask_groupby_agg(
12191228
# chunk numpy arrays like the input array
12201229
# This removes an extra rechunk-merge layer that would be
12211230
# added otherwise
1222-
by = dask.array.from_array(by, chunks=tuple(array.chunks[ax] for ax in range(-by.ndim, 0)))
1231+
chunks = tuple(array.chunks[ax] if by.shape[ax] != 1 else (1,) for ax in range(-by.ndim, 0))
1232+
1233+
by = dask.array.from_array(by, chunks=chunks)
12231234
_, (array, by) = dask.array.unify_chunks(array, inds, by, inds[-by.ndim :])
12241235

12251236
# preprocess the array: for argreductions, this zips the index together with the array block
@@ -1429,7 +1440,7 @@ def _extract_result(result_dict: FinalResultsDict, key) -> np.ndarray:
14291440

14301441

14311442
def _validate_reindex(
1432-
reindex: bool | None, func, method: T_Method, expected_groups, by_is_dask: bool
1443+
reindex: bool | None, func, method: T_Method, expected_groups, any_by_dask: bool
14331444
) -> bool:
14341445
if reindex is True:
14351446
if _is_arg_reduction(func):
@@ -1447,7 +1458,7 @@ def _validate_reindex(
14471458
reindex = False
14481459

14491460
elif method == "map-reduce":
1450-
if expected_groups is None and by_is_dask:
1461+
if expected_groups is None and any_by_dask:
14511462
reindex = False
14521463
else:
14531464
reindex = True
@@ -1457,8 +1468,9 @@ def _validate_reindex(
14571468

14581469

14591470
def _assert_by_is_aligned(shape, by):
1471+
assert all(b.ndim == by[0].ndim for b in by[1:])
14601472
for idx, b in enumerate(by):
1461-
if shape[-b.ndim :] != b.shape:
1473+
if not all(j in [i, 1] for i, j in zip(shape[-b.ndim :], b.shape)):
14621474
raise ValueError(
14631475
"`array` and `by` arrays must be aligned "
14641476
"i.e. array.shape[-by.ndim :] == by.shape. "
@@ -1495,26 +1507,34 @@ def _lazy_factorize_wrapper(*by, **kwargs):
14951507
return group_idx
14961508

14971509

1498-
def _factorize_multiple(by, expected_groups, by_is_dask, reindex):
1510+
def _factorize_multiple(by, expected_groups, any_by_dask, reindex):
14991511
kwargs = dict(
15001512
expected_groups=expected_groups,
15011513
axis=None, # always None, we offset later if necessary.
15021514
fastpath=True,
15031515
reindex=reindex,
15041516
)
1505-
if by_is_dask:
1517+
if any_by_dask:
15061518
import dask.array
15071519

1520+
# unifying chunks will make sure all arrays in `by` are dask arrays
1521+
# with compatible chunks, even if there was originally a numpy array
1522+
inds = tuple(range(by[0].ndim))
1523+
chunks, by_ = dask.array.unify_chunks(*itertools.chain(*zip(by, (inds,) * len(by))))
1524+
15081525
group_idx = dask.array.map_blocks(
15091526
_lazy_factorize_wrapper,
1510-
*np.broadcast_arrays(*by),
1527+
*by_,
1528+
chunks=tuple(chunks.values()),
15111529
meta=np.array((), dtype=np.int64),
15121530
**kwargs,
15131531
)
15141532
found_groups = tuple(
15151533
None if is_duck_dask_array(b) else pd.unique(b.reshape(-1)) for b in by
15161534
)
1517-
grp_shape = tuple(len(e) for e in expected_groups)
1535+
grp_shape = tuple(
1536+
len(e) if e is not None else len(f) for e, f in zip(expected_groups, found_groups)
1537+
)
15181538
else:
15191539
group_idx, found_groups, grp_shape = factorize_(by, **kwargs)
15201540

@@ -1644,15 +1664,16 @@ def groupby_reduce(
16441664

16451665
bys = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by)
16461666
nby = len(bys)
1647-
by_is_dask = any(is_duck_dask_array(b) for b in bys)
1667+
by_is_dask = tuple(is_duck_dask_array(b) for b in bys)
1668+
any_by_dask = any(by_is_dask)
16481669

1649-
if method in ["split-reduce", "cohorts"] and by_is_dask:
1670+
if method in ["split-reduce", "cohorts"] and any_by_dask:
16501671
raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.")
16511672

16521673
if method == "split-reduce":
16531674
method = "cohorts"
16541675

1655-
reindex = _validate_reindex(reindex, func, method, expected_groups, by_is_dask)
1676+
reindex = _validate_reindex(reindex, func, method, expected_groups, any_by_dask)
16561677

16571678
if not is_duck_array(array):
16581679
array = np.asarray(array)
@@ -1667,6 +1688,11 @@ def groupby_reduce(
16671688
expected_groups = (None,) * nby
16681689

16691690
_assert_by_is_aligned(array.shape, bys)
1691+
for idx, (expect, is_dask) in enumerate(zip(expected_groups, by_is_dask)):
1692+
if is_dask and (reindex or nby > 1) and expect is None:
1693+
raise ValueError(
1694+
f"`expected_groups` for array {idx} in `by` cannot be None since it is a dask.array."
1695+
)
16701696

16711697
if nby == 1 and not isinstance(expected_groups, tuple):
16721698
expected_groups = (np.asarray(expected_groups),)
@@ -1686,7 +1712,7 @@ def groupby_reduce(
16861712
)
16871713
if factorize_early:
16881714
bys, final_groups, grp_shape = _factorize_multiple(
1689-
bys, expected_groups, by_is_dask=by_is_dask, reindex=reindex
1715+
bys, expected_groups, any_by_dask=any_by_dask, reindex=reindex
16901716
)
16911717
expected_groups = (pd.RangeIndex(math.prod(grp_shape)),)
16921718

@@ -1709,7 +1735,7 @@ def groupby_reduce(
17091735

17101736
# TODO: make sure expected_groups is unique
17111737
if nax == 1 and by_.ndim > 1 and expected_groups is None:
1712-
if not by_is_dask:
1738+
if not any_by_dask:
17131739
expected_groups = _get_expected_groups(by_, sort)
17141740
else:
17151741
# When we reduce along all axes, we are guaranteed to see all

flox/xarray.py

+39-34
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,6 @@
2626
Dims = Union[str, Iterable[Hashable], None]
2727

2828

29-
def _get_input_core_dims(group_names, dim, ds, grouper_dims):
30-
input_core_dims = [[], []]
31-
for g in group_names:
32-
if g in dim:
33-
continue
34-
if g in ds.dims:
35-
input_core_dims[0].extend([g])
36-
if g in grouper_dims:
37-
input_core_dims[1].extend([g])
38-
input_core_dims[0].extend(dim)
39-
input_core_dims[1].extend(dim)
40-
return input_core_dims
41-
42-
4329
def _restore_dim_order(result, obj, by):
4430
def lookup_order(dimension):
4531
if dimension == by.name and by.ndim == 1:
@@ -54,6 +40,26 @@ def lookup_order(dimension):
5440
return result.transpose(*new_order)
5541

5642

43+
def _broadcast_size_one_dims(*arrays, core_dims):
44+
"""Broadcast by adding size-1 dimensions in the right place.
45+
46+
Workaround because apply_ufunc doesn't support this yet.
47+
https://github.com/pydata/xarray/issues/3032#issuecomment-503337637
48+
49+
Specialized to the groupby problem.
50+
"""
51+
array_dims = set(core_dims[0])
52+
broadcasted = [arrays[0]]
53+
for dims, array in zip(core_dims[1:], arrays[1:]):
54+
assert set(dims).issubset(array_dims)
55+
order = [dims.index(d) for d in core_dims[0] if d in dims]
56+
array = array.transpose(*order)
57+
axis = [core_dims[0].index(d) for d in core_dims[0] if d not in dims]
58+
broadcasted.append(np.expand_dims(array, axis))
59+
60+
return broadcasted
61+
62+
5763
def xarray_reduce(
5864
obj: T_Dataset | T_DataArray,
5965
*by: T_DataArray | Hashable,
@@ -255,20 +261,11 @@ def xarray_reduce(
255261
elif dim is not None:
256262
dim_tuple = _atleast_1d(dim)
257263
else:
258-
dim_tuple = tuple()
264+
dim_tuple = tuple(grouper_dims)
259265

260-
# broadcast all variables against each other along all dimensions in `by` variables
261-
# don't exclude `dim` because it need not be a dimension in any of the `by` variables!
262-
# in the case where dim is Ellipsis, and by.ndim < obj.ndim
263-
# then we also broadcast `by` to all `obj.dims`
264-
# TODO: avoid this broadcasting
266+
# broadcast to make sure grouper dimensions are present in the array.
265267
exclude_dims = tuple(d for d in ds.dims if d not in grouper_dims and d not in dim_tuple)
266-
ds_broad, *by_broad = xr.broadcast(ds, *by_da, exclude=exclude_dims)
267-
268-
# all members of by_broad have the same dimensions
269-
# so we just pull by_broad[0].dims if dim is None
270-
if not dim_tuple:
271-
dim_tuple = tuple(by_broad[0].dims)
268+
ds_broad = xr.broadcast(ds, *by_da, exclude=exclude_dims)[0]
272269

273270
if any(d not in grouper_dims and d not in obj.dims for d in dim_tuple):
274271
raise ValueError(f"Cannot reduce over absent dimensions {dim}.")
@@ -298,7 +295,7 @@ def xarray_reduce(
298295
expected_groups = list(expected_groups)
299296
group_names: tuple[Any, ...] = ()
300297
group_sizes: dict[Any, int] = {}
301-
for idx, (b_, expect, isbin_) in enumerate(zip(by_broad, expected_groups, isbins)):
298+
for idx, (b_, expect, isbin_) in enumerate(zip(by_da, expected_groups, isbins)):
302299
group_name = b_.name if not isbin_ else f"{b_.name}_bins"
303300
group_names += (group_name,)
304301

@@ -326,7 +323,10 @@ def xarray_reduce(
326323
# This will never be reached
327324
raise ValueError("expect_index cannot be None")
328325

329-
def wrapper(array, *by, func, skipna, **kwargs):
326+
def wrapper(array, *by, func, skipna, core_dims, **kwargs):
327+
328+
array, *by = _broadcast_size_one_dims(array, *by, core_dims=core_dims)
329+
330330
# Handle skipna here because I need to know dtype to make a good default choice.
331331
# We cannnot handle this easily for xarray Datasets in xarray_reduce
332332
if skipna and func in ["all", "any", "count"]:
@@ -374,17 +374,21 @@ def wrapper(array, *by, func, skipna, **kwargs):
374374
if is_missing_dim:
375375
missing_dim[k] = v
376376

377-
input_core_dims = _get_input_core_dims(group_names, dim_tuple, ds_broad, grouper_dims)
378-
input_core_dims += [input_core_dims[-1]] * (nby - 1)
377+
# dim_tuple contains dimensions we are reducing over. These need to be the last
378+
# core dimensions to be synchronized with axis.
379+
input_core_dims = [[d for d in grouper_dims if d not in dim_tuple] + list(dim_tuple)]
380+
input_core_dims += [list(b.dims) for b in by_da]
379381

382+
output_core_dims = [d for d in input_core_dims[0] if d not in dim_tuple]
383+
output_core_dims.extend(group_names)
380384
actual = xr.apply_ufunc(
381385
wrapper,
382386
ds_broad.drop_vars(tuple(missing_dim)).transpose(..., *grouper_dims),
383-
*by_broad,
387+
*by_da,
384388
input_core_dims=input_core_dims,
385389
# for xarray's test_groupby_duplicate_coordinate_labels
386390
exclude_dims=set(dim_tuple),
387-
output_core_dims=[group_names],
391+
output_core_dims=[output_core_dims],
388392
dask="allowed",
389393
dask_gufunc_kwargs=dict(
390394
output_sizes=group_sizes, output_dtypes=[dtype] if dtype is not None else None
@@ -404,6 +408,7 @@ def wrapper(array, *by, func, skipna, **kwargs):
404408
"isbin": isbins,
405409
"finalize_kwargs": finalize_kwargs,
406410
"dtype": dtype,
411+
"core_dims": input_core_dims,
407412
},
408413
)
409414

@@ -413,7 +418,7 @@ def wrapper(array, *by, func, skipna, **kwargs):
413418
if all(d not in ds_broad[var].dims for d in dim_tuple):
414419
actual[var] = ds_broad[var]
415420

416-
for name, expect, by_ in zip(group_names, expected_groups, by_broad):
421+
for name, expect, by_ in zip(group_names, expected_groups, by_da):
417422
# Can't remove this till xarray handles IntervalIndex
418423
if isinstance(expect, pd.IntervalIndex):
419424
expect = expect.to_numpy()
@@ -443,7 +448,7 @@ def wrapper(array, *by, func, skipna, **kwargs):
443448
template = obj
444449

445450
if actual[var].ndim > 1:
446-
actual[var] = _restore_dim_order(actual[var], template, by_broad[0])
451+
actual[var] = _restore_dim_order(actual[var], template, by_da[0])
447452

448453
if missing_dim:
449454
for k, v in missing_dim.items():

0 commit comments

Comments
 (0)