Skip to content

Map enum int's into Enums redux #293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
735d473
Add issue templates
Gobot1234 Oct 19, 2020
7ca7506
Update to the newer issue templates
Gobot1234 Oct 26, 2021
7dc936d
Re-implement Enum to be faster along with being an open set
Gobot1234 Nov 19, 2021
079d53f
Extra stuff
Gobot1234 Nov 19, 2021
9f0e388
Update __str__
Gobot1234 Nov 29, 2021
25c037a
Remove issue templates
Gobot1234 Nov 29, 2021
8eb8325
Final few things
Gobot1234 Nov 29, 2021
0f503df
Remove PR templates
Gobot1234 Nov 29, 2021
04e6ab6
Merge branch 'master' into enum
Gobot1234 Nov 29, 2021
67cc0cd
Update poetry.lock
Gobot1234 Nov 29, 2021
c7d407e
Fix 3.9 type usage
Gobot1234 Nov 29, 2021
9f8668e
Fix 585 usage
Gobot1234 Dec 24, 2021
6d6468c
Merge branch 'enum' of https://github.com/Gobot1234/python-betterprot…
Gobot1234 Dec 24, 2021
cd86e94
Lock again
Gobot1234 Dec 24, 2021
549a1ff
Merge branch 'master' into enum
Gobot1234 Dec 24, 2021
a6d6616
Fix 585 stuff
Gobot1234 Dec 24, 2021
28b8e9b
Merge branch 'enum' of https://github.com/Gobot1234/python-betterprot…
Gobot1234 Dec 24, 2021
27bc065
Fix regression
Gobot1234 Dec 24, 2021
02a96c9
Fix another regression
Gobot1234 Dec 24, 2021
adb2d77
Merge branch 'master' into enum
Gobot1234 Apr 20, 2022
6b0331b
Relock and move Enum into if not TYPE_CHECKING for mypy
Gobot1234 Apr 20, 2022
0b4ecbd
IDK why the lock isn't ok
Gobot1234 Apr 21, 2022
432945f
Update black and relock?
Gobot1234 Apr 21, 2022
3e23469
Relock
Gobot1234 Apr 23, 2022
30249b1
Remove annotation and parsing from int
Gobot1234 Apr 23, 2022
a74c210
Relock again
Gobot1234 Apr 23, 2022
72696d3
Merge branch 'master' into enum
Gobot1234 Aug 31, 2022
bd249c4
Relock and make typecheckers happier
Gobot1234 Aug 31, 2022
74b66b1
Remove unused TypeVar
Gobot1234 Aug 31, 2022
1a7e1db
Format
Gobot1234 Aug 31, 2022
70e54bb
Add isort to poe format
Gobot1234 Dec 2, 2022
752dcd7
Merge master
Gobot1234 Aug 17, 2023
0582bb9
Fix a couple of things
Gobot1234 Aug 17, 2023
6d93d46
Fix tests to pin to <pydantic v2
Gobot1234 Aug 17, 2023
3a1d5ba
Merge branch 'master' into enum
Gobot1234 Aug 17, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
897 changes: 417 additions & 480 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ importlib-metadata = { version = ">=1.6.0", python = "<3.8" }
jinja2 = { version = ">=3.0.3", optional = true }
python-dateutil = "^2.8"
isort = {version = "^5.11.5", optional = true}
typing-extensions = "^4.7.1"

[tool.poetry.dev-dependencies]
asv = "^0.4.2"
Expand All @@ -37,7 +38,7 @@ sphinx-rtd-theme = "0.5.0"
tomlkit = "^0.7.0"
tox = "^3.15.1"
pre-commit = "^2.17.0"
pydantic = ">=1.8.0"
pydantic = ">=1.8.0,<2"


[tool.poetry.scripts]
Expand Down
38 changes: 8 additions & 30 deletions src/betterproto/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses
import enum
import enum as builtin_enum
import json
import math
import struct
Expand Down Expand Up @@ -43,7 +43,8 @@
safe_snake_case,
snake_case,
)
from .grpc.grpclib_client import ServiceStub
from .enum import Enum as Enum
from .grpc.grpclib_client import ServiceStub as ServiceStub


# Proto 3 data types
Expand Down Expand Up @@ -136,7 +137,7 @@ def datetime_default_gen() -> datetime:
NAN = "NaN"


class Casing(enum.Enum):
class Casing(builtin_enum.Enum):
"""Casing constants for serialization."""

CAMEL = camel_case #: A camelCase sterilization function.
Expand Down Expand Up @@ -305,32 +306,6 @@ def map_field(
)


class Enum(enum.IntEnum):
"""
The base class for protobuf enumerations, all generated enumerations will inherit
from this. Bases :class:`enum.IntEnum`.
"""

@classmethod
def from_string(cls, name: str) -> "Enum":
"""Return the value which corresponds to the string name.

Parameters
-----------
name: :class:`str`
The name of the enum member to get

Raises
-------
:exc:`ValueError`
The member was not found in the Enum.
"""
try:
return cls._member_map_[name] # type: ignore
except KeyError as e:
raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e


def _pack_fmt(proto_type: str) -> str:
"""Returns a little-endian format string for reading/writing binary."""
return {
Expand Down Expand Up @@ -930,7 +905,7 @@ def _get_field_default_gen(cls, field: dataclasses.Field) -> Any:
return t
elif issubclass(t, Enum):
# Enums always default to zero.
return int
return t.try_value
elif t is datetime:
# Offsets are relative to 1970-01-01T00:00:00Z
return datetime_default_gen
Expand All @@ -955,6 +930,9 @@ def _postprocess_single(
elif meta.proto_type == TYPE_BOOL:
# Booleans use a varint encoding, so convert it to true/false.
value = value > 0
elif meta.proto_type == TYPE_ENUM:
# Convert enum ints to python enum instances
value = self._betterproto.cls_by_field[field_name].try_value(value)
elif wire_type in (WIRE_FIXED_32, WIRE_FIXED_64):
fmt = _pack_fmt(meta.proto_type)
value = struct.unpack(fmt, value)[0]
Expand Down
199 changes: 199 additions & 0 deletions src/betterproto/enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
from __future__ import annotations

import sys
from enum import (
EnumMeta,
IntEnum,
)
from types import MappingProxyType
from typing import (
TYPE_CHECKING,
Any,
Dict,
Optional,
Tuple,
)


if TYPE_CHECKING:
from collections.abc import (
Generator,
Mapping,
)

from typing_extensions import (
Never,
Self,
)


def _is_descriptor(obj: object) -> bool:
return (
hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__")
)


class EnumType(EnumMeta if TYPE_CHECKING else type):
_value_map_: Mapping[int, Enum]
_member_map_: Mapping[str, Enum]

def __new__(
mcs, name: str, bases: Tuple[type, ...], namespace: Dict[str, Any]
) -> Self:
value_map = {}
member_map = {}

new_mcs = type(
f"{name}Type",
tuple(
dict.fromkeys(
[base.__class__ for base in bases if base.__class__ is not type]
+ [EnumType, type]
)
), # reorder the bases so EnumType and type are last to avoid conflicts
{"_value_map_": value_map, "_member_map_": member_map},
)

members = {
name: value
for name, value in namespace.items()
if not _is_descriptor(value) and name[0] != "_"
}

cls = type.__new__(
new_mcs,
name,
bases,
{key: value for key, value in namespace.items() if key not in members},
)
# this allows us to disallow member access from other members as
# members become proper class variables

for name, value in members.items():
if _is_descriptor(value) or name[0] == "_":
continue

member = value_map.get(value)
if member is None:
member = cls.__new__(cls, name=name, value=value) # type: ignore
value_map[value] = member
member_map[name] = member
type.__setattr__(new_mcs, name, member)

return cls

if not TYPE_CHECKING:

def __call__(cls, value: int) -> Enum:
try:
return cls._value_map_[value]
except (KeyError, TypeError):
raise ValueError(f"{value!r} is not a valid {cls.__name__}") from None

def __iter__(cls) -> Generator[Enum, None, None]:
yield from cls._member_map_.values()

if sys.version_info >= (3, 8): # 3.8 added __reversed__ to dict_values

def __reversed__(cls) -> Generator[Enum, None, None]:
yield from reversed(cls._member_map_.values())

else:

def __reversed__(cls) -> Generator[Enum, None, None]:
yield from reversed(tuple(cls._member_map_.values()))

def __getitem__(cls, key: str) -> Enum:
return cls._member_map_[key]

@property
def __members__(cls) -> MappingProxyType[str, Enum]:
return MappingProxyType(cls._member_map_)

def __repr__(cls) -> str:
return f"<enum {cls.__name__!r}>"

def __len__(cls) -> int:
return len(cls._member_map_)

def __setattr__(cls, name: str, value: Any) -> Never:
raise AttributeError(f"{cls.__name__}: cannot reassign Enum members.")

def __delattr__(cls, name: str) -> Never:
raise AttributeError(f"{cls.__name__}: cannot delete Enum members.")

def __contains__(cls, member: object) -> bool:
return isinstance(member, cls) and member.name in cls._member_map_


class Enum(IntEnum if TYPE_CHECKING else int, metaclass=EnumType):
"""
The base class for protobuf enumerations, all generated enumerations will
inherit from this. Emulates `enum.IntEnum`.
"""

name: Optional[str]
value: int

if not TYPE_CHECKING:

def __new__(cls, *, name: Optional[str], value: int) -> Self:
self = super().__new__(cls, value)
super().__setattr__(self, "name", name)
super().__setattr__(self, "value", value)
return self

def __str__(self) -> str:
return self.name or "None"

def __repr__(self) -> str:
return f"{self.__class__.__name__}.{self.name}"

def __setattr__(self, key: str, value: Any) -> Never:
raise AttributeError(
f"{self.__class__.__name__} Cannot reassign a member's attributes."
)

def __delattr__(self, item: Any) -> Never:
raise AttributeError(
f"{self.__class__.__name__} Cannot delete a member's attributes."
)

@classmethod
def try_value(cls, value: int = 0) -> Self:
"""Return the value which corresponds to the value.

Parameters
-----------
value: :class:`int`
The value of the enum member to get.

Returns
-------
:class:`Enum`
The corresponding member or a new instance of the enum if
``value`` isn't actually a member.
"""
try:
return cls._value_map_[value]
except (KeyError, TypeError):
return cls.__new__(cls, name=None, value=value)

@classmethod
def from_string(cls, name: str) -> Self:
"""Return the value which corresponds to the string name.

Parameters
-----------
name: :class:`str`
The name of the enum member to get.

Raises
-------
:exc:`ValueError`
The member was not found in the Enum.
"""
try:
return cls._member_map_[name]
except KeyError as e:
raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e
20 changes: 20 additions & 0 deletions tests/inputs/enum/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,23 @@ def enum_generator():
yield Choice.THREE

assert Test(choices=enum_generator()).to_dict()["choices"] == ["ONE", "THREE"]


def test_enum_mapped_on_parse():
# test default value
b = Test().parse(bytes(Test()))
assert b.choice.name == Choice.ZERO.name
assert b.choices == []

# test non default value
a = Test().parse(bytes(Test(choice=Choice.ONE)))
assert a.choice.name == Choice.ONE.name
assert b.choices == []

# test repeated
c = Test().parse(bytes(Test(choices=[Choice.THREE, Choice.FOUR])))
assert c.choices[0].name == Choice.THREE.name
assert c.choices[1].name == Choice.FOUR.name

# bonus: defaults after empty init are also mapped
assert Test().choice.name == Choice.ZERO.name
79 changes: 79 additions & 0 deletions tests/test_enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from typing import (
Optional,
Tuple,
)

import pytest

import betterproto


class Colour(betterproto.Enum):
RED = 1
GREEN = 2
BLUE = 3


PURPLE = Colour.__new__(Colour, name=None, value=4)


@pytest.mark.parametrize(
"member, str_value",
[
(Colour.RED, "RED"),
(Colour.GREEN, "GREEN"),
(Colour.BLUE, "BLUE"),
],
)
def test_str(member: Colour, str_value: str) -> None:
assert str(member) == str_value


@pytest.mark.parametrize(
"member, repr_value",
[
(Colour.RED, "Colour.RED"),
(Colour.GREEN, "Colour.GREEN"),
(Colour.BLUE, "Colour.BLUE"),
],
)
def test_repr(member: Colour, repr_value: str) -> None:
assert repr(member) == repr_value


@pytest.mark.parametrize(
"member, values",
[
(Colour.RED, ("RED", 1)),
(Colour.GREEN, ("GREEN", 2)),
(Colour.BLUE, ("BLUE", 3)),
(PURPLE, (None, 4)),
],
)
def test_name_values(member: Colour, values: Tuple[Optional[str], int]) -> None:
assert (member.name, member.value) == values


@pytest.mark.parametrize(
"member, input_str",
[
(Colour.RED, "RED"),
(Colour.GREEN, "GREEN"),
(Colour.BLUE, "BLUE"),
],
)
def test_from_string(member: Colour, input_str: str) -> None:
assert Colour.from_string(input_str) == member


@pytest.mark.parametrize(
"member, input_int",
[
(Colour.RED, 1),
(Colour.GREEN, 2),
(Colour.BLUE, 3),
(PURPLE, 4),
],
)
def test_try_value(member: Colour, input_int: int) -> None:
assert Colour.try_value(input_int) == member