Skip to content

Commit 0e936eb

Browse files
Gobot1234cetanu
authored andcommitted
Map enum int's into Enums redux (#293)
Re-implement Enum to be faster along with being an open set --------- Co-authored-by: ydylla <[email protected]>
1 parent 0b0345f commit 0e936eb

File tree

6 files changed

+725
-511
lines changed

6 files changed

+725
-511
lines changed

poetry.lock

+417-480
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ importlib-metadata = { version = ">=1.6.0", python = "<3.8" }
1919
jinja2 = { version = ">=3.0.3", optional = true }
2020
python-dateutil = "^2.8"
2121
isort = {version = "^5.11.5", optional = true}
22+
typing-extensions = "^4.7.1"
2223

2324
[tool.poetry.dev-dependencies]
2425
asv = "^0.4.2"
@@ -37,7 +38,7 @@ sphinx-rtd-theme = "0.5.0"
3738
tomlkit = "^0.7.0"
3839
tox = "^3.15.1"
3940
pre-commit = "^2.17.0"
40-
pydantic = ">=1.8.0"
41+
pydantic = ">=1.8.0,<2"
4142

4243

4344
[tool.poetry.scripts]

src/betterproto/__init__.py

+8-30
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import dataclasses
2-
import enum
2+
import enum as builtin_enum
33
import json
44
import math
55
import struct
@@ -45,7 +45,8 @@
4545
safe_snake_case,
4646
snake_case,
4747
)
48-
from .grpc.grpclib_client import ServiceStub
48+
from .enum import Enum as Enum
49+
from .grpc.grpclib_client import ServiceStub as ServiceStub
4950

5051

5152
if TYPE_CHECKING:
@@ -140,7 +141,7 @@ def datetime_default_gen() -> datetime:
140141
NAN = "NaN"
141142

142143

143-
class Casing(enum.Enum):
144+
class Casing(builtin_enum.Enum):
144145
"""Casing constants for serialization."""
145146

146147
CAMEL = camel_case #: A camelCase sterilization function.
@@ -309,32 +310,6 @@ def map_field(
309310
)
310311

311312

312-
class Enum(enum.IntEnum):
313-
"""
314-
The base class for protobuf enumerations, all generated enumerations will inherit
315-
from this. Bases :class:`enum.IntEnum`.
316-
"""
317-
318-
@classmethod
319-
def from_string(cls, name: str) -> "Enum":
320-
"""Return the value which corresponds to the string name.
321-
322-
Parameters
323-
-----------
324-
name: :class:`str`
325-
The name of the enum member to get
326-
327-
Raises
328-
-------
329-
:exc:`ValueError`
330-
The member was not found in the Enum.
331-
"""
332-
try:
333-
return cls._member_map_[name] # type: ignore
334-
except KeyError as e:
335-
raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e
336-
337-
338313
def _pack_fmt(proto_type: str) -> str:
339314
"""Returns a little-endian format string for reading/writing binary."""
340315
return {
@@ -1185,7 +1160,7 @@ def _get_field_default_gen(cls, field: dataclasses.Field) -> Any:
11851160
return t
11861161
elif issubclass(t, Enum):
11871162
# Enums always default to zero.
1188-
return int
1163+
return t.try_value
11891164
elif t is datetime:
11901165
# Offsets are relative to 1970-01-01T00:00:00Z
11911166
return datetime_default_gen
@@ -1210,6 +1185,9 @@ def _postprocess_single(
12101185
elif meta.proto_type == TYPE_BOOL:
12111186
# Booleans use a varint encoding, so convert it to true/false.
12121187
value = value > 0
1188+
elif meta.proto_type == TYPE_ENUM:
1189+
# Convert enum ints to python enum instances
1190+
value = self._betterproto.cls_by_field[field_name].try_value(value)
12131191
elif wire_type in (WIRE_FIXED_32, WIRE_FIXED_64):
12141192
fmt = _pack_fmt(meta.proto_type)
12151193
value = struct.unpack(fmt, value)[0]

src/betterproto/enum.py

+199
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
from __future__ import annotations
2+
3+
import sys
4+
from enum import (
5+
EnumMeta,
6+
IntEnum,
7+
)
8+
from types import MappingProxyType
9+
from typing import (
10+
TYPE_CHECKING,
11+
Any,
12+
Dict,
13+
Optional,
14+
Tuple,
15+
)
16+
17+
18+
if TYPE_CHECKING:
19+
from collections.abc import (
20+
Generator,
21+
Mapping,
22+
)
23+
24+
from typing_extensions import (
25+
Never,
26+
Self,
27+
)
28+
29+
30+
def _is_descriptor(obj: object) -> bool:
31+
return (
32+
hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__")
33+
)
34+
35+
36+
class EnumType(EnumMeta if TYPE_CHECKING else type):
37+
_value_map_: Mapping[int, Enum]
38+
_member_map_: Mapping[str, Enum]
39+
40+
def __new__(
41+
mcs, name: str, bases: Tuple[type, ...], namespace: Dict[str, Any]
42+
) -> Self:
43+
value_map = {}
44+
member_map = {}
45+
46+
new_mcs = type(
47+
f"{name}Type",
48+
tuple(
49+
dict.fromkeys(
50+
[base.__class__ for base in bases if base.__class__ is not type]
51+
+ [EnumType, type]
52+
)
53+
), # reorder the bases so EnumType and type are last to avoid conflicts
54+
{"_value_map_": value_map, "_member_map_": member_map},
55+
)
56+
57+
members = {
58+
name: value
59+
for name, value in namespace.items()
60+
if not _is_descriptor(value) and name[0] != "_"
61+
}
62+
63+
cls = type.__new__(
64+
new_mcs,
65+
name,
66+
bases,
67+
{key: value for key, value in namespace.items() if key not in members},
68+
)
69+
# this allows us to disallow member access from other members as
70+
# members become proper class variables
71+
72+
for name, value in members.items():
73+
if _is_descriptor(value) or name[0] == "_":
74+
continue
75+
76+
member = value_map.get(value)
77+
if member is None:
78+
member = cls.__new__(cls, name=name, value=value) # type: ignore
79+
value_map[value] = member
80+
member_map[name] = member
81+
type.__setattr__(new_mcs, name, member)
82+
83+
return cls
84+
85+
if not TYPE_CHECKING:
86+
87+
def __call__(cls, value: int) -> Enum:
88+
try:
89+
return cls._value_map_[value]
90+
except (KeyError, TypeError):
91+
raise ValueError(f"{value!r} is not a valid {cls.__name__}") from None
92+
93+
def __iter__(cls) -> Generator[Enum, None, None]:
94+
yield from cls._member_map_.values()
95+
96+
if sys.version_info >= (3, 8): # 3.8 added __reversed__ to dict_values
97+
98+
def __reversed__(cls) -> Generator[Enum, None, None]:
99+
yield from reversed(cls._member_map_.values())
100+
101+
else:
102+
103+
def __reversed__(cls) -> Generator[Enum, None, None]:
104+
yield from reversed(tuple(cls._member_map_.values()))
105+
106+
def __getitem__(cls, key: str) -> Enum:
107+
return cls._member_map_[key]
108+
109+
@property
110+
def __members__(cls) -> MappingProxyType[str, Enum]:
111+
return MappingProxyType(cls._member_map_)
112+
113+
def __repr__(cls) -> str:
114+
return f"<enum {cls.__name__!r}>"
115+
116+
def __len__(cls) -> int:
117+
return len(cls._member_map_)
118+
119+
def __setattr__(cls, name: str, value: Any) -> Never:
120+
raise AttributeError(f"{cls.__name__}: cannot reassign Enum members.")
121+
122+
def __delattr__(cls, name: str) -> Never:
123+
raise AttributeError(f"{cls.__name__}: cannot delete Enum members.")
124+
125+
def __contains__(cls, member: object) -> bool:
126+
return isinstance(member, cls) and member.name in cls._member_map_
127+
128+
129+
class Enum(IntEnum if TYPE_CHECKING else int, metaclass=EnumType):
130+
"""
131+
The base class for protobuf enumerations, all generated enumerations will
132+
inherit from this. Emulates `enum.IntEnum`.
133+
"""
134+
135+
name: Optional[str]
136+
value: int
137+
138+
if not TYPE_CHECKING:
139+
140+
def __new__(cls, *, name: Optional[str], value: int) -> Self:
141+
self = super().__new__(cls, value)
142+
super().__setattr__(self, "name", name)
143+
super().__setattr__(self, "value", value)
144+
return self
145+
146+
def __str__(self) -> str:
147+
return self.name or "None"
148+
149+
def __repr__(self) -> str:
150+
return f"{self.__class__.__name__}.{self.name}"
151+
152+
def __setattr__(self, key: str, value: Any) -> Never:
153+
raise AttributeError(
154+
f"{self.__class__.__name__} Cannot reassign a member's attributes."
155+
)
156+
157+
def __delattr__(self, item: Any) -> Never:
158+
raise AttributeError(
159+
f"{self.__class__.__name__} Cannot delete a member's attributes."
160+
)
161+
162+
@classmethod
163+
def try_value(cls, value: int = 0) -> Self:
164+
"""Return the value which corresponds to the value.
165+
166+
Parameters
167+
-----------
168+
value: :class:`int`
169+
The value of the enum member to get.
170+
171+
Returns
172+
-------
173+
:class:`Enum`
174+
The corresponding member or a new instance of the enum if
175+
``value`` isn't actually a member.
176+
"""
177+
try:
178+
return cls._value_map_[value]
179+
except (KeyError, TypeError):
180+
return cls.__new__(cls, name=None, value=value)
181+
182+
@classmethod
183+
def from_string(cls, name: str) -> Self:
184+
"""Return the value which corresponds to the string name.
185+
186+
Parameters
187+
-----------
188+
name: :class:`str`
189+
The name of the enum member to get.
190+
191+
Raises
192+
-------
193+
:exc:`ValueError`
194+
The member was not found in the Enum.
195+
"""
196+
try:
197+
return cls._member_map_[name]
198+
except KeyError as e:
199+
raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e

tests/inputs/enum/test_enum.py

+20
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,23 @@ def enum_generator():
8282
yield Choice.THREE
8383

8484
assert Test(choices=enum_generator()).to_dict()["choices"] == ["ONE", "THREE"]
85+
86+
87+
def test_enum_mapped_on_parse():
88+
# test default value
89+
b = Test().parse(bytes(Test()))
90+
assert b.choice.name == Choice.ZERO.name
91+
assert b.choices == []
92+
93+
# test non default value
94+
a = Test().parse(bytes(Test(choice=Choice.ONE)))
95+
assert a.choice.name == Choice.ONE.name
96+
assert b.choices == []
97+
98+
# test repeated
99+
c = Test().parse(bytes(Test(choices=[Choice.THREE, Choice.FOUR])))
100+
assert c.choices[0].name == Choice.THREE.name
101+
assert c.choices[1].name == Choice.FOUR.name
102+
103+
# bonus: defaults after empty init are also mapped
104+
assert Test().choice.name == Choice.ZERO.name

tests/test_enum.py

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from typing import (
2+
Optional,
3+
Tuple,
4+
)
5+
6+
import pytest
7+
8+
import betterproto
9+
10+
11+
class Colour(betterproto.Enum):
12+
RED = 1
13+
GREEN = 2
14+
BLUE = 3
15+
16+
17+
PURPLE = Colour.__new__(Colour, name=None, value=4)
18+
19+
20+
@pytest.mark.parametrize(
21+
"member, str_value",
22+
[
23+
(Colour.RED, "RED"),
24+
(Colour.GREEN, "GREEN"),
25+
(Colour.BLUE, "BLUE"),
26+
],
27+
)
28+
def test_str(member: Colour, str_value: str) -> None:
29+
assert str(member) == str_value
30+
31+
32+
@pytest.mark.parametrize(
33+
"member, repr_value",
34+
[
35+
(Colour.RED, "Colour.RED"),
36+
(Colour.GREEN, "Colour.GREEN"),
37+
(Colour.BLUE, "Colour.BLUE"),
38+
],
39+
)
40+
def test_repr(member: Colour, repr_value: str) -> None:
41+
assert repr(member) == repr_value
42+
43+
44+
@pytest.mark.parametrize(
45+
"member, values",
46+
[
47+
(Colour.RED, ("RED", 1)),
48+
(Colour.GREEN, ("GREEN", 2)),
49+
(Colour.BLUE, ("BLUE", 3)),
50+
(PURPLE, (None, 4)),
51+
],
52+
)
53+
def test_name_values(member: Colour, values: Tuple[Optional[str], int]) -> None:
54+
assert (member.name, member.value) == values
55+
56+
57+
@pytest.mark.parametrize(
58+
"member, input_str",
59+
[
60+
(Colour.RED, "RED"),
61+
(Colour.GREEN, "GREEN"),
62+
(Colour.BLUE, "BLUE"),
63+
],
64+
)
65+
def test_from_string(member: Colour, input_str: str) -> None:
66+
assert Colour.from_string(input_str) == member
67+
68+
69+
@pytest.mark.parametrize(
70+
"member, input_int",
71+
[
72+
(Colour.RED, 1),
73+
(Colour.GREEN, 2),
74+
(Colour.BLUE, 3),
75+
(PURPLE, 4),
76+
],
77+
)
78+
def test_try_value(member: Colour, input_int: int) -> None:
79+
assert Colour.try_value(input_int) == member

0 commit comments

Comments
 (0)