Skip to content

Commit 17b70ca

Browse files
authored
apply_ufunc: Add meta kwarg + bump dask to 2.2 (#3660)
* apply_func: Set meta=np.ndarray when vectorize=True and dask="parallelized" Closes #3574 * Add meta kwarg to apply_ufunc. * Bump minimum dask to 2.1.0 * Update distributed too * bump minimum dask, distributed to 2.2 * Update whats-new * minor. * fix whats-new * Attempt numpy=1.15 * Revert "Attempt numpy=1.15" This reverts commit 2b22470. * xfail test. * More xfailed tests. * Update xfail reason. * fix whats-new * Add test to ensure meta is passed on to dask. * Use skipif instead of xfail.
1 parent 27a3929 commit 17b70ca

File tree

6 files changed

+71
-6
lines changed

6 files changed

+71
-6
lines changed

ci/requirements/py36-min-all-deps.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ dependencies:
1515
- cfgrib=0.9
1616
- cftime=1.0
1717
- coveralls
18-
- dask=1.2
19-
- distributed=1.27
18+
- dask=2.2
19+
- distributed=2.2
2020
- flake8
2121
- h5netcdf=0.7
2222
- h5py=2.9 # Policy allows for 2.10, but it's a conflict-fest

doc/whats-new.rst

+6-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ v0.15.0 (unreleased)
2121

2222
Breaking changes
2323
~~~~~~~~~~~~~~~~
24+
- Bumped minimum ``dask`` version to 2.2.
2425
- Remove ``compat`` and ``encoding`` kwargs from ``DataArray``, which
2526
have been deprecated since 0.12. (:pull:`3650`).
2627
Instead, specify the encoding when writing to disk or set
@@ -50,6 +51,8 @@ New Features
5051
- Added the ``count`` reduction method to both :py:class:`~core.rolling.DatasetCoarsen`
5152
and :py:class:`~core.rolling.DataArrayCoarsen` objects. (:pull:`3500`)
5253
By `Deepak Cherian <https://github.com/dcherian>`_
54+
- Add ``meta`` kwarg to :py:func:`~xarray.apply_ufunc`; this is passed on to
55+
:py:meth:`dask.array.blockwise`. (:pull:`3660`) By `Deepak Cherian <https://github.com/dcherian>`_.
5356
- Add `attrs_file` option in :py:func:`~xarray.open_mfdataset` to choose the
5457
source file for global attributes in a multi-file dataset (:issue:`2382`,
5558
:pull:`3498`) by `Julien Seguinot <https://github.com/juseg>_`.
@@ -63,7 +66,9 @@ New Features
6366

6467
Bug fixes
6568
~~~~~~~~~
66-
69+
- Applying a user-defined function that adds new dimensions using :py:func:`apply_ufunc`
70+
and ``vectorize=True`` now works with ``dask > 2.0``. (:issue:`3574`, :pull:`3660`).
71+
By `Deepak Cherian <https://github.com/dcherian>`_.
6772
- Fix :py:meth:`xarray.combine_by_coords` to allow for combining incomplete
6873
hypercubes of Datasets (:issue:`3648`). By `Ian Bolliger
6974
<https://github.com/bolliger32>`_.

xarray/core/computation.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,7 @@ def apply_variable_ufunc(
548548
output_dtypes=None,
549549
output_sizes=None,
550550
keep_attrs=False,
551+
meta=None,
551552
):
552553
"""Apply a ndarray level function over Variable and/or ndarray objects.
553554
"""
@@ -590,6 +591,7 @@ def func(*arrays):
590591
signature,
591592
output_dtypes,
592593
output_sizes,
594+
meta,
593595
)
594596

595597
elif dask == "allowed":
@@ -648,7 +650,14 @@ def func(*arrays):
648650

649651

650652
def _apply_blockwise(
651-
func, args, input_dims, output_dims, signature, output_dtypes, output_sizes=None
653+
func,
654+
args,
655+
input_dims,
656+
output_dims,
657+
signature,
658+
output_dtypes,
659+
output_sizes=None,
660+
meta=None,
652661
):
653662
import dask.array
654663

@@ -720,6 +729,7 @@ def _apply_blockwise(
720729
dtype=dtype,
721730
concatenate=True,
722731
new_axes=output_sizes,
732+
meta=meta,
723733
)
724734

725735

@@ -761,6 +771,7 @@ def apply_ufunc(
761771
dask: str = "forbidden",
762772
output_dtypes: Sequence = None,
763773
output_sizes: Mapping[Any, int] = None,
774+
meta: Any = None,
764775
) -> Any:
765776
"""Apply a vectorized function for unlabeled arrays on xarray objects.
766777
@@ -857,6 +868,9 @@ def apply_ufunc(
857868
Optional mapping from dimension names to sizes for outputs. Only used
858869
if dask='parallelized' and new dimensions (not found on inputs) appear
859870
on outputs.
871+
meta : optional
872+
Size-0 object representing the type of array wrapped by dask array. Passed on to
873+
``dask.array.blockwise``.
860874
861875
Returns
862876
-------
@@ -990,6 +1004,11 @@ def earth_mover_distance(first_samples,
9901004
func = functools.partial(func, **kwargs)
9911005

9921006
if vectorize:
1007+
if meta is None:
1008+
# set meta=np.ndarray by default for numpy vectorized functions
1009+
# work around dask bug computing meta with vectorized functions: GH5642
1010+
meta = np.ndarray
1011+
9931012
if signature.all_core_dims:
9941013
func = np.vectorize(
9951014
func, otypes=output_dtypes, signature=signature.to_gufunc_string()
@@ -1006,6 +1025,7 @@ def earth_mover_distance(first_samples,
10061025
dask=dask,
10071026
output_dtypes=output_dtypes,
10081027
output_sizes=output_sizes,
1028+
meta=meta,
10091029
)
10101030

10111031
if any(isinstance(a, GroupBy) for a in args):
@@ -1020,6 +1040,7 @@ def earth_mover_distance(first_samples,
10201040
dataset_fill_value=dataset_fill_value,
10211041
keep_attrs=keep_attrs,
10221042
dask=dask,
1043+
meta=meta,
10231044
)
10241045
return apply_groupby_func(this_apply, *args)
10251046
elif any(is_dict_like(a) for a in args):

xarray/tests/test_backends.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from xarray.core import indexing
3838
from xarray.core.options import set_options
3939
from xarray.core.pycompat import dask_array_type
40-
from xarray.tests import mock
40+
from xarray.tests import LooseVersion, mock
4141

4242
from . import (
4343
arm_xfail,
@@ -76,9 +76,14 @@
7676
pass
7777

7878
try:
79+
import dask
7980
import dask.array as da
81+
82+
dask_version = dask.__version__
8083
except ImportError:
81-
pass
84+
# needed for xfailed tests when dask < 2.4.0
85+
# remove when min dask > 2.4.0
86+
dask_version = "10.0"
8287

8388
ON_WINDOWS = sys.platform == "win32"
8489

@@ -1723,6 +1728,7 @@ def test_hidden_zarr_keys(self):
17231728
with xr.decode_cf(store):
17241729
pass
17251730

1731+
@pytest.mark.skipif(LooseVersion(dask_version) < "2.4", reason="dask GH5334")
17261732
def test_write_persistence_modes(self):
17271733
original = create_test_data()
17281734

@@ -1787,6 +1793,7 @@ def test_encoding_kwarg_fixed_width_string(self):
17871793
def test_dataset_caching(self):
17881794
super().test_dataset_caching()
17891795

1796+
@pytest.mark.skipif(LooseVersion(dask_version) < "2.4", reason="dask GH5334")
17901797
def test_append_write(self):
17911798
ds, ds_to_append, _ = create_append_test_data()
17921799
with self.create_zarr_target() as store_target:
@@ -1863,6 +1870,7 @@ def test_check_encoding_is_consistent_after_append(self):
18631870
xr.concat([ds, ds_to_append], dim="time"),
18641871
)
18651872

1873+
@pytest.mark.skipif(LooseVersion(dask_version) < "2.4", reason="dask GH5334")
18661874
def test_append_with_new_variable(self):
18671875

18681876
ds, ds_to_append, ds_with_new_var = create_append_test_data()

xarray/tests/test_computation.py

+18
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,24 @@ def test_vectorize_dask():
817817
assert_identical(expected, actual)
818818

819819

820+
@requires_dask
821+
def test_vectorize_dask_new_output_dims():
822+
# regression test for GH3574
823+
data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y"))
824+
func = lambda x: x[np.newaxis, ...]
825+
expected = data_array.expand_dims("z")
826+
actual = apply_ufunc(
827+
func,
828+
data_array.chunk({"x": 1}),
829+
output_core_dims=[["z"]],
830+
vectorize=True,
831+
dask="parallelized",
832+
output_dtypes=[float],
833+
output_sizes={"z": 1},
834+
).transpose(*expected.dims)
835+
assert_identical(expected, actual)
836+
837+
820838
def test_output_wrong_number():
821839
variable = xr.Variable("x", np.arange(10))
822840

xarray/tests/test_sparse.py

+13
Original file line numberDiff line numberDiff line change
@@ -873,3 +873,16 @@ def test_dask_token():
873873
t5 = dask.base.tokenize(ac + 1)
874874
assert t4 != t5
875875
assert isinstance(ac.data._meta, sparse.COO)
876+
877+
878+
@requires_dask
879+
def test_apply_ufunc_meta_to_blockwise():
880+
da = xr.DataArray(np.zeros((2, 3)), dims=["x", "y"]).chunk({"x": 2, "y": 1})
881+
sparse_meta = sparse.COO.from_numpy(np.zeros((0, 0)))
882+
883+
# if dask computed meta, it would be np.ndarray
884+
expected = xr.apply_ufunc(
885+
lambda x: x, da, dask="parallelized", output_dtypes=[da.dtype], meta=sparse_meta
886+
).data._meta
887+
888+
assert_sparse_equal(expected, sparse_meta)

0 commit comments

Comments
 (0)