Skip to content

Commit cdfcf37

Browse files
authored
Faster unstacking to sparse (#5577)
1 parent f086728 commit cdfcf37

File tree

7 files changed

+117
-18
lines changed

7 files changed

+117
-18
lines changed

asv_bench/asv.conf.json

+1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
"bottleneck": [""],
6666
"dask": [""],
6767
"distributed": [""],
68+
"sparse": [""]
6869
},
6970

7071

asv_bench/benchmarks/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ def requires_dask():
2222
raise NotImplementedError()
2323

2424

25+
def requires_sparse():
26+
try:
27+
import sparse # noqa: F401
28+
except ImportError:
29+
raise NotImplementedError()
30+
31+
2532
def randn(shape, frac_nan=None, chunks=None, seed=0):
2633
rng = np.random.RandomState(seed)
2734
if chunks is None:

asv_bench/benchmarks/unstacking.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import numpy as np
2+
import pandas as pd
23

34
import xarray as xr
45

5-
from . import requires_dask
6+
from . import requires_dask, requires_sparse
67

78

89
class Unstacking:
@@ -27,3 +28,37 @@ def setup(self, *args, **kwargs):
2728
requires_dask()
2829
super().setup(**kwargs)
2930
self.da_full = self.da_full.chunk({"flat_dim": 25})
31+
32+
33+
class UnstackingSparse(Unstacking):
34+
def setup(self, *args, **kwargs):
35+
requires_sparse()
36+
37+
import sparse
38+
39+
data = sparse.random((500, 1000), random_state=0, fill_value=0)
40+
self.da_full = xr.DataArray(data, dims=list("ab")).stack(flat_dim=[...])
41+
self.da_missing = self.da_full[:-1]
42+
43+
mindex = pd.MultiIndex.from_arrays([np.arange(100), np.arange(100)])
44+
self.da_eye_2d = xr.DataArray(np.ones((100,)), dims="z", coords={"z": mindex})
45+
self.da_eye_3d = xr.DataArray(
46+
np.ones((100, 50)),
47+
dims=("z", "foo"),
48+
coords={"z": mindex, "foo": np.arange(50)},
49+
)
50+
51+
def time_unstack_to_sparse_2d(self):
52+
self.da_eye_2d.unstack(sparse=True)
53+
54+
def time_unstack_to_sparse_3d(self):
55+
self.da_eye_3d.unstack(sparse=True)
56+
57+
def peakmem_unstack_to_sparse_2d(self):
58+
self.da_eye_2d.unstack(sparse=True)
59+
60+
def peakmem_unstack_to_sparse_3d(self):
61+
self.da_eye_3d.unstack(sparse=True)
62+
63+
def time_unstack_pandas_slow(self):
64+
pass

doc/whats-new.rst

+6
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ Deprecations
3333
~~~~~~~~~~~~
3434

3535

36+
Performance
37+
~~~~~~~~~~~
38+
39+
- Significantly faster unstacking to a ``sparse`` array. :pull:`5577`
40+
By `Deepak Cherian <https://github.com/dcherian>`_.
41+
3642
Bug fixes
3743
~~~~~~~~~
3844
- :py:func:`xr.map_blocks` and :py:func:`xr.corr` now work when dask is not installed (:issue:`3391`, :issue:`5715`, :pull:`5731`).

xarray/core/dataset.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -4045,7 +4045,9 @@ def ensure_stackable(val):
40454045

40464046
return data_array
40474047

4048-
def _unstack_once(self, dim: Hashable, fill_value) -> "Dataset":
4048+
def _unstack_once(
4049+
self, dim: Hashable, fill_value, sparse: bool = False
4050+
) -> "Dataset":
40494051
index = self.get_index(dim)
40504052
index = remove_unused_levels_categories(index)
40514053

@@ -4061,7 +4063,7 @@ def _unstack_once(self, dim: Hashable, fill_value) -> "Dataset":
40614063
fill_value_ = fill_value
40624064

40634065
variables[name] = var._unstack_once(
4064-
index=index, dim=dim, fill_value=fill_value_
4066+
index=index, dim=dim, fill_value=fill_value_, sparse=sparse
40654067
)
40664068
else:
40674069
variables[name] = var
@@ -4195,7 +4197,7 @@ def unstack(
41954197
# Once that is resolved, explicitly exclude pint arrays.
41964198
# pint doesn't implement `np.full_like` in a way that's
41974199
# currently compatible.
4198-
needs_full_reindex = sparse or any(
4200+
needs_full_reindex = any(
41994201
is_duck_dask_array(v.data)
42004202
or isinstance(v.data, sparse_array_type)
42014203
or not isinstance(v.data, np.ndarray)
@@ -4206,7 +4208,7 @@ def unstack(
42064208
if needs_full_reindex:
42074209
result = result._unstack_full_reindex(dim, fill_value, sparse)
42084210
else:
4209-
result = result._unstack_once(dim, fill_value)
4211+
result = result._unstack_once(dim, fill_value, sparse)
42104212
return result
42114213

42124214
def update(self, other: "CoercibleMapping") -> "Dataset":

xarray/core/variable.py

+32-12
Original file line numberDiff line numberDiff line change
@@ -1631,6 +1631,7 @@ def _unstack_once(
16311631
index: pd.MultiIndex,
16321632
dim: Hashable,
16331633
fill_value=dtypes.NA,
1634+
sparse: bool = False,
16341635
) -> "Variable":
16351636
"""
16361637
Unstacks this variable given an index to unstack and the name of the
@@ -1658,19 +1659,38 @@ def _unstack_once(
16581659
else:
16591660
dtype = self.dtype
16601661

1661-
data = np.full_like(
1662-
self.data,
1663-
fill_value=fill_value,
1664-
shape=new_shape,
1665-
dtype=dtype,
1666-
)
1662+
if sparse:
1663+
# unstacking a dense multitindexed array to a sparse array
1664+
from sparse import COO
1665+
1666+
codes = zip(*index.codes)
1667+
if reordered.ndim == 1:
1668+
indexes = codes
1669+
else:
1670+
sizes = itertools.product(*[range(s) for s in reordered.shape[:-1]])
1671+
tuple_indexes = itertools.product(sizes, codes)
1672+
indexes = map(lambda x: list(itertools.chain(*x)), tuple_indexes) # type: ignore
1673+
1674+
data = COO(
1675+
coords=np.array(list(indexes)).T,
1676+
data=self.data.astype(dtype).ravel(),
1677+
fill_value=fill_value,
1678+
shape=new_shape,
1679+
sorted=index.is_monotonic_increasing,
1680+
)
1681+
1682+
else:
1683+
data = np.full_like(
1684+
self.data,
1685+
fill_value=fill_value,
1686+
shape=new_shape,
1687+
dtype=dtype,
1688+
)
16671689

1668-
# Indexer is a list of lists of locations. Each list is the locations
1669-
# on the new dimension. This is robust to the data being sparse; in that
1670-
# case the destinations will be NaN / zero.
1671-
# sparse doesn't support item assigment,
1672-
# https://github.com/pydata/sparse/issues/114
1673-
data[(..., *indexer)] = reordered
1690+
# Indexer is a list of lists of locations. Each list is the locations
1691+
# on the new dimension. This is robust to the data being sparse; in that
1692+
# case the destinations will be NaN / zero.
1693+
data[(..., *indexer)] = reordered
16741694

16751695
return self._replace(dims=new_dims, data=data)
16761696

xarray/tests/test_dataset.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from xarray.core import dtypes, indexing, utils
2929
from xarray.core.common import duck_array_ops, full_like
3030
from xarray.core.indexes import Index
31-
from xarray.core.pycompat import integer_types
31+
from xarray.core.pycompat import integer_types, sparse_array_type
3232
from xarray.core.utils import is_scalar
3333

3434
from . import (
@@ -3085,14 +3085,42 @@ def test_unstack_sparse(self):
30853085
# test fill_value
30863086
actual = ds.unstack("index", sparse=True)
30873087
expected = ds.unstack("index")
3088+
assert isinstance(actual["var"].data, sparse_array_type)
30883089
assert actual["var"].variable._to_dense().equals(expected["var"].variable)
30893090
assert actual["var"].data.density < 1.0
30903091

30913092
actual = ds["var"].unstack("index", sparse=True)
30923093
expected = ds["var"].unstack("index")
3094+
assert isinstance(actual.data, sparse_array_type)
30933095
assert actual.variable._to_dense().equals(expected.variable)
30943096
assert actual.data.density < 1.0
30953097

3098+
mindex = pd.MultiIndex.from_arrays(
3099+
[np.arange(3), np.arange(3)], names=["a", "b"]
3100+
)
3101+
ds_eye = Dataset(
3102+
{"var": (("z", "foo", "bar"), np.ones((3, 4, 5)))},
3103+
coords={"z": mindex, "foo": np.arange(4), "bar": np.arange(5)},
3104+
)
3105+
actual = ds_eye.unstack(sparse=True, fill_value=0)
3106+
assert isinstance(actual["var"].data, sparse_array_type)
3107+
expected = xr.Dataset(
3108+
{
3109+
"var": (
3110+
("foo", "bar", "a", "b"),
3111+
np.broadcast_to(np.eye(3, 3), (4, 5, 3, 3)),
3112+
)
3113+
},
3114+
coords={
3115+
"foo": np.arange(4),
3116+
"bar": np.arange(5),
3117+
"a": np.arange(3),
3118+
"b": np.arange(3),
3119+
},
3120+
)
3121+
actual["var"].data = actual["var"].data.todense()
3122+
assert_equal(expected, actual)
3123+
30963124
def test_stack_unstack_fast(self):
30973125
ds = Dataset(
30983126
{

0 commit comments

Comments
 (0)