diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 38b4540cc4e..562d30dd6c7 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -1,11 +1,22 @@ -import collections.abc from collections import OrderedDict from contextlib import contextmanager -from typing import Any, Hashable, Mapping, Iterator, Union, TYPE_CHECKING +from typing import ( + TYPE_CHECKING, + Any, + Hashable, + Mapping, + Iterator, + Union, + Set, + Tuple, + Sequence, + cast, +) import pandas as pd from . import formatting, indexing +from .indexes import Indexes from .merge import ( expand_and_merge_variables, merge_coords, @@ -23,49 +34,58 @@ _THIS_ARRAY = ReprObject("") -class AbstractCoordinates(collections.abc.Mapping): - def __getitem__(self, key): - raise NotImplementedError +class AbstractCoordinates(Mapping[Hashable, "DataArray"]): + _data = None # type: Union["DataArray", "Dataset"] - def __setitem__(self, key, value): + def __getitem__(self, key: Hashable) -> "DataArray": + raise NotImplementedError() + + def __setitem__(self, key: Hashable, value: Any) -> None: self.update({key: value}) @property - def indexes(self): + def _names(self) -> Set[Hashable]: + raise NotImplementedError() + + @property + def dims(self) -> Union[Mapping[Hashable, int], Tuple[Hashable, ...]]: + raise NotImplementedError() + + @property + def indexes(self) -> Indexes: return self._data.indexes @property def variables(self): - raise NotImplementedError + raise NotImplementedError() def _update_coords(self, coords): - raise NotImplementedError + raise NotImplementedError() - def __iter__(self): + def __iter__(self) -> Iterator["Hashable"]: # needs to be in the same order as the dataset variables for k in self.variables: if k in self._names: yield k - def __len__(self): + def __len__(self) -> int: return len(self._names) - def __contains__(self, key): + def __contains__(self, key: Hashable) -> bool: return key in self._names - def __repr__(self): + def __repr__(self) -> str: return formatting.coords_repr(self) - @property - def dims(self): - return self._data.dims + def to_dataset(self) -> "Dataset": + raise NotImplementedError() - def to_index(self, ordered_dims=None): + def to_index(self, ordered_dims: Sequence[Hashable] = None) -> pd.Index: """Convert all index coordinates into a :py:class:`pandas.Index`. Parameters ---------- - ordered_dims : sequence, optional + ordered_dims : sequence of hashable, optional Possibly reordered version of this object's dimensions indicating the order in which dimensions should appear on the result. @@ -77,7 +97,7 @@ def to_index(self, ordered_dims=None): than more dimension. """ if ordered_dims is None: - ordered_dims = self.dims + ordered_dims = list(self.dims) elif set(ordered_dims) != set(self.dims): raise ValueError( "ordered_dims must match dims, but does not: " @@ -94,7 +114,7 @@ def to_index(self, ordered_dims=None): names = list(ordered_dims) return pd.MultiIndex.from_product(indexes, names=names) - def update(self, other): + def update(self, other: Mapping[Hashable, Any]) -> None: other_vars = getattr(other, "variables", other) coords = merge_coords( [self.variables, other_vars], priority_arg=1, indexes=self.indexes @@ -127,7 +147,7 @@ def _merge_inplace(self, other): yield self._update_coords(variables) - def merge(self, other): + def merge(self, other: "AbstractCoordinates") -> "Dataset": """Merge two sets of coordinates to create a new Dataset The method implements the logic used for joining coordinates in the @@ -167,32 +187,38 @@ class DatasetCoordinates(AbstractCoordinates): objects. """ - def __init__(self, dataset): + _data = None # type: Dataset + + def __init__(self, dataset: "Dataset"): self._data = dataset @property - def _names(self): + def _names(self) -> Set[Hashable]: return self._data._coord_names @property - def variables(self): + def dims(self) -> Mapping[Hashable, int]: + return self._data.dims + + @property + def variables(self) -> Mapping[Hashable, Variable]: return Frozen( OrderedDict( (k, v) for k, v in self._data.variables.items() if k in self._names ) ) - def __getitem__(self, key): + def __getitem__(self, key: Hashable) -> "DataArray": if key in self._data.data_vars: raise KeyError(key) - return self._data[key] + return cast("DataArray", self._data[key]) - def to_dataset(self): + def to_dataset(self) -> "Dataset": """Convert these coordinates into a new Dataset """ return self._data._copy_listed(self._names) - def _update_coords(self, coords): + def _update_coords(self, coords: Mapping[Hashable, Any]) -> None: from .dataset import calculate_dimensions variables = self._data._variables.copy() @@ -210,7 +236,7 @@ def _update_coords(self, coords): self._data._dims = dims self._data._indexes = None - def __delitem__(self, key): + def __delitem__(self, key: Hashable) -> None: if key in self: del self._data[key] else: @@ -232,17 +258,23 @@ class DataArrayCoordinates(AbstractCoordinates): dimensions and the values given by corresponding DataArray objects. """ - def __init__(self, dataarray): + _data = None # type: DataArray + + def __init__(self, dataarray: "DataArray"): self._data = dataarray @property - def _names(self): + def dims(self) -> Tuple[Hashable, ...]: + return self._data.dims + + @property + def _names(self) -> Set[Hashable]: return set(self._data._coords) - def __getitem__(self, key): + def __getitem__(self, key: Hashable) -> "DataArray": return self._data._getitem_coord(key) - def _update_coords(self, coords): + def _update_coords(self, coords) -> None: from .dataset import calculate_dimensions coords_plus_data = coords.copy() @@ -259,19 +291,15 @@ def _update_coords(self, coords): def variables(self): return Frozen(self._data._coords) - def _to_dataset(self, shallow_copy=True): + def to_dataset(self) -> "Dataset": from .dataset import Dataset coords = OrderedDict( - (k, v.copy(deep=False) if shallow_copy else v) - for k, v in self._data._coords.items() + (k, v.copy(deep=False)) for k, v in self._data._coords.items() ) return Dataset._from_vars_and_coord_names(coords, set(coords)) - def to_dataset(self): - return self._to_dataset() - - def __delitem__(self, key): + def __delitem__(self, key: Hashable) -> None: del self._data._coords[key] def _ipython_key_completions_(self): @@ -300,9 +328,10 @@ def __len__(self) -> int: return len(self._data._level_coords) -def assert_coordinate_consistent(obj, coords): - """ Maeke sure the dimension coordinate of obj is - consistent with coords. +def assert_coordinate_consistent( + obj: Union["DataArray", "Dataset"], coords: Mapping[Hashable, Variable] +) -> None: + """Make sure the dimension coordinate of obj is consistent with coords. obj: DataArray or Dataset coords: Dict-like of variables @@ -320,17 +349,20 @@ def assert_coordinate_consistent(obj, coords): def remap_label_indexers( - obj, indexers=None, method=None, tolerance=None, **indexers_kwargs -): - """ - Remap **indexers from obj.coords. - If indexer is an instance of DataArray and it has coordinate, then this - coordinate will be attached to pos_indexers. + obj: Union["DataArray", "Dataset"], + indexers: Mapping[Hashable, Any] = None, + method: str = None, + tolerance=None, + **indexers_kwargs: Any +) -> Tuple[dict, dict]: # TODO more precise return type after annotations in indexing + """Remap indexers from obj.coords. + If indexer is an instance of DataArray and it has coordinate, then this coordinate + will be attached to pos_indexers. Returns ------- pos_indexers: Same type of indexers. - np.ndarray or Variable or DataArra + np.ndarray or Variable or DataArray new_indexes: mapping of new dimensional-coordinate. """ from .dataarray import DataArray diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 33be8d96e91..72e000ec609 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -175,7 +175,7 @@ def __setitem__(self, key, value) -> None: labels = indexing.expanded_indexer(key, self.data_array.ndim) key = dict(zip(self.data_array.dims, labels)) - pos_indexers, _ = remap_label_indexers(self.data_array, **key) + pos_indexers, _ = remap_label_indexers(self.data_array, key) self.data_array[pos_indexers] = value diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 618d70e06e9..4250bf9564e 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -335,7 +335,7 @@ def as_dataset(obj: Any) -> "Dataset": return obj -class DataVariables(Mapping[Hashable, "Union[DataArray, Dataset]"]): +class DataVariables(Mapping[Hashable, "DataArray"]): def __init__(self, dataset: "Dataset"): self._dataset = dataset @@ -349,14 +349,13 @@ def __iter__(self) -> Iterator[Hashable]: def __len__(self) -> int: return len(self._dataset._variables) - len(self._dataset._coord_names) - def __contains__(self, key) -> bool: + def __contains__(self, key: Hashable) -> bool: return key in self._dataset._variables and key not in self._dataset._coord_names - def __getitem__(self, key) -> "Union[DataArray, Dataset]": + def __getitem__(self, key: Hashable) -> "DataArray": if key not in self._dataset._coord_names: - return self._dataset[key] - else: - raise KeyError(key) + return cast("DataArray", self._dataset[key]) + raise KeyError(key) def __repr__(self) -> str: return formatting.data_vars_repr(self) @@ -1317,7 +1316,7 @@ def identical(self, other: "Dataset") -> bool: return False @property - def indexes(self) -> "Mapping[Any, pd.Index]": + def indexes(self) -> Indexes: """Mapping of pandas.Index objects used for label based indexing """ if self._indexes is None: