Skip to content

Annotations for .data_vars() and .coords() #3207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 82 additions & 50 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
import collections.abc
from collections import OrderedDict
from contextlib import contextmanager
from typing import Any, Hashable, Mapping, Iterator, Union, TYPE_CHECKING
from typing import (
TYPE_CHECKING,
Any,
Hashable,
Mapping,
Iterator,
Union,
Set,
Tuple,
Sequence,
cast,
)

import pandas as pd

from . import formatting, indexing
from .indexes import Indexes
from .merge import (
expand_and_merge_variables,
merge_coords,
Expand All @@ -23,49 +34,58 @@
_THIS_ARRAY = ReprObject("<this-array>")


class AbstractCoordinates(collections.abc.Mapping):
def __getitem__(self, key):
raise NotImplementedError
class AbstractCoordinates(Mapping[Hashable, "DataArray"]):
_data = None # type: Union["DataArray", "Dataset"]

def __setitem__(self, key, value):
def __getitem__(self, key: Hashable) -> "DataArray":
raise NotImplementedError()

def __setitem__(self, key: Hashable, value: Any) -> None:
self.update({key: value})

@property
def indexes(self):
def _names(self) -> Set[Hashable]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this new? Should this be on the abstract class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a property invoked by the methods of the abstract class, defined in both the implementations.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, I'm not so sure - do all concrete classes need to define this? Generally we wouldn't define private methods on an abstract class (even if all currently concrete classes implement it), because the implementation would be up to the concrete class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I basically did was turn AbstractCoordinates from a mixin (that is, a class that invokes own methods that are only defined in a subclass) into a C++-style abstract class.
mypy, justifiably, doesn't like mixins - short of hacking your way through with # type tags.

This, in synthesis, is how it was before:

class C:
    def g(self) -> int:
        return self.f()


class D(C):
    def f(self) -> int:
        return 123

mypy correctly states:
3: error: "C" has no attribute "f"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right - I agree this should either be an abstract class without implementations or a mixin with implementations - not a bit of each. It's a bit odd we call this class Abstract but then enforce a _names property.

That a nice upside of mypy - we get to find those things out

I agree we're in a reasonable local maximum here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thanks for fixing this up!

raise NotImplementedError()

@property
def dims(self) -> Union[Mapping[Hashable, int], Tuple[Hashable, ...]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the change here from return self._data.dims?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because DataArray.dims is a Tuple[Hashable, ...] whereas Dataset.dims is a Mapping[Hashable, int]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See below for the two implementations, which are identical in code but different in signature.
The method needs stay in the abstract class because other methods of the abstract class invoke it; this signature with NotImplementedError() guarantees that the methods of the abstract class can only use dims as the common denominator of the two output types, which is a Collection[Hashable].

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK great

raise NotImplementedError()

@property
def indexes(self) -> Indexes:
return self._data.indexes

@property
def variables(self):
raise NotImplementedError
raise NotImplementedError()

def _update_coords(self, coords):
raise NotImplementedError
raise NotImplementedError()

def __iter__(self):
def __iter__(self) -> Iterator["Hashable"]:
# needs to be in the same order as the dataset variables
for k in self.variables:
if k in self._names:
yield k

def __len__(self):
def __len__(self) -> int:
return len(self._names)

def __contains__(self, key):
def __contains__(self, key: Hashable) -> bool:
return key in self._names

def __repr__(self):
def __repr__(self) -> str:
return formatting.coords_repr(self)

@property
def dims(self):
return self._data.dims
def to_dataset(self) -> "Dataset":
raise NotImplementedError()

def to_index(self, ordered_dims=None):
def to_index(self, ordered_dims: Sequence[Hashable] = None) -> pd.Index:
"""Convert all index coordinates into a :py:class:`pandas.Index`.

Parameters
----------
ordered_dims : sequence, optional
ordered_dims : sequence of hashable, optional
Possibly reordered version of this object's dimensions indicating
the order in which dimensions should appear on the result.

Expand All @@ -77,7 +97,7 @@ def to_index(self, ordered_dims=None):
than more dimension.
"""
if ordered_dims is None:
ordered_dims = self.dims
ordered_dims = list(self.dims)
elif set(ordered_dims) != set(self.dims):
raise ValueError(
"ordered_dims must match dims, but does not: "
Expand All @@ -94,7 +114,7 @@ def to_index(self, ordered_dims=None):
names = list(ordered_dims)
return pd.MultiIndex.from_product(indexes, names=names)

def update(self, other):
def update(self, other: Mapping[Hashable, Any]) -> None:
other_vars = getattr(other, "variables", other)
coords = merge_coords(
[self.variables, other_vars], priority_arg=1, indexes=self.indexes
Expand Down Expand Up @@ -127,7 +147,7 @@ def _merge_inplace(self, other):
yield
self._update_coords(variables)

def merge(self, other):
def merge(self, other: "AbstractCoordinates") -> "Dataset":
"""Merge two sets of coordinates to create a new Dataset

The method implements the logic used for joining coordinates in the
Expand Down Expand Up @@ -167,32 +187,38 @@ class DatasetCoordinates(AbstractCoordinates):
objects.
"""

def __init__(self, dataset):
_data = None # type: Dataset

def __init__(self, dataset: "Dataset"):
self._data = dataset

@property
def _names(self):
def _names(self) -> Set[Hashable]:
return self._data._coord_names

@property
def variables(self):
def dims(self) -> Mapping[Hashable, int]:
return self._data.dims

@property
def variables(self) -> Mapping[Hashable, Variable]:
return Frozen(
OrderedDict(
(k, v) for k, v in self._data.variables.items() if k in self._names
)
)

def __getitem__(self, key):
def __getitem__(self, key: Hashable) -> "DataArray":
if key in self._data.data_vars:
raise KeyError(key)
return self._data[key]
return cast("DataArray", self._data[key])

def to_dataset(self):
def to_dataset(self) -> "Dataset":
"""Convert these coordinates into a new Dataset
"""
return self._data._copy_listed(self._names)

def _update_coords(self, coords):
def _update_coords(self, coords: Mapping[Hashable, Any]) -> None:
from .dataset import calculate_dimensions

variables = self._data._variables.copy()
Expand All @@ -210,7 +236,7 @@ def _update_coords(self, coords):
self._data._dims = dims
self._data._indexes = None

def __delitem__(self, key):
def __delitem__(self, key: Hashable) -> None:
if key in self:
del self._data[key]
else:
Expand All @@ -232,17 +258,23 @@ class DataArrayCoordinates(AbstractCoordinates):
dimensions and the values given by corresponding DataArray objects.
"""

def __init__(self, dataarray):
_data = None # type: DataArray

def __init__(self, dataarray: "DataArray"):
self._data = dataarray

@property
def _names(self):
def dims(self) -> Tuple[Hashable, ...]:
return self._data.dims

@property
def _names(self) -> Set[Hashable]:
return set(self._data._coords)

def __getitem__(self, key):
def __getitem__(self, key: Hashable) -> "DataArray":
return self._data._getitem_coord(key)

def _update_coords(self, coords):
def _update_coords(self, coords) -> None:
from .dataset import calculate_dimensions

coords_plus_data = coords.copy()
Expand All @@ -259,19 +291,15 @@ def _update_coords(self, coords):
def variables(self):
return Frozen(self._data._coords)

def _to_dataset(self, shallow_copy=True):
def to_dataset(self) -> "Dataset":
from .dataset import Dataset

coords = OrderedDict(
(k, v.copy(deep=False) if shallow_copy else v)
for k, v in self._data._coords.items()
(k, v.copy(deep=False)) for k, v in self._data._coords.items()
)
return Dataset._from_vars_and_coord_names(coords, set(coords))

def to_dataset(self):
return self._to_dataset()

def __delitem__(self, key):
def __delitem__(self, key: Hashable) -> None:
del self._data._coords[key]

def _ipython_key_completions_(self):
Expand Down Expand Up @@ -300,9 +328,10 @@ def __len__(self) -> int:
return len(self._data._level_coords)


def assert_coordinate_consistent(obj, coords):
""" Maeke sure the dimension coordinate of obj is
consistent with coords.
def assert_coordinate_consistent(
obj: Union["DataArray", "Dataset"], coords: Mapping[Hashable, Variable]
) -> None:
"""Make sure the dimension coordinate of obj is consistent with coords.

obj: DataArray or Dataset
coords: Dict-like of variables
Expand All @@ -320,17 +349,20 @@ def assert_coordinate_consistent(obj, coords):


def remap_label_indexers(
obj, indexers=None, method=None, tolerance=None, **indexers_kwargs
):
"""
Remap **indexers from obj.coords.
If indexer is an instance of DataArray and it has coordinate, then this
coordinate will be attached to pos_indexers.
obj: Union["DataArray", "Dataset"],
indexers: Mapping[Hashable, Any] = None,
method: str = None,
tolerance=None,
**indexers_kwargs: Any
) -> Tuple[dict, dict]: # TODO more precise return type after annotations in indexing
"""Remap indexers from obj.coords.
If indexer is an instance of DataArray and it has coordinate, then this coordinate
will be attached to pos_indexers.

Returns
-------
pos_indexers: Same type of indexers.
np.ndarray or Variable or DataArra
np.ndarray or Variable or DataArray
new_indexes: mapping of new dimensional-coordinate.
"""
from .dataarray import DataArray
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def __setitem__(self, key, value) -> None:
labels = indexing.expanded_indexer(key, self.data_array.ndim)
key = dict(zip(self.data_array.dims, labels))

pos_indexers, _ = remap_label_indexers(self.data_array, **key)
pos_indexers, _ = remap_label_indexers(self.data_array, key)
self.data_array[pos_indexers] = value


Expand Down
13 changes: 6 additions & 7 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def as_dataset(obj: Any) -> "Dataset":
return obj


class DataVariables(Mapping[Hashable, "Union[DataArray, Dataset]"]):
class DataVariables(Mapping[Hashable, "DataArray"]):
def __init__(self, dataset: "Dataset"):
self._dataset = dataset

Expand All @@ -349,14 +349,13 @@ def __iter__(self) -> Iterator[Hashable]:
def __len__(self) -> int:
return len(self._dataset._variables) - len(self._dataset._coord_names)

def __contains__(self, key) -> bool:
def __contains__(self, key: Hashable) -> bool:
return key in self._dataset._variables and key not in self._dataset._coord_names

def __getitem__(self, key) -> "Union[DataArray, Dataset]":
def __getitem__(self, key: Hashable) -> "DataArray":
if key not in self._dataset._coord_names:
return self._dataset[key]
else:
raise KeyError(key)
return cast("DataArray", self._dataset[key])
raise KeyError(key)

def __repr__(self) -> str:
return formatting.data_vars_repr(self)
Expand Down Expand Up @@ -1317,7 +1316,7 @@ def identical(self, other: "Dataset") -> bool:
return False

@property
def indexes(self) -> "Mapping[Any, pd.Index]":
def indexes(self) -> Indexes:
"""Mapping of pandas.Index objects used for label based indexing
"""
if self._indexes is None:
Expand Down