diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4dc7ed7310f..3e473476197 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -36,6 +36,9 @@ Deprecations Bug fixes ~~~~~~~~~ +- :py:meth:`~xarray.Dataset.to_stacked_array` now uses dimensions in order of appearance. + This fixes the issue where using :py:meth:`~xarray.Dataset.transpose` before :py:meth:`~xarray.Dataset.to_stacked_array` + had no effect. (Mentioned in :issue:`9921`) Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index bf2858c1b18..b7b058c5057 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5246,7 +5246,13 @@ def to_stacked_array( """ from xarray.structure.concat import concat - stacking_dims = tuple(dim for dim in self.dims if dim not in sample_dims) + # add stacking dims by order of appearance + stacking_dims_list: list[Hashable] = [] + for da in self.data_vars.values(): + for dim in da.dims: + if dim not in sample_dims and dim not in stacking_dims_list: + stacking_dims_list.append(dim) + stacking_dims = tuple(stacking_dims_list) for key, da in self.data_vars.items(): missing_sample_dims = set(sample_dims) - set(da.dims) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index b273b7d1a0d..c1310bc7e1d 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4098,6 +4098,33 @@ def test_to_stacked_array_preserves_dtype(self) -> None: expected_stacked_variable, ) + def test_to_stacked_array_transposed(self) -> None: + # test that to_stacked_array uses updated dim order after transposition + ds = xr.Dataset( + data_vars=dict( + v1=(["d1", "d2"], np.arange(6).reshape((2, 3))), + ), + coords=dict( + d1=(["d1"], np.arange(2)), + d2=(["d2"], np.arange(3)), + ), + ) + da = ds.to_stacked_array( + new_dim="new_dim", + sample_dims=[], + variable_dim="variable", + ) + dsT = ds.transpose() + daT = dsT.to_stacked_array( + new_dim="new_dim", + sample_dims=[], + variable_dim="variable", + ) + v1 = np.arange(6) + v1T = np.arange(6).reshape((2, 3)).T.flatten() + np.testing.assert_equal(da.to_numpy(), v1) + np.testing.assert_equal(daT.to_numpy(), v1T) + def test_update(self) -> None: data = create_test_data(seed=0) expected = data.copy()