Skip to content

Commit a551c7f

Browse files
committed
fix multi-index selection regression
See pydata#5691
1 parent 1bb61d9 commit a551c7f

File tree

2 files changed

+29
-7
lines changed

2 files changed

+29
-7
lines changed

xarray/core/indexes.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,15 @@ def _is_nested_tuple(possible_tuple):
129129
)
130130

131131

132+
def normalize_label(value, extract_scalar=False):
133+
if getattr(value, "ndim", 1) <= 1:
134+
value = _asarray_tuplesafe(value)
135+
if extract_scalar:
136+
# see https://github.com/pydata/xarray/pull/4292 for details
137+
value = value[()] if value.dtype.kind in "mM" else value.item()
138+
return value
139+
140+
132141
def get_indexer_nd(index, labels, method=None, tolerance=None):
133142
"""Wrapper around :meth:`pandas.Index.get_indexer` supporting n-dimensional
134143
labels
@@ -207,14 +216,9 @@ def query(self, labels, method=None, tolerance=None):
207216
"a dimension that does not have a MultiIndex"
208217
)
209218
else:
210-
label = (
211-
label
212-
if getattr(label, "ndim", 1) > 1 # vectorized-indexing
213-
else _asarray_tuplesafe(label)
214-
)
219+
label = normalize_label(label)
215220
if label.ndim == 0:
216-
# see https://github.com/pydata/xarray/pull/4292 for details
217-
label_value = label[()] if label.dtype.kind in "mM" else label.item()
221+
label_value = normalize_label(label, extract_scalar=True)
218222
if isinstance(self.index, pd.CategoricalIndex):
219223
if method is not None:
220224
raise ValueError(
@@ -336,6 +340,10 @@ def query(self, labels, method=None, tolerance=None):
336340
# label(s) given for multi-index level(s)
337341
if all([lbl in self.index.names for lbl in labels]):
338342
is_nested_vals = _is_nested_tuple(tuple(labels.values()))
343+
labels = {
344+
k: normalize_label(v, extract_scalar=True) for k, v in labels.items()
345+
}
346+
339347
if len(labels) == self.index.nlevels and not is_nested_vals:
340348
indexer = self.index.get_loc(tuple(labels[k] for k in self.index.names))
341349
else:

xarray/tests/test_dataarray.py

+14
Original file line numberDiff line numberDiff line change
@@ -1006,6 +1006,20 @@ def test_sel_float(self):
10061006
assert_equal(expected_scalar, actual_scalar)
10071007
assert_equal(expected_16, actual_16)
10081008

1009+
def test_sel_float_multiindex(self):
1010+
# regression test https://github.com/pydata/xarray/issues/5691
1011+
midx = pd.MultiIndex.from_arrays(
1012+
[["a", "a", "b", "b"], [0.1, 0.2, 0.3, 0.4]], names=["lvl1", "lvl2"]
1013+
)
1014+
da = xr.DataArray([1, 2, 3, 4], coords={"x": midx}, dims="x")
1015+
1016+
actual = da.sel(lvl1="a", lvl2=0.1)
1017+
expected = da.isel(x=0)
1018+
1019+
assert_equal(actual, expected)
1020+
1021+
# TODO: test multi-index created from coordinates, one with dtype=float32
1022+
10091023
def test_sel_no_index(self):
10101024
array = DataArray(np.arange(10), dims="x")
10111025
assert_identical(array[0], array.sel(x=0))

0 commit comments

Comments
 (0)