Skip to content

Commit 41b1b8c

Browse files
authored
Allow rank to run on dask arrays (#8475)
1 parent cb14f2f commit 41b1b8c

File tree

2 files changed

+28
-19
lines changed

2 files changed

+28
-19
lines changed

xarray/core/variable.py

+12-15
Original file line numberDiff line numberDiff line change
@@ -2063,6 +2063,7 @@ def rank(self, dim, pct=False):
20632063
--------
20642064
Dataset.rank, DataArray.rank
20652065
"""
2066+
# This could / should arguably be implemented at the DataArray & Dataset level
20662067
if not OPTIONS["use_bottleneck"]:
20672068
raise RuntimeError(
20682069
"rank requires bottleneck to be enabled."
@@ -2071,24 +2072,20 @@ def rank(self, dim, pct=False):
20712072

20722073
import bottleneck as bn
20732074

2074-
data = self.data
2075-
2076-
if is_duck_dask_array(data):
2077-
raise TypeError(
2078-
"rank does not work for arrays stored as dask "
2079-
"arrays. Load the data via .compute() or .load() "
2080-
"prior to calling this method."
2081-
)
2082-
elif not isinstance(data, np.ndarray):
2083-
raise TypeError(f"rank is not implemented for {type(data)} objects.")
2084-
2085-
axis = self.get_axis_num(dim)
20862075
func = bn.nanrankdata if self.dtype.kind == "f" else bn.rankdata
2087-
ranked = func(data, axis=axis)
2076+
ranked = xr.apply_ufunc(
2077+
func,
2078+
self,
2079+
input_core_dims=[[dim]],
2080+
output_core_dims=[[dim]],
2081+
dask="parallelized",
2082+
kwargs=dict(axis=-1),
2083+
).transpose(*self.dims)
2084+
20882085
if pct:
2089-
count = np.sum(~np.isnan(data), axis=axis, keepdims=True)
2086+
count = self.notnull().sum(dim)
20902087
ranked /= count
2091-
return Variable(self.dims, ranked)
2088+
return ranked
20922089

20932090
def rolling_window(
20942091
self, dim, window, window_dim, center=False, fill_value=dtypes.NA

xarray/tests/test_variable.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -1878,9 +1878,20 @@ def test_quantile_out_of_bounds(self, q):
18781878

18791879
@requires_dask
18801880
@requires_bottleneck
1881-
def test_rank_dask_raises(self):
1882-
v = Variable(["x"], [3.0, 1.0, np.nan, 2.0, 4.0]).chunk(2)
1883-
with pytest.raises(TypeError, match=r"arrays stored as dask"):
1881+
def test_rank_dask(self):
1882+
# Instead of a single test here, we could parameterize the other tests for both
1883+
# arrays. But this is sufficient.
1884+
v = Variable(
1885+
["x", "y"], [[30.0, 1.0, np.nan, 20.0, 4.0], [30.0, 1.0, np.nan, 20.0, 4.0]]
1886+
).chunk(x=1)
1887+
expected = Variable(
1888+
["x", "y"], [[4.0, 1.0, np.nan, 3.0, 2.0], [4.0, 1.0, np.nan, 3.0, 2.0]]
1889+
)
1890+
assert_equal(v.rank("y").compute(), expected)
1891+
1892+
with pytest.raises(
1893+
ValueError, match=r" with dask='parallelized' consists of multiple chunks"
1894+
):
18841895
v.rank("x")
18851896

18861897
def test_rank_use_bottleneck(self):
@@ -1912,7 +1923,8 @@ def test_rank(self):
19121923
v_expect = Variable(["x"], [0.75, 0.25, np.nan, 0.5, 1.0])
19131924
assert_equal(v.rank("x", pct=True), v_expect)
19141925
# invalid dim
1915-
with pytest.raises(ValueError, match=r"not found"):
1926+
with pytest.raises(ValueError):
1927+
# apply_ufunc error message isn't great here — `ValueError: tuple.index(x): x not in tuple`
19161928
v.rank("y")
19171929

19181930
def test_big_endian_reduce(self):

0 commit comments

Comments
 (0)