From cc830ee0ece69b512b71c5da81efbff9cf29a643 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 4 Sep 2024 14:46:02 -0700 Subject: [PATCH 1/5] Disallow passing a DataArray as data into the DataTree constructor I don't think there's a good reason to support syntax like `DataTree(DataArray(...))`, when it's easy enough to explicitly convert with `to_dataset()`. `Dataset(DataArray(...))` currently raises an error, so this feels a bit inconsistent. If this was intended for the sake of usability, then I think supporting a dict (that gets coerced into a Dataset) in the DataTree constructor would be much helpful. We can consider adding this later. --- xarray/core/datatree.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 59984c5afa3..535ae28ddfd 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -91,17 +91,13 @@ def _collect_data_and_coord_variables( return data_variables, coord_variables -def _coerce_to_dataset(data: Dataset | DataArray | None) -> Dataset: - if isinstance(data, DataArray): - ds = data.to_dataset() - elif isinstance(data, Dataset): +def _to_new_dataset(data: Dataset | None) -> Dataset: + if isinstance(data, Dataset): ds = data.copy(deep=False) elif data is None: ds = Dataset() else: - raise TypeError( - f"data object is not an xarray Dataset, DataArray, or None, it is of type {type(data)}" - ) + raise TypeError(f"data object is not an xarray.Dataset, dict, or None: {data}") return ds @@ -422,7 +418,7 @@ class DataTree( def __init__( self, - data: Dataset | DataArray | None = None, + data: Dataset | None = None, parent: DataTree | None = None, children: Mapping[str, DataTree] | None = None, name: str | None = None, @@ -436,9 +432,8 @@ def __init__( Parameters ---------- - data : Dataset, DataArray, or None, optional - Data to store under the .ds attribute of this node. DataArrays will - be promoted to Datasets. Default is None. + data : Dataset, optional + Data to store under the .ds attribute of this node. parent : DataTree, optional Parent node to this node. Default is None. children : Mapping[str, DataTree], optional @@ -458,7 +453,7 @@ def __init__( children = {} super().__init__(name=name) - self._set_node_data(_coerce_to_dataset(data)) + self._set_node_data(_to_new_dataset(data)) self.parent = parent self.children = children @@ -553,8 +548,8 @@ def ds(self) -> DatasetView: return self._to_dataset_view(rebuild_dims=True) @ds.setter - def ds(self, data: Dataset | DataArray | None = None) -> None: - ds = _coerce_to_dataset(data) + def ds(self, data: Dataset | None = None) -> None: + ds = _to_new_dataset(data) self._replace_node(ds) def to_dataset(self, inherited: bool = True) -> Dataset: @@ -1096,8 +1091,12 @@ def from_dict( if isinstance(root_data, DataTree): obj = root_data.copy() obj.orphan() - else: + elif root_data is None or isinstance(root_data, Dataset): obj = cls(name=name, data=root_data, parent=None, children=None) + else: + raise TypeError( + f'root node data (at "/") must be a Dataset or DataTree, got {type(root_data)}' + ) def depth(item) -> int: pathstr, _ = item From 972e202f85850a765e79c799d5469d6fe7e07bc3 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sat, 7 Sep 2024 12:50:11 -0700 Subject: [PATCH 2/5] Tests for strict data arg --- xarray/tests/test_datatree.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 9a15376a1f8..bbb955edd8e 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -33,6 +33,14 @@ def test_bad_names(self): with pytest.raises(ValueError): DataTree(name="folder/data") + def test_data_arg(self): + ds = xr.Dataset({"foo": 42}) + tree: DataTree = DataTree(data=ds) + assert_identical(tree.to_dataset(), ds) + + with pytest.raises(TypeError): + DataTree(data=xr.DataArray(42, name="foo")) # type: ignore + class TestFamilyTree: def test_setparent_unnamed_child_node_fails(self): @@ -586,6 +594,11 @@ def test_insertion_order(self): # despite 'Bart' coming before 'Lisa' when sorted alphabetically assert list(reversed["Homer"].children.keys()) == ["Lisa", "Bart"] + def test_array_values(self): + data = {"foo": xr.DataArray(1, name="bar")} + with pytest.raises(TypeError): + DataTree.from_dict(data) # type: ignore + class TestDatasetView: def test_view_contents(self): From 8b2a7ba07d0284d3e3c6bc1563cb6675551a1a80 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sat, 7 Sep 2024 13:02:53 -0700 Subject: [PATCH 3/5] type error --- xarray/core/datatree.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index f3f29e5e031..3f423eaa8a6 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1046,7 +1046,7 @@ def drop_nodes( @classmethod def from_dict( cls, - d: Mapping[str, Dataset | DataArray | DataTree | None], + d: Mapping[str, Dataset | DataTree | None], name: str | None = None, ) -> DataTree: """ @@ -1055,10 +1055,10 @@ def from_dict( Parameters ---------- d : dict-like - A mapping from path names to xarray.Dataset, xarray.DataArray, or DataTree objects. + A mapping from path names to xarray.Dataset or DataTree objects. - Path names are to be given as unix-like path. If path names containing more than one part are given, new - tree nodes will be constructed as necessary. + Path names are to be given as unix-like path. If path names containing more than one + part are given, new tree nodes will be constructed as necessary. To assign data to the root node of the tree use "/" as the path. name : Hashable | None, optional From e9574397255dbd022cacd621081e78e6f7367bd4 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sat, 7 Sep 2024 13:04:14 -0700 Subject: [PATCH 4/5] Make first from_dict arg positional only --- xarray/core/datatree.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 3f423eaa8a6..a86a62a5763 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1047,6 +1047,7 @@ def drop_nodes( def from_dict( cls, d: Mapping[str, Dataset | DataTree | None], + /, name: str | None = None, ) -> DataTree: """ From 1b56aadd21dc1e483d79b5ecdad4ed4e3522a22c Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sat, 7 Sep 2024 13:09:15 -0700 Subject: [PATCH 5/5] don't make positional only arg --- xarray/core/datatree.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index a86a62a5763..3f423eaa8a6 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1047,7 +1047,6 @@ def drop_nodes( def from_dict( cls, d: Mapping[str, Dataset | DataTree | None], - /, name: str | None = None, ) -> DataTree: """