Skip to content

Commit f3116ac

Browse files
committed
wip: deeper refactoring label-based sel
Created QueryResult and MergedQueryResults classes for convenience.
1 parent 64b71c9 commit f3116ac

File tree

6 files changed

+224
-120
lines changed

6 files changed

+224
-120
lines changed

xarray/core/alignment.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ def _override_indexes(objects, all_indexes, exclude):
6969
dim: all_indexes[dim][0] for dim in obj.xindexes if dim not in exclude
7070
}
7171

72-
# TODO: benbovy - explicit indexes: not refactored yet (dirty fix)
73-
objects[idx + 1] = obj._overwrite_indexes(new_indexes, {}, [])
72+
# TODO: benbovy - explicit indexes: not refactored yet!
73+
objects[idx + 1] = obj._overwrite_indexes(new_indexes)
7474

7575
return objects
7676

xarray/core/coordinates.py

+8-12
Original file line numberDiff line numberDiff line change
@@ -398,9 +398,7 @@ def remap_label_indexers(
398398
method: str = None,
399399
tolerance=None,
400400
**indexers_kwargs: Any,
401-
) -> Tuple[
402-
dict, dict, dict, list
403-
]: # TODO more precise return type after annotations in indexing
401+
) -> Any:
404402
"""Remap indexers from obj.coords.
405403
If indexer is an instance of DataArray and it has coordinate, then this coordinate
406404
will be attached to pos_indexers.
@@ -421,23 +419,21 @@ def remap_label_indexers(
421419
for k, v in indexers.items()
422420
}
423421

424-
(
425-
pos_indexers,
426-
new_indexes,
427-
new_variables,
428-
drop_variables,
429-
) = indexing.remap_label_indexers(
422+
query_results = indexing.remap_label_indexers(
430423
obj, v_indexers, method=method, tolerance=tolerance
431424
)
432425

433426
# attach indexer's coordinate to pos_indexers
434427
for k, v in indexers.items():
428+
dim_indexer = query_results.dim_indexers.get(k, None)
435429
if isinstance(v, Variable):
436-
pos_indexers[k] = Variable(v.dims, pos_indexers[k])
430+
query_results.dim_indexers[k] = Variable(v.dims, dim_indexer)
437431
elif isinstance(v, DataArray):
438432
# drop coordinates found in indexers since .sel() already
439433
# ensures alignments
440434
coords = {k: var for k, var in v._coords.items() if k not in indexers}
441-
pos_indexers[k] = DataArray(pos_indexers[k], coords=coords, dims=v.dims)
435+
query_results.dim_indexers[k] = DataArray(
436+
dim_indexer, coords=coords, dims=v.dims
437+
)
442438

443-
return pos_indexers, new_indexes, new_variables, drop_variables
439+
return query_results

xarray/core/dataarray.py

+9-13
Original file line numberDiff line numberDiff line change
@@ -463,37 +463,33 @@ def _replace_maybe_drop_dims(
463463
def _overwrite_indexes(
464464
self,
465465
indexes: Mapping[Hashable, Index],
466-
coords: Mapping[Hashable, Variable],
467-
drop_coords: List[Hashable],
466+
coords: Optional[Mapping[Hashable, Variable]] = None,
467+
drop_coords: Optional[List[Hashable]] = None,
468+
rename_dims: Optional[Mapping[Hashable, Hashable]] = None,
468469
) -> "DataArray":
469470
"""Maybe replace indexes and their corresponding coordinates."""
470471
if not indexes:
471472
return self
472473

473-
assert indexes.keys() == coords.keys()
474+
if coords is None:
475+
coords = {}
476+
if drop_coords is None:
477+
drop_coords = []
474478

475479
new_variable = self.variable.copy()
476480
new_coords = self._coords.copy()
477481
new_indexes = dict(self.xindexes)
478-
dims_dict = {}
479482

480483
for name in indexes:
481-
# new coordinate variables may have renamed dimensions (e.g., level
482-
# name of a multi-index converted to a single index)
483-
old_vs_new_dims = zip(self._coords[name].dims, coords[name].dims)
484-
for old_dim, new_dim in old_vs_new_dims:
485-
if old_dim != new_dim:
486-
dims_dict[old_dim] = new_dim
487-
488484
new_coords[name] = coords[name]
489485
new_indexes[name] = indexes[name]
490486

491487
for name in drop_coords:
492488
new_coords.pop(name)
493489
new_indexes.pop(name)
494490

495-
if dims_dict:
496-
new_variable.dims = [dims_dict.get(d, d) for d in new_variable.dims]
491+
if rename_dims:
492+
new_variable.dims = [rename_dims.get(d, d) for d in new_variable.dims]
497493

498494
return self._replace(
499495
variable=new_variable, coords=new_coords, indexes=new_indexes

xarray/core/dataset.py

+13-19
Original file line numberDiff line numberDiff line change
@@ -1166,30 +1166,24 @@ def _replace_vars_and_dims(
11661166
def _overwrite_indexes(
11671167
self,
11681168
indexes: Mapping[Hashable, Index],
1169-
variables: Mapping[Hashable, Variable],
1170-
drop_variables: List[Hashable],
1169+
variables: Optional[Mapping[Hashable, Variable]] = None,
1170+
drop_variables: Optional[List[Hashable]] = None,
1171+
rename_dims: Optional[Mapping[Hashable, Hashable]] = None,
11711172
) -> "Dataset":
11721173
"""Maybe replace indexes and their corresponding index variables."""
11731174
if not indexes:
11741175
return self
11751176

1176-
assert indexes.keys() == variables.keys()
1177+
if variables is None:
1178+
variables = {}
1179+
if drop_variables is None:
1180+
drop_variables = []
11771181

11781182
new_variables = self._variables.copy()
11791183
new_coord_names = self._coord_names.copy()
11801184
new_indexes = dict(self.xindexes)
1181-
dims_dict = {}
11821185

11831186
for name in indexes:
1184-
# new coordinate variables may have renamed dimensions (e.g., level
1185-
# name of a multi-index converted to a single index)
1186-
# TODO: instead of infer renamed dimensions from the coordinates,
1187-
# should we require explicitly providing it from Index.query?
1188-
old_vs_new_dims = zip(self._variables[name].dims, variables[name].dims)
1189-
for old_dim, new_dim in old_vs_new_dims:
1190-
if old_dim != new_dim:
1191-
dims_dict[old_dim] = new_dim
1192-
11931187
new_variables[name] = variables[name]
11941188
new_indexes[name] = indexes[name]
11951189

@@ -1202,10 +1196,10 @@ def _overwrite_indexes(
12021196
variables=new_variables, coord_names=new_coord_names, indexes=new_indexes
12031197
)
12041198

1205-
if dims_dict:
1199+
if rename_dims:
12061200
# skip rename indexes: they should already have the right name(s)
1207-
dims = replaced._rename_dims(dims_dict)
1208-
new_variables, new_coord_names = replaced._rename_vars({}, dims_dict)
1201+
dims = replaced._rename_dims(rename_dims)
1202+
new_variables, new_coord_names = replaced._rename_vars({}, rename_dims)
12091203
return replaced._replace(
12101204
variables=new_variables, coord_names=new_coord_names, dims=dims
12111205
)
@@ -2480,12 +2474,12 @@ def sel(
24802474
DataArray.sel
24812475
"""
24822476
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel")
2483-
pos_indexers, new_indexes, new_variables, drop_variables = remap_label_indexers(
2477+
query_results = remap_label_indexers(
24842478
self, indexers=indexers, method=method, tolerance=tolerance
24852479
)
24862480

2487-
result = self.isel(indexers=pos_indexers, drop=drop)
2488-
return result._overwrite_indexes(new_indexes, new_variables, drop_variables)
2481+
result = self.isel(indexers=query_results.dim_indexers, drop=drop)
2482+
return result._overwrite_indexes(*query_results.to_tuple()[1:])
24892483

24902484
def head(
24912485
self,

xarray/core/indexes.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import pandas as pd
1717

1818
from . import formatting, utils
19-
from .indexing import PandasIndexingAdapter, PandasMultiIndexingAdapter
19+
from .indexing import PandasIndexingAdapter, PandasMultiIndexingAdapter, QueryResult
2020
from .utils import is_dict_like, is_scalar
2121

2222
if TYPE_CHECKING:
@@ -47,7 +47,7 @@ def to_pandas_index(self) -> pd.Index:
4747

4848
def query(
4949
self, labels: Dict[Hashable, Any], **kwargs
50-
) -> Tuple[Mapping[str, Any], Optional[IndexWithVars]]: # pragma: no cover
50+
) -> QueryResult: # pragma: no cover
5151
raise NotImplementedError()
5252

5353
def equals(self, other): # pragma: no cover
@@ -243,7 +243,7 @@ def query(self, labels, method=None, tolerance=None):
243243
if np.any(indexer < 0):
244244
raise KeyError(f"not all values found in index {coord_name!r}")
245245

246-
return {self.dim: indexer}, None
246+
return QueryResult({self.dim: indexer})
247247

248248
def equals(self, other):
249249
return self.index.equals(other.index)
@@ -425,13 +425,27 @@ def query(self, labels, method=None, tolerance=None):
425425
new_index, new_vars = PandasMultiIndex.from_pandas_index(
426426
new_index, self.dim
427427
)
428+
dims_dict = {}
429+
drop_coords = set(self.index.names) - set(new_index.index.names)
428430
else:
429431
new_index, new_vars = PandasIndex.from_pandas_index(
430432
new_index, new_index.name
431433
)
432-
return {self.dim: indexer}, (new_index, new_vars)
434+
dims_dict = {self.dim: new_index.index.name}
435+
drop_coords = set(self.index.names) - {new_index.index.name} | {
436+
self.dim
437+
}
438+
439+
return QueryResult(
440+
{self.dim: indexer},
441+
index=new_index,
442+
index_vars=new_vars,
443+
drop_coords=list(drop_coords),
444+
rename_dims=dims_dict,
445+
)
446+
433447
else:
434-
return {self.dim: indexer}, None
448+
return QueryResult({self.dim: indexer})
435449

436450

437451
def remove_unused_levels_categories(index: pd.Index) -> pd.Index:

0 commit comments

Comments
 (0)