Skip to content

Commit a683790

Browse files
max-sixtypre-commit-ci[bot]dcherian
authored
Use numbagg for ffill by default (#8389)
* Use `numbagg` for `ffill` * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Use duck_array_ops for numbagg version, test import is lazy * Update xarray/core/duck_array_ops.py Co-authored-by: Deepak Cherian <[email protected]> * Update xarray/core/nputils.py Co-authored-by: Deepak Cherian <[email protected]> * Update xarray/core/rolling_exp.py Co-authored-by: Deepak Cherian <[email protected]> * Update xarray/core/nputils.py Co-authored-by: Deepak Cherian <[email protected]> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deepak Cherian <[email protected]>
1 parent dc66f0d commit a683790

11 files changed

+132
-83
lines changed

doc/whats-new.rst

+4
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ Documentation
5555
Internal Changes
5656
~~~~~~~~~~~~~~~~
5757

58+
- :py:meth:`DataArray.bfill` & :py:meth:`DataArray.ffill` now use numbagg by
59+
default, which is up to 5x faster where parallelization is possible. (:pull:`8339`)
60+
By `Maximilian Roos <https://github.com/max-sixty>`_.
61+
5862
.. _whats-new.2023.11.0:
5963

6064
v2023.11.0 (Nov 16, 2023)

xarray/backends/zarr.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,8 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks):
177177
# DESIGN CHOICE: do not allow multiple dask chunks on a single zarr chunk
178178
# this avoids the need to get involved in zarr synchronization / locking
179179
# From zarr docs:
180-
# "If each worker in a parallel computation is writing to a separate
181-
# region of the array, and if region boundaries are perfectly aligned
180+
# "If each worker in a parallel computation is writing to a
181+
# separate region of the array, and if region boundaries are perfectly aligned
182182
# with chunk boundaries, then no synchronization is required."
183183
# TODO: incorporate synchronizer to allow writes from multiple dask
184184
# threads

xarray/core/dask_array_ops.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,11 @@ def push(array, n, axis):
5959
"""
6060
Dask-aware bottleneck.push
6161
"""
62-
import bottleneck
6362
import dask.array as da
6463
import numpy as np
6564

65+
from xarray.core.duck_array_ops import _push
66+
6667
def _fill_with_last_one(a, b):
6768
# cumreduction apply the push func over all the blocks first so, the only missing part is filling
6869
# the missing values using the last data of the previous chunk
@@ -85,7 +86,7 @@ def _fill_with_last_one(a, b):
8586

8687
# The method parameter makes that the tests for python 3.7 fails.
8788
return da.reductions.cumreduction(
88-
func=bottleneck.push,
89+
func=_push,
8990
binop=_fill_with_last_one,
9091
ident=np.nan,
9192
x=array,

xarray/core/duck_array_ops.py

+37-4
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@
3131
from numpy import concatenate as _concatenate
3232
from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined]
3333
from numpy.lib.stride_tricks import sliding_window_view # noqa
34+
from packaging.version import Version
3435

35-
from xarray.core import dask_array_ops, dtypes, nputils
36+
from xarray.core import dask_array_ops, dtypes, nputils, pycompat
37+
from xarray.core.options import OPTIONS
3638
from xarray.core.parallelcompat import get_chunked_array_type, is_chunked_array
3739
from xarray.core.pycompat import array_type, is_duck_dask_array
3840
from xarray.core.utils import is_duck_array, module_available
@@ -688,13 +690,44 @@ def least_squares(lhs, rhs, rcond=None, skipna=False):
688690
return nputils.least_squares(lhs, rhs, rcond=rcond, skipna=skipna)
689691

690692

691-
def push(array, n, axis):
692-
from bottleneck import push
693+
def _push(array, n: int | None = None, axis: int = -1):
694+
"""
695+
Use either bottleneck or numbagg depending on options & what's available
696+
"""
697+
698+
if not OPTIONS["use_bottleneck"] and not OPTIONS["use_numbagg"]:
699+
raise RuntimeError(
700+
"ffill & bfill requires bottleneck or numbagg to be enabled."
701+
" Call `xr.set_options(use_bottleneck=True)` or `xr.set_options(use_numbagg=True)` to enable one."
702+
)
703+
if OPTIONS["use_numbagg"] and module_available("numbagg"):
704+
import numbagg
705+
706+
if pycompat.mod_version("numbagg") < Version("0.6.2"):
707+
warnings.warn(
708+
f"numbagg >= 0.6.2 is required for bfill & ffill; {pycompat.mod_version('numbagg')} is installed. We'll attempt with bottleneck instead."
709+
)
710+
else:
711+
return numbagg.ffill(array, limit=n, axis=axis)
712+
713+
# work around for bottleneck 178
714+
limit = n if n is not None else array.shape[axis]
715+
716+
import bottleneck as bn
717+
718+
return bn.push(array, limit, axis)
693719

720+
721+
def push(array, n, axis):
722+
if not OPTIONS["use_bottleneck"] and not OPTIONS["use_numbagg"]:
723+
raise RuntimeError(
724+
"ffill & bfill requires bottleneck or numbagg to be enabled."
725+
" Call `xr.set_options(use_bottleneck=True)` or `xr.set_options(use_numbagg=True)` to enable one."
726+
)
694727
if is_duck_dask_array(array):
695728
return dask_array_ops.push(array, n, axis)
696729
else:
697-
return push(array, n, axis)
730+
return _push(array, n, axis)
698731

699732

700733
def _first_last_wrapper(array, *, axis, op, keepdims):

xarray/core/missing.py

+1-11
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from xarray.core.common import _contains_datetime_like_objects, ones_like
1515
from xarray.core.computation import apply_ufunc
1616
from xarray.core.duck_array_ops import datetime_to_numeric, push, timedelta_to_numeric
17-
from xarray.core.options import OPTIONS, _get_keep_attrs
17+
from xarray.core.options import _get_keep_attrs
1818
from xarray.core.parallelcompat import get_chunked_array_type, is_chunked_array
1919
from xarray.core.types import Interp1dOptions, InterpOptions
2020
from xarray.core.utils import OrderedSet, is_scalar
@@ -413,11 +413,6 @@ def _bfill(arr, n=None, axis=-1):
413413

414414
def ffill(arr, dim=None, limit=None):
415415
"""forward fill missing values"""
416-
if not OPTIONS["use_bottleneck"]:
417-
raise RuntimeError(
418-
"ffill requires bottleneck to be enabled."
419-
" Call `xr.set_options(use_bottleneck=True)` to enable it."
420-
)
421416

422417
axis = arr.get_axis_num(dim)
423418

@@ -436,11 +431,6 @@ def ffill(arr, dim=None, limit=None):
436431

437432
def bfill(arr, dim=None, limit=None):
438433
"""backfill missing values"""
439-
if not OPTIONS["use_bottleneck"]:
440-
raise RuntimeError(
441-
"bfill requires bottleneck to be enabled."
442-
" Call `xr.set_options(use_bottleneck=True)` to enable it."
443-
)
444434

445435
axis = arr.get_axis_num(dim)
446436

xarray/core/nputils.py

+16-18
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
from __future__ import annotations
22

33
import warnings
4+
from typing import Callable
45

56
import numpy as np
67
import pandas as pd
78
from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined]
89
from packaging.version import Version
910

11+
from xarray.core import pycompat
12+
from xarray.core.utils import module_available
13+
1014
# remove once numpy 2.0 is the oldest supported version
1115
try:
1216
from numpy.exceptions import RankWarning # type: ignore[attr-defined,unused-ignore]
@@ -25,15 +29,6 @@
2529
bn = np
2630
_BOTTLENECK_AVAILABLE = False
2731

28-
try:
29-
import numbagg
30-
31-
_HAS_NUMBAGG = Version(numbagg.__version__) >= Version("0.5.0")
32-
except ImportError:
33-
# use numpy methods instead
34-
numbagg = np # type: ignore
35-
_HAS_NUMBAGG = False
36-
3732

3833
def _select_along_axis(values, idx, axis):
3934
other_ind = np.ix_(*[np.arange(s) for s in idx.shape])
@@ -171,29 +166,32 @@ def __setitem__(self, key, value):
171166
self._array[key] = np.moveaxis(value, vindex_positions, mixed_positions)
172167

173168

174-
def _create_method(name, npmodule=np):
169+
def _create_method(name, npmodule=np) -> Callable:
175170
def f(values, axis=None, **kwargs):
176171
dtype = kwargs.get("dtype", None)
177172
bn_func = getattr(bn, name, None)
178-
nba_func = getattr(numbagg, name, None)
179173

180174
if (
181-
_HAS_NUMBAGG
175+
module_available("numbagg")
176+
and pycompat.mod_version("numbagg") >= Version("0.5.0")
182177
and OPTIONS["use_numbagg"]
183178
and isinstance(values, np.ndarray)
184-
and nba_func is not None
185179
# numbagg uses ddof=1 only, but numpy uses ddof=0 by default
186180
and (("var" in name or "std" in name) and kwargs.get("ddof", 0) == 1)
187181
# TODO: bool?
188182
and values.dtype.kind in "uifc"
189183
# and values.dtype.isnative
190184
and (dtype is None or np.dtype(dtype) == values.dtype)
191185
):
192-
# numbagg does not take care dtype, ddof
193-
kwargs.pop("dtype", None)
194-
kwargs.pop("ddof", None)
195-
result = nba_func(values, axis=axis, **kwargs)
196-
elif (
186+
import numbagg
187+
188+
nba_func = getattr(numbagg, name, None)
189+
if nba_func is not None:
190+
# numbagg does not take care dtype, ddof
191+
kwargs.pop("dtype", None)
192+
kwargs.pop("ddof", None)
193+
return nba_func(values, axis=axis, **kwargs)
194+
if (
197195
_BOTTLENECK_AVAILABLE
198196
and OPTIONS["use_bottleneck"]
199197
and isinstance(values, np.ndarray)

xarray/core/pycompat.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
integer_types = (int, np.integer)
1313

1414
if TYPE_CHECKING:
15-
ModType = Literal["dask", "pint", "cupy", "sparse", "cubed"]
15+
ModType = Literal["dask", "pint", "cupy", "sparse", "cubed", "numbagg"]
1616
DuckArrayTypes = tuple[type[Any], ...] # TODO: improve this? maybe Generic
1717

1818

@@ -47,6 +47,9 @@ def __init__(self, mod: ModType) -> None:
4747
duck_array_type = (duck_array_module.SparseArray,)
4848
elif mod == "cubed":
4949
duck_array_type = (duck_array_module.Array,)
50+
# Not a duck array module, but using this system regardless, to get lazy imports
51+
elif mod == "numbagg":
52+
duck_array_type = ()
5053
else:
5154
raise NotImplementedError
5255

xarray/core/rolling_exp.py

+26-24
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,12 @@
66
import numpy as np
77
from packaging.version import Version
88

9+
from xarray.core import pycompat
910
from xarray.core.computation import apply_ufunc
1011
from xarray.core.options import _get_keep_attrs
1112
from xarray.core.pdcompat import count_not_none
1213
from xarray.core.types import T_DataWithCoords
13-
14-
try:
15-
import numbagg
16-
from numbagg import move_exp_nanmean, move_exp_nansum
17-
18-
_NUMBAGG_VERSION: Version | None = Version(numbagg.__version__)
19-
except ImportError:
20-
_NUMBAGG_VERSION = None
14+
from xarray.core.utils import module_available
2115

2216

2317
def _get_alpha(
@@ -83,17 +77,17 @@ def __init__(
8377
window_type: str = "span",
8478
min_weight: float = 0.0,
8579
):
86-
if _NUMBAGG_VERSION is None:
80+
if not module_available("numbagg"):
8781
raise ImportError(
8882
"numbagg >= 0.2.1 is required for rolling_exp but currently numbagg is not installed"
8983
)
90-
elif _NUMBAGG_VERSION < Version("0.2.1"):
84+
elif pycompat.mod_version("numbagg") < Version("0.2.1"):
9185
raise ImportError(
92-
f"numbagg >= 0.2.1 is required for rolling_exp but currently version {_NUMBAGG_VERSION} is installed"
86+
f"numbagg >= 0.2.1 is required for rolling_exp but currently version {pycompat.mod_version('numbagg')} is installed"
9387
)
94-
elif _NUMBAGG_VERSION < Version("0.3.1") and min_weight > 0:
88+
elif pycompat.mod_version("numbagg") < Version("0.3.1") and min_weight > 0:
9589
raise ImportError(
96-
f"numbagg >= 0.3.1 is required for `min_weight > 0` within `.rolling_exp` but currently version {_NUMBAGG_VERSION} is installed"
90+
f"numbagg >= 0.3.1 is required for `min_weight > 0` within `.rolling_exp` but currently version {pycompat.mod_version('numbagg')} is installed"
9791
)
9892

9993
self.obj: T_DataWithCoords = obj
@@ -127,13 +121,15 @@ def mean(self, keep_attrs: bool | None = None) -> T_DataWithCoords:
127121
Dimensions without coordinates: x
128122
"""
129123

124+
import numbagg
125+
130126
if keep_attrs is None:
131127
keep_attrs = _get_keep_attrs(default=True)
132128

133129
dim_order = self.obj.dims
134130

135131
return apply_ufunc(
136-
move_exp_nanmean,
132+
numbagg.move_exp_nanmean,
137133
self.obj,
138134
input_core_dims=[[self.dim]],
139135
kwargs=self.kwargs,
@@ -163,13 +159,15 @@ def sum(self, keep_attrs: bool | None = None) -> T_DataWithCoords:
163159
Dimensions without coordinates: x
164160
"""
165161

162+
import numbagg
163+
166164
if keep_attrs is None:
167165
keep_attrs = _get_keep_attrs(default=True)
168166

169167
dim_order = self.obj.dims
170168

171169
return apply_ufunc(
172-
move_exp_nansum,
170+
numbagg.move_exp_nansum,
173171
self.obj,
174172
input_core_dims=[[self.dim]],
175173
kwargs=self.kwargs,
@@ -194,10 +192,12 @@ def std(self) -> T_DataWithCoords:
194192
Dimensions without coordinates: x
195193
"""
196194

197-
if _NUMBAGG_VERSION is None or _NUMBAGG_VERSION < Version("0.4.0"):
195+
if pycompat.mod_version("numbagg") < Version("0.4.0"):
198196
raise ImportError(
199-
f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {_NUMBAGG_VERSION} is installed"
197+
f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {pycompat.mod_version('numbagg')} is installed"
200198
)
199+
import numbagg
200+
201201
dim_order = self.obj.dims
202202

203203
return apply_ufunc(
@@ -225,12 +225,12 @@ def var(self) -> T_DataWithCoords:
225225
array([ nan, 0. , 0.46153846, 0.18461538, 0.06446281])
226226
Dimensions without coordinates: x
227227
"""
228-
229-
if _NUMBAGG_VERSION is None or _NUMBAGG_VERSION < Version("0.4.0"):
228+
if pycompat.mod_version("numbagg") < Version("0.4.0"):
230229
raise ImportError(
231-
f"numbagg >= 0.4.0 is required for rolling_exp().var(), currently {_NUMBAGG_VERSION} is installed"
230+
f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {pycompat.mod_version('numbagg')} is installed"
232231
)
233232
dim_order = self.obj.dims
233+
import numbagg
234234

235235
return apply_ufunc(
236236
numbagg.move_exp_nanvar,
@@ -258,11 +258,12 @@ def cov(self, other: T_DataWithCoords) -> T_DataWithCoords:
258258
Dimensions without coordinates: x
259259
"""
260260

261-
if _NUMBAGG_VERSION is None or _NUMBAGG_VERSION < Version("0.4.0"):
261+
if pycompat.mod_version("numbagg") < Version("0.4.0"):
262262
raise ImportError(
263-
f"numbagg >= 0.4.0 is required for rolling_exp().cov(), currently {_NUMBAGG_VERSION} is installed"
263+
f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {pycompat.mod_version('numbagg')} is installed"
264264
)
265265
dim_order = self.obj.dims
266+
import numbagg
266267

267268
return apply_ufunc(
268269
numbagg.move_exp_nancov,
@@ -291,11 +292,12 @@ def corr(self, other: T_DataWithCoords) -> T_DataWithCoords:
291292
Dimensions without coordinates: x
292293
"""
293294

294-
if _NUMBAGG_VERSION is None or _NUMBAGG_VERSION < Version("0.4.0"):
295+
if pycompat.mod_version("numbagg") < Version("0.4.0"):
295296
raise ImportError(
296-
f"numbagg >= 0.4.0 is required for rolling_exp().cov(), currently {_NUMBAGG_VERSION} is installed"
297+
f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {pycompat.mod_version('numbagg')} is installed"
297298
)
298299
dim_order = self.obj.dims
300+
import numbagg
299301

300302
return apply_ufunc(
301303
numbagg.move_exp_nancorr,

xarray/tests/__init__.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ def _importorskip(
5353
mod = importlib.import_module(modname)
5454
has = True
5555
if minversion is not None:
56-
if Version(mod.__version__) < Version(minversion):
56+
v = getattr(mod, "__version__", "999")
57+
if Version(v) < Version(minversion):
5758
raise ImportError("Minimum version not satisfied")
5859
except ImportError:
5960
has = False
@@ -96,6 +97,10 @@ def _importorskip(
9697
requires_scipy_or_netCDF4 = pytest.mark.skipif(
9798
not has_scipy_or_netCDF4, reason="requires scipy or netCDF4"
9899
)
100+
has_numbagg_or_bottleneck = has_numbagg or has_bottleneck
101+
requires_numbagg_or_bottleneck = pytest.mark.skipif(
102+
not has_scipy_or_netCDF4, reason="requires scipy or netCDF4"
103+
)
99104
# _importorskip does not work for development versions
100105
has_pandas_version_two = Version(pd.__version__).major >= 2
101106
requires_pandas_version_two = pytest.mark.skipif(

0 commit comments

Comments
 (0)