Skip to content

Commit a0c71c1

Browse files
authored
Faster unstacking (#4746)
* Significantly improve unstacking performance * Hack to get sparse tests passing * Use the existing unstack function for dask & sparse * Add whatsnew * Require numpy 1.17 for new unstack * Also special case pint * Revert "Also special case pint" This reverts commit b33aded. * Only run fast unstack on numpy arrays * Update asvs for unstacking * Update whatsnew
1 parent d555172 commit a0c71c1

File tree

4 files changed

+153
-12
lines changed

4 files changed

+153
-12
lines changed

asv_bench/benchmarks/unstacking.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,23 @@
77

88
class Unstacking:
99
def setup(self):
10-
data = np.random.RandomState(0).randn(1, 1000, 500)
11-
self.ds = xr.DataArray(data).stack(flat_dim=["dim_1", "dim_2"])
10+
data = np.random.RandomState(0).randn(500, 1000)
11+
self.da_full = xr.DataArray(data, dims=list("ab")).stack(flat_dim=[...])
12+
self.da_missing = self.da_full[:-1]
13+
self.df_missing = self.da_missing.to_pandas()
1214

1315
def time_unstack_fast(self):
14-
self.ds.unstack("flat_dim")
16+
self.da_full.unstack("flat_dim")
1517

1618
def time_unstack_slow(self):
17-
self.ds[:, ::-1].unstack("flat_dim")
19+
self.da_missing.unstack("flat_dim")
20+
21+
def time_unstack_pandas_slow(self):
22+
self.df_missing.unstack()
1823

1924

2025
class UnstackingDask(Unstacking):
2126
def setup(self, *args, **kwargs):
2227
requires_dask()
2328
super().setup(**kwargs)
24-
self.ds = self.ds.chunk({"flat_dim": 50})
29+
self.da_full = self.da_full.chunk({"flat_dim": 50})

doc/whats-new.rst

+6-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ What's New
1717
1818
.. _whats-new.0.16.3:
1919

20-
v0.16.3 (unreleased)
20+
v0.17.0 (unreleased)
2121
--------------------
2222

2323
Breaking changes
@@ -45,6 +45,11 @@ Breaking changes
4545

4646
New Features
4747
~~~~~~~~~~~~
48+
- Significantly higher ``unstack`` performance on numpy-backed arrays which
49+
contain missing values; 8x faster in our benchmark, and 2x faster than pandas.
50+
(:pull:`4746`);
51+
By `Maximilian Roos <https://github.com/max-sixty>`_.
52+
4853
- Performance improvement when constructing DataArrays. Significantly speeds up repr for Datasets with large number of variables.
4954
By `Deepak Cherian <https://github.com/dcherian>`_.
5055
- :py:meth:`DataArray.swap_dims` & :py:meth:`Dataset.swap_dims` now accept dims

xarray/core/dataset.py

+72-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import sys
55
import warnings
66
from collections import defaultdict
7+
from distutils.version import LooseVersion
78
from html import escape
89
from numbers import Number
910
from operator import methodcaller
@@ -79,7 +80,7 @@
7980
)
8081
from .missing import get_clean_interp_index
8182
from .options import OPTIONS, _get_keep_attrs
82-
from .pycompat import is_duck_dask_array
83+
from .pycompat import is_duck_dask_array, sparse_array_type
8384
from .utils import (
8485
Default,
8586
Frozen,
@@ -3715,7 +3716,40 @@ def ensure_stackable(val):
37153716

37163717
return data_array
37173718

3718-
def _unstack_once(self, dim: Hashable, fill_value, sparse) -> "Dataset":
3719+
def _unstack_once(self, dim: Hashable, fill_value) -> "Dataset":
3720+
index = self.get_index(dim)
3721+
index = remove_unused_levels_categories(index)
3722+
3723+
variables: Dict[Hashable, Variable] = {}
3724+
indexes = {k: v for k, v in self.indexes.items() if k != dim}
3725+
3726+
for name, var in self.variables.items():
3727+
if name != dim:
3728+
if dim in var.dims:
3729+
if isinstance(fill_value, Mapping):
3730+
fill_value_ = fill_value[name]
3731+
else:
3732+
fill_value_ = fill_value
3733+
3734+
variables[name] = var._unstack_once(
3735+
index=index, dim=dim, fill_value=fill_value_
3736+
)
3737+
else:
3738+
variables[name] = var
3739+
3740+
for name, lev in zip(index.names, index.levels):
3741+
variables[name] = IndexVariable(name, lev)
3742+
indexes[name] = lev
3743+
3744+
coord_names = set(self._coord_names) - {dim} | set(index.names)
3745+
3746+
return self._replace_with_new_dims(
3747+
variables, coord_names=coord_names, indexes=indexes
3748+
)
3749+
3750+
def _unstack_full_reindex(
3751+
self, dim: Hashable, fill_value, sparse: bool
3752+
) -> "Dataset":
37193753
index = self.get_index(dim)
37203754
index = remove_unused_levels_categories(index)
37213755
full_idx = pd.MultiIndex.from_product(index.levels, names=index.names)
@@ -3812,7 +3846,38 @@ def unstack(
38123846

38133847
result = self.copy(deep=False)
38143848
for dim in dims:
3815-
result = result._unstack_once(dim, fill_value, sparse)
3849+
3850+
if (
3851+
# Dask arrays don't support assignment by index, which the fast unstack
3852+
# function requires.
3853+
# https://github.com/pydata/xarray/pull/4746#issuecomment-753282125
3854+
any(is_duck_dask_array(v.data) for v in self.variables.values())
3855+
# Sparse doesn't currently support (though we could special-case
3856+
# it)
3857+
# https://github.com/pydata/sparse/issues/422
3858+
or any(
3859+
isinstance(v.data, sparse_array_type)
3860+
for v in self.variables.values()
3861+
)
3862+
or sparse
3863+
# numpy full_like only added `shape` in 1.17
3864+
or LooseVersion(np.__version__) < LooseVersion("1.17")
3865+
# Until https://github.com/pydata/xarray/pull/4751 is resolved,
3866+
# we check explicitly whether it's a numpy array. Once that is
3867+
# resolved, explicitly exclude pint arrays.
3868+
# # pint doesn't implement `np.full_like` in a way that's
3869+
# # currently compatible.
3870+
# # https://github.com/pydata/xarray/pull/4746#issuecomment-753425173
3871+
# # or any(
3872+
# # isinstance(v.data, pint_array_type) for v in self.variables.values()
3873+
# # )
3874+
or any(
3875+
not isinstance(v.data, np.ndarray) for v in self.variables.values()
3876+
)
3877+
):
3878+
result = result._unstack_full_reindex(dim, fill_value, sparse)
3879+
else:
3880+
result = result._unstack_once(dim, fill_value)
38163881
return result
38173882

38183883
def update(self, other: "CoercibleMapping") -> "Dataset":
@@ -4982,6 +5047,10 @@ def _set_numpy_data_from_dataframe(
49825047
self[name] = (dims, values)
49835048
return
49845049

5050+
# NB: similar, more general logic, now exists in
5051+
# variable.unstack_once; we could consider combining them at some
5052+
# point.
5053+
49855054
shape = tuple(lev.size for lev in idx.levels)
49865055
indexer = tuple(idx.codes)
49875056

xarray/core/variable.py

+65-3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
Any,
1111
Dict,
1212
Hashable,
13+
List,
1314
Mapping,
1415
Optional,
1516
Sequence,
@@ -1488,7 +1489,7 @@ def set_dims(self, dims, shape=None):
14881489
)
14891490
return expanded_var.transpose(*dims)
14901491

1491-
def _stack_once(self, dims, new_dim):
1492+
def _stack_once(self, dims: List[Hashable], new_dim: Hashable):
14921493
if not set(dims) <= set(self.dims):
14931494
raise ValueError("invalid existing dimensions: %s" % dims)
14941495

@@ -1544,7 +1545,15 @@ def stack(self, dimensions=None, **dimensions_kwargs):
15441545
result = result._stack_once(dims, new_dim)
15451546
return result
15461547

1547-
def _unstack_once(self, dims, old_dim):
1548+
def _unstack_once_full(
1549+
self, dims: Mapping[Hashable, int], old_dim: Hashable
1550+
) -> "Variable":
1551+
"""
1552+
Unstacks the variable without needing an index.
1553+
1554+
Unlike `_unstack_once`, this function requires the existing dimension to
1555+
contain the full product of the new dimensions.
1556+
"""
15481557
new_dim_names = tuple(dims.keys())
15491558
new_dim_sizes = tuple(dims.values())
15501559

@@ -1573,13 +1582,64 @@ def _unstack_once(self, dims, old_dim):
15731582

15741583
return Variable(new_dims, new_data, self._attrs, self._encoding, fastpath=True)
15751584

1585+
def _unstack_once(
1586+
self,
1587+
index: pd.MultiIndex,
1588+
dim: Hashable,
1589+
fill_value=dtypes.NA,
1590+
) -> "Variable":
1591+
"""
1592+
Unstacks this variable given an index to unstack and the name of the
1593+
dimension to which the index refers.
1594+
"""
1595+
1596+
reordered = self.transpose(..., dim)
1597+
1598+
new_dim_sizes = [lev.size for lev in index.levels]
1599+
new_dim_names = index.names
1600+
indexer = index.codes
1601+
1602+
# Potentially we could replace `len(other_dims)` with just `-1`
1603+
other_dims = [d for d in self.dims if d != dim]
1604+
new_shape = list(reordered.shape[: len(other_dims)]) + new_dim_sizes
1605+
new_dims = reordered.dims[: len(other_dims)] + new_dim_names
1606+
1607+
if fill_value is dtypes.NA:
1608+
is_missing_values = np.prod(new_shape) > np.prod(self.shape)
1609+
if is_missing_values:
1610+
dtype, fill_value = dtypes.maybe_promote(self.dtype)
1611+
else:
1612+
dtype = self.dtype
1613+
fill_value = dtypes.get_fill_value(dtype)
1614+
else:
1615+
dtype = self.dtype
1616+
1617+
# Currently fails on sparse due to https://github.com/pydata/sparse/issues/422
1618+
data = np.full_like(
1619+
self.data,
1620+
fill_value=fill_value,
1621+
shape=new_shape,
1622+
dtype=dtype,
1623+
)
1624+
1625+
# Indexer is a list of lists of locations. Each list is the locations
1626+
# on the new dimension. This is robust to the data being sparse; in that
1627+
# case the destinations will be NaN / zero.
1628+
data[(..., *indexer)] = reordered
1629+
1630+
return self._replace(dims=new_dims, data=data)
1631+
15761632
def unstack(self, dimensions=None, **dimensions_kwargs):
15771633
"""
15781634
Unstack an existing dimension into multiple new dimensions.
15791635
15801636
New dimensions will be added at the end, and the order of the data
15811637
along each new dimension will be in contiguous (C) order.
15821638
1639+
Note that unlike ``DataArray.unstack`` and ``Dataset.unstack``, this
1640+
method requires the existing dimension to contain the full product of
1641+
the new dimensions.
1642+
15831643
Parameters
15841644
----------
15851645
dimensions : mapping of hashable to mapping of hashable to int
@@ -1598,11 +1658,13 @@ def unstack(self, dimensions=None, **dimensions_kwargs):
15981658
See also
15991659
--------
16001660
Variable.stack
1661+
DataArray.unstack
1662+
Dataset.unstack
16011663
"""
16021664
dimensions = either_dict_or_kwargs(dimensions, dimensions_kwargs, "unstack")
16031665
result = self
16041666
for old_dim, dims in dimensions.items():
1605-
result = result._unstack_once(dims, old_dim)
1667+
result = result._unstack_once_full(dims, old_dim)
16061668
return result
16071669

16081670
def fillna(self, value):

0 commit comments

Comments
 (0)