Skip to content

Commit 9d1499e

Browse files
authored
misc. fixes for Indexes with pd.Index objects (#7003)
1 parent 1f4be33 commit 9d1499e

File tree

2 files changed

+44
-9
lines changed

2 files changed

+44
-9
lines changed

xarray/core/indexes.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -1092,12 +1092,13 @@ def get_unique(self) -> list[T_PandasOrXarrayIndex]:
10921092
"""Return a list of unique indexes, preserving order."""
10931093

10941094
unique_indexes: list[T_PandasOrXarrayIndex] = []
1095-
seen: set[T_PandasOrXarrayIndex] = set()
1095+
seen: set[int] = set()
10961096

10971097
for index in self._indexes.values():
1098-
if index not in seen:
1098+
index_id = id(index)
1099+
if index_id not in seen:
10991100
unique_indexes.append(index)
1100-
seen.add(index)
1101+
seen.add(index_id)
11011102

11021103
return unique_indexes
11031104

@@ -1201,9 +1202,24 @@ def copy_indexes(
12011202
"""
12021203
new_indexes = {}
12031204
new_index_vars = {}
1205+
12041206
for idx, coords in self.group_by_index():
1207+
if isinstance(idx, pd.Index):
1208+
convert_new_idx = True
1209+
dim = next(iter(coords.values())).dims[0]
1210+
if isinstance(idx, pd.MultiIndex):
1211+
idx = PandasMultiIndex(idx, dim)
1212+
else:
1213+
idx = PandasIndex(idx, dim)
1214+
else:
1215+
convert_new_idx = False
1216+
12051217
new_idx = idx.copy(deep=deep)
12061218
idx_vars = idx.create_variables(coords)
1219+
1220+
if convert_new_idx:
1221+
new_idx = cast(PandasIndex, new_idx).index
1222+
12071223
new_indexes.update({k: new_idx for k in coords})
12081224
new_index_vars.update(idx_vars)
12091225

xarray/tests/test_indexes.py

+25-6
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import xarray as xr
1111
from xarray.core.indexes import (
12+
Hashable,
1213
Index,
1314
Indexes,
1415
PandasIndex,
@@ -535,18 +536,37 @@ def test_copy(self) -> None:
535536

536537
class TestIndexes:
537538
@pytest.fixture
538-
def unique_indexes(self) -> list[PandasIndex]:
539+
def indexes_and_vars(self) -> tuple[list[PandasIndex], dict[Hashable, Variable]]:
539540
x_idx = PandasIndex(pd.Index([1, 2, 3], name="x"), "x")
540541
y_idx = PandasIndex(pd.Index([4, 5, 6], name="y"), "y")
541542
z_pd_midx = pd.MultiIndex.from_product(
542543
[["a", "b"], [1, 2]], names=["one", "two"]
543544
)
544545
z_midx = PandasMultiIndex(z_pd_midx, "z")
545546

546-
return [x_idx, y_idx, z_midx]
547+
indexes = [x_idx, y_idx, z_midx]
548+
549+
variables = {}
550+
for idx in indexes:
551+
variables.update(idx.create_variables())
552+
553+
return indexes, variables
554+
555+
@pytest.fixture(params=["pd_index", "xr_index"])
556+
def unique_indexes(
557+
self, request, indexes_and_vars
558+
) -> list[PandasIndex] | list[pd.Index]:
559+
xr_indexes, _ = indexes_and_vars
560+
561+
if request.param == "pd_index":
562+
return [idx.index for idx in xr_indexes]
563+
else:
564+
return xr_indexes
547565

548566
@pytest.fixture
549-
def indexes(self, unique_indexes) -> Indexes[Index]:
567+
def indexes(
568+
self, unique_indexes, indexes_and_vars
569+
) -> Indexes[Index] | Indexes[pd.Index]:
550570
x_idx, y_idx, z_midx = unique_indexes
551571
indexes: dict[Any, Index] = {
552572
"x": x_idx,
@@ -555,9 +575,8 @@ def indexes(self, unique_indexes) -> Indexes[Index]:
555575
"one": z_midx,
556576
"two": z_midx,
557577
}
558-
variables: dict[Any, Variable] = {}
559-
for idx in unique_indexes:
560-
variables.update(idx.create_variables())
578+
579+
_, variables = indexes_and_vars
561580

562581
return Indexes(indexes, variables)
563582

0 commit comments

Comments
 (0)