Skip to content
forked from pydata/xarray

Commit 6958172

Browse files
committed
Propagate indexes in DataArray binary operations.
Works by propagating indexes in DataArray._replace. xref pydata#2227. Tests pass!
1 parent 53c5199 commit 6958172

File tree

5 files changed

+30
-3
lines changed

5 files changed

+30
-3
lines changed

xarray/core/dataarray.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -386,14 +386,17 @@ def _replace(
386386
variable: Variable = None,
387387
coords=None,
388388
name: Union[Hashable, None, Default] = _default,
389+
indexes=None,
389390
) -> "DataArray":
390391
if variable is None:
391392
variable = self.variable
392393
if coords is None:
393394
coords = self._coords
394395
if name is _default:
395396
name = self.name
396-
return type(self)(variable, coords, name=name, fastpath=True)
397+
# if indexes is None:
398+
# indexes = self.indexes
399+
return type(self)(variable, coords, name=name, fastpath=True, indexes=indexes)
397400

398401
def _replace_maybe_drop_dims(
399402
self, variable: Variable, name: Union[Hashable, None, Default] = _default
@@ -440,7 +443,8 @@ def _from_temp_dataset(
440443
) -> "DataArray":
441444
variable = dataset._variables.pop(_THIS_ARRAY)
442445
coords = dataset._variables
443-
return self._replace(variable, coords, name)
446+
indexes = dataset._indexes
447+
return self._replace(variable, coords, name, indexes=indexes)
444448

445449
def _to_dataset_split(self, dim: Hashable) -> Dataset:
446450
def subset(dim, label):
@@ -2506,7 +2510,7 @@ def func(self, other):
25062510
coords, indexes = self.coords._merge_raw(other_coords)
25072511
name = self._result_name(other)
25082512

2509-
return self._replace(variable, coords, name)
2513+
return self._replace(variable, coords, name, indexes=indexes)
25102514

25112515
return func
25122516

xarray/core/groupby.py

+1
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,7 @@ def _maybe_unstack(self, obj):
529529
for dim in self._inserted_dims:
530530
if dim in obj.coords:
531531
del obj.coords[dim]
532+
del obj.indexes[dim]
532533
return obj
533534

534535
def fillna(self, value):

xarray/core/indexes.py

+3
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ def __contains__(self, key):
3535
def __getitem__(self, key):
3636
return self._indexes[key]
3737

38+
def __delitem__(self, key):
39+
del self._indexes[key]
40+
3841
def __repr__(self):
3942
return formatting.indexes_repr(self)
4043

xarray/tests/test_dataarray.py

+11
Original file line numberDiff line numberDiff line change
@@ -3953,6 +3953,17 @@ def test_matmul(self):
39533953
expected = da.dot(da)
39543954
assert_identical(result, expected)
39553955

3956+
def test_binary_op_propagate_indexes(self):
3957+
# regression test for GH2227
3958+
self.dv["x"] = np.arange(self.dv.sizes["x"])
3959+
expected = self.dv.indexes["x"]
3960+
3961+
actual = (self.dv * 10).indexes["x"]
3962+
assert expected is actual
3963+
3964+
actual = (self.dv > 10).indexes["x"]
3965+
assert expected is actual
3966+
39563967
def test_binary_op_join_setting(self):
39573968
dim = "x"
39583969
align_type = "outer"

xarray/tests/test_dataset.py

+8
Original file line numberDiff line numberDiff line change
@@ -4951,6 +4951,14 @@ def test_filter_by_attrs(self):
49514951
)
49524952
assert not bool(new_ds.data_vars)
49534953

4954+
def test_binary_op_propagate_indexes(self):
4955+
ds = Dataset(
4956+
{"d1": DataArray([1, 2, 3], dims=["x"], coords={"x": [10, 20, 30]})}
4957+
)
4958+
expected = ds.indexes["x"]
4959+
actual = (ds * 2).indexes["x"]
4960+
assert expected is actual
4961+
49544962
def test_binary_op_join_setting(self):
49554963
# arithmetic_join applies to data array coordinates
49564964
missing_2 = xr.Dataset({"x": [0, 1]})

0 commit comments

Comments
 (0)