Skip to content

Commit ff6793d

Browse files
Illviljanpre-commit-ci[bot]headtr1ck
authored
Switch to T_DataArray in .coords (#7285)
* Switch to T_DataArray in .coords * Update coordinates.py * Update coordinates.py * mypy understands the type from items better apparanetly * Update coordinates.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * resolve DataArrayCoords generic type * fix import * Test adding __class_getitem__ * Update coordinates.py * Test adding a _data slot. * Adding class_getitem seems to work. * test mypy on 3.8 * Update ci-additional.yaml Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Michael Niklas <[email protected]>
1 parent 5344ccb commit ff6793d

File tree

2 files changed

+24
-11
lines changed

2 files changed

+24
-11
lines changed

xarray/core/coordinates.py

+22-9
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import warnings
44
from contextlib import contextmanager
5-
from typing import TYPE_CHECKING, Any, Hashable, Iterator, Mapping, Sequence, cast
5+
from typing import TYPE_CHECKING, Any, Hashable, Iterator, List, Mapping, Sequence
66

77
import numpy as np
88
import pandas as pd
@@ -14,18 +14,27 @@
1414
from .variable import Variable, calculate_dimensions
1515

1616
if TYPE_CHECKING:
17+
from .common import DataWithCoords
1718
from .dataarray import DataArray
1819
from .dataset import Dataset
20+
from .types import T_DataArray
1921

2022
# Used as the key corresponding to a DataArray's variable when converting
2123
# arbitrary DataArray objects to datasets
2224
_THIS_ARRAY = ReprObject("<this-array>")
2325

26+
# TODO: Remove when min python version >= 3.9:
27+
GenericAlias = type(List[int])
2428

25-
class Coordinates(Mapping[Hashable, "DataArray"]):
26-
__slots__ = ()
2729

28-
def __getitem__(self, key: Hashable) -> DataArray:
30+
class Coordinates(Mapping[Hashable, "T_DataArray"]):
31+
_data: DataWithCoords
32+
__slots__ = ("_data",)
33+
34+
# TODO: Remove when min python version >= 3.9:
35+
__class_getitem__ = classmethod(GenericAlias)
36+
37+
def __getitem__(self, key: Hashable) -> T_DataArray:
2938
raise NotImplementedError()
3039

3140
def __setitem__(self, key: Hashable, value: Any) -> None:
@@ -238,6 +247,8 @@ class DatasetCoordinates(Coordinates):
238247
objects.
239248
"""
240249

250+
_data: Dataset
251+
241252
__slots__ = ("_data",)
242253

243254
def __init__(self, dataset: Dataset):
@@ -278,7 +289,7 @@ def variables(self) -> Mapping[Hashable, Variable]:
278289
def __getitem__(self, key: Hashable) -> DataArray:
279290
if key in self._data.data_vars:
280291
raise KeyError(key)
281-
return cast("DataArray", self._data[key])
292+
return self._data[key]
282293

283294
def to_dataset(self) -> Dataset:
284295
"""Convert these coordinates into a new Dataset"""
@@ -334,16 +345,18 @@ def _ipython_key_completions_(self):
334345
]
335346

336347

337-
class DataArrayCoordinates(Coordinates):
348+
class DataArrayCoordinates(Coordinates["T_DataArray"]):
338349
"""Dictionary like container for DataArray coordinates.
339350
340351
Essentially a dict with keys given by the array's
341352
dimensions and the values given by corresponding DataArray objects.
342353
"""
343354

355+
_data: T_DataArray
356+
344357
__slots__ = ("_data",)
345358

346-
def __init__(self, dataarray: DataArray):
359+
def __init__(self, dataarray: T_DataArray) -> None:
347360
self._data = dataarray
348361

349362
@property
@@ -366,7 +379,7 @@ def dtypes(self) -> Frozen[Hashable, np.dtype]:
366379
def _names(self) -> set[Hashable]:
367380
return set(self._data._coords)
368381

369-
def __getitem__(self, key: Hashable) -> DataArray:
382+
def __getitem__(self, key: Hashable) -> T_DataArray:
370383
return self._data._getitem_coord(key)
371384

372385
def _update_coords(
@@ -452,7 +465,7 @@ def drop_coords(
452465

453466

454467
def assert_coordinate_consistent(
455-
obj: DataArray | Dataset, coords: Mapping[Any, Variable]
468+
obj: T_DataArray | Dataset, coords: Mapping[Any, Variable]
456469
) -> None:
457470
"""Make sure the dimension coordinate of obj is consistent with coords.
458471

xarray/core/dataarray.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3993,8 +3993,8 @@ def to_dict(self, data: bool = True, encoding: bool = False) -> dict[str, Any]:
39933993
"""
39943994
d = self.variable.to_dict(data=data)
39953995
d.update({"coords": {}, "name": self.name})
3996-
for k in self.coords:
3997-
d["coords"][k] = self.coords[k].variable.to_dict(data=data)
3996+
for k, coord in self.coords.items():
3997+
d["coords"][k] = coord.variable.to_dict(data=data)
39983998
if encoding:
39993999
d["encoding"] = dict(self.encoding)
40004000
return d

0 commit comments

Comments
 (0)