Skip to content

Commit 3ca6a02

Browse files
committed
fix some tests + minor tweaks
1 parent f3116ac commit 3ca6a02

File tree

4 files changed

+59
-47
lines changed

4 files changed

+59
-47
lines changed

xarray/core/indexing.py

+23-18
Original file line numberDiff line numberDiff line change
@@ -127,17 +127,18 @@ def __init__(self, query_results: List[QueryResult]):
127127
+ "Suggestion: use a multi-index for each of those dimension(s)."
128128
)
129129

130-
self.dim_indexers = {
131-
k: v for res in query_results for k, v in res.dim_indexers.items()
132-
}
133-
self.indexes = {k: v for res in query_results for k, v in res.indexes.items()}
134-
self.index_vars = {
135-
k: v for res in query_results for k, v in res.index_vars.items()
136-
}
137-
self.drop_coords = [c for res in query_results for c in res.drop_coords]
138-
self.rename_dims = {
139-
k: v for res in query_results for k, v in res.rename_dims.items()
140-
}
130+
self.dim_indexers = {}
131+
self.indexes = {}
132+
self.index_vars = {}
133+
self.drop_coords = []
134+
self.rename_dims = {}
135+
136+
for res in query_results:
137+
self.dim_indexers.update(res.dim_indexers)
138+
self.indexes.update(res.indexes)
139+
self.index_vars.update(res.index_vars)
140+
self.drop_coords += res.drop_coords
141+
self.rename_dims.update(res.rename_dims)
141142

142143
def to_tuple(self):
143144
return (
@@ -191,15 +192,14 @@ def group_indexers_by_index(
191192
unique_indexes[index_id] = index
192193
label = maybe_cast_to_coords_dtype(label, coord.dtype) # type: ignore
193194
grouped_indexers[index_id][key] = label
194-
elif coord is not None:
195+
elif key in obj.coords:
195196
raise KeyError(f"no index found for coordinate {key}")
196197
elif key not in obj.dims:
197198
raise KeyError(f"{key} is not a valid dimension or coordinate")
198199
elif len(query_kwargs):
199200
raise ValueError(
200-
"cannot supply selection options "
201-
"when the indexed dimension does not have "
202-
"an associated coordinate."
201+
f"cannot supply selection options {query_kwargs!r} for dimension {key!r}"
202+
"that has no asssociated coordinate or index"
203203
)
204204
else:
205205
# key is a dimension without coordinate
@@ -212,14 +212,19 @@ def group_indexers_by_index(
212212
def remap_label_indexers(
213213
obj: Union["DataArray", "Dataset"],
214214
indexers: Mapping[Hashable, Any],
215-
**query_kwargs,
216-
# query_kwargs: Mapping[str, Any],
217-
# **indexers_kwargs,
215+
method=None,
216+
tolerance=None,
218217
) -> MergedQueryResults:
219218
"""Execute index queries from a DataArray / Dataset and label-based indexers
220219
and return the (merged) query results.
221220
222221
"""
222+
# TODO benbovy - flexible indexes: remove when custom index options are available
223+
if method is None and tolerance is None:
224+
query_kwargs = {}
225+
else:
226+
query_kwargs = {"method": method, "tolerance": tolerance}
227+
223228
# indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "map_index_queries")
224229
indexes, grouped_indexers = group_indexers_by_index(obj, indexers, query_kwargs)
225230

xarray/tests/test_dataset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1680,7 +1680,7 @@ def test_sel_method(self):
16801680

16811681
with pytest.raises(TypeError, match=r"``method``"):
16821682
# this should not pass silently
1683-
data.sel(method=data)
1683+
data.sel(dim2=1, method=data)
16841684

16851685
# cannot pass method if there is no associated coordinate
16861686
with pytest.raises(ValueError, match=r"cannot supply"):

xarray/tests/test_indexes.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,11 @@ def test_query_datetime(self):
8686
pd.to_datetime(["2000-01-01", "2001-01-01", "2002-01-01"]), "x"
8787
)
8888
actual = index.query({"x": "2001-01-01"})
89-
expected = ({"x": 1}, None)
90-
assert actual == expected
89+
expected_dim_indexers = {"x": 1}
90+
assert actual.dim_indexers == expected_dim_indexers
9191

9292
actual = index.query({"x": index.to_pandas_index().to_numpy()[1]})
93-
assert actual == expected
93+
assert actual.dim_indexers == expected_dim_indexers
9494

9595
def test_query_unsorted_datetime_index_raises(self):
9696
index = PandasIndex(pd.to_datetime(["2001", "2000", "2002"]), "x")
@@ -191,11 +191,11 @@ def test_query(self):
191191
index = PandasMultiIndex(
192192
pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")), "x"
193193
)
194+
194195
# test tuples inside slice are considered as scalar indexer values
195-
assert index.query({"x": slice(("a", 1), ("b", 2))}) == (
196-
{"x": slice(0, 4)},
197-
None,
198-
)
196+
actual = index.query({"x": slice(("a", 1), ("b", 2))})
197+
expected_dim_indexers = {"x": slice(0, 4)}
198+
assert actual.dim_indexers == expected_dim_indexers
199199

200200
with pytest.raises(KeyError, match=r"not all values found"):
201201
index.query({"x": [0]})

xarray/tests/test_indexing.py

+28-21
Original file line numberDiff line numberDiff line change
@@ -64,21 +64,23 @@ def test_group_indexers_by_index(self):
6464
data.coords["y2"] = ("y", [2.0, 3.0])
6565

6666
indexes, grouped_indexers = indexing.group_indexers_by_index(
67-
data, {"z": 0, "one": "a", "two": 1, "y": 0}
67+
data, {"z": 0, "one": "a", "two": 1, "y": 0}, {}
6868
)
69-
assert indexes == {"x": data.xindexes["x"], "y": data.xindexes["y"]}
70-
assert grouped_indexers == {
71-
"x": {"one": "a", "two": 1},
72-
"y": {"y": 0},
73-
None: {"z": 0},
74-
}
69+
for k in indexes:
70+
if indexes[k].equals(data.xindexes["x"]):
71+
assert grouped_indexers[k] == {"one": "a", "two": 1}
72+
elif indexes[k].equals(data.xindexes["y"]):
73+
assert grouped_indexers[k] == {"y": 0}
74+
assert grouped_indexers[None] == {"z": 0}
75+
grouped_indexers.pop(None)
76+
assert indexes.keys() == grouped_indexers.keys()
7577

7678
with pytest.raises(KeyError, match=r"no index found for coordinate y2"):
77-
indexing.group_indexers_by_index(data, {"y2": 2.0})
79+
indexing.group_indexers_by_index(data, {"y2": 2.0}, {})
7880
with pytest.raises(KeyError, match=r"w is not a valid dimension or coordinate"):
79-
indexing.group_indexers_by_index(data, {"w": "a"})
81+
indexing.group_indexers_by_index(data, {"w": "a"}, {})
8082
with pytest.raises(ValueError, match=r"cannot supply.*"):
81-
indexing.group_indexers_by_index(data, {"z": 1}, method="nearest")
83+
indexing.group_indexers_by_index(data, {"z": 1}, {"method": "nearest"})
8284

8385
def test_remap_label_indexers(self):
8486
def test_indexer(
@@ -88,6 +90,7 @@ def test_indexer(
8890
expected_idx=None,
8991
expected_vars=None,
9092
expected_drop=None,
93+
expected_rename_dims=None,
9194
):
9295
if expected_vars is None:
9396
expected_vars = {}
@@ -97,22 +100,23 @@ def test_indexer(
97100
expected_idx = {k: expected_idx for k in expected_vars}
98101
if expected_drop is None:
99102
expected_drop = []
103+
if expected_rename_dims is None:
104+
expected_rename_dims = {}
100105

101-
pos, new_idx, new_vars, drop_vars = indexing.remap_label_indexers(
102-
data, {"x": x}
103-
)
106+
results = indexing.remap_label_indexers(data, {"x": x})
104107

105-
assert_array_equal(pos.get("x"), expected_pos)
108+
assert_array_equal(results.dim_indexers.get("x"), expected_pos)
106109

107-
assert new_idx.keys() == expected_idx.keys()
108-
for k in new_idx:
109-
assert new_idx[k].equals(expected_idx[k])
110+
assert results.indexes.keys() == expected_idx.keys()
111+
for k in results.indexes:
112+
assert results.indexes[k].equals(expected_idx[k])
110113

111-
assert new_vars.keys() == expected_vars.keys()
112-
for k in new_vars:
113-
assert_array_equal(new_vars[k], expected_vars[k])
114+
assert results.index_vars.keys() == expected_vars.keys()
115+
for k in results.index_vars:
116+
assert_array_equal(results.index_vars[k], expected_vars[k])
114117

115-
assert drop_vars == expected_drop
118+
assert set(results.drop_coords) == set(expected_drop)
119+
assert results.rename_dims == expected_rename_dims
116120

117121
data = Dataset({"x": ("x", [1, 2, 3])})
118122
mindex = pd.MultiIndex.from_product(
@@ -130,6 +134,7 @@ def test_indexer(
130134
[True, True, False, False, False, False, False, False],
131135
*PandasIndex.from_pandas_index(pd.Index([-1, -2]), "three"),
132136
["x", "one", "two"],
137+
{"x": "three"},
133138
)
134139
test_indexer(
135140
mdata,
@@ -161,13 +166,15 @@ def test_indexer(
161166
[True, True, False, False, False, False, False, False],
162167
*PandasIndex.from_pandas_index(pd.Index([-1, -2]), "three"),
163168
["x", "one", "two"],
169+
{"x": "three"},
164170
)
165171
test_indexer(
166172
mdata,
167173
{"one": "a", "three": -1},
168174
[True, False, True, False, False, False, False, False],
169175
*PandasIndex.from_pandas_index(pd.Index([1, 2]), "two"),
170176
["x", "one", "three"],
177+
{"x": "two"},
171178
)
172179
test_indexer(
173180
mdata,

0 commit comments

Comments
 (0)