Skip to content

Commit b649846

Browse files
authored
Propagate indexes in DataArray binary operations. (#3481)
* Propagate indexes in DataArray binary operations. Works by propagating indexes in DataArray._replace. xref #2227. Tests pass! * remove commented code. * fix roll
1 parent 46c4931 commit b649846

File tree

6 files changed

+30
-3
lines changed

6 files changed

+30
-3
lines changed

xarray/core/dataarray.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -386,14 +386,15 @@ def _replace(
386386
variable: Variable = None,
387387
coords=None,
388388
name: Union[Hashable, None, Default] = _default,
389+
indexes=None,
389390
) -> "DataArray":
390391
if variable is None:
391392
variable = self.variable
392393
if coords is None:
393394
coords = self._coords
394395
if name is _default:
395396
name = self.name
396-
return type(self)(variable, coords, name=name, fastpath=True)
397+
return type(self)(variable, coords, name=name, fastpath=True, indexes=indexes)
397398

398399
def _replace_maybe_drop_dims(
399400
self, variable: Variable, name: Union[Hashable, None, Default] = _default
@@ -440,7 +441,8 @@ def _from_temp_dataset(
440441
) -> "DataArray":
441442
variable = dataset._variables.pop(_THIS_ARRAY)
442443
coords = dataset._variables
443-
return self._replace(variable, coords, name)
444+
indexes = dataset._indexes
445+
return self._replace(variable, coords, name, indexes=indexes)
444446

445447
def _to_dataset_split(self, dim: Hashable) -> Dataset:
446448
def subset(dim, label):
@@ -2506,7 +2508,7 @@ def func(self, other):
25062508
coords, indexes = self.coords._merge_raw(other_coords)
25072509
name = self._result_name(other)
25082510

2509-
return self._replace(variable, coords, name)
2511+
return self._replace(variable, coords, name, indexes=indexes)
25102512

25112513
return func
25122514

xarray/core/dataset.py

+2
Original file line numberDiff line numberDiff line change
@@ -4891,6 +4891,8 @@ def roll(self, shifts=None, roll_coords=None, **shifts_kwargs):
48914891
(dim,) = self.variables[k].dims
48924892
if dim in shifts:
48934893
indexes[k] = roll_index(v, shifts[dim])
4894+
else:
4895+
indexes[k] = v
48944896
else:
48954897
indexes = dict(self.indexes)
48964898

xarray/core/groupby.py

+1
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,7 @@ def _maybe_unstack(self, obj):
529529
for dim in self._inserted_dims:
530530
if dim in obj.coords:
531531
del obj.coords[dim]
532+
del obj.indexes[dim]
532533
return obj
533534

534535
def fillna(self, value):

xarray/core/indexes.py

+3
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ def __contains__(self, key):
3535
def __getitem__(self, key):
3636
return self._indexes[key]
3737

38+
def __delitem__(self, key):
39+
del self._indexes[key]
40+
3841
def __repr__(self):
3942
return formatting.indexes_repr(self)
4043

xarray/tests/test_dataarray.py

+11
Original file line numberDiff line numberDiff line change
@@ -3953,6 +3953,17 @@ def test_matmul(self):
39533953
expected = da.dot(da)
39543954
assert_identical(result, expected)
39553955

3956+
def test_binary_op_propagate_indexes(self):
3957+
# regression test for GH2227
3958+
self.dv["x"] = np.arange(self.dv.sizes["x"])
3959+
expected = self.dv.indexes["x"]
3960+
3961+
actual = (self.dv * 10).indexes["x"]
3962+
assert expected is actual
3963+
3964+
actual = (self.dv > 10).indexes["x"]
3965+
assert expected is actual
3966+
39563967
def test_binary_op_join_setting(self):
39573968
dim = "x"
39583969
align_type = "outer"

xarray/tests/test_dataset.py

+8
Original file line numberDiff line numberDiff line change
@@ -4951,6 +4951,14 @@ def test_filter_by_attrs(self):
49514951
)
49524952
assert not bool(new_ds.data_vars)
49534953

4954+
def test_binary_op_propagate_indexes(self):
4955+
ds = Dataset(
4956+
{"d1": DataArray([1, 2, 3], dims=["x"], coords={"x": [10, 20, 30]})}
4957+
)
4958+
expected = ds.indexes["x"]
4959+
actual = (ds * 2).indexes["x"]
4960+
assert expected is actual
4961+
49544962
def test_binary_op_join_setting(self):
49554963
# arithmetic_join applies to data array coordinates
49564964
missing_2 = xr.Dataset({"x": [0, 1]})

0 commit comments

Comments
 (0)