Skip to content

Commit d27bff4

Browse files
committed
Add suggested changes to make apply_ufunc keep attrs
1 parent 8974453 commit d27bff4

File tree

2 files changed

+6
-10
lines changed

2 files changed

+6
-10
lines changed

xarray/core/computation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@ def apply_dict_of_variables_vfunc(
433433
join="inner",
434434
fill_value=None,
435435
on_missing_core_dim: MissingCoreDimOptions = "raise",
436+
keep_attrs="override",
436437
):
437438
"""Apply a variable level function over dicts of DataArray, DataArray,
438439
Variable and ndarray objects.
@@ -445,7 +446,7 @@ def apply_dict_of_variables_vfunc(
445446
for name, variable_args in zip(names, grouped_by_name):
446447
core_dim_present = _check_core_dims(signature, variable_args, name)
447448
if core_dim_present is True:
448-
result_vars[name] = func(*variable_args)
449+
result_vars[name] = func(*variable_args, keep_attrs=keep_attrs)
449450
else:
450451
if on_missing_core_dim == "raise":
451452
raise ValueError(core_dim_present)
@@ -522,6 +523,7 @@ def apply_dataset_vfunc(
522523
join=dataset_join,
523524
fill_value=fill_value,
524525
on_missing_core_dim=on_missing_core_dim,
526+
keep_attrs=keep_attrs,
525527
)
526528

527529
out: Dataset | tuple[Dataset, ...]

xarray/tests/test_rolling.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -772,18 +772,13 @@ def test_rolling_exp_keep_attrs(self, ds) -> None:
772772
# discard attrs
773773
result = ds.rolling_exp(time=10).mean(keep_attrs=False)
774774
assert result.attrs == {}
775-
# TODO: from #8114 — this arguably should be empty, but `apply_ufunc` doesn't do
776-
# that at the moment. We should change in `apply_func` rather than
777-
# special-case it here.
778-
#
779-
# assert result.z1.attrs == {}
775+
assert result.z1.attrs == {}
780776

781777
# test discard attrs using global option
782778
with set_options(keep_attrs=False):
783779
result = ds.rolling_exp(time=10).mean()
784780
assert result.attrs == {}
785-
# See above
786-
# assert result.z1.attrs == {}
781+
assert result.z1.attrs == {}
787782

788783
# keyword takes precedence over global option
789784
with set_options(keep_attrs=False):
@@ -794,8 +789,7 @@ def test_rolling_exp_keep_attrs(self, ds) -> None:
794789
with set_options(keep_attrs=True):
795790
result = ds.rolling_exp(time=10).mean(keep_attrs=False)
796791
assert result.attrs == {}
797-
# See above
798-
# assert result.z1.attrs == {}
792+
assert result.z1.attrs == {}
799793

800794
with pytest.warns(
801795
UserWarning,

0 commit comments

Comments
 (0)