Skip to content

Commit 57cd76d

Browse files
TomNicholasshoyer
authored andcommitted
Bugfix/reduce no axis (#2769)
* New test for reduce func which takes no axes * Fixed axis logic * Recorded fix in what's new * Added intermediate variable
1 parent cd8e370 commit 57cd76d

File tree

3 files changed

+28
-4
lines changed

3 files changed

+28
-4
lines changed

doc/whats-new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ Bug fixes
9494
- Masking data arrays with :py:meth:`xarray.DataArray.where` now returns an
9595
array with the name of the original masked array (:issue:`2748` and :issue:`2457`).
9696
By `Yohai Bar-Sinai <https://github.com/yohai>`_.
97+
- Fixed error when trying to reduce a DataArray using a function which does not
98+
require an axis argument. (:issue:`2768`)
99+
By `Tom Nicholas <http://github.com/TomNicholas>`_.
100+
97101
- Per `CF conventions
98102
<http://cfconventions.org/cf-conventions/cf-conventions.html#calendar>`_,
99103
specifying ``'standard'`` as the calendar type in

xarray/core/variable.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1361,8 +1361,11 @@ def reduce(self, func, dim=None, axis=None,
13611361

13621362
if dim is not None:
13631363
axis = self.get_axis_num(dim)
1364-
data = func(self.data if allow_lazy else self.values,
1365-
axis=axis, **kwargs)
1364+
input_data = self.data if allow_lazy else self.values
1365+
if axis is not None:
1366+
data = func(input_data, axis=axis, **kwargs)
1367+
else:
1368+
data = func(input_data, **kwargs)
13661369

13671370
if getattr(data, 'shape', ()) == self.shape:
13681371
dims = self.dims

xarray/tests/test_dataset.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3633,11 +3633,28 @@ def mean_only_one_axis(x, axis):
36333633
actual = ds.reduce(mean_only_one_axis, 'y')
36343634
assert_identical(expected, actual)
36353635

3636-
with raises_regex(TypeError, 'non-integer axis'):
3636+
with raises_regex(TypeError, "missing 1 required positional argument: "
3637+
"'axis'"):
36373638
ds.reduce(mean_only_one_axis)
36383639

36393640
with raises_regex(TypeError, 'non-integer axis'):
3640-
ds.reduce(mean_only_one_axis, ['x', 'y'])
3641+
ds.reduce(mean_only_one_axis, axis=['x', 'y'])
3642+
3643+
def test_reduce_no_axis(self):
3644+
3645+
def total_sum(x):
3646+
return np.sum(x.flatten())
3647+
3648+
ds = Dataset({'a': (['x', 'y'], [[0, 1, 2, 3, 4]])})
3649+
expected = Dataset({'a': ((), 10)})
3650+
actual = ds.reduce(total_sum)
3651+
assert_identical(expected, actual)
3652+
3653+
with raises_regex(TypeError, "unexpected keyword argument 'axis'"):
3654+
ds.reduce(total_sum, axis=0)
3655+
3656+
with raises_regex(TypeError, "unexpected keyword argument 'axis'"):
3657+
ds.reduce(total_sum, dim='x')
36413658

36423659
def test_quantile(self):
36433660

0 commit comments

Comments
 (0)