From 023389293d60bf782d9dc0c718080be073c7b0f5 Mon Sep 17 00:00:00 2001 From: Alban Farchi Date: Mon, 7 Apr 2025 09:26:02 +0200 Subject: [PATCH 1/4] Fixes dimension order in xarray.Dataset.to_stacked_array --- doc/whats-new.rst | 3 +++ xarray/core/dataset.py | 8 +++++++- xarray/tests/test_dataset.py | 27 +++++++++++++++++++++++++++ 3 files changed, 37 insertions(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 60cf2be873a..e74cb2c3bd6 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -34,6 +34,9 @@ Deprecations Bug fixes ~~~~~~~~~ +- :py:meth:`~xarray.Dataset.to_stacked_array` now uses dimensions in order of appearance. + This fixes the issue where using :py:meth:`~xarray.Dataset.transpose` before :py:meth:`~xarray.Dataset.to_stacked_array` + had no effect. (Mentioned in :issue:`9921`) Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index bf2858c1b18..b7711628726 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5246,7 +5246,13 @@ def to_stacked_array( """ from xarray.structure.concat import concat - stacking_dims = tuple(dim for dim in self.dims if dim not in sample_dims) + # add stacking dims by order of appearance + stacking_dims = [] + for da in self.data_vars.values(): + for dim in da.dims: + if dim not in sample_dims and dim not in stacking_dims: + stacking_dims.append(dim) + stacking_dims = tuple(stacking_dims) for key, da in self.data_vars.items(): missing_sample_dims = set(sample_dims) - set(da.dims) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index b273b7d1a0d..c1310bc7e1d 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4098,6 +4098,33 @@ def test_to_stacked_array_preserves_dtype(self) -> None: expected_stacked_variable, ) + def test_to_stacked_array_transposed(self) -> None: + # test that to_stacked_array uses updated dim order after transposition + ds = xr.Dataset( + data_vars=dict( + v1=(["d1", "d2"], np.arange(6).reshape((2, 3))), + ), + coords=dict( + d1=(["d1"], np.arange(2)), + d2=(["d2"], np.arange(3)), + ), + ) + da = ds.to_stacked_array( + new_dim="new_dim", + sample_dims=[], + variable_dim="variable", + ) + dsT = ds.transpose() + daT = dsT.to_stacked_array( + new_dim="new_dim", + sample_dims=[], + variable_dim="variable", + ) + v1 = np.arange(6) + v1T = np.arange(6).reshape((2, 3)).T.flatten() + np.testing.assert_equal(da.to_numpy(), v1) + np.testing.assert_equal(daT.to_numpy(), v1T) + def test_update(self) -> None: data = create_test_data(seed=0) expected = data.copy() From 556fb8bc36d515d53b894aefd949698115f589f6 Mon Sep 17 00:00:00 2001 From: Alban Farchi Date: Mon, 7 Apr 2025 13:20:18 +0200 Subject: [PATCH 2/4] corrected dummy variable name to satisfy mypy --- xarray/core/dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index b7711628726..03e2a4a98e6 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5247,12 +5247,12 @@ def to_stacked_array( from xarray.structure.concat import concat # add stacking dims by order of appearance - stacking_dims = [] + stacking_dims_list = [] for da in self.data_vars.values(): for dim in da.dims: - if dim not in sample_dims and dim not in stacking_dims: - stacking_dims.append(dim) - stacking_dims = tuple(stacking_dims) + if dim not in sample_dims and dim not in stacking_dims_list: + stacking_dims_list.append(dim) + stacking_dims = tuple(stacking_dims_list) for key, da in self.data_vars.items(): missing_sample_dims = set(sample_dims) - set(da.dims) From 9d1d388b1beeefd50d05d9f1e710f0ee9cfdf215 Mon Sep 17 00:00:00 2001 From: Alban Farchi Date: Mon, 7 Apr 2025 13:27:48 +0200 Subject: [PATCH 3/4] added type annotation to satisfy mypy --- xarray/core/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 03e2a4a98e6..5b64ff21e48 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5247,7 +5247,7 @@ def to_stacked_array( from xarray.structure.concat import concat # add stacking dims by order of appearance - stacking_dims_list = [] + stacking_dims_list: list[str] = [] for da in self.data_vars.values(): for dim in da.dims: if dim not in sample_dims and dim not in stacking_dims_list: From 8f4ada462235b19f170de58f55e2a619370272b7 Mon Sep 17 00:00:00 2001 From: Alban Farchi Date: Mon, 7 Apr 2025 13:30:59 +0200 Subject: [PATCH 4/4] corrected type annotation to satisfy mypy --- xarray/core/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 5b64ff21e48..b7b058c5057 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5247,7 +5247,7 @@ def to_stacked_array( from xarray.structure.concat import concat # add stacking dims by order of appearance - stacking_dims_list: list[str] = [] + stacking_dims_list: list[Hashable] = [] for da in self.data_vars.values(): for dim in da.dims: if dim not in sample_dims and dim not in stacking_dims_list: