Skip to content

Split up betterproto into separate modules #322

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,372 changes: 5 additions & 1,367 deletions src/betterproto/__init__.py

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions src/betterproto/casing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import enum
import keyword
import re

__all__ = ("Casing",)

# Word delimiters and symbols that will not be preserved when re-casing.
# language=PythonRegExp
SYMBOLS = "[^a-zA-Z0-9]*"
Expand Down Expand Up @@ -136,3 +139,10 @@ def lowercase_first(value: str) -> str:
def sanitize_name(value: str) -> str:
# https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles
return f"{value}_" if keyword.iskeyword(value) else value


class Casing(enum.Enum):
"""Casing constants for serialization."""

CAMEL = camel_case #: A camelCase sterilization function.
SNAKE = snake_case #: A snake_case sterilization function.
91 changes: 91 additions & 0 deletions src/betterproto/const.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Proto 3 data types
from datetime import datetime, timezone


TYPE_ENUM = "enum"
TYPE_BOOL = "bool"
TYPE_INT32 = "int32"
TYPE_INT64 = "int64"
TYPE_UINT32 = "uint32"
TYPE_UINT64 = "uint64"
TYPE_SINT32 = "sint32"
TYPE_SINT64 = "sint64"
TYPE_FLOAT = "float"
TYPE_DOUBLE = "double"
TYPE_FIXED32 = "fixed32"
TYPE_SFIXED32 = "sfixed32"
TYPE_FIXED64 = "fixed64"
TYPE_SFIXED64 = "sfixed64"
TYPE_STRING = "string"
TYPE_BYTES = "bytes"
TYPE_MESSAGE = "message"
TYPE_MAP = "map"


# Fields that use a fixed amount of space (4 or 8 bytes)
FIXED_TYPES = [
TYPE_FLOAT,
TYPE_DOUBLE,
TYPE_FIXED32,
TYPE_SFIXED32,
TYPE_FIXED64,
TYPE_SFIXED64,
]

# Fields that are numerical 64-bit types
INT_64_TYPES = [TYPE_INT64, TYPE_UINT64, TYPE_SINT64, TYPE_FIXED64, TYPE_SFIXED64]

# Fields that are efficiently packed when
PACKED_TYPES = [
TYPE_ENUM,
TYPE_BOOL,
TYPE_INT32,
TYPE_INT64,
TYPE_UINT32,
TYPE_UINT64,
TYPE_SINT32,
TYPE_SINT64,
TYPE_FLOAT,
TYPE_DOUBLE,
TYPE_FIXED32,
TYPE_SFIXED32,
TYPE_FIXED64,
TYPE_SFIXED64,
]

# Wire types
# https://developers.google.com/protocol-buffers/docs/encoding#structure
WIRE_VARINT = 0
WIRE_FIXED_64 = 1
WIRE_LEN_DELIM = 2
WIRE_FIXED_32 = 5

# Mappings of which Proto 3 types correspond to which wire types.
WIRE_VARINT_TYPES = [
TYPE_ENUM,
TYPE_BOOL,
TYPE_INT32,
TYPE_INT64,
TYPE_UINT32,
TYPE_UINT64,
TYPE_SINT32,
TYPE_SINT64,
]

WIRE_FIXED_32_TYPES = [TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32]
WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]


# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
def datetime_default_gen() -> datetime:
return datetime(1970, 1, 1, tzinfo=timezone.utc)


DATETIME_ZERO = datetime_default_gen()


# Special protobuf json doubles
INFINITY = "Infinity"
NEG_INFINITY = "-Infinity"
NAN = "NaN"
29 changes: 29 additions & 0 deletions src/betterproto/enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import enum

__all__ = ("Enum",)


class Enum(enum.IntEnum):
"""
The base class for protobuf enumerations, all generated enumerations will inherit
from this. Bases :class:`enum.IntEnum`.
"""

@classmethod
def from_string(cls, name: str) -> "Enum":
"""Return the value which corresponds to the string name.

Parameters
-----------
name: :class:`str`
The name of the enum member to get

Raises
-------
:exc:`ValueError`
The member was not found in the Enum.
"""
try:
return cls._member_map_[name] # type: ignore
except KeyError as e:
raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e
2 changes: 2 additions & 0 deletions src/betterproto/grpc/grpclib_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
_MessageLike = Union[T, ST]
_MessageSource = Union[Iterable[ST], AsyncIterable[ST]]

__all__ = ("ServiceStub",)


class ServiceStub(ABC):
"""
Expand Down
204 changes: 204 additions & 0 deletions src/betterproto/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import dataclasses
from datetime import timedelta
import struct
from typing import Any, List, Union, Tuple

from typing import Generator

from .message import _Duration, _Timestamp, _get_wrapper
from .const import *


def _pack_fmt(proto_type: str) -> str:
"""Returns a little-endian format string for reading/writing binary."""
return {
TYPE_DOUBLE: "<d",
TYPE_FLOAT: "<f",
TYPE_FIXED32: "<I",
TYPE_FIXED64: "<Q",
TYPE_SFIXED32: "<i",
TYPE_SFIXED64: "<q",
}[proto_type]


def encode_varint(value: int) -> bytes:
"""Encodes a single varint value for serialization."""
b: List[int] = []

if value < 0:
value += 1 << 64

bits = value & 0x7F
value >>= 7
while value:
b.append(0x80 | bits)
bits = value & 0x7F
value >>= 7
return bytes(b + [bits])


def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes:
"""Adjusts values before serialization."""
if proto_type in (
TYPE_ENUM,
TYPE_BOOL,
TYPE_INT32,
TYPE_INT64,
TYPE_UINT32,
TYPE_UINT64,
):
return encode_varint(value)
elif proto_type in (TYPE_SINT32, TYPE_SINT64):
# Handle zig-zag encoding.
return encode_varint(value << 1 if value >= 0 else (value << 1) ^ (~0))
elif proto_type in FIXED_TYPES:
return struct.pack(_pack_fmt(proto_type), value)
elif proto_type == TYPE_STRING:
return value.encode("utf-8")
elif proto_type == TYPE_MESSAGE:
if isinstance(value, datetime):
# Convert the `datetime` to a timestamp message.
seconds = int(value.timestamp())
nanos = int(value.microsecond * 1e3)
value = _Timestamp(seconds=seconds, nanos=nanos)
elif isinstance(value, timedelta):
# Convert the `timedelta` to a duration message.
total_ms = value // timedelta(microseconds=1)
seconds = int(total_ms / 1e6)
nanos = int((total_ms % 1e6) * 1e3)
value = _Duration(seconds=seconds, nanos=nanos)
elif wraps:
if value is None:
return b""
value = _get_wrapper(wraps)(value=value)

return bytes(value)

return value


def _serialize_single(
field_number: int,
proto_type: str,
value: Any,
*,
serialize_empty: bool = False,
wraps: str = "",
) -> bytes:
"""Serializes a single field and value."""
value = _preprocess_single(proto_type, wraps, value)

output = bytearray()
if proto_type in WIRE_VARINT_TYPES:
key = encode_varint(field_number << 3)
output += key + value
elif proto_type in WIRE_FIXED_32_TYPES:
key = encode_varint((field_number << 3) | 5)
output += key + value
elif proto_type in WIRE_FIXED_64_TYPES:
key = encode_varint((field_number << 3) | 1)
output += key + value
elif proto_type in WIRE_LEN_DELIM_TYPES:
if len(value) or serialize_empty or wraps:
key = encode_varint((field_number << 3) | 2)
output += key + encode_varint(len(value)) + value
else:
raise NotImplementedError(proto_type)

return bytes(output)


def _parse_float(value: Any) -> float:
"""Parse the given value to a float

Parameters
----------
value : Any
Value to parse

Returns
-------
float
Parsed value
"""
if value == INFINITY:
return float("inf")
if value == NEG_INFINITY:
return -float("inf")
if value == NAN:
return float("nan")
return float(value)


def _dump_float(value: float) -> Union[float, str]:
"""Dump the given float to JSON

Parameters
----------
value : float
Value to dump

Returns
-------
Union[float, str]
Dumped valid, either a float or the strings
"Infinity" or "-Infinity"
"""
if value == float("inf"):
return INFINITY
if value == -float("inf"):
return NEG_INFINITY
if value == float("nan"):
return NAN
return value


def decode_varint(buffer: bytes, pos: int) -> Tuple[int, int]:
"""
Decode a single varint value from a byte buffer. Returns the value and the
new position in the buffer.
"""
result = 0
shift = 0
while 1:
b = buffer[pos]
result |= (b & 0x7F) << shift
pos += 1
if not (b & 0x80):
return result, pos
shift += 7
if shift >= 64:
raise ValueError("Too many bytes when decoding varint.")


@dataclasses.dataclass(frozen=True)
class ParsedField:
number: int
wire_type: int
value: Any
raw: bytes


def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
i = 0
while i < len(value):
start = i
num_wire, i = decode_varint(value, i)
number = num_wire >> 3
wire_type = num_wire & 0x7

decoded: Any = None
if wire_type == WIRE_VARINT:
decoded, i = decode_varint(value, i)
elif wire_type == WIRE_FIXED_64:
decoded, i = value[i : i + 8], i + 8
elif wire_type == WIRE_LEN_DELIM:
length, i = decode_varint(value, i)
decoded = value[i : i + length]
i += length
elif wire_type == WIRE_FIXED_32:
decoded, i = value[i : i + 4], i + 4

yield ParsedField(
number=number, wire_type=wire_type, value=decoded, raw=value[start:i]
)
Loading