Skip to content

Commit 3141a05

Browse files
committed
update
1 parent 63ea00e commit 3141a05

File tree

2 files changed

+48
-12
lines changed

2 files changed

+48
-12
lines changed

cf_xarray/groupers.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,33 @@
1+
from dataclasses import dataclass
2+
13
import numpy as np
24
import pandas as pd
3-
from xarray.groupers import EncodedGroups, Grouper
5+
from xarray.groupers import EncodedGroups, UniqueGrouper
46

57

6-
class FlagGrouper(Grouper):
8+
@dataclass
9+
class FlagGrouper(UniqueGrouper):
710
def factorize(self, group) -> EncodedGroups:
8-
assert "flag_values" in group.attrs
9-
assert "flag_meanings" in group.attrs
11+
if "flag_values" not in group.attrs or "flag_meanings" not in group.attrs:
12+
raise ValueError(
13+
"FlagGrouper can only be used with flag variables that have"
14+
"`flag_values` and `flag_meanings` specified in attrs."
15+
)
1016

1117
values = np.array(group.attrs["flag_values"])
1218
full_index = pd.Index(group.attrs["flag_meanings"].split(" "))
1319

14-
if group.dtype.kind in "iu" and (np.diff(values) == 1).all():
15-
# optimize
16-
codes = group.data - values[0].astype(int)
17-
else:
18-
codes, _ = pd.factorize(group.data.ravel())
20+
self.labels = values
21+
ret = super().factorize(group)
1922

20-
codes_da = group.copy(data=codes.reshape(group.shape))
23+
codes_da = ret.codes
2124
codes_da.attrs.pop("flag_values")
2225
codes_da.attrs.pop("flag_meanings")
2326

24-
return EncodedGroups(codes=codes_da, full_index=full_index)
27+
ret.codes = codes_da
28+
ret.full_index = full_index
29+
30+
return ret
2531

2632
def reset(self):
2733
pass

cf_xarray/tests/test_groupers.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,45 @@
11
import numpy as np
2+
import pytest
3+
import xarray as xr
24
from xarray.testing import assert_identical
35

46
from cf_xarray.datasets import flag_excl
57
from cf_xarray.groupers import FlagGrouper
68

79

810
def test_flag_grouper():
9-
ds = flag_excl.to_dataset().set_coords("flag_var")
11+
ds = flag_excl.to_dataset().set_coords("flag_var").copy(deep=True)
1012
ds["foo"] = ("time", np.arange(8))
1113
actual = ds.groupby(flag_var=FlagGrouper()).mean()
1214
expected = ds.groupby("flag_var").mean()
1315
expected["flag_var"] = ["flag_1", "flag_2", "flag_3"]
1416
expected["flag_var"].attrs["standard_name"] = "flag_mutual_exclusive"
1517
assert_identical(actual, expected)
18+
19+
del ds.flag_var.attrs["flag_values"]
20+
with pytest.raises(ValueError):
21+
ds.groupby(flag_var=FlagGrouper())
22+
23+
ds.flag_var.attrs["flag_values"] = [0, 1, 2]
24+
del ds.flag_var.attrs["flag_meanings"]
25+
with pytest.raises(ValueError):
26+
ds.groupby(flag_var=FlagGrouper())
27+
28+
29+
@pytest.mark.parametrize(
30+
"values",
31+
[
32+
[1, 2],
33+
[1, 2, 3], # value out of range of flag_values
34+
],
35+
)
36+
def test_flag_grouper_optimized(values):
37+
ds = xr.Dataset(
38+
{"foo": ("x", values, {"flag_values": [0, 1, 2], "flag_meanings": "a b c"})}
39+
)
40+
ret = FlagGrouper().factorize(ds.foo)
41+
expected = ds.foo
42+
expected.data[ds.foo.data > 2] = -1
43+
del ds.foo.attrs["flag_meanings"]
44+
del ds.foo.attrs["flag_values"]
45+
assert_identical(ret.codes, ds.foo)

0 commit comments

Comments
 (0)