Skip to content
forked from pydata/xarray

Commit 781877c

Browse files
TomNicholaspre-commit-ci[bot]shoyer
authored
Fix DataTree.coords.__setitem__ by adding DataTreeCoordinates class (pydata#9451)
* add a DataTreeCoordinates class * passing read-only properties tests * tests for modifying in-place * WIP making the modification test pass * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * get to the delete tests * test * improve error message * implement delitem * test KeyError * subclass Coordinates instead of DatasetCoordinates * use Frozen(self._data._coord_variables) * Simplify when to raise KeyError Co-authored-by: Stephan Hoyer <[email protected]> * correct bug in suggestion * Update xarray/core/coordinates.py Co-authored-by: Stephan Hoyer <[email protected]> * simplify _update_coords by creating new node data first * update indexes correctly * passes test * update ._drop_indexed_coords * some mypy fixes * remove the apparently-unused _drop_indexed_coords method * fix import error * test that Dataset and DataArray constructors can handle being passed a DataTreeCoordinates object * test dt.coords can be passed to DataTree constructor * improve readability of inline comment * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * initial tests with inherited coords * ignore typeerror indicating dodgy inheritance * try to avoid Unbound type error * cast return value correctly * cehck that .coords works with inherited coords * fix data->dataset * fix return type of __getitem__ * Use .dataset instead of .to_dataset() Co-authored-by: Stephan Hoyer <[email protected]> * _check_alignment -> check_alignment * remove dict comprehension Co-authored-by: Stephan Hoyer <[email protected]> * KeyError message formatting Co-authored-by: Stephan Hoyer <[email protected]> * keep generic types for .dims and .sizes * test verifying you cant delete inherited coord * fix mypy complaint * type hint as accepting objects * update note about .dims returning all dims --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Stephan Hoyer <[email protected]>
1 parent fac2c89 commit 781877c

File tree

3 files changed

+293
-30
lines changed

3 files changed

+293
-30
lines changed

xarray/core/coordinates.py

+105-12
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from xarray.core.common import DataWithCoords
3737
from xarray.core.dataarray import DataArray
3838
from xarray.core.dataset import Dataset
39+
from xarray.core.datatree import DataTree
3940

4041
# Used as the key corresponding to a DataArray's variable when converting
4142
# arbitrary DataArray objects to datasets
@@ -197,12 +198,12 @@ class Coordinates(AbstractCoordinates):
197198
198199
Coordinates are either:
199200
200-
- returned via the :py:attr:`Dataset.coords` and :py:attr:`DataArray.coords`
201-
properties
201+
- returned via the :py:attr:`Dataset.coords`, :py:attr:`DataArray.coords`,
202+
and :py:attr:`DataTree.coords` properties,
202203
- built from Pandas or other index objects
203-
(e.g., :py:meth:`Coordinates.from_pandas_multiindex`)
204+
(e.g., :py:meth:`Coordinates.from_pandas_multiindex`),
204205
- built directly from coordinate data and Xarray ``Index`` objects (beware that
205-
no consistency check is done on those inputs)
206+
no consistency check is done on those inputs),
206207
207208
Parameters
208209
----------
@@ -704,6 +705,7 @@ def _names(self) -> set[Hashable]:
704705

705706
@property
706707
def dims(self) -> Frozen[Hashable, int]:
708+
# deliberately display all dims, not just those on coordinate variables - see https://github.com/pydata/xarray/issues/9466
707709
return self._data.dims
708710

709711
@property
@@ -771,14 +773,6 @@ def _drop_coords(self, coord_names):
771773
del self._data._indexes[name]
772774
self._data._coord_names.difference_update(coord_names)
773775

774-
def _drop_indexed_coords(self, coords_to_drop: set[Hashable]) -> None:
775-
assert self._data.xindexes is not None
776-
new_coords = drop_indexed_coords(coords_to_drop, self)
777-
for name in self._data._coord_names - new_coords._names:
778-
del self._data._variables[name]
779-
self._data._indexes = dict(new_coords.xindexes)
780-
self._data._coord_names.intersection_update(new_coords._names)
781-
782776
def __delitem__(self, key: Hashable) -> None:
783777
if key in self:
784778
del self._data[key]
@@ -796,6 +790,105 @@ def _ipython_key_completions_(self):
796790
]
797791

798792

793+
class DataTreeCoordinates(Coordinates):
794+
"""
795+
Dictionary like container for coordinates of a DataTree node (variables + indexes).
796+
797+
This collection can be passed directly to the :py:class:`~xarray.Dataset`
798+
and :py:class:`~xarray.DataArray` constructors via their `coords` argument.
799+
This will add both the coordinates variables and their index.
800+
"""
801+
802+
# TODO: This only needs to be a separate class from `DatasetCoordinates` because DataTree nodes store their variables differently
803+
# internally than how Datasets do, see https://github.com/pydata/xarray/issues/9203.
804+
805+
_data: DataTree # type: ignore[assignment] # complaining that DataTree is not a subclass of DataWithCoords - this can be fixed by refactoring, see #9203
806+
807+
__slots__ = ("_data",)
808+
809+
def __init__(self, datatree: DataTree):
810+
self._data = datatree
811+
812+
@property
813+
def _names(self) -> set[Hashable]:
814+
return set(self._data._coord_variables)
815+
816+
@property
817+
def dims(self) -> Frozen[Hashable, int]:
818+
# deliberately display all dims, not just those on coordinate variables - see https://github.com/pydata/xarray/issues/9466
819+
return Frozen(self._data.dims)
820+
821+
@property
822+
def dtypes(self) -> Frozen[Hashable, np.dtype]:
823+
"""Mapping from coordinate names to dtypes.
824+
825+
Cannot be modified directly, but is updated when adding new variables.
826+
827+
See Also
828+
--------
829+
Dataset.dtypes
830+
"""
831+
return Frozen({n: v.dtype for n, v in self._data._coord_variables.items()})
832+
833+
@property
834+
def variables(self) -> Mapping[Hashable, Variable]:
835+
return Frozen(self._data._coord_variables)
836+
837+
def __getitem__(self, key: Hashable) -> DataArray:
838+
if key not in self._data._coord_variables:
839+
raise KeyError(key)
840+
return self._data.dataset[key]
841+
842+
def to_dataset(self) -> Dataset:
843+
"""Convert these coordinates into a new Dataset"""
844+
return self._data.dataset._copy_listed(self._names)
845+
846+
def _update_coords(
847+
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
848+
) -> None:
849+
from xarray.core.datatree import check_alignment
850+
851+
# create updated node (`.to_dataset` makes a copy so this doesn't modify in-place)
852+
node_ds = self._data.to_dataset(inherited=False)
853+
node_ds.coords._update_coords(coords, indexes)
854+
855+
# check consistency *before* modifying anything in-place
856+
# TODO can we clean up the signature of check_alignment to make this less awkward?
857+
if self._data.parent is not None:
858+
parent_ds = self._data.parent._to_dataset_view(
859+
inherited=True, rebuild_dims=False
860+
)
861+
else:
862+
parent_ds = None
863+
check_alignment(self._data.path, node_ds, parent_ds, self._data.children)
864+
865+
# assign updated attributes
866+
coord_variables = dict(node_ds.coords.variables)
867+
self._data._node_coord_variables = coord_variables
868+
self._data._node_dims = node_ds._dims
869+
self._data._node_indexes = node_ds._indexes
870+
871+
def _drop_coords(self, coord_names):
872+
# should drop indexed coordinates only
873+
for name in coord_names:
874+
del self._data._node_coord_variables[name]
875+
del self._data._node_indexes[name]
876+
877+
def __delitem__(self, key: Hashable) -> None:
878+
if key in self:
879+
del self._data[key] # type: ignore[arg-type] # see https://github.com/pydata/xarray/issues/8836
880+
else:
881+
raise KeyError(key)
882+
883+
def _ipython_key_completions_(self):
884+
"""Provide method for the key-autocompletions in IPython."""
885+
return [
886+
key
887+
for key in self._data._ipython_key_completions_()
888+
if key in self._data._coord_variables
889+
]
890+
891+
799892
class DataArrayCoordinates(Coordinates, Generic[T_DataArray]):
800893
"""Dictionary like container for DataArray coordinates (variables + indexes).
801894

xarray/core/datatree.py

+22-16
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from xarray.core import utils
1717
from xarray.core.alignment import align
1818
from xarray.core.common import TreeAttrAccessMixin
19-
from xarray.core.coordinates import DatasetCoordinates
19+
from xarray.core.coordinates import Coordinates, DataTreeCoordinates
2020
from xarray.core.dataarray import DataArray
2121
from xarray.core.dataset import Dataset, DataVariables
2222
from xarray.core.datatree_mapping import (
@@ -91,9 +91,11 @@ def _collect_data_and_coord_variables(
9191
return data_variables, coord_variables
9292

9393

94-
def _to_new_dataset(data: Dataset | None) -> Dataset:
94+
def _to_new_dataset(data: Dataset | Coordinates | None) -> Dataset:
9595
if isinstance(data, Dataset):
9696
ds = data.copy(deep=False)
97+
elif isinstance(data, Coordinates):
98+
ds = data.to_dataset()
9799
elif data is None:
98100
ds = Dataset()
99101
else:
@@ -125,7 +127,7 @@ def _indented(text: str) -> str:
125127
return textwrap.indent(text, prefix=" ")
126128

127129

128-
def _check_alignment(
130+
def check_alignment(
129131
path: str,
130132
node_ds: Dataset,
131133
parent_ds: Dataset | None,
@@ -151,7 +153,7 @@ def _check_alignment(
151153
for child_name, child in children.items():
152154
child_path = str(NodePath(path) / child_name)
153155
child_ds = child.to_dataset(inherited=False)
154-
_check_alignment(child_path, child_ds, base_ds, child.children)
156+
check_alignment(child_path, child_ds, base_ds, child.children)
155157

156158

157159
class DatasetView(Dataset):
@@ -417,7 +419,7 @@ class DataTree(
417419

418420
def __init__(
419421
self,
420-
dataset: Dataset | None = None,
422+
dataset: Dataset | Coordinates | None = None,
421423
children: Mapping[str, DataTree] | None = None,
422424
name: str | None = None,
423425
):
@@ -473,7 +475,7 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None:
473475
path = str(NodePath(parent.path) / name)
474476
node_ds = self.to_dataset(inherited=False)
475477
parent_ds = parent._to_dataset_view(rebuild_dims=False, inherited=True)
476-
_check_alignment(path, node_ds, parent_ds, self.children)
478+
check_alignment(path, node_ds, parent_ds, self.children)
477479

478480
@property
479481
def _coord_variables(self) -> ChainMap[Hashable, Variable]:
@@ -498,8 +500,10 @@ def _to_dataset_view(self, rebuild_dims: bool, inherited: bool) -> DatasetView:
498500
elif inherited:
499501
# Note: rebuild_dims=False with inherited=True can create
500502
# technically invalid Dataset objects because it still includes
501-
# dimensions that are only defined on parent data variables (i.e. not present on any parent coordinate variables), e.g.,
502-
# consider:
503+
# dimensions that are only defined on parent data variables
504+
# (i.e. not present on any parent coordinate variables).
505+
#
506+
# For example:
503507
# >>> tree = DataTree.from_dict(
504508
# ... {
505509
# ... "/": xr.Dataset({"foo": ("x", [1, 2])}), # x has size 2
@@ -514,11 +518,13 @@ def _to_dataset_view(self, rebuild_dims: bool, inherited: bool) -> DatasetView:
514518
# Data variables:
515519
# *empty*
516520
#
517-
# Notice the "x" dimension is still defined, even though there are no
518-
# variables or coordinates.
519-
# Normally this is not supposed to be possible in xarray's data model, but here it is useful internally for use cases where we
520-
# want to inherit everything from parents nodes, e.g., for align()
521-
# and repr().
521+
# Notice the "x" dimension is still defined, even though there are no variables
522+
# or coordinates.
523+
#
524+
# Normally this is not supposed to be possible in xarray's data model,
525+
# but here it is useful internally for use cases where we
526+
# want to inherit everything from parents nodes, e.g., for align() and repr().
527+
#
522528
# The user should never be able to see this dimension via public API.
523529
dims = dict(self._dims)
524530
else:
@@ -762,7 +768,7 @@ def _replace_node(
762768
if self.parent is not None
763769
else None
764770
)
765-
_check_alignment(self.path, ds, parent_ds, children)
771+
check_alignment(self.path, ds, parent_ds, children)
766772

767773
if data is not _default:
768774
self._set_node_data(ds)
@@ -1187,11 +1193,11 @@ def xindexes(self) -> Indexes[Index]:
11871193
)
11881194

11891195
@property
1190-
def coords(self) -> DatasetCoordinates:
1196+
def coords(self) -> DataTreeCoordinates:
11911197
"""Dictionary of xarray.DataArray objects corresponding to coordinate
11921198
variables
11931199
"""
1194-
return DatasetCoordinates(self.to_dataset())
1200+
return DataTreeCoordinates(self)
11951201

11961202
@property
11971203
def data_vars(self) -> DataVariables:

0 commit comments

Comments
 (0)