|
18 | 18 |
|
19 | 19 |
|
20 | 20 | class TestConcatDataset:
|
21 |
| - def test_concat(self): |
22 |
| - # TODO: simplify and split this test case |
23 |
| - |
24 |
| - # drop the third dimension to keep things relatively understandable |
25 |
| - data = create_test_data() |
26 |
| - for k in list(data.variables): |
27 |
| - if "dim3" in data[k].dims: |
28 |
| - del data[k] |
29 |
| - |
30 |
| - split_data = [data.isel(dim1=slice(3)), data.isel(dim1=slice(3, None))] |
31 |
| - assert_identical(data, concat(split_data, "dim1")) |
32 |
| - |
33 |
| - def rectify_dim_order(dataset): |
34 |
| - # return a new dataset with all variable dimensions transposed into |
35 |
| - # the order in which they are found in `data` |
36 |
| - return Dataset( |
37 |
| - {k: v.transpose(*data[k].dims) for k, v in dataset.data_vars.items()}, |
38 |
| - dataset.coords, |
39 |
| - attrs=dataset.attrs, |
40 |
| - ) |
41 |
| - |
42 |
| - for dim in ["dim1", "dim2"]: |
43 |
| - datasets = [g for _, g in data.groupby(dim, squeeze=False)] |
44 |
| - assert_identical(data, concat(datasets, dim)) |
| 21 | + @pytest.fixture(autouse=True) |
| 22 | + def setUp(self): |
| 23 | + self.data = create_test_data().drop_dims("dim3") |
| 24 | + |
| 25 | + def rectify_dim_order(self, dataset): |
| 26 | + # return a new dataset with all variable dimensions transposed into |
| 27 | + # the order in which they are found in `data` |
| 28 | + return Dataset( |
| 29 | + {k: v.transpose(*self.data[k].dims) for k, v in dataset.data_vars.items()}, |
| 30 | + dataset.coords, |
| 31 | + attrs=dataset.attrs, |
| 32 | + ) |
45 | 33 |
|
46 |
| - dim = "dim2" |
47 |
| - assert_identical(data, concat(datasets, data[dim])) |
48 |
| - assert_identical(data, concat(datasets, data[dim], coords="minimal")) |
| 34 | + @pytest.mark.parametrize("coords", ["different", "minimal"]) |
| 35 | + @pytest.mark.parametrize("dim", ["dim1", "dim2"]) |
| 36 | + def test_concat_simple(self, dim, coords): |
| 37 | + datasets = [g for _, g in self.data.groupby(dim, squeeze=False)] |
| 38 | + assert_identical(self.data, concat(datasets, dim, coords=coords)) |
49 | 39 |
|
50 |
| - datasets = [g for _, g in data.groupby(dim, squeeze=True)] |
51 |
| - concat_over = [k for k, v in data.coords.items() if dim in v.dims and k != dim] |
52 |
| - actual = concat(datasets, data[dim], coords=concat_over) |
53 |
| - assert_identical(data, rectify_dim_order(actual)) |
54 |
| - |
55 |
| - actual = concat(datasets, data[dim], coords="different") |
56 |
| - assert_identical(data, rectify_dim_order(actual)) |
| 40 | + datasets = [g for _, g in self.data.groupby(dim, squeeze=True)] |
| 41 | + concat_over = [ |
| 42 | + k for k, v in self.data.coords.items() if dim in v.dims and k != dim |
| 43 | + ] |
| 44 | + actual = concat(datasets, self.data[dim], coords=concat_over) |
| 45 | + assert_identical(self.data, self.rectify_dim_order(actual)) |
57 | 46 |
|
| 47 | + @pytest.mark.parametrize("coords", ["different", "minimal", "all"]) |
| 48 | + @pytest.mark.parametrize("dim", ["dim1", "dim2"]) |
| 49 | + def test_concat_coords_kwarg(self, dim, coords): |
| 50 | + data = self.data.copy(deep=True) |
58 | 51 | # make sure the coords argument behaves as expected
|
59 | 52 | data.coords["extra"] = ("dim4", np.arange(3))
|
60 |
| - for dim in ["dim1", "dim2"]: |
61 |
| - datasets = [g for _, g in data.groupby(dim, squeeze=True)] |
62 |
| - actual = concat(datasets, data[dim], coords="all") |
| 53 | + datasets = [g for _, g in data.groupby(dim, squeeze=True)] |
| 54 | + |
| 55 | + actual = concat(datasets, data[dim], coords=coords) |
| 56 | + if coords == "all": |
63 | 57 | expected = np.array([data["extra"].values for _ in range(data.dims[dim])])
|
64 | 58 | assert_array_equal(actual["extra"].values, expected)
|
65 | 59 |
|
66 |
| - actual = concat(datasets, data[dim], coords="different") |
67 |
| - assert_equal(data["extra"], actual["extra"]) |
68 |
| - actual = concat(datasets, data[dim], coords="minimal") |
| 60 | + else: |
69 | 61 | assert_equal(data["extra"], actual["extra"])
|
70 | 62 |
|
| 63 | + def test_concat(self): |
| 64 | + |
| 65 | + data = self.data |
| 66 | + split_data = [data.isel(dim1=slice(3)), data.isel(dim1=slice(3, None))] |
| 67 | + assert_identical(data, concat(split_data, "dim1")) |
| 68 | + |
| 69 | + def test_concat_dim_precedence(self): |
71 | 70 | # verify that the dim argument takes precedence over
|
72 | 71 | # concatenating dataset variables of the same name
|
73 |
| - dim = (2 * data["dim1"]).rename("dim1") |
74 |
| - datasets = [g for _, g in data.groupby("dim1", squeeze=False)] |
75 |
| - expected = data.copy() |
| 72 | + dim = (2 * self.data["dim1"]).rename("dim1") |
| 73 | + datasets = [g for _, g in self.data.groupby("dim1", squeeze=False)] |
| 74 | + expected = self.data.copy() |
76 | 75 | expected["dim1"] = dim
|
77 | 76 | assert_identical(expected, concat(datasets, dim))
|
78 | 77 |
|
|
0 commit comments