Skip to content

Commit d555172

Browse files
authored
Allow swap_dims to take kwargs (#4841)
1 parent bc35548 commit d555172

File tree

5 files changed

+38
-3
lines changed

5 files changed

+38
-3
lines changed

doc/whats-new.rst

+4-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@ Breaking changes
4646
New Features
4747
~~~~~~~~~~~~
4848
- Performance improvement when constructing DataArrays. Significantly speeds up repr for Datasets with large number of variables.
49-
By `Deepak Cherian <https://github.com/dcherian>`_
49+
By `Deepak Cherian <https://github.com/dcherian>`_.
50+
- :py:meth:`DataArray.swap_dims` & :py:meth:`Dataset.swap_dims` now accept dims
51+
in the form of kwargs as well as a dict, like most similar methods.
52+
By `Maximilian Roos <https://github.com/max-sixty>`_.
5053

5154
Bug fixes
5255
~~~~~~~~~

xarray/core/dataarray.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -1699,7 +1699,9 @@ def rename(
16991699
new_name_or_name_dict = cast(Hashable, new_name_or_name_dict)
17001700
return self._replace(name=new_name_or_name_dict)
17011701

1702-
def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "DataArray":
1702+
def swap_dims(
1703+
self, dims_dict: Mapping[Hashable, Hashable] = None, **dims_kwargs
1704+
) -> "DataArray":
17031705
"""Returns a new DataArray with swapped dimensions.
17041706
17051707
Parameters
@@ -1708,6 +1710,10 @@ def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "DataArray":
17081710
Dictionary whose keys are current dimension names and whose values
17091711
are new names.
17101712
1713+
**dim_kwargs : {dim: , ...}, optional
1714+
The keyword arguments form of ``dims_dict``.
1715+
One of dims_dict or dims_kwargs must be provided.
1716+
17111717
Returns
17121718
-------
17131719
swapped : DataArray
@@ -1749,6 +1755,7 @@ def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "DataArray":
17491755
DataArray.rename
17501756
Dataset.swap_dims
17511757
"""
1758+
dims_dict = either_dict_or_kwargs(dims_dict, dims_kwargs, "swap_dims")
17521759
ds = self._to_temp_dataset().swap_dims(dims_dict)
17531760
return self._from_temp_dataset(ds)
17541761

xarray/core/dataset.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -3155,7 +3155,9 @@ def rename_vars(
31553155
)
31563156
return self._replace(variables, coord_names, dims=dims, indexes=indexes)
31573157

3158-
def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "Dataset":
3158+
def swap_dims(
3159+
self, dims_dict: Mapping[Hashable, Hashable] = None, **dims_kwargs
3160+
) -> "Dataset":
31593161
"""Returns a new object with swapped dimensions.
31603162
31613163
Parameters
@@ -3164,6 +3166,10 @@ def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "Dataset":
31643166
Dictionary whose keys are current dimension names and whose values
31653167
are new names.
31663168
3169+
**dim_kwargs : {existing_dim: new_dim, ...}, optional
3170+
The keyword arguments form of ``dims_dict``.
3171+
One of dims_dict or dims_kwargs must be provided.
3172+
31673173
Returns
31683174
-------
31693175
swapped : Dataset
@@ -3214,6 +3220,8 @@ def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "Dataset":
32143220
"""
32153221
# TODO: deprecate this method in favor of a (less confusing)
32163222
# rename_dims() method that only renames dimensions.
3223+
3224+
dims_dict = either_dict_or_kwargs(dims_dict, dims_kwargs, "swap_dims")
32173225
for k, v in dims_dict.items():
32183226
if k not in self.dims:
32193227
raise ValueError(

xarray/tests/test_dataarray.py

+10
Original file line numberDiff line numberDiff line change
@@ -1639,6 +1639,16 @@ def test_swap_dims(self):
16391639
expected.indexes[dim_name], actual.indexes[dim_name]
16401640
)
16411641

1642+
# as kwargs
1643+
array = DataArray(np.random.randn(3), {"x": list("abc")}, "x")
1644+
expected = DataArray(array.values, {"x": ("y", list("abc"))}, dims="y")
1645+
actual = array.swap_dims(x="y")
1646+
assert_identical(expected, actual)
1647+
for dim_name in set().union(expected.indexes.keys(), actual.indexes.keys()):
1648+
pd.testing.assert_index_equal(
1649+
expected.indexes[dim_name], actual.indexes[dim_name]
1650+
)
1651+
16421652
# multiindex case
16431653
idx = pd.MultiIndex.from_arrays([list("aab"), list("yzz")], names=["y1", "y2"])
16441654
array = DataArray(np.random.randn(3), {"y": ("x", idx)}, "x")

xarray/tests/test_dataset.py

+7
Original file line numberDiff line numberDiff line change
@@ -2748,6 +2748,13 @@ def test_swap_dims(self):
27482748
actual = original.swap_dims({"x": "u"})
27492749
assert_identical(expected, actual)
27502750

2751+
# as kwargs
2752+
expected = Dataset(
2753+
{"y": ("u", list("abc")), "z": 42}, coords={"x": ("u", [1, 2, 3])}
2754+
)
2755+
actual = original.swap_dims(x="u")
2756+
assert_identical(expected, actual)
2757+
27512758
# handle multiindex case
27522759
idx = pd.MultiIndex.from_arrays([list("aab"), list("yzz")], names=["y1", "y2"])
27532760
original = Dataset({"x": [1, 2, 3], "y": ("x", idx), "z": 42})

0 commit comments

Comments
 (0)