Skip to content
forked from pydata/xarray

Commit 4eefd25

Browse files
committed
Refactor concat tests.
1 parent e0bddcb commit 4eefd25

File tree

1 file changed

+42
-43
lines changed

1 file changed

+42
-43
lines changed

xarray/tests/test_concat.py

+42-43
Original file line numberDiff line numberDiff line change
@@ -18,61 +18,60 @@
1818

1919

2020
class TestConcatDataset:
21-
def test_concat(self):
22-
# TODO: simplify and split this test case
23-
24-
# drop the third dimension to keep things relatively understandable
25-
data = create_test_data()
26-
for k in list(data.variables):
27-
if "dim3" in data[k].dims:
28-
del data[k]
29-
30-
split_data = [data.isel(dim1=slice(3)), data.isel(dim1=slice(3, None))]
31-
assert_identical(data, concat(split_data, "dim1"))
32-
33-
def rectify_dim_order(dataset):
34-
# return a new dataset with all variable dimensions transposed into
35-
# the order in which they are found in `data`
36-
return Dataset(
37-
{k: v.transpose(*data[k].dims) for k, v in dataset.data_vars.items()},
38-
dataset.coords,
39-
attrs=dataset.attrs,
40-
)
41-
42-
for dim in ["dim1", "dim2"]:
43-
datasets = [g for _, g in data.groupby(dim, squeeze=False)]
44-
assert_identical(data, concat(datasets, dim))
21+
@pytest.fixture(autouse=True)
22+
def setUp(self):
23+
self.data = create_test_data().drop_dims("dim3")
24+
25+
def rectify_dim_order(self, dataset):
26+
# return a new dataset with all variable dimensions transposed into
27+
# the order in which they are found in `data`
28+
return Dataset(
29+
{k: v.transpose(*self.data[k].dims) for k, v in dataset.data_vars.items()},
30+
dataset.coords,
31+
attrs=dataset.attrs,
32+
)
4533

46-
dim = "dim2"
47-
assert_identical(data, concat(datasets, data[dim]))
48-
assert_identical(data, concat(datasets, data[dim], coords="minimal"))
34+
@pytest.mark.parametrize("coords", ["different", "minimal"])
35+
@pytest.mark.parametrize("dim", ["dim1", "dim2"])
36+
def test_concat_simple(self, dim, coords):
37+
datasets = [g for _, g in self.data.groupby(dim, squeeze=False)]
38+
assert_identical(self.data, concat(datasets, dim, coords=coords))
4939

50-
datasets = [g for _, g in data.groupby(dim, squeeze=True)]
51-
concat_over = [k for k, v in data.coords.items() if dim in v.dims and k != dim]
52-
actual = concat(datasets, data[dim], coords=concat_over)
53-
assert_identical(data, rectify_dim_order(actual))
54-
55-
actual = concat(datasets, data[dim], coords="different")
56-
assert_identical(data, rectify_dim_order(actual))
40+
datasets = [g for _, g in self.data.groupby(dim, squeeze=True)]
41+
concat_over = [
42+
k for k, v in self.data.coords.items() if dim in v.dims and k != dim
43+
]
44+
actual = concat(datasets, self.data[dim], coords=concat_over)
45+
assert_identical(self.data, self.rectify_dim_order(actual))
5746

47+
@pytest.mark.parametrize("coords", ["different", "minimal", "all"])
48+
@pytest.mark.parametrize("dim", ["dim1", "dim2"])
49+
def test_concat_coords_kwarg(self, dim, coords):
50+
data = self.data.copy(deep=True)
5851
# make sure the coords argument behaves as expected
5952
data.coords["extra"] = ("dim4", np.arange(3))
60-
for dim in ["dim1", "dim2"]:
61-
datasets = [g for _, g in data.groupby(dim, squeeze=True)]
62-
actual = concat(datasets, data[dim], coords="all")
53+
datasets = [g for _, g in data.groupby(dim, squeeze=True)]
54+
55+
actual = concat(datasets, data[dim], coords=coords)
56+
if coords == "all":
6357
expected = np.array([data["extra"].values for _ in range(data.dims[dim])])
6458
assert_array_equal(actual["extra"].values, expected)
6559

66-
actual = concat(datasets, data[dim], coords="different")
67-
assert_equal(data["extra"], actual["extra"])
68-
actual = concat(datasets, data[dim], coords="minimal")
60+
else:
6961
assert_equal(data["extra"], actual["extra"])
7062

63+
def test_concat(self):
64+
65+
data = self.data
66+
split_data = [data.isel(dim1=slice(3)), data.isel(dim1=slice(3, None))]
67+
assert_identical(data, concat(split_data, "dim1"))
68+
69+
def test_concat_dim_precedence(self):
7170
# verify that the dim argument takes precedence over
7271
# concatenating dataset variables of the same name
73-
dim = (2 * data["dim1"]).rename("dim1")
74-
datasets = [g for _, g in data.groupby("dim1", squeeze=False)]
75-
expected = data.copy()
72+
dim = (2 * self.data["dim1"]).rename("dim1")
73+
datasets = [g for _, g in self.data.groupby("dim1", squeeze=False)]
74+
expected = self.data.copy()
7675
expected["dim1"] = dim
7776
assert_identical(expected, concat(datasets, dim))
7877

0 commit comments

Comments
 (0)