Skip to content

Commit 18e34cc

Browse files
jaicherdcherian
andauthored
Fix swap_dims() index names (issue #3748) (#3752)
* Added test for GH3748 * Rename newly created index in swap_dims() to dim name if not multiindex Fixes GH3748 * Updated whats-new.rst with pull request information for swap_dims fix * Move tests for GH3748 into existing swap_dims tests + integrated new tests for GH3748 for DataArray into existing swap_dims tests + added similar tests for Dataset + added test for multiindex case Co-authored-by: Deepak Cherian <[email protected]>
1 parent b14eea2 commit 18e34cc

File tree

4 files changed

+40
-3
lines changed

4 files changed

+40
-3
lines changed

doc/whats-new.rst

+5-1
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,14 @@ New Features
3434

3535
Bug fixes
3636
~~~~~~~~~
37+
38+
- Fix :py:meth:`Dataset.swap_dims` and :py:meth:`DataArray.swap_dims` producing
39+
index with name reflecting the previous dimension name instead of the new one
40+
(:issue:`3748`, :pull:`3752`). By `Joseph K Aicher
41+
<https://github.com/jaicher>`_.
3742
- Use ``dask_array_type`` instead of ``dask_array.Array`` for type
3843
checking. (:issue:`3779`, :pull:`3787`)
3944
By `Justus Magin <https://github.com/keewis>`_.
40-
4145
- :py:func:`concat` can now handle coordinate variables only present in one of
4246
the objects to be concatenated when ``coords="different"``.
4347
By `Deepak Cherian <https://github.com/dcherian>`_.

xarray/core/dataset.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -2948,7 +2948,11 @@ def swap_dims(
29482948
if k in self.indexes:
29492949
indexes[k] = self.indexes[k]
29502950
else:
2951-
indexes[k] = var.to_index()
2951+
new_index = var.to_index()
2952+
if new_index.nlevels == 1:
2953+
# make sure index name matches dimension name
2954+
new_index = new_index.rename(k)
2955+
indexes[k] = new_index
29522956
else:
29532957
var = v.to_base_variable()
29542958
var.dims = dims

xarray/tests/test_dataarray.py

+19
Original file line numberDiff line numberDiff line change
@@ -1536,11 +1536,30 @@ def test_swap_dims(self):
15361536
expected = DataArray(array.values, {"y": list("abc")}, dims="y")
15371537
actual = array.swap_dims({"x": "y"})
15381538
assert_identical(expected, actual)
1539+
for dim_name in set().union(expected.indexes.keys(), actual.indexes.keys()):
1540+
pd.testing.assert_index_equal(
1541+
expected.indexes[dim_name], actual.indexes[dim_name]
1542+
)
15391543

15401544
array = DataArray(np.random.randn(3), {"x": list("abc")}, "x")
15411545
expected = DataArray(array.values, {"x": ("y", list("abc"))}, dims="y")
15421546
actual = array.swap_dims({"x": "y"})
15431547
assert_identical(expected, actual)
1548+
for dim_name in set().union(expected.indexes.keys(), actual.indexes.keys()):
1549+
pd.testing.assert_index_equal(
1550+
expected.indexes[dim_name], actual.indexes[dim_name]
1551+
)
1552+
1553+
# multiindex case
1554+
idx = pd.MultiIndex.from_arrays([list("aab"), list("yzz")], names=["y1", "y2"])
1555+
array = DataArray(np.random.randn(3), {"y": ("x", idx)}, "x")
1556+
expected = DataArray(array.values, {"y": idx}, "y")
1557+
actual = array.swap_dims({"x": "y"})
1558+
assert_identical(expected, actual)
1559+
for dim_name in set().union(expected.indexes.keys(), actual.indexes.keys()):
1560+
pd.testing.assert_index_equal(
1561+
expected.indexes[dim_name], actual.indexes[dim_name]
1562+
)
15441563

15451564
def test_expand_dims_error(self):
15461565
array = DataArray(

xarray/tests/test_dataset.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -2596,7 +2596,7 @@ def test_swap_dims(self):
25962596
assert_identical(expected, actual)
25972597
assert isinstance(actual.variables["y"], IndexVariable)
25982598
assert isinstance(actual.variables["x"], Variable)
2599-
assert actual.indexes["y"].equals(pd.Index(list("abc")))
2599+
pd.testing.assert_index_equal(actual.indexes["y"], expected.indexes["y"])
26002600

26012601
roundtripped = actual.swap_dims({"y": "x"})
26022602
assert_identical(original.set_coords("y"), roundtripped)
@@ -2612,6 +2612,16 @@ def test_swap_dims(self):
26122612
actual = original.swap_dims({"x": "u"})
26132613
assert_identical(expected, actual)
26142614

2615+
# handle multiindex case
2616+
idx = pd.MultiIndex.from_arrays([list("aab"), list("yzz")], names=["y1", "y2"])
2617+
original = Dataset({"x": [1, 2, 3], "y": ("x", idx), "z": 42})
2618+
expected = Dataset({"z": 42}, {"x": ("y", [1, 2, 3]), "y": idx})
2619+
actual = original.swap_dims({"x": "y"})
2620+
assert_identical(expected, actual)
2621+
assert isinstance(actual.variables["y"], IndexVariable)
2622+
assert isinstance(actual.variables["x"], Variable)
2623+
pd.testing.assert_index_equal(actual.indexes["y"], expected.indexes["y"])
2624+
26152625
def test_expand_dims_error(self):
26162626
original = Dataset(
26172627
{

0 commit comments

Comments
 (0)