Skip to content

Avoid broadcasting by variables against each other #186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Nov 26, 2022
2 changes: 1 addition & 1 deletion asv_bench/benchmarks/cohorts.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def setup(self, *args, **kwargs):
ret = flox.core._factorize_multiple(
by,
expected_groups=(pd.Index(np.arange(1, 13)), pd.Index(np.arange(1, 25))),
by_is_dask=False,
any_by_dask=False,
reindex=False,
)
# Add one so the rechunk code is simpler and makes sense
Expand Down
60 changes: 43 additions & 17 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def find_group_cohorts(labels, chunks, merge: bool = True):
axis = range(-labels.ndim, 0)
# Easier to create a dask array and use the .blocks property
array = dask.array.ones(tuple(sum(c) for c in chunks), chunks=chunks)
labels = np.broadcast_to(labels, array.shape[-labels.ndim :])

# Iterate over each block and create a new block of same shape with "chunk number"
shape = tuple(array.blocks.shape[ax] for ax in axis)
Expand Down Expand Up @@ -479,7 +480,7 @@ def factorize_(
idx, groups = pd.factorize(flat, sort=sort)

found_groups.append(np.array(groups))
factorized.append(idx)
factorized.append(idx.reshape(groupvar.shape))

grp_shape = tuple(len(grp) for grp in found_groups)
ngroups = math.prod(grp_shape)
Expand All @@ -489,20 +490,18 @@ def factorize_(
# Restore these after the raveling
nan_by_mask = reduce(np.logical_or, [(f == -1) for f in factorized])
group_idx[nan_by_mask] = -1
group_idx = group_idx.reshape(by[0].shape)
else:
group_idx = factorized[0]

if fastpath:
return group_idx.reshape(by[0].shape), found_groups, grp_shape
return group_idx, found_groups, grp_shape

if np.isscalar(axis) and groupvar.ndim > 1:
# Not reducing along all dimensions of by
# this is OK because for 3D by and axis=(1,2),
# we collapse to a 2D by and axis=-1
offset_group = True
group_idx, size = offset_labels(group_idx.reshape(by[0].shape), ngroups)
group_idx = group_idx.reshape(-1)
else:
size = ngroups
offset_group = False
Expand Down Expand Up @@ -647,6 +646,8 @@ def chunk_reduce(
else:
nax = by.ndim

assert by.ndim <= array.ndim

final_array_shape = array.shape[:-nax] + (1,) * (nax - 1)
final_groups_shape = (1,) * (nax - 1)

Expand All @@ -667,9 +668,17 @@ def chunk_reduce(
)
groups = groups[0]

if isinstance(axis, Sequence):
needs_broadcast = any(
group_idx.shape[ax] != array.shape[ax] and group_idx.shape[ax] == 1
for ax in range(-len(axis), 0)
)
if needs_broadcast:
group_idx = np.broadcast_to(group_idx, array.shape[-by.ndim :])
# always reshape to 1D along group dimensions
newshape = array.shape[: array.ndim - by.ndim] + (math.prod(array.shape[-by.ndim :]),)
array = array.reshape(newshape)
group_idx = group_idx.reshape(-1)

assert group_idx.ndim == 1
empty = np.all(props.nanmask)
Expand Down Expand Up @@ -1220,7 +1229,9 @@ def dask_groupby_agg(
# chunk numpy arrays like the input array
# This removes an extra rechunk-merge layer that would be
# added otherwise
by = dask.array.from_array(by, chunks=tuple(array.chunks[ax] for ax in range(-by.ndim, 0)))
chunks = tuple(array.chunks[ax] if by.shape[ax] != 1 else (1,) for ax in range(-by.ndim, 0))

by = dask.array.from_array(by, chunks=chunks)
_, (array, by) = dask.array.unify_chunks(array, inds, by, inds[-by.ndim :])

# preprocess the array: for argreductions, this zips the index together with the array block
Expand Down Expand Up @@ -1396,7 +1407,7 @@ def dask_groupby_agg(


def _validate_reindex(
reindex: bool | None, func, method: T_Method, expected_groups, by_is_dask: bool
reindex: bool | None, func, method: T_Method, expected_groups, any_by_dask: bool
) -> bool:
if reindex is True:
if _is_arg_reduction(func):
Expand All @@ -1414,7 +1425,7 @@ def _validate_reindex(
reindex = False

elif method == "map-reduce":
if expected_groups is None and by_is_dask:
if expected_groups is None and any_by_dask:
reindex = False
else:
reindex = True
Expand All @@ -1424,8 +1435,9 @@ def _validate_reindex(


def _assert_by_is_aligned(shape, by):
assert all(b.ndim == by[0].ndim for b in by[1:])
for idx, b in enumerate(by):
if shape[-b.ndim :] != b.shape:
if not all(j in [i, 1] for i, j in zip(shape[-b.ndim :], b.shape)):
raise ValueError(
"`array` and `by` arrays must be aligned "
"i.e. array.shape[-by.ndim :] == by.shape. "
Expand Down Expand Up @@ -1462,26 +1474,34 @@ def _lazy_factorize_wrapper(*by, **kwargs):
return group_idx


def _factorize_multiple(by, expected_groups, by_is_dask, reindex):
def _factorize_multiple(by, expected_groups, any_by_dask, reindex):
kwargs = dict(
expected_groups=expected_groups,
axis=None, # always None, we offset later if necessary.
fastpath=True,
reindex=reindex,
)
if by_is_dask:
if any_by_dask:
import dask.array

# unifying chunks will make sure all arrays in `by` are dask arrays
# with compatible chunks, even if there was originally a numpy array
inds = tuple(range(by[0].ndim))
chunks, by_ = dask.array.unify_chunks(*itertools.chain(*zip(by, (inds,) * len(by))))

group_idx = dask.array.map_blocks(
_lazy_factorize_wrapper,
*np.broadcast_arrays(*by),
*by_,
chunks=tuple(chunks.values()),
meta=np.array((), dtype=np.int64),
**kwargs,
)
found_groups = tuple(
None if is_duck_dask_array(b) else pd.unique(b.reshape(-1)) for b in by
)
grp_shape = tuple(len(e) for e in expected_groups)
grp_shape = tuple(
len(e) if e is not None else len(f) for e, f in zip(expected_groups, found_groups)
)
else:
group_idx, found_groups, grp_shape = factorize_(by, **kwargs)

Expand Down Expand Up @@ -1611,15 +1631,16 @@ def groupby_reduce(

bys = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by)
nby = len(bys)
by_is_dask = any(is_duck_dask_array(b) for b in bys)
by_is_dask = tuple(is_duck_dask_array(b) for b in bys)
any_by_dask = any(by_is_dask)

if method in ["split-reduce", "cohorts"] and by_is_dask:
if method in ["split-reduce", "cohorts"] and any_by_dask:
raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.")

if method == "split-reduce":
method = "cohorts"

reindex = _validate_reindex(reindex, func, method, expected_groups, by_is_dask)
reindex = _validate_reindex(reindex, func, method, expected_groups, any_by_dask)

if not is_duck_array(array):
array = np.asarray(array)
Expand All @@ -1634,6 +1655,11 @@ def groupby_reduce(
expected_groups = (None,) * nby

_assert_by_is_aligned(array.shape, bys)
for idx, (expect, is_dask) in enumerate(zip(expected_groups, by_is_dask)):
if is_dask and (reindex or nby > 1) and expect is None:
raise ValueError(
f"`expected_groups` for array {idx} in `by` cannot be None since it is a dask.array."
)

if nby == 1 and not isinstance(expected_groups, tuple):
expected_groups = (np.asarray(expected_groups),)
Expand All @@ -1653,7 +1679,7 @@ def groupby_reduce(
)
if factorize_early:
bys, final_groups, grp_shape = _factorize_multiple(
bys, expected_groups, by_is_dask=by_is_dask, reindex=reindex
bys, expected_groups, any_by_dask=any_by_dask, reindex=reindex
)
expected_groups = (pd.RangeIndex(math.prod(grp_shape)),)

Expand All @@ -1676,7 +1702,7 @@ def groupby_reduce(

# TODO: make sure expected_groups is unique
if nax == 1 and by_.ndim > 1 and expected_groups is None:
if not by_is_dask:
if not any_by_dask:
expected_groups = _get_expected_groups(by_, sort)
else:
# When we reduce along all axes, we are guaranteed to see all
Expand Down
73 changes: 39 additions & 34 deletions flox/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,6 @@
Dims = Union[str, Iterable[Hashable], None]


def _get_input_core_dims(group_names, dim, ds, grouper_dims):
input_core_dims = [[], []]
for g in group_names:
if g in dim:
continue
if g in ds.dims:
input_core_dims[0].extend([g])
if g in grouper_dims:
input_core_dims[1].extend([g])
input_core_dims[0].extend(dim)
input_core_dims[1].extend(dim)
return input_core_dims


def _restore_dim_order(result, obj, by):
def lookup_order(dimension):
if dimension == by.name and by.ndim == 1:
Expand All @@ -54,6 +40,26 @@ def lookup_order(dimension):
return result.transpose(*new_order)


def _broadcast_size_one_dims(*arrays, core_dims):
"""Broadcast by adding size-1 dimensions in the right place.

Workaround because apply_ufunc doesn't support this yet.
https://github.com/pydata/xarray/issues/3032#issuecomment-503337637

Specialized to the groupby problem.
"""
array_dims = set(core_dims[0])
broadcasted = [arrays[0]]
for dims, array in zip(core_dims[1:], arrays[1:]):
assert set(dims).issubset(array_dims)
order = [dims.index(d) for d in core_dims[0] if d in dims]
array = array.transpose(*order)
axis = [core_dims[0].index(d) for d in core_dims[0] if d not in dims]
broadcasted.append(np.expand_dims(array, axis))

return broadcasted


def xarray_reduce(
obj: T_Dataset | T_DataArray,
*by: T_DataArray | Hashable,
Expand Down Expand Up @@ -255,20 +261,11 @@ def xarray_reduce(
elif dim is not None:
dim_tuple = _atleast_1d(dim)
else:
dim_tuple = tuple()
dim_tuple = tuple(grouper_dims)

# broadcast all variables against each other along all dimensions in `by` variables
# don't exclude `dim` because it need not be a dimension in any of the `by` variables!
# in the case where dim is Ellipsis, and by.ndim < obj.ndim
# then we also broadcast `by` to all `obj.dims`
# TODO: avoid this broadcasting
# broadcast to make sure grouper dimensions are present in the array.
exclude_dims = tuple(d for d in ds.dims if d not in grouper_dims and d not in dim_tuple)
ds_broad, *by_broad = xr.broadcast(ds, *by_da, exclude=exclude_dims)

# all members of by_broad have the same dimensions
# so we just pull by_broad[0].dims if dim is None
if not dim_tuple:
dim_tuple = tuple(by_broad[0].dims)
ds_broad = xr.broadcast(ds, *by_da, exclude=exclude_dims)[0]

if any(d not in grouper_dims and d not in obj.dims for d in dim_tuple):
raise ValueError(f"Cannot reduce over absent dimensions {dim}.")
Expand Down Expand Up @@ -298,7 +295,7 @@ def xarray_reduce(
expected_groups = list(expected_groups)
group_names: tuple[Any, ...] = ()
group_sizes: dict[Any, int] = {}
for idx, (b_, expect, isbin_) in enumerate(zip(by_broad, expected_groups, isbins)):
for idx, (b_, expect, isbin_) in enumerate(zip(by_da, expected_groups, isbins)):
group_name = b_.name if not isbin_ else f"{b_.name}_bins"
group_names += (group_name,)

Expand Down Expand Up @@ -326,7 +323,10 @@ def xarray_reduce(
# This will never be reached
raise ValueError("expect_index cannot be None")

def wrapper(array, *by, func, skipna, **kwargs):
def wrapper(array, *by, func, skipna, core_dims, **kwargs):

array, *by = _broadcast_size_one_dims(array, *by, core_dims=core_dims)

# Handle skipna here because I need to know dtype to make a good default choice.
# We cannnot handle this easily for xarray Datasets in xarray_reduce
if skipna and func in ["all", "any", "count"]:
Expand Down Expand Up @@ -374,17 +374,21 @@ def wrapper(array, *by, func, skipna, **kwargs):
if is_missing_dim:
missing_dim[k] = v

input_core_dims = _get_input_core_dims(group_names, dim_tuple, ds_broad, grouper_dims)
input_core_dims += [input_core_dims[-1]] * (nby - 1)
# dim_tuple contains dimensions we are reducing over. These need to be the last
# core dimensions to be synchronized with axis.
input_core_dims = [[d for d in grouper_dims if d not in dim_tuple] + list(dim_tuple)]
input_core_dims += [list(b.dims) for b in by_da]

output_core_dims = [d for d in input_core_dims[0] if d not in dim_tuple]
output_core_dims.extend(group_names)
actual = xr.apply_ufunc(
wrapper,
ds_broad.drop_vars(tuple(missing_dim)).transpose(..., *grouper_dims),
*by_broad,
*by_da,
input_core_dims=input_core_dims,
# for xarray's test_groupby_duplicate_coordinate_labels
exclude_dims=set(dim_tuple),
output_core_dims=[group_names],
output_core_dims=[output_core_dims],
dask="allowed",
dask_gufunc_kwargs=dict(
output_sizes=group_sizes, output_dtypes=[dtype] if dtype is not None else None
Expand All @@ -404,6 +408,7 @@ def wrapper(array, *by, func, skipna, **kwargs):
"isbin": isbins,
"finalize_kwargs": finalize_kwargs,
"dtype": dtype,
"core_dims": input_core_dims,
},
)

Expand All @@ -413,7 +418,7 @@ def wrapper(array, *by, func, skipna, **kwargs):
if all(d not in ds_broad[var].dims for d in dim_tuple):
actual[var] = ds_broad[var]

for name, expect, by_ in zip(group_names, expected_groups, by_broad):
for name, expect, by_ in zip(group_names, expected_groups, by_da):
# Can't remove this till xarray handles IntervalIndex
if isinstance(expect, pd.IntervalIndex):
expect = expect.to_numpy()
Expand Down Expand Up @@ -443,7 +448,7 @@ def wrapper(array, *by, func, skipna, **kwargs):
template = obj

if actual[var].ndim > 1:
actual[var] = _restore_dim_order(actual[var], template, by_broad[0])
actual[var] = _restore_dim_order(actual[var], template, by_da[0])

if missing_dim:
for k, v in missing_dim.items():
Expand Down
Loading