Skip to content

Commit 6c6f09e

Browse files
committed
update test_map_index_queries
1 parent e750883 commit 6c6f09e

File tree

1 file changed

+91
-52
lines changed

1 file changed

+91
-52
lines changed

xarray/tests/test_indexing.py

+91-52
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import itertools
2+
from typing import Any, Dict, cast
23

34
import numpy as np
45
import pandas as pd
@@ -7,6 +8,7 @@
78
from xarray import DataArray, Dataset, Variable
89
from xarray.core import indexing, nputils
910
from xarray.core.indexes import PandasIndex, PandasMultiIndex
11+
from xarray.core.types import T_Xarray
1012

1113
from . import IndexerMaker, ReturnItem, assert_array_equal
1214

@@ -86,109 +88,146 @@ def test_group_indexers_by_index(self) -> None:
8688
indexing.group_indexers_by_index(data, {"z": 1}, {"method": "nearest"})
8789

8890
def test_map_index_queries(self) -> None:
91+
def create_query_results(
92+
x_indexer,
93+
x_index,
94+
index_vars,
95+
other_vars,
96+
drop_coords,
97+
drop_indexes,
98+
rename_dims,
99+
):
100+
dim_indexers = {"x": x_indexer}
101+
indexes = {k: x_index for k in index_vars}
102+
variables = {}
103+
variables.update(index_vars)
104+
variables.update(other_vars)
105+
106+
return indexing.QueryResult(
107+
dim_indexers=dim_indexers,
108+
indexes=indexes,
109+
variables=variables,
110+
drop_coords=drop_coords,
111+
drop_indexes=drop_indexes,
112+
rename_dims=rename_dims,
113+
)
114+
89115
def test_indexer(
90-
data,
91-
x,
92-
expected_pos,
93-
expected_idx=None,
94-
expected_vars=None,
95-
expected_drop=None,
96-
expected_rename_dims=None,
116+
data: T_Xarray,
117+
x: Any,
118+
expected: indexing.QueryResult,
97119
) -> None:
98-
if expected_vars is None:
99-
expected_vars = {}
100-
if expected_idx is None:
101-
expected_idx = {}
102-
else:
103-
expected_idx = {k: expected_idx for k in expected_vars}
104-
if expected_drop is None:
105-
expected_drop = []
106-
if expected_rename_dims is None:
107-
expected_rename_dims = {}
108-
109120
results = indexing.map_index_queries(data, {"x": x})
110121

111-
assert_array_equal(results.dim_indexers.get("x"), expected_pos)
122+
assert results.dim_indexers.keys() == expected.dim_indexers.keys()
123+
assert_array_equal(results.dim_indexers["x"], expected.dim_indexers["x"])
112124

113-
assert results.indexes.keys() == expected_idx.keys()
125+
assert results.indexes.keys() == expected.indexes.keys()
114126
for k in results.indexes:
115-
assert results.indexes[k].equals(expected_idx[k])
127+
assert results.indexes[k].equals(expected.indexes[k])
116128

117-
assert results.variables.keys() == expected_vars.keys()
129+
assert results.variables.keys() == expected.variables.keys()
118130
for k in results.variables:
119-
assert_array_equal(results.variables[k], expected_vars[k])
131+
assert_array_equal(results.variables[k], expected.variables[k])
120132

121-
assert set(results.drop_coords) == set(expected_drop)
122-
assert results.rename_dims == expected_rename_dims
133+
assert set(results.drop_coords) == set(expected.drop_coords)
134+
assert set(results.drop_indexes) == set(expected.drop_indexes)
135+
assert results.rename_dims == expected.rename_dims
123136

124137
data = Dataset({"x": ("x", [1, 2, 3])})
125138
mindex = pd.MultiIndex.from_product(
126139
[["a", "b"], [1, 2], [-1, -2]], names=("one", "two", "three")
127140
)
128141
mdata = DataArray(range(8), [("x", mindex)])
129142

130-
test_indexer(data, 1, 0)
131-
test_indexer(data, np.int32(1), 0)
132-
test_indexer(data, Variable([], 1), 0)
133-
test_indexer(mdata, ("a", 1, -1), 0)
134-
test_indexer(
135-
mdata,
136-
("a", 1),
143+
test_indexer(data, 1, indexing.QueryResult({"x": 0}))
144+
test_indexer(data, np.int32(1), indexing.QueryResult({"x": 0}))
145+
test_indexer(data, Variable([], 1), indexing.QueryResult({"x": 0}))
146+
test_indexer(mdata, ("a", 1, -1), indexing.QueryResult({"x": 0}))
147+
148+
expected = create_query_results(
137149
[True, True, False, False, False, False, False, False],
138150
*PandasIndex.from_pandas_index(pd.Index([-1, -2]), "three"),
139-
["x", "one", "two"],
151+
{"one": Variable((), "a"), "two": Variable((), 1)},
152+
["x"],
153+
["one", "two"],
140154
{"x": "three"},
141155
)
142-
test_indexer(
143-
mdata,
144-
"a",
156+
test_indexer(mdata, ("a", 1), expected)
157+
158+
expected = create_query_results(
145159
slice(0, 4, None),
146160
*PandasMultiIndex.from_pandas_index(
147161
pd.MultiIndex.from_product([[1, 2], [-1, -2]], names=("two", "three")),
148162
"x",
149163
),
164+
{"one": Variable((), "a")},
165+
[],
150166
["one"],
167+
{},
151168
)
152-
test_indexer(
153-
mdata,
154-
("a",),
169+
test_indexer(mdata, "a", expected)
170+
171+
expected = create_query_results(
155172
[True, True, True, True, False, False, False, False],
156173
*PandasMultiIndex.from_pandas_index(
157174
pd.MultiIndex.from_product([[1, 2], [-1, -2]], names=("two", "three")),
158175
"x",
159176
),
177+
{"one": Variable((), "a")},
178+
[],
160179
["one"],
180+
{},
181+
)
182+
test_indexer(mdata, ("a",), expected)
183+
184+
test_indexer(
185+
mdata, [("a", 1, -1), ("b", 2, -2)], indexing.QueryResult({"x": [0, 7]})
186+
)
187+
test_indexer(
188+
mdata, slice("a", "b"), indexing.QueryResult({"x": slice(0, 8, None)})
161189
)
162-
test_indexer(mdata, [("a", 1, -1), ("b", 2, -2)], [0, 7])
163-
test_indexer(mdata, slice("a", "b"), slice(0, 8, None))
164-
test_indexer(mdata, slice(("a", 1), ("b", 1)), slice(0, 6, None))
165-
test_indexer(mdata, {"one": "a", "two": 1, "three": -1}, 0)
166190
test_indexer(
167191
mdata,
168-
{"one": "a", "two": 1},
192+
slice(("a", 1), ("b", 1)),
193+
indexing.QueryResult({"x": slice(0, 6, None)}),
194+
)
195+
test_indexer(
196+
mdata, {"one": "a", "two": 1, "three": -1}, indexing.QueryResult({"x": 0})
197+
)
198+
199+
expected = create_query_results(
169200
[True, True, False, False, False, False, False, False],
170201
*PandasIndex.from_pandas_index(pd.Index([-1, -2]), "three"),
171-
["x", "one", "two"],
202+
{"one": Variable((), "a"), "two": Variable((), 1)},
203+
["x"],
204+
["one", "two"],
172205
{"x": "three"},
173206
)
174-
test_indexer(
175-
mdata,
176-
{"one": "a", "three": -1},
207+
test_indexer(mdata, {"one": "a", "two": 1}, expected)
208+
209+
expected = create_query_results(
177210
[True, False, True, False, False, False, False, False],
178211
*PandasIndex.from_pandas_index(pd.Index([1, 2]), "two"),
179-
["x", "one", "three"],
212+
{"one": Variable((), "a"), "three": Variable((), -1)},
213+
["x"],
214+
["one", "three"],
180215
{"x": "two"},
181216
)
182-
test_indexer(
183-
mdata,
184-
{"one": "a"},
217+
test_indexer(mdata, {"one": "a", "three": -1}, expected)
218+
219+
expected = create_query_results(
185220
[True, True, True, True, False, False, False, False],
186221
*PandasMultiIndex.from_pandas_index(
187222
pd.MultiIndex.from_product([[1, 2], [-1, -2]], names=("two", "three")),
188223
"x",
189224
),
225+
{"one": Variable((), "a")},
226+
[],
190227
["one"],
228+
{},
191229
)
230+
test_indexer(mdata, {"one": "a"}, expected)
192231

193232
def test_read_only_view(self) -> None:
194233

0 commit comments

Comments
 (0)