diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index 5e514d5c1..9dea98842 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -1,1316 +1,9 @@ -import dataclasses -import enum -import inspect -import json -import math -import struct -import sys -import typing -from abc import ABC -from base64 import b64decode, b64encode -from datetime import datetime, timedelta, timezone -from dateutil.parser import isoparse -from typing import ( - Any, - Callable, - Dict, - Generator, - Iterable, - List, - Optional, - Set, - Tuple, - Type, - Union, - get_type_hints, -) - -from ._types import T from ._version import __version__ -from .casing import camel_case, safe_snake_case, snake_case -from .grpc.grpclib_client import ServiceStub - - -# Proto 3 data types -TYPE_ENUM = "enum" -TYPE_BOOL = "bool" -TYPE_INT32 = "int32" -TYPE_INT64 = "int64" -TYPE_UINT32 = "uint32" -TYPE_UINT64 = "uint64" -TYPE_SINT32 = "sint32" -TYPE_SINT64 = "sint64" -TYPE_FLOAT = "float" -TYPE_DOUBLE = "double" -TYPE_FIXED32 = "fixed32" -TYPE_SFIXED32 = "sfixed32" -TYPE_FIXED64 = "fixed64" -TYPE_SFIXED64 = "sfixed64" -TYPE_STRING = "string" -TYPE_BYTES = "bytes" -TYPE_MESSAGE = "message" -TYPE_MAP = "map" - - -# Fields that use a fixed amount of space (4 or 8 bytes) -FIXED_TYPES = [ - TYPE_FLOAT, - TYPE_DOUBLE, - TYPE_FIXED32, - TYPE_SFIXED32, - TYPE_FIXED64, - TYPE_SFIXED64, -] - -# Fields that are numerical 64-bit types -INT_64_TYPES = [TYPE_INT64, TYPE_UINT64, TYPE_SINT64, TYPE_FIXED64, TYPE_SFIXED64] - -# Fields that are efficiently packed when -PACKED_TYPES = [ - TYPE_ENUM, - TYPE_BOOL, - TYPE_INT32, - TYPE_INT64, - TYPE_UINT32, - TYPE_UINT64, - TYPE_SINT32, - TYPE_SINT64, - TYPE_FLOAT, - TYPE_DOUBLE, - TYPE_FIXED32, - TYPE_SFIXED32, - TYPE_FIXED64, - TYPE_SFIXED64, -] - -# Wire types -# https://developers.google.com/protocol-buffers/docs/encoding#structure -WIRE_VARINT = 0 -WIRE_FIXED_64 = 1 -WIRE_LEN_DELIM = 2 -WIRE_FIXED_32 = 5 - -# Mappings of which Proto 3 types correspond to which wire types. -WIRE_VARINT_TYPES = [ - TYPE_ENUM, - TYPE_BOOL, - TYPE_INT32, - TYPE_INT64, - TYPE_UINT32, - TYPE_UINT64, - TYPE_SINT32, - TYPE_SINT64, -] - -WIRE_FIXED_32_TYPES = [TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32] -WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64] -WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP] - - -# Protobuf datetimes start at the Unix Epoch in 1970 in UTC. -def datetime_default_gen() -> datetime: - return datetime(1970, 1, 1, tzinfo=timezone.utc) - - -DATETIME_ZERO = datetime_default_gen() - - -# Special protobuf json doubles -INFINITY = "Infinity" -NEG_INFINITY = "-Infinity" -NAN = "NaN" - - -class Casing(enum.Enum): - """Casing constants for serialization.""" - - CAMEL = camel_case #: A camelCase sterilization function. - SNAKE = snake_case #: A snake_case sterilization function. - - -PLACEHOLDER: Any = object() - - -@dataclasses.dataclass(frozen=True) -class FieldMetadata: - """Stores internal metadata used for parsing & serialization.""" - - # Protobuf field number - number: int - # Protobuf type name - proto_type: str - # Map information if the proto_type is a map - map_types: Optional[Tuple[str, str]] = None - # Groups several "one-of" fields together - group: Optional[str] = None - # Describes the wrapped type (e.g. when using google.protobuf.BoolValue) - wraps: Optional[str] = None - # Is the field optional - optional: Optional[bool] = False - - @staticmethod - def get(field: dataclasses.Field) -> "FieldMetadata": - """Returns the field metadata for a dataclass field.""" - return field.metadata["betterproto"] - - -def dataclass_field( - number: int, - proto_type: str, - *, - map_types: Optional[Tuple[str, str]] = None, - group: Optional[str] = None, - wraps: Optional[str] = None, - optional: bool = False, -) -> dataclasses.Field: - """Creates a dataclass field with attached protobuf metadata.""" - return dataclasses.field( - default=None if optional else PLACEHOLDER, - metadata={ - "betterproto": FieldMetadata( - number, proto_type, map_types, group, wraps, optional - ) - }, - ) - - -# Note: the fields below return `Any` to prevent type errors in the generated -# data classes since the types won't match with `Field` and they get swapped -# out at runtime. The generated dataclass variables are still typed correctly. - - -def enum_field(number: int, group: Optional[str] = None, optional: bool = False) -> Any: - return dataclass_field(number, TYPE_ENUM, group=group, optional=optional) - - -def bool_field(number: int, group: Optional[str] = None, optional: bool = False) -> Any: - return dataclass_field(number, TYPE_BOOL, group=group, optional=optional) - - -def int32_field( - number: int, group: Optional[str] = None, optional: bool = False -) -> Any: - return dataclass_field(number, TYPE_INT32, group=group, optional=optional) - - -def int64_field( - number: int, group: Optional[str] = None, optional: bool = False -) -> Any: - return dataclass_field(number, TYPE_INT64, group=group, optional=optional) - - -def uint32_field( - number: int, group: Optional[str] = None, optional: bool = False -) -> Any: - return dataclass_field(number, TYPE_UINT32, group=group, optional=optional) - - -def uint64_field( - number: int, group: Optional[str] = None, optional: bool = False -) -> Any: - return dataclass_field(number, TYPE_UINT64, group=group, optional=optional) - - -def sint32_field( - number: int, group: Optional[str] = None, optional: bool = False -) -> Any: - return dataclass_field(number, TYPE_SINT32, group=group, optional=optional) - - -def sint64_field( - number: int, group: Optional[str] = None, optional: bool = False -) -> Any: - return dataclass_field(number, TYPE_SINT64, group=group, optional=optional) - - -def float_field( - number: int, group: Optional[str] = None, optional: bool = False -) -> Any: - return dataclass_field(number, TYPE_FLOAT, group=group, optional=optional) - - -def double_field( - number: int, group: Optional[str] = None, optional: bool = False -) -> Any: - return dataclass_field(number, TYPE_DOUBLE, group=group, optional=optional) - - -def fixed32_field( - number: int, group: Optional[str] = None, optional: bool = False -) -> Any: - return dataclass_field(number, TYPE_FIXED32, group=group, optional=optional) - - -def fixed64_field( - number: int, group: Optional[str] = None, optional: bool = False -) -> Any: - return dataclass_field(number, TYPE_FIXED64, group=group, optional=optional) - - -def sfixed32_field( - number: int, group: Optional[str] = None, optional: bool = False -) -> Any: - return dataclass_field(number, TYPE_SFIXED32, group=group, optional=optional) - - -def sfixed64_field( - number: int, group: Optional[str] = None, optional: bool = False -) -> Any: - return dataclass_field(number, TYPE_SFIXED64, group=group, optional=optional) - - -def string_field( - number: int, group: Optional[str] = None, optional: bool = False -) -> Any: - return dataclass_field(number, TYPE_STRING, group=group, optional=optional) - - -def bytes_field( - number: int, group: Optional[str] = None, optional: bool = False -) -> Any: - return dataclass_field(number, TYPE_BYTES, group=group, optional=optional) - - -def message_field( - number: int, - group: Optional[str] = None, - wraps: Optional[str] = None, - optional: bool = False, -) -> Any: - return dataclass_field( - number, TYPE_MESSAGE, group=group, wraps=wraps, optional=optional - ) - - -def map_field( - number: int, key_type: str, value_type: str, group: Optional[str] = None -) -> Any: - return dataclass_field( - number, TYPE_MAP, map_types=(key_type, value_type), group=group - ) - - -class Enum(enum.IntEnum): - """ - The base class for protobuf enumerations, all generated enumerations will inherit - from this. Bases :class:`enum.IntEnum`. - """ - - @classmethod - def from_string(cls, name: str) -> "Enum": - """Return the value which corresponds to the string name. - - Parameters - ----------- - name: :class:`str` - The name of the enum member to get - - Raises - ------- - :exc:`ValueError` - The member was not found in the Enum. - """ - try: - return cls._member_map_[name] # type: ignore - except KeyError as e: - raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e - - -def _pack_fmt(proto_type: str) -> str: - """Returns a little-endian format string for reading/writing binary.""" - return { - TYPE_DOUBLE: " bytes: - """Encodes a single varint value for serialization.""" - b: List[int] = [] - - if value < 0: - value += 1 << 64 - - bits = value & 0x7F - value >>= 7 - while value: - b.append(0x80 | bits) - bits = value & 0x7F - value >>= 7 - return bytes(b + [bits]) - - -def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes: - """Adjusts values before serialization.""" - if proto_type in ( - TYPE_ENUM, - TYPE_BOOL, - TYPE_INT32, - TYPE_INT64, - TYPE_UINT32, - TYPE_UINT64, - ): - return encode_varint(value) - elif proto_type in (TYPE_SINT32, TYPE_SINT64): - # Handle zig-zag encoding. - return encode_varint(value << 1 if value >= 0 else (value << 1) ^ (~0)) - elif proto_type in FIXED_TYPES: - return struct.pack(_pack_fmt(proto_type), value) - elif proto_type == TYPE_STRING: - return value.encode("utf-8") - elif proto_type == TYPE_MESSAGE: - if isinstance(value, datetime): - # Convert the `datetime` to a timestamp message. - seconds = int(value.timestamp()) - nanos = int(value.microsecond * 1e3) - value = _Timestamp(seconds=seconds, nanos=nanos) - elif isinstance(value, timedelta): - # Convert the `timedelta` to a duration message. - total_ms = value // timedelta(microseconds=1) - seconds = int(total_ms / 1e6) - nanos = int((total_ms % 1e6) * 1e3) - value = _Duration(seconds=seconds, nanos=nanos) - elif wraps: - if value is None: - return b"" - value = _get_wrapper(wraps)(value=value) - - return bytes(value) - - return value - - -def _serialize_single( - field_number: int, - proto_type: str, - value: Any, - *, - serialize_empty: bool = False, - wraps: str = "", -) -> bytes: - """Serializes a single field and value.""" - value = _preprocess_single(proto_type, wraps, value) - - output = bytearray() - if proto_type in WIRE_VARINT_TYPES: - key = encode_varint(field_number << 3) - output += key + value - elif proto_type in WIRE_FIXED_32_TYPES: - key = encode_varint((field_number << 3) | 5) - output += key + value - elif proto_type in WIRE_FIXED_64_TYPES: - key = encode_varint((field_number << 3) | 1) - output += key + value - elif proto_type in WIRE_LEN_DELIM_TYPES: - if len(value) or serialize_empty or wraps: - key = encode_varint((field_number << 3) | 2) - output += key + encode_varint(len(value)) + value - else: - raise NotImplementedError(proto_type) - - return bytes(output) - - -def _parse_float(value: Any) -> float: - """Parse the given value to a float - - Parameters - ---------- - value : Any - Value to parse - - Returns - ------- - float - Parsed value - """ - if value == INFINITY: - return float("inf") - if value == NEG_INFINITY: - return -float("inf") - if value == NAN: - return float("nan") - return float(value) - - -def _dump_float(value: float) -> Union[float, str]: - """Dump the given float to JSON - - Parameters - ---------- - value : float - Value to dump - - Returns - ------- - Union[float, str] - Dumped valid, either a float or the strings - "Infinity" or "-Infinity" - """ - if value == float("inf"): - return INFINITY - if value == -float("inf"): - return NEG_INFINITY - if value == float("nan"): - return NAN - return value - - -def decode_varint(buffer: bytes, pos: int) -> Tuple[int, int]: - """ - Decode a single varint value from a byte buffer. Returns the value and the - new position in the buffer. - """ - result = 0 - shift = 0 - while 1: - b = buffer[pos] - result |= (b & 0x7F) << shift - pos += 1 - if not (b & 0x80): - return result, pos - shift += 7 - if shift >= 64: - raise ValueError("Too many bytes when decoding varint.") - - -@dataclasses.dataclass(frozen=True) -class ParsedField: - number: int - wire_type: int - value: Any - raw: bytes - - -def parse_fields(value: bytes) -> Generator[ParsedField, None, None]: - i = 0 - while i < len(value): - start = i - num_wire, i = decode_varint(value, i) - number = num_wire >> 3 - wire_type = num_wire & 0x7 - - decoded: Any = None - if wire_type == WIRE_VARINT: - decoded, i = decode_varint(value, i) - elif wire_type == WIRE_FIXED_64: - decoded, i = value[i : i + 8], i + 8 - elif wire_type == WIRE_LEN_DELIM: - length, i = decode_varint(value, i) - decoded = value[i : i + length] - i += length - elif wire_type == WIRE_FIXED_32: - decoded, i = value[i : i + 4], i + 4 - - yield ParsedField( - number=number, wire_type=wire_type, value=decoded, raw=value[start:i] - ) - - -class ProtoClassMetadata: - __slots__ = ( - "oneof_group_by_field", - "oneof_field_by_group", - "default_gen", - "cls_by_field", - "field_name_by_number", - "meta_by_field_name", - "sorted_field_names", - ) - - oneof_group_by_field: Dict[str, str] - oneof_field_by_group: Dict[str, Set[dataclasses.Field]] - field_name_by_number: Dict[int, str] - meta_by_field_name: Dict[str, FieldMetadata] - sorted_field_names: Tuple[str, ...] - default_gen: Dict[str, Callable[[], Any]] - cls_by_field: Dict[str, Type] - - def __init__(self, cls: Type["Message"]): - by_field = {} - by_group: Dict[str, Set] = {} - by_field_name = {} - by_field_number = {} - - fields = dataclasses.fields(cls) - for field in fields: - meta = FieldMetadata.get(field) - - if meta.group: - # This is part of a one-of group. - by_field[field.name] = meta.group - - by_group.setdefault(meta.group, set()).add(field) - - by_field_name[field.name] = meta - by_field_number[meta.number] = field.name - - self.oneof_group_by_field = by_field - self.oneof_field_by_group = by_group - self.field_name_by_number = by_field_number - self.meta_by_field_name = by_field_name - self.sorted_field_names = tuple( - by_field_number[number] for number in sorted(by_field_number) - ) - self.default_gen = self._get_default_gen(cls, fields) - self.cls_by_field = self._get_cls_by_field(cls, fields) - - @staticmethod - def _get_default_gen( - cls: Type["Message"], fields: Iterable[dataclasses.Field] - ) -> Dict[str, Callable[[], Any]]: - return {field.name: cls._get_field_default_gen(field) for field in fields} - - @staticmethod - def _get_cls_by_field( - cls: Type["Message"], fields: Iterable[dataclasses.Field] - ) -> Dict[str, Type]: - field_cls = {} - - for field in fields: - meta = FieldMetadata.get(field) - if meta.proto_type == TYPE_MAP: - assert meta.map_types - kt = cls._cls_for(field, index=0) - vt = cls._cls_for(field, index=1) - field_cls[field.name] = dataclasses.make_dataclass( - "Entry", - [ - ("key", kt, dataclass_field(1, meta.map_types[0])), - ("value", vt, dataclass_field(2, meta.map_types[1])), - ], - bases=(Message,), - ) - field_cls[f"{field.name}.value"] = vt - else: - field_cls[field.name] = cls._cls_for(field) - - return field_cls - - -class Message(ABC): - """ - The base class for protobuf messages, all generated messages will inherit from - this. This class registers the message fields which are used by the serializers and - parsers to go between the Python, binary and JSON representations of the message. - - .. container:: operations - - .. describe:: bytes(x) - - Calls :meth:`__bytes__`. - - .. describe:: bool(x) - - Calls :meth:`__bool__`. - """ - - _serialized_on_wire: bool - _unknown_fields: bytes - _group_current: Dict[str, str] - - def __post_init__(self) -> None: - # Keep track of whether every field was default - all_sentinel = True - - # Set current field of each group after `__init__` has already been run. - group_current: Dict[str, Optional[str]] = {} - for field_name, meta in self._betterproto.meta_by_field_name.items(): - - if meta.group: - group_current.setdefault(meta.group) - - value = self.__raw_get(field_name) - if value != PLACEHOLDER and not (meta.optional and value is None): - # Found a non-sentinel value - all_sentinel = False - - if meta.group: - # This was set, so make it the selected value of the one-of. - group_current[meta.group] = field_name - - # Now that all the defaults are set, reset it! - self.__dict__["_serialized_on_wire"] = not all_sentinel - self.__dict__["_unknown_fields"] = b"" - self.__dict__["_group_current"] = group_current - - def __raw_get(self, name: str) -> Any: - return super().__getattribute__(name) - - def __eq__(self, other) -> bool: - if type(self) is not type(other): - return False - - for field_name in self._betterproto.meta_by_field_name: - self_val = self.__raw_get(field_name) - other_val = other.__raw_get(field_name) - if self_val is PLACEHOLDER: - if other_val is PLACEHOLDER: - continue - self_val = self._get_field_default(field_name) - elif other_val is PLACEHOLDER: - other_val = other._get_field_default(field_name) - - if self_val != other_val: - # We consider two nan values to be the same for the - # purposes of comparing messages (otherwise a message - # is not equal to itself) - if ( - isinstance(self_val, float) - and isinstance(other_val, float) - and math.isnan(self_val) - and math.isnan(other_val) - ): - continue - else: - return False - - return True - - def __repr__(self) -> str: - parts = [ - f"{field_name}={value!r}" - for field_name in self._betterproto.sorted_field_names - for value in (self.__raw_get(field_name),) - if value is not PLACEHOLDER - ] - return f"{self.__class__.__name__}({', '.join(parts)})" - - def __getattribute__(self, name: str) -> Any: - """ - Lazily initialize default values to avoid infinite recursion for recursive - message types - """ - value = super().__getattribute__(name) - if value is not PLACEHOLDER: - return value - - value = self._get_field_default(name) - super().__setattr__(name, value) - return value - - def __setattr__(self, attr: str, value: Any) -> None: - if attr != "_serialized_on_wire": - # Track when a field has been set. - self.__dict__["_serialized_on_wire"] = True - - if hasattr(self, "_group_current"): # __post_init__ had already run - if attr in self._betterproto.oneof_group_by_field: - group = self._betterproto.oneof_group_by_field[attr] - for field in self._betterproto.oneof_field_by_group[group]: - if field.name == attr: - self._group_current[group] = field.name - else: - super().__setattr__(field.name, PLACEHOLDER) - - super().__setattr__(attr, value) - - def __bool__(self) -> bool: - """True if the Message has any fields with non-default values.""" - return any( - self.__raw_get(field_name) - not in (PLACEHOLDER, self._get_field_default(field_name)) - for field_name in self._betterproto.meta_by_field_name - ) - - @property - def _betterproto(self) -> ProtoClassMetadata: - """ - Lazy initialize metadata for each protobuf class. - It may be initialized multiple times in a multi-threaded environment, - but that won't affect the correctness. - """ - meta = getattr(self.__class__, "_betterproto_meta", None) - if not meta: - meta = ProtoClassMetadata(self.__class__) - self.__class__._betterproto_meta = meta # type: ignore - return meta - - def __bytes__(self) -> bytes: - """ - Get the binary encoded Protobuf representation of this message instance. - """ - output = bytearray() - for field_name, meta in self._betterproto.meta_by_field_name.items(): - value = getattr(self, field_name) - - if value is None: - # Optional items should be skipped. This is used for the Google - # wrapper types and proto3 field presence/optional fields. - continue - - # Being selected in a a group means this field is the one that is - # currently set in a `oneof` group, so it must be serialized even - # if the value is the default zero value. - # - # Note that proto3 field presence/optional fields are put in a - # synthetic single-item oneof by protoc, which helps us ensure we - # send the value even if the value is the default zero value. - selected_in_group = ( - meta.group and self._group_current[meta.group] == field_name - ) - - # Empty messages can still be sent on the wire if they were - # set (or received empty). - serialize_empty = isinstance(value, Message) and value._serialized_on_wire - - include_default_value_for_oneof = self._include_default_value_for_oneof( - field_name=field_name, meta=meta - ) - - if value == self._get_field_default(field_name) and not ( - selected_in_group or serialize_empty or include_default_value_for_oneof - ): - # Default (zero) values are not serialized. Two exceptions are - # if this is the selected oneof item or if we know we have to - # serialize an empty message (i.e. zero value was explicitly - # set by the user). - continue - - if isinstance(value, list): - if meta.proto_type in PACKED_TYPES: - # Packed lists look like a length-delimited field. First, - # preprocess/encode each value into a buffer and then - # treat it like a field of raw bytes. - buf = bytearray() - for item in value: - buf += _preprocess_single(meta.proto_type, "", item) - output += _serialize_single(meta.number, TYPE_BYTES, buf) - else: - for item in value: - output += ( - _serialize_single( - meta.number, - meta.proto_type, - item, - wraps=meta.wraps or "", - ) - # if it's an empty message it still needs to be represented - # as an item in the repeated list - or b"\n\x00" - ) - - elif isinstance(value, dict): - for k, v in value.items(): - assert meta.map_types - sk = _serialize_single(1, meta.map_types[0], k) - sv = _serialize_single(2, meta.map_types[1], v) - output += _serialize_single(meta.number, meta.proto_type, sk + sv) - else: - # If we have an empty string and we're including the default value for - # a oneof, make sure we serialize it. This ensures that the byte string - # output isn't simply an empty string. This also ensures that round trip - # serialization will keep `which_one_of` calls consistent. - if ( - isinstance(value, str) - and value == "" - and include_default_value_for_oneof - ): - serialize_empty = True - - output += _serialize_single( - meta.number, - meta.proto_type, - value, - serialize_empty=serialize_empty or bool(selected_in_group), - wraps=meta.wraps or "", - ) - - output += self._unknown_fields - return bytes(output) - - # For compatibility with other libraries - def SerializeToString(self: T) -> bytes: - """ - Get the binary encoded Protobuf representation of this message instance. - - .. note:: - This is a method for compatibility with other libraries, - you should really use ``bytes(x)``. - - Returns - -------- - :class:`bytes` - The binary encoded Protobuf representation of this message instance - """ - return bytes(self) - - @classmethod - def _type_hint(cls, field_name: str) -> Type: - return cls._type_hints()[field_name] - - @classmethod - def _type_hints(cls) -> Dict[str, Type]: - module = sys.modules[cls.__module__] - return get_type_hints(cls, module.__dict__, {}) - - @classmethod - def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type: - """Get the message class for a field from the type hints.""" - field_cls = cls._type_hint(field.name) - if hasattr(field_cls, "__args__") and index >= 0: - if field_cls.__args__ is not None: - field_cls = field_cls.__args__[index] - return field_cls - - def _get_field_default(self, field_name: str) -> Any: - return self._betterproto.default_gen[field_name]() - - @classmethod - def _get_field_default_gen(cls, field: dataclasses.Field) -> Any: - t = cls._type_hint(field.name) - - if hasattr(t, "__origin__"): - if t.__origin__ in (dict, Dict): - # This is some kind of map (dict in Python). - return dict - elif t.__origin__ in (list, List): - # This is some kind of list (repeated) field. - return list - elif t.__origin__ is Union and t.__args__[1] is type(None): - # This is an optional field (either wrapped, or using proto3 - # field presence). For setting the default we really don't care - # what kind of field it is. - return type(None) - else: - return t - elif issubclass(t, Enum): - # Enums always default to zero. - return int - elif t is datetime: - # Offsets are relative to 1970-01-01T00:00:00Z - return datetime_default_gen - else: - # This is either a primitive scalar or another message type. Calling - # it should result in its zero value. - return t - - def _postprocess_single( - self, wire_type: int, meta: FieldMetadata, field_name: str, value: Any - ) -> Any: - """Adjusts values after parsing.""" - if wire_type == WIRE_VARINT: - if meta.proto_type in (TYPE_INT32, TYPE_INT64): - bits = int(meta.proto_type[3:]) - value = value & ((1 << bits) - 1) - signbit = 1 << (bits - 1) - value = int((value ^ signbit) - signbit) - elif meta.proto_type in (TYPE_SINT32, TYPE_SINT64): - # Undo zig-zag encoding - value = (value >> 1) ^ (-(value & 1)) - elif meta.proto_type == TYPE_BOOL: - # Booleans use a varint encoding, so convert it to true/false. - value = value > 0 - elif wire_type in (WIRE_FIXED_32, WIRE_FIXED_64): - fmt = _pack_fmt(meta.proto_type) - value = struct.unpack(fmt, value)[0] - elif wire_type == WIRE_LEN_DELIM: - if meta.proto_type == TYPE_STRING: - value = str(value, "utf-8") - elif meta.proto_type == TYPE_MESSAGE: - cls = self._betterproto.cls_by_field[field_name] - - if cls == datetime: - value = _Timestamp().parse(value).to_datetime() - elif cls == timedelta: - value = _Duration().parse(value).to_timedelta() - elif meta.wraps: - # This is a Google wrapper value message around a single - # scalar type. - value = _get_wrapper(meta.wraps)().parse(value).value - else: - value = cls().parse(value) - value._serialized_on_wire = True - elif meta.proto_type == TYPE_MAP: - value = self._betterproto.cls_by_field[field_name]().parse(value) - - return value - - def _include_default_value_for_oneof( - self, field_name: str, meta: FieldMetadata - ) -> bool: - return ( - meta.group is not None and self._group_current.get(meta.group) == field_name - ) - - def parse(self: T, data: bytes) -> T: - """ - Parse the binary encoded Protobuf into this message instance. This - returns the instance itself and is therefore assignable and chainable. - - Parameters - ----------- - data: :class:`bytes` - The data to parse the protobuf from. - - Returns - -------- - :class:`Message` - The initialized message. - """ - # Got some data over the wire - self._serialized_on_wire = True - proto_meta = self._betterproto - for parsed in parse_fields(data): - field_name = proto_meta.field_name_by_number.get(parsed.number) - if not field_name: - self._unknown_fields += parsed.raw - continue - - meta = proto_meta.meta_by_field_name[field_name] - - value: Any - if parsed.wire_type == WIRE_LEN_DELIM and meta.proto_type in PACKED_TYPES: - # This is a packed repeated field. - pos = 0 - value = [] - while pos < len(parsed.value): - if meta.proto_type in (TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32): - decoded, pos = parsed.value[pos : pos + 4], pos + 4 - wire_type = WIRE_FIXED_32 - elif meta.proto_type in (TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64): - decoded, pos = parsed.value[pos : pos + 8], pos + 8 - wire_type = WIRE_FIXED_64 - else: - decoded, pos = decode_varint(parsed.value, pos) - wire_type = WIRE_VARINT - decoded = self._postprocess_single( - wire_type, meta, field_name, decoded - ) - value.append(decoded) - else: - value = self._postprocess_single( - parsed.wire_type, meta, field_name, parsed.value - ) - - current = getattr(self, field_name) - if meta.proto_type == TYPE_MAP: - # Value represents a single key/value pair entry in the map. - current[value.key] = value.value - elif isinstance(current, list) and not isinstance(value, list): - current.append(value) - else: - setattr(self, field_name, value) - - return self - - # For compatibility with other libraries. - @classmethod - def FromString(cls: Type[T], data: bytes) -> T: - """ - Parse the binary encoded Protobuf into this message instance. This - returns the instance itself and is therefore assignable and chainable. - - .. note:: - This is a method for compatibility with other libraries, - you should really use :meth:`parse`. - - - Parameters - ----------- - data: :class:`bytes` - The data to parse the protobuf from. - - Returns - -------- - :class:`Message` - The initialized message. - """ - return cls().parse(data) - - def to_dict( - self, casing: Casing = Casing.CAMEL, include_default_values: bool = False - ) -> Dict[str, Any]: - """ - Returns a JSON serializable dict representation of this object. - - Parameters - ----------- - casing: :class:`Casing` - The casing to use for key values. Default is :attr:`Casing.CAMEL` for - compatibility purposes. - include_default_values: :class:`bool` - If ``True`` will include the default values of fields. Default is ``False``. - E.g. an ``int32`` field will be included with a value of ``0`` if this is - set to ``True``, otherwise this would be ignored. - - Returns - -------- - Dict[:class:`str`, Any] - The JSON serializable dict representation of this object. - """ - output: Dict[str, Any] = {} - field_types = self._type_hints() - defaults = self._betterproto.default_gen - for field_name, meta in self._betterproto.meta_by_field_name.items(): - field_is_repeated = defaults[field_name] is list - value = getattr(self, field_name) - cased_name = casing(field_name).rstrip("_") # type: ignore - if meta.proto_type == TYPE_MESSAGE: - if isinstance(value, datetime): - if ( - value != DATETIME_ZERO - or include_default_values - or self._include_default_value_for_oneof( - field_name=field_name, meta=meta - ) - ): - output[cased_name] = _Timestamp.timestamp_to_json(value) - elif isinstance(value, timedelta): - if ( - value != timedelta(0) - or include_default_values - or self._include_default_value_for_oneof( - field_name=field_name, meta=meta - ) - ): - output[cased_name] = _Duration.delta_to_json(value) - elif meta.wraps: - if value is not None or include_default_values: - output[cased_name] = value - elif field_is_repeated: - # Convert each item. - cls = self._betterproto.cls_by_field[field_name] - if cls == datetime: - value = [_Timestamp.timestamp_to_json(i) for i in value] - elif cls == timedelta: - value = [_Duration.delta_to_json(i) for i in value] - else: - value = [ - i.to_dict(casing, include_default_values) for i in value - ] - if value or include_default_values: - output[cased_name] = value - elif value is None: - if include_default_values: - output[cased_name] = value - elif ( - value._serialized_on_wire - or include_default_values - or self._include_default_value_for_oneof( - field_name=field_name, meta=meta - ) - ): - output[cased_name] = value.to_dict(casing, include_default_values) - elif meta.proto_type == TYPE_MAP: - for k in value: - if hasattr(value[k], "to_dict"): - value[k] = value[k].to_dict(casing, include_default_values) - - if value or include_default_values: - output[cased_name] = value - elif ( - value != self._get_field_default(field_name) - or include_default_values - or self._include_default_value_for_oneof( - field_name=field_name, meta=meta - ) - ): - if meta.proto_type in INT_64_TYPES: - if field_is_repeated: - output[cased_name] = [str(n) for n in value] - elif value is None: - if include_default_values: - output[cased_name] = value - else: - output[cased_name] = str(value) - elif meta.proto_type == TYPE_BYTES: - if field_is_repeated: - output[cased_name] = [ - b64encode(b).decode("utf8") for b in value - ] - elif value is None and include_default_values: - output[cased_name] = value - else: - output[cased_name] = b64encode(value).decode("utf8") - elif meta.proto_type == TYPE_ENUM: - if field_is_repeated: - enum_class = field_types[field_name].__args__[0] - if isinstance(value, typing.Iterable) and not isinstance( - value, str - ): - output[cased_name] = [enum_class(el).name for el in value] - else: - # transparently upgrade single value to repeated - output[cased_name] = [enum_class(value).name] - elif value is None: - if include_default_values: - output[cased_name] = value - elif meta.optional: - enum_class = field_types[field_name].__args__[0] - output[cased_name] = enum_class(value).name - else: - enum_class = field_types[field_name] # noqa - output[cased_name] = enum_class(value).name - elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE): - if field_is_repeated: - output[cased_name] = [_dump_float(n) for n in value] - else: - output[cased_name] = _dump_float(value) - else: - output[cased_name] = value - return output - - def from_dict(self: T, value: Dict[str, Any]) -> T: - """ - Parse the key/value pairs into the current message instance. This returns the - instance itself and is therefore assignable and chainable. - - Parameters - ----------- - value: Dict[:class:`str`, Any] - The dictionary to parse from. - - Returns - -------- - :class:`Message` - The initialized message. - """ - self._serialized_on_wire = True - for key in value: - field_name = safe_snake_case(key) - meta = self._betterproto.meta_by_field_name.get(field_name) - if not meta: - continue - - if value[key] is not None: - if meta.proto_type == TYPE_MESSAGE: - v = getattr(self, field_name) - if isinstance(v, list): - cls = self._betterproto.cls_by_field[field_name] - if cls == datetime: - v = [isoparse(item) for item in value[key]] - elif cls == timedelta: - v = [ - timedelta(seconds=float(item[:-1])) - for item in value[key] - ] - else: - v = [cls().from_dict(item) for item in value[key]] - elif isinstance(v, datetime): - v = isoparse(value[key]) - setattr(self, field_name, v) - elif isinstance(v, timedelta): - v = timedelta(seconds=float(value[key][:-1])) - setattr(self, field_name, v) - elif meta.wraps: - setattr(self, field_name, value[key]) - elif v is None: - cls = self._betterproto.cls_by_field[field_name] - setattr(self, field_name, cls().from_dict(value[key])) - else: - # NOTE: `from_dict` mutates the underlying message, so no - # assignment here is necessary. - v.from_dict(value[key]) - elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE: - v = getattr(self, field_name) - cls = self._betterproto.cls_by_field[f"{field_name}.value"] - for k in value[key]: - v[k] = cls().from_dict(value[key][k]) - else: - v = value[key] - if meta.proto_type in INT_64_TYPES: - if isinstance(value[key], list): - v = [int(n) for n in value[key]] - else: - v = int(value[key]) - elif meta.proto_type == TYPE_BYTES: - if isinstance(value[key], list): - v = [b64decode(n) for n in value[key]] - else: - v = b64decode(value[key]) - elif meta.proto_type == TYPE_ENUM: - enum_cls = self._betterproto.cls_by_field[field_name] - if isinstance(v, list): - v = [enum_cls.from_string(e) for e in v] - elif isinstance(v, str): - v = enum_cls.from_string(v) - elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE): - if isinstance(value[key], list): - v = [_parse_float(n) for n in value[key]] - else: - v = _parse_float(value[key]) - - if v is not None: - setattr(self, field_name, v) - return self - - def to_json(self, indent: Union[None, int, str] = None) -> str: - """A helper function to parse the message instance into its JSON - representation. - - This is equivalent to:: - - json.dumps(message.to_dict(), indent=indent) - - Parameters - ----------- - indent: Optional[Union[:class:`int`, :class:`str`]] - The indent to pass to :func:`json.dumps`. - - Returns - -------- - :class:`str` - The JSON representation of the message. - """ - return json.dumps(self.to_dict(), indent=indent) - - def from_json(self: T, value: Union[str, bytes]) -> T: - """A helper function to return the message instance from its JSON - representation. This returns the instance itself and is therefore assignable - and chainable. - - This is equivalent to:: - - return message.from_dict(json.loads(value)) - - Parameters - ----------- - value: Union[:class:`str`, :class:`bytes`] - The value to pass to :func:`json.loads`. - - Returns - -------- - :class:`Message` - The initialized message. - """ - return self.from_dict(json.loads(value)) - - -def serialized_on_wire(message: Message) -> bool: - """ - If this message was or should be serialized on the wire. This can be used to detect - presence (e.g. optional wrapper message) and is used internally during - parsing/serialization. - - Returns - -------- - :class:`bool` - Whether this message was or should be serialized on the wire. - """ - return message._serialized_on_wire - - -def which_one_of(message: Message, group_name: str) -> Tuple[str, Optional[Any]]: - """ - Return the name and value of a message's one-of field group. - - Returns - -------- - Tuple[:class:`str`, Any] - The field name and the value for that field. - """ - field_name = message._group_current.get(group_name) - if not field_name: - return "", None - return field_name, getattr(message, field_name) - - -# Circular import workaround: google.protobuf depends on base classes defined above. +from .casing import * +from .grpc.grpclib_client import * +from .const import * +from .enum import * +from .message import * from .lib.google.protobuf import ( # noqa BoolValue, BytesValue, @@ -1325,58 +18,3 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Optional[Any]] UInt32Value, UInt64Value, ) - - -class _Duration(Duration): - def to_timedelta(self) -> timedelta: - return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3) - - @staticmethod - def delta_to_json(delta: timedelta) -> str: - parts = str(delta.total_seconds()).split(".") - if len(parts) > 1: - while len(parts[1]) not in (3, 6, 9): - parts[1] = f"{parts[1]}0" - return f"{'.'.join(parts)}s" - - -class _Timestamp(Timestamp): - def to_datetime(self) -> datetime: - ts = self.seconds + (self.nanos / 1e9) - return datetime.fromtimestamp(ts, tz=timezone.utc) - - @staticmethod - def timestamp_to_json(dt: datetime) -> str: - nanos = dt.microsecond * 1e3 - copy = dt.replace(microsecond=0, tzinfo=None) - result = copy.isoformat() - if (nanos % 1e9) == 0: - # If there are 0 fractional digits, the fractional - # point '.' should be omitted when serializing. - return f"{result}Z" - if (nanos % 1e6) == 0: - # Serialize 3 fractional digits. - return f"{result}.{int(nanos // 1e6) :03d}Z" - if (nanos % 1e3) == 0: - # Serialize 6 fractional digits. - return f"{result}.{int(nanos // 1e3) :06d}Z" - # Serialize 9 fractional digits. - return f"{result}.{nanos:09d}" - - -def _get_wrapper(proto_type: str) -> Type: - """Get the wrapper message class for a wrapped type.""" - - # TODO: include ListValue and NullValue? - return { - TYPE_BOOL: BoolValue, - TYPE_BYTES: BytesValue, - TYPE_DOUBLE: DoubleValue, - TYPE_FLOAT: FloatValue, - TYPE_ENUM: EnumValue, - TYPE_INT32: Int32Value, - TYPE_INT64: Int64Value, - TYPE_STRING: StringValue, - TYPE_UINT32: UInt32Value, - TYPE_UINT64: UInt64Value, - }[proto_type] diff --git a/src/betterproto/casing.py b/src/betterproto/casing.py index cd3c34472..21caf1ab4 100644 --- a/src/betterproto/casing.py +++ b/src/betterproto/casing.py @@ -1,6 +1,9 @@ +import enum import keyword import re +__all__ = ("Casing",) + # Word delimiters and symbols that will not be preserved when re-casing. # language=PythonRegExp SYMBOLS = "[^a-zA-Z0-9]*" @@ -136,3 +139,10 @@ def lowercase_first(value: str) -> str: def sanitize_name(value: str) -> str: # https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles return f"{value}_" if keyword.iskeyword(value) else value + + +class Casing(enum.Enum): + """Casing constants for serialization.""" + + CAMEL = camel_case #: A camelCase sterilization function. + SNAKE = snake_case #: A snake_case sterilization function. diff --git a/src/betterproto/const.py b/src/betterproto/const.py new file mode 100644 index 000000000..a1b9a62c9 --- /dev/null +++ b/src/betterproto/const.py @@ -0,0 +1,91 @@ +# Proto 3 data types +from datetime import datetime, timezone + + +TYPE_ENUM = "enum" +TYPE_BOOL = "bool" +TYPE_INT32 = "int32" +TYPE_INT64 = "int64" +TYPE_UINT32 = "uint32" +TYPE_UINT64 = "uint64" +TYPE_SINT32 = "sint32" +TYPE_SINT64 = "sint64" +TYPE_FLOAT = "float" +TYPE_DOUBLE = "double" +TYPE_FIXED32 = "fixed32" +TYPE_SFIXED32 = "sfixed32" +TYPE_FIXED64 = "fixed64" +TYPE_SFIXED64 = "sfixed64" +TYPE_STRING = "string" +TYPE_BYTES = "bytes" +TYPE_MESSAGE = "message" +TYPE_MAP = "map" + + +# Fields that use a fixed amount of space (4 or 8 bytes) +FIXED_TYPES = [ + TYPE_FLOAT, + TYPE_DOUBLE, + TYPE_FIXED32, + TYPE_SFIXED32, + TYPE_FIXED64, + TYPE_SFIXED64, +] + +# Fields that are numerical 64-bit types +INT_64_TYPES = [TYPE_INT64, TYPE_UINT64, TYPE_SINT64, TYPE_FIXED64, TYPE_SFIXED64] + +# Fields that are efficiently packed when +PACKED_TYPES = [ + TYPE_ENUM, + TYPE_BOOL, + TYPE_INT32, + TYPE_INT64, + TYPE_UINT32, + TYPE_UINT64, + TYPE_SINT32, + TYPE_SINT64, + TYPE_FLOAT, + TYPE_DOUBLE, + TYPE_FIXED32, + TYPE_SFIXED32, + TYPE_FIXED64, + TYPE_SFIXED64, +] + +# Wire types +# https://developers.google.com/protocol-buffers/docs/encoding#structure +WIRE_VARINT = 0 +WIRE_FIXED_64 = 1 +WIRE_LEN_DELIM = 2 +WIRE_FIXED_32 = 5 + +# Mappings of which Proto 3 types correspond to which wire types. +WIRE_VARINT_TYPES = [ + TYPE_ENUM, + TYPE_BOOL, + TYPE_INT32, + TYPE_INT64, + TYPE_UINT32, + TYPE_UINT64, + TYPE_SINT32, + TYPE_SINT64, +] + +WIRE_FIXED_32_TYPES = [TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32] +WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64] +WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP] + + +# Protobuf datetimes start at the Unix Epoch in 1970 in UTC. +def datetime_default_gen() -> datetime: + return datetime(1970, 1, 1, tzinfo=timezone.utc) + + +DATETIME_ZERO = datetime_default_gen() + + +# Special protobuf json doubles +INFINITY = "Infinity" +NEG_INFINITY = "-Infinity" +NAN = "NaN" diff --git a/src/betterproto/enum.py b/src/betterproto/enum.py new file mode 100644 index 000000000..59c1f4f50 --- /dev/null +++ b/src/betterproto/enum.py @@ -0,0 +1,29 @@ +import enum + +__all__ = ("Enum",) + + +class Enum(enum.IntEnum): + """ + The base class for protobuf enumerations, all generated enumerations will inherit + from this. Bases :class:`enum.IntEnum`. + """ + + @classmethod + def from_string(cls, name: str) -> "Enum": + """Return the value which corresponds to the string name. + + Parameters + ----------- + name: :class:`str` + The name of the enum member to get + + Raises + ------- + :exc:`ValueError` + The member was not found in the Enum. + """ + try: + return cls._member_map_[name] # type: ignore + except KeyError as e: + raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e diff --git a/src/betterproto/grpc/grpclib_client.py b/src/betterproto/grpc/grpclib_client.py index a22b7e358..54f9abe71 100644 --- a/src/betterproto/grpc/grpclib_client.py +++ b/src/betterproto/grpc/grpclib_client.py @@ -27,6 +27,8 @@ _MessageLike = Union[T, ST] _MessageSource = Union[Iterable[ST], AsyncIterable[ST]] +__all__ = ("ServiceStub",) + class ServiceStub(ABC): """ diff --git a/src/betterproto/io.py b/src/betterproto/io.py new file mode 100644 index 000000000..700a8f0e2 --- /dev/null +++ b/src/betterproto/io.py @@ -0,0 +1,204 @@ +import dataclasses +from datetime import timedelta +import struct +from typing import Any, List, Union, Tuple + +from typing import Generator + +from .message import _Duration, _Timestamp, _get_wrapper +from .const import * + + +def _pack_fmt(proto_type: str) -> str: + """Returns a little-endian format string for reading/writing binary.""" + return { + TYPE_DOUBLE: " bytes: + """Encodes a single varint value for serialization.""" + b: List[int] = [] + + if value < 0: + value += 1 << 64 + + bits = value & 0x7F + value >>= 7 + while value: + b.append(0x80 | bits) + bits = value & 0x7F + value >>= 7 + return bytes(b + [bits]) + + +def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes: + """Adjusts values before serialization.""" + if proto_type in ( + TYPE_ENUM, + TYPE_BOOL, + TYPE_INT32, + TYPE_INT64, + TYPE_UINT32, + TYPE_UINT64, + ): + return encode_varint(value) + elif proto_type in (TYPE_SINT32, TYPE_SINT64): + # Handle zig-zag encoding. + return encode_varint(value << 1 if value >= 0 else (value << 1) ^ (~0)) + elif proto_type in FIXED_TYPES: + return struct.pack(_pack_fmt(proto_type), value) + elif proto_type == TYPE_STRING: + return value.encode("utf-8") + elif proto_type == TYPE_MESSAGE: + if isinstance(value, datetime): + # Convert the `datetime` to a timestamp message. + seconds = int(value.timestamp()) + nanos = int(value.microsecond * 1e3) + value = _Timestamp(seconds=seconds, nanos=nanos) + elif isinstance(value, timedelta): + # Convert the `timedelta` to a duration message. + total_ms = value // timedelta(microseconds=1) + seconds = int(total_ms / 1e6) + nanos = int((total_ms % 1e6) * 1e3) + value = _Duration(seconds=seconds, nanos=nanos) + elif wraps: + if value is None: + return b"" + value = _get_wrapper(wraps)(value=value) + + return bytes(value) + + return value + + +def _serialize_single( + field_number: int, + proto_type: str, + value: Any, + *, + serialize_empty: bool = False, + wraps: str = "", +) -> bytes: + """Serializes a single field and value.""" + value = _preprocess_single(proto_type, wraps, value) + + output = bytearray() + if proto_type in WIRE_VARINT_TYPES: + key = encode_varint(field_number << 3) + output += key + value + elif proto_type in WIRE_FIXED_32_TYPES: + key = encode_varint((field_number << 3) | 5) + output += key + value + elif proto_type in WIRE_FIXED_64_TYPES: + key = encode_varint((field_number << 3) | 1) + output += key + value + elif proto_type in WIRE_LEN_DELIM_TYPES: + if len(value) or serialize_empty or wraps: + key = encode_varint((field_number << 3) | 2) + output += key + encode_varint(len(value)) + value + else: + raise NotImplementedError(proto_type) + + return bytes(output) + + +def _parse_float(value: Any) -> float: + """Parse the given value to a float + + Parameters + ---------- + value : Any + Value to parse + + Returns + ------- + float + Parsed value + """ + if value == INFINITY: + return float("inf") + if value == NEG_INFINITY: + return -float("inf") + if value == NAN: + return float("nan") + return float(value) + + +def _dump_float(value: float) -> Union[float, str]: + """Dump the given float to JSON + + Parameters + ---------- + value : float + Value to dump + + Returns + ------- + Union[float, str] + Dumped valid, either a float or the strings + "Infinity" or "-Infinity" + """ + if value == float("inf"): + return INFINITY + if value == -float("inf"): + return NEG_INFINITY + if value == float("nan"): + return NAN + return value + + +def decode_varint(buffer: bytes, pos: int) -> Tuple[int, int]: + """ + Decode a single varint value from a byte buffer. Returns the value and the + new position in the buffer. + """ + result = 0 + shift = 0 + while 1: + b = buffer[pos] + result |= (b & 0x7F) << shift + pos += 1 + if not (b & 0x80): + return result, pos + shift += 7 + if shift >= 64: + raise ValueError("Too many bytes when decoding varint.") + + +@dataclasses.dataclass(frozen=True) +class ParsedField: + number: int + wire_type: int + value: Any + raw: bytes + + +def parse_fields(value: bytes) -> Generator[ParsedField, None, None]: + i = 0 + while i < len(value): + start = i + num_wire, i = decode_varint(value, i) + number = num_wire >> 3 + wire_type = num_wire & 0x7 + + decoded: Any = None + if wire_type == WIRE_VARINT: + decoded, i = decode_varint(value, i) + elif wire_type == WIRE_FIXED_64: + decoded, i = value[i : i + 8], i + 8 + elif wire_type == WIRE_LEN_DELIM: + length, i = decode_varint(value, i) + decoded = value[i : i + length] + i += length + elif wire_type == WIRE_FIXED_32: + decoded, i = value[i : i + 4], i + 4 + + yield ParsedField( + number=number, wire_type=wire_type, value=decoded, raw=value[start:i] + ) diff --git a/src/betterproto/lib/google/protobuf/__init__.py b/src/betterproto/lib/google/protobuf/__init__.py index e9b6de7b5..ceda878b4 100644 --- a/src/betterproto/lib/google/protobuf/__init__.py +++ b/src/betterproto/lib/google/protobuf/__init__.py @@ -5,11 +5,12 @@ from dataclasses import dataclass from typing import Dict, List -import betterproto -from betterproto.grpc.grpclib_server import ServiceBase +from ....enum import * +from ....const import * +from ....message import * -class Syntax(betterproto.Enum): +class Syntax(Enum): """The syntax in which a protocol buffer element is defined.""" # Syntax `proto2`. @@ -18,7 +19,7 @@ class Syntax(betterproto.Enum): SYNTAX_PROTO3 = 1 -class FieldKind(betterproto.Enum): +class FieldKind(Enum): TYPE_UNKNOWN = 0 TYPE_DOUBLE = 1 TYPE_FLOAT = 2 @@ -40,14 +41,14 @@ class FieldKind(betterproto.Enum): TYPE_SINT64 = 18 -class FieldCardinality(betterproto.Enum): +class FieldCardinality(Enum): CARDINALITY_UNKNOWN = 0 CARDINALITY_OPTIONAL = 1 CARDINALITY_REQUIRED = 2 CARDINALITY_REPEATED = 3 -class FieldDescriptorProtoType(betterproto.Enum): +class FieldDescriptorProtoType(Enum): TYPE_DOUBLE = 1 TYPE_FLOAT = 2 TYPE_INT64 = 3 @@ -68,37 +69,37 @@ class FieldDescriptorProtoType(betterproto.Enum): TYPE_SINT64 = 18 -class FieldDescriptorProtoLabel(betterproto.Enum): +class FieldDescriptorProtoLabel(Enum): LABEL_OPTIONAL = 1 LABEL_REQUIRED = 2 LABEL_REPEATED = 3 -class FileOptionsOptimizeMode(betterproto.Enum): +class FileOptionsOptimizeMode(Enum): SPEED = 1 CODE_SIZE = 2 LITE_RUNTIME = 3 -class FieldOptionsCType(betterproto.Enum): +class FieldOptionsCType(Enum): STRING = 0 CORD = 1 STRING_PIECE = 2 -class FieldOptionsJsType(betterproto.Enum): +class FieldOptionsJsType(Enum): JS_NORMAL = 0 JS_STRING = 1 JS_NUMBER = 2 -class MethodOptionsIdempotencyLevel(betterproto.Enum): +class MethodOptionsIdempotencyLevel(Enum): IDEMPOTENCY_UNKNOWN = 0 NO_SIDE_EFFECTS = 1 IDEMPOTENT = 2 -class NullValue(betterproto.Enum): +class NullValue(Enum): """ `NullValue` is a singleton enumeration to represent the null value for the `Value` type union. The JSON representation for `NullValue` is JSON @@ -110,7 +111,7 @@ class NullValue(betterproto.Enum): @dataclass(eq=False, repr=False) -class Any(betterproto.Message): +class Any(Message): """ `Any` contains an arbitrary serialized protocol buffer message along with a URL that describes the type of the serialized message. Protobuf library @@ -163,13 +164,13 @@ class Any(betterproto.Message): # and it is not used for type URLs beginning with type.googleapis.com. # Schemes other than `http`, `https` (or the empty scheme) might be used with # implementation specific semantics. - type_url: str = betterproto.string_field(1) + type_url: str = string_field(1) # Must be a valid serialized protocol buffer of the above specified type. - value: bytes = betterproto.bytes_field(2) + value: bytes = bytes_field(2) @dataclass(eq=False, repr=False) -class SourceContext(betterproto.Message): +class SourceContext(Message): """ `SourceContext` represents information about the source of a protobuf element, like the file in which it is defined. @@ -177,87 +178,85 @@ class SourceContext(betterproto.Message): # The path-qualified name of the .proto file that contained the associated # protobuf element. For example: `"google/protobuf/source_context.proto"`. - file_name: str = betterproto.string_field(1) + file_name: str = string_field(1) @dataclass(eq=False, repr=False) -class Type(betterproto.Message): +class Type(Message): """A protocol buffer message type.""" # The fully qualified message name. - name: str = betterproto.string_field(1) + name: str = string_field(1) # The list of fields. - fields: List["Field"] = betterproto.message_field(2) + fields: List["Field"] = message_field(2) # The list of types appearing in `oneof` definitions in this type. - oneofs: List[str] = betterproto.string_field(3) + oneofs: List[str] = string_field(3) # The protocol buffer options. - options: List["Option"] = betterproto.message_field(4) + options: List["Option"] = message_field(4) # The source context. - source_context: "SourceContext" = betterproto.message_field(5) + source_context: "SourceContext" = message_field(5) # The source syntax. - syntax: "Syntax" = betterproto.enum_field(6) + syntax: "Syntax" = enum_field(6) @dataclass(eq=False, repr=False) -class Field(betterproto.Message): +class Field(Message): """A single field of a message type.""" # The field type. - kind: "FieldKind" = betterproto.enum_field(1) + kind: "FieldKind" = enum_field(1) # The field cardinality. - cardinality: "FieldCardinality" = betterproto.enum_field(2) + cardinality: "FieldCardinality" = enum_field(2) # The field number. - number: int = betterproto.int32_field(3) + number: int = int32_field(3) # The field name. - name: str = betterproto.string_field(4) + name: str = string_field(4) # The field type URL, without the scheme, for message or enumeration types. # Example: `"type.googleapis.com/google.protobuf.Timestamp"`. - type_url: str = betterproto.string_field(6) + type_url: str = string_field(6) # The index of the field type in `Type.oneofs`, for message or enumeration # types. The first type has index 1; zero means the type is not in the list. - oneof_index: int = betterproto.int32_field(7) + oneof_index: int = int32_field(7) # Whether to use alternative packed wire representation. - packed: bool = betterproto.bool_field(8) + packed: bool = bool_field(8) # The protocol buffer options. - options: List["Option"] = betterproto.message_field(9) + options: List["Option"] = message_field(9) # The field JSON name. - json_name: str = betterproto.string_field(10) + json_name: str = string_field(10) # The string value of the default value of this field. Proto2 syntax only. - default_value: str = betterproto.string_field(11) + default_value: str = string_field(11) @dataclass(eq=False, repr=False) -class Enum(betterproto.Message): +class Enum(Message): """Enum type definition.""" # Enum type name. - name: str = betterproto.string_field(1) + name: str = string_field(1) # Enum value definitions. - enumvalue: List["EnumValue"] = betterproto.message_field( - 2, wraps=betterproto.TYPE_ENUM - ) + enumvalue: List["EnumValue"] = message_field(2, wraps=TYPE_ENUM) # Protocol buffer options. - options: List["Option"] = betterproto.message_field(3) + options: List["Option"] = message_field(3) # The source context. - source_context: "SourceContext" = betterproto.message_field(4) + source_context: "SourceContext" = message_field(4) # The source syntax. - syntax: "Syntax" = betterproto.enum_field(5) + syntax: "Syntax" = enum_field(5) @dataclass(eq=False, repr=False) -class EnumValue(betterproto.Message): +class EnumValue(Message): """Enum value definition.""" # Enum value name. - name: str = betterproto.string_field(1) + name: str = string_field(1) # Enum value number. - number: int = betterproto.int32_field(2) + number: int = int32_field(2) # Protocol buffer options. - options: List["Option"] = betterproto.message_field(3) + options: List["Option"] = message_field(3) @dataclass(eq=False, repr=False) -class Option(betterproto.Message): +class Option(Message): """ A protocol buffer option, which can be attached to a message, field, enumeration, etc. @@ -267,16 +266,16 @@ class Option(betterproto.Message): # descriptor.proto), this is the short name. For example, `"map_entry"`. For # custom options, it should be the fully-qualified name. For example, # `"google.api.http"`. - name: str = betterproto.string_field(1) + name: str = string_field(1) # The option's value packed in an Any message. If the value is a primitive, # the corresponding wrapper type defined in google/protobuf/wrappers.proto # should be used. If the value is an enum, it should be stored as an int32 # value using the google.protobuf.Int32Value type. - value: "Any" = betterproto.message_field(2) + value: "Any" = message_field(2) @dataclass(eq=False, repr=False) -class Api(betterproto.Message): +class Api(Message): """ Api is a light-weight descriptor for an API Interface. Interfaces are also described as "protocol buffer services" in some contexts, such as by the @@ -290,11 +289,11 @@ class Api(betterproto.Message): # The fully qualified name of this interface, including package name followed # by the interface's simple name. - name: str = betterproto.string_field(1) + name: str = string_field(1) # The methods of this interface, in unspecified order. - methods: List["Method"] = betterproto.message_field(2) + methods: List["Method"] = message_field(2) # Any metadata attached to the interface. - options: List["Option"] = betterproto.message_field(3) + options: List["Option"] = message_field(3) # A version string for this interface. If specified, must have the form # `major-version.minor-version`, as in `1.10`. If the minor version is # omitted, it defaults to zero. If the entire version field is empty, the @@ -309,37 +308,37 @@ class Api(betterproto.Message): # must end in `v`, as in `google.feature.v1`. For major # versions 0 and 1, the suffix can be omitted. Zero major versions must only # be used for experimental, non-GA interfaces. - version: str = betterproto.string_field(4) + version: str = string_field(4) # Source context for the protocol buffer service represented by this message. - source_context: "SourceContext" = betterproto.message_field(5) + source_context: "SourceContext" = message_field(5) # Included interfaces. See [Mixin][]. - mixins: List["Mixin"] = betterproto.message_field(6) + mixins: List["Mixin"] = message_field(6) # The source syntax of the service. - syntax: "Syntax" = betterproto.enum_field(7) + syntax: "Syntax" = enum_field(7) @dataclass(eq=False, repr=False) -class Method(betterproto.Message): +class Method(Message): """Method represents a method of an API interface.""" # The simple name of this method. - name: str = betterproto.string_field(1) + name: str = string_field(1) # A URL of the input message type. - request_type_url: str = betterproto.string_field(2) + request_type_url: str = string_field(2) # If true, the request is streamed. - request_streaming: bool = betterproto.bool_field(3) + request_streaming: bool = bool_field(3) # The URL of the output message type. - response_type_url: str = betterproto.string_field(4) + response_type_url: str = string_field(4) # If true, the response is streamed. - response_streaming: bool = betterproto.bool_field(5) + response_streaming: bool = bool_field(5) # Any metadata attached to the method. - options: List["Option"] = betterproto.message_field(6) + options: List["Option"] = message_field(6) # The source syntax of this method. - syntax: "Syntax" = betterproto.enum_field(7) + syntax: "Syntax" = enum_field(7) @dataclass(eq=False, repr=False) -class Mixin(betterproto.Message): +class Mixin(Message): """ Declares an API Interface to be included in this interface. The including interface must redeclare all the methods from the included interface, but @@ -378,127 +377,125 @@ class Mixin(betterproto.Message): """ # The fully qualified name of the interface which is included. - name: str = betterproto.string_field(1) + name: str = string_field(1) # If non-empty specifies a path under which inherited HTTP paths are rooted. - root: str = betterproto.string_field(2) + root: str = string_field(2) @dataclass(eq=False, repr=False) -class FileDescriptorSet(betterproto.Message): +class FileDescriptorSet(Message): """ The protocol compiler can output a FileDescriptorSet containing the .proto files it parses. """ - file: List["FileDescriptorProto"] = betterproto.message_field(1) + file: List["FileDescriptorProto"] = message_field(1) @dataclass(eq=False, repr=False) -class FileDescriptorProto(betterproto.Message): +class FileDescriptorProto(Message): """Describes a complete .proto file.""" - name: str = betterproto.string_field(1) - package: str = betterproto.string_field(2) + name: str = string_field(1) + package: str = string_field(2) # Names of files imported by this file. - dependency: List[str] = betterproto.string_field(3) + dependency: List[str] = string_field(3) # Indexes of the public imported files in the dependency list above. - public_dependency: List[int] = betterproto.int32_field(10) + public_dependency: List[int] = int32_field(10) # Indexes of the weak imported files in the dependency list. For Google- # internal migration only. Do not use. - weak_dependency: List[int] = betterproto.int32_field(11) + weak_dependency: List[int] = int32_field(11) # All top-level definitions in this file. - message_type: List["DescriptorProto"] = betterproto.message_field(4) - enum_type: List["EnumDescriptorProto"] = betterproto.message_field(5) - service: List["ServiceDescriptorProto"] = betterproto.message_field(6) - extension: List["FieldDescriptorProto"] = betterproto.message_field(7) - options: "FileOptions" = betterproto.message_field(8) + message_type: List["DescriptorProto"] = message_field(4) + enum_type: List["EnumDescriptorProto"] = message_field(5) + service: List["ServiceDescriptorProto"] = message_field(6) + extension: List["FieldDescriptorProto"] = message_field(7) + options: "FileOptions" = message_field(8) # This field contains optional information about the original source code. # You may safely remove this entire field without harming runtime # functionality of the descriptors -- the information is needed only by # development tools. - source_code_info: "SourceCodeInfo" = betterproto.message_field(9) + source_code_info: "SourceCodeInfo" = message_field(9) # The syntax of the proto file. The supported values are "proto2" and # "proto3". - syntax: str = betterproto.string_field(12) + syntax: str = string_field(12) @dataclass(eq=False, repr=False) -class DescriptorProto(betterproto.Message): +class DescriptorProto(Message): """Describes a message type.""" - name: str = betterproto.string_field(1) - field: List["FieldDescriptorProto"] = betterproto.message_field(2) - extension: List["FieldDescriptorProto"] = betterproto.message_field(6) - nested_type: List["DescriptorProto"] = betterproto.message_field(3) - enum_type: List["EnumDescriptorProto"] = betterproto.message_field(4) - extension_range: List["DescriptorProtoExtensionRange"] = betterproto.message_field( - 5 - ) - oneof_decl: List["OneofDescriptorProto"] = betterproto.message_field(8) - options: "MessageOptions" = betterproto.message_field(7) - reserved_range: List["DescriptorProtoReservedRange"] = betterproto.message_field(9) + name: str = string_field(1) + field: List["FieldDescriptorProto"] = message_field(2) + extension: List["FieldDescriptorProto"] = message_field(6) + nested_type: List["DescriptorProto"] = message_field(3) + enum_type: List["EnumDescriptorProto"] = message_field(4) + extension_range: List["DescriptorProtoExtensionRange"] = message_field(5) + oneof_decl: List["OneofDescriptorProto"] = message_field(8) + options: "MessageOptions" = message_field(7) + reserved_range: List["DescriptorProtoReservedRange"] = message_field(9) # Reserved field names, which may not be used by fields in the same message. # A given name may only be reserved once. - reserved_name: List[str] = betterproto.string_field(10) + reserved_name: List[str] = string_field(10) @dataclass(eq=False, repr=False) -class DescriptorProtoExtensionRange(betterproto.Message): - start: int = betterproto.int32_field(1) - end: int = betterproto.int32_field(2) - options: "ExtensionRangeOptions" = betterproto.message_field(3) +class DescriptorProtoExtensionRange(Message): + start: int = int32_field(1) + end: int = int32_field(2) + options: "ExtensionRangeOptions" = message_field(3) @dataclass(eq=False, repr=False) -class DescriptorProtoReservedRange(betterproto.Message): +class DescriptorProtoReservedRange(Message): """ Range of reserved tag numbers. Reserved tag numbers may not be used by fields or extension ranges in the same message. Reserved ranges may not overlap. """ - start: int = betterproto.int32_field(1) - end: int = betterproto.int32_field(2) + start: int = int32_field(1) + end: int = int32_field(2) @dataclass(eq=False, repr=False) -class ExtensionRangeOptions(betterproto.Message): +class ExtensionRangeOptions(Message): # The parser stores options it doesn't recognize here. See above. - uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999) + uninterpreted_option: List["UninterpretedOption"] = message_field(999) @dataclass(eq=False, repr=False) -class FieldDescriptorProto(betterproto.Message): +class FieldDescriptorProto(Message): """Describes a field within a message.""" - name: str = betterproto.string_field(1) - number: int = betterproto.int32_field(3) - label: "FieldDescriptorProtoLabel" = betterproto.enum_field(4) + name: str = string_field(1) + number: int = int32_field(3) + label: "FieldDescriptorProtoLabel" = enum_field(4) # If type_name is set, this need not be set. If both this and type_name are # set, this must be one of TYPE_ENUM, TYPE_MESSAGE or TYPE_GROUP. - type: "FieldDescriptorProtoType" = betterproto.enum_field(5) + type: "FieldDescriptorProtoType" = enum_field(5) # For message and enum types, this is the name of the type. If the name # starts with a '.', it is fully-qualified. Otherwise, C++-like scoping # rules are used to find the type (i.e. first the nested types within this # message are searched, then within the parent, on up to the root namespace). - type_name: str = betterproto.string_field(6) + type_name: str = string_field(6) # For extensions, this is the name of the type being extended. It is # resolved in the same manner as type_name. - extendee: str = betterproto.string_field(2) + extendee: str = string_field(2) # For numeric types, contains the original text representation of the value. # For booleans, "true" or "false". For strings, contains the default text # contents (not escaped in any way). For bytes, contains the C escaped value. # All bytes >= 128 are escaped. TODO(kenton): Base-64 encode? - default_value: str = betterproto.string_field(7) + default_value: str = string_field(7) # If set, gives the index of a oneof in the containing type's oneof_decl # list. This field is a member of that oneof. - oneof_index: int = betterproto.int32_field(9) + oneof_index: int = int32_field(9) # JSON name of this field. The value is set by protocol compiler. If the user # has set a "json_name" option on this field, that option's value will be # used. Otherwise, it's deduced from the field's name by converting it to # camelCase. - json_name: str = betterproto.string_field(10) - options: "FieldOptions" = betterproto.message_field(8) + json_name: str = string_field(10) + options: "FieldOptions" = message_field(8) # If true, this is a proto3 "optional". When a proto3 field is optional, it # tracks presence regardless of field type. When proto3_optional is true, # this field must be belong to a oneof to signal to old proto3 clients that @@ -515,37 +512,35 @@ class FieldDescriptorProto(betterproto.Message): # the parser can't tell if a field is a message or an enum, so it must always # create a synthetic oneof. Proto2 optional fields do not set this flag, # because they already indicate optional with `LABEL_OPTIONAL`. - proto3_optional: bool = betterproto.bool_field(17) + proto3_optional: bool = bool_field(17) @dataclass(eq=False, repr=False) -class OneofDescriptorProto(betterproto.Message): +class OneofDescriptorProto(Message): """Describes a oneof.""" - name: str = betterproto.string_field(1) - options: "OneofOptions" = betterproto.message_field(2) + name: str = string_field(1) + options: "OneofOptions" = message_field(2) @dataclass(eq=False, repr=False) -class EnumDescriptorProto(betterproto.Message): +class EnumDescriptorProto(Message): """Describes an enum type.""" - name: str = betterproto.string_field(1) - value: List["EnumValueDescriptorProto"] = betterproto.message_field(2) - options: "EnumOptions" = betterproto.message_field(3) + name: str = string_field(1) + value: List["EnumValueDescriptorProto"] = message_field(2) + options: "EnumOptions" = message_field(3) # Range of reserved numeric values. Reserved numeric values may not be used # by enum values in the same enum declaration. Reserved ranges may not # overlap. - reserved_range: List[ - "EnumDescriptorProtoEnumReservedRange" - ] = betterproto.message_field(4) + reserved_range: List["EnumDescriptorProtoEnumReservedRange"] = message_field(4) # Reserved enum value names, which may not be reused. A given name may only # be reserved once. - reserved_name: List[str] = betterproto.string_field(5) + reserved_name: List[str] = string_field(5) @dataclass(eq=False, repr=False) -class EnumDescriptorProtoEnumReservedRange(betterproto.Message): +class EnumDescriptorProtoEnumReservedRange(Message): """ Range of reserved numeric values. Reserved values may not be used by entries in the same enum. Reserved ranges may not overlap. Note that this @@ -553,79 +548,79 @@ class EnumDescriptorProtoEnumReservedRange(betterproto.Message): that it can appropriately represent the entire int32 domain. """ - start: int = betterproto.int32_field(1) - end: int = betterproto.int32_field(2) + start: int = int32_field(1) + end: int = int32_field(2) @dataclass(eq=False, repr=False) -class EnumValueDescriptorProto(betterproto.Message): +class EnumValueDescriptorProto(Message): """Describes a value within an enum.""" - name: str = betterproto.string_field(1) - number: int = betterproto.int32_field(2) - options: "EnumValueOptions" = betterproto.message_field(3) + name: str = string_field(1) + number: int = int32_field(2) + options: "EnumValueOptions" = message_field(3) @dataclass(eq=False, repr=False) -class ServiceDescriptorProto(betterproto.Message): +class ServiceDescriptorProto(Message): """Describes a service.""" - name: str = betterproto.string_field(1) - method: List["MethodDescriptorProto"] = betterproto.message_field(2) - options: "ServiceOptions" = betterproto.message_field(3) + name: str = string_field(1) + method: List["MethodDescriptorProto"] = message_field(2) + options: "ServiceOptions" = message_field(3) @dataclass(eq=False, repr=False) -class MethodDescriptorProto(betterproto.Message): +class MethodDescriptorProto(Message): """Describes a method of a service.""" - name: str = betterproto.string_field(1) + name: str = string_field(1) # Input and output type names. These are resolved in the same way as # FieldDescriptorProto.type_name, but must refer to a message type. - input_type: str = betterproto.string_field(2) - output_type: str = betterproto.string_field(3) - options: "MethodOptions" = betterproto.message_field(4) + input_type: str = string_field(2) + output_type: str = string_field(3) + options: "MethodOptions" = message_field(4) # Identifies if client streams multiple client messages - client_streaming: bool = betterproto.bool_field(5) + client_streaming: bool = bool_field(5) # Identifies if server streams multiple server messages - server_streaming: bool = betterproto.bool_field(6) + server_streaming: bool = bool_field(6) @dataclass(eq=False, repr=False) -class FileOptions(betterproto.Message): +class FileOptions(Message): # Sets the Java package where classes generated from this .proto will be # placed. By default, the proto package is used, but this is often # inappropriate because proto packages do not normally start with backwards # domain names. - java_package: str = betterproto.string_field(1) + java_package: str = string_field(1) # Controls the name of the wrapper Java class generated for the .proto file. # That class will always contain the .proto file's getDescriptor() method as # well as any top-level extensions defined in the .proto file. If # java_multiple_files is disabled, then all the other classes from the .proto # file will be nested inside the single wrapper outer class. - java_outer_classname: str = betterproto.string_field(8) + java_outer_classname: str = string_field(8) # If enabled, then the Java code generator will generate a separate .java # file for each top-level message, enum, and service defined in the .proto # file. Thus, these types will *not* be nested inside the wrapper class # named by java_outer_classname. However, the wrapper class will still be # generated to contain the file's getDescriptor() method as well as any top- # level extensions defined in the file. - java_multiple_files: bool = betterproto.bool_field(10) + java_multiple_files: bool = bool_field(10) # This option does nothing. - java_generate_equals_and_hash: bool = betterproto.bool_field(20) + java_generate_equals_and_hash: bool = bool_field(20) # If set true, then the Java2 code generator will generate code that throws # an exception whenever an attempt is made to assign a non-UTF-8 byte # sequence to a string field. Message reflection will do the same. However, # an extension field still accepts non-UTF-8 byte sequences. This option has # no effect on when used with the lite runtime. - java_string_check_utf8: bool = betterproto.bool_field(27) - optimize_for: "FileOptionsOptimizeMode" = betterproto.enum_field(9) + java_string_check_utf8: bool = bool_field(27) + optimize_for: "FileOptionsOptimizeMode" = enum_field(9) # Sets the Go package where structs generated from this .proto will be # placed. If omitted, the Go package will be derived from the following: - # The basename of the package import path, if provided. - Otherwise, the # package statement in the .proto file, if present. - Otherwise, the # basename of the .proto file, without extension. - go_package: str = betterproto.string_field(11) + go_package: str = string_field(11) # Should generic services be generated in each language? "Generic" services # are not specific to any particular RPC system. They are generated by the # main code generators in each language (without additional plugins). Generic @@ -634,45 +629,45 @@ class FileOptions(betterproto.Message): # in favor of using plugins that generate code specific to your particular # RPC system. Therefore, these default to false. Old code which depends on # generic services should explicitly set them to true. - cc_generic_services: bool = betterproto.bool_field(16) - java_generic_services: bool = betterproto.bool_field(17) - py_generic_services: bool = betterproto.bool_field(18) - php_generic_services: bool = betterproto.bool_field(42) + cc_generic_services: bool = bool_field(16) + java_generic_services: bool = bool_field(17) + py_generic_services: bool = bool_field(18) + php_generic_services: bool = bool_field(42) # Is this file deprecated? Depending on the target platform, this can emit # Deprecated annotations for everything in the file, or it will be completely # ignored; in the very least, this is a formalization for deprecating files. - deprecated: bool = betterproto.bool_field(23) + deprecated: bool = bool_field(23) # Enables the use of arenas for the proto messages in this file. This applies # only to generated classes for C++. - cc_enable_arenas: bool = betterproto.bool_field(31) + cc_enable_arenas: bool = bool_field(31) # Sets the objective c class prefix which is prepended to all objective c # generated classes from this .proto. There is no default. - objc_class_prefix: str = betterproto.string_field(36) + objc_class_prefix: str = string_field(36) # Namespace for generated classes; defaults to the package. - csharp_namespace: str = betterproto.string_field(37) + csharp_namespace: str = string_field(37) # By default Swift generators will take the proto package and CamelCase it # replacing '.' with underscore and use that to prefix the types/symbols # defined. When this options is provided, they will use this value instead to # prefix the types/symbols defined. - swift_prefix: str = betterproto.string_field(39) + swift_prefix: str = string_field(39) # Sets the php class prefix which is prepended to all php generated classes # from this .proto. Default is empty. - php_class_prefix: str = betterproto.string_field(40) + php_class_prefix: str = string_field(40) # Use this option to change the namespace of php generated classes. Default # is empty. When this option is empty, the package name will be used for # determining the namespace. - php_namespace: str = betterproto.string_field(41) + php_namespace: str = string_field(41) # Use this option to change the namespace of php generated metadata classes. # Default is empty. When this option is empty, the proto file name will be # used for determining the namespace. - php_metadata_namespace: str = betterproto.string_field(44) + php_metadata_namespace: str = string_field(44) # Use this option to change the package of ruby generated classes. Default is # empty. When this option is not set, the package name will be used for # determining the ruby package. - ruby_package: str = betterproto.string_field(45) + ruby_package: str = string_field(45) # The parser stores options it doesn't recognize here. See the documentation # for the "Options" section above. - uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999) + uninterpreted_option: List["UninterpretedOption"] = message_field(999) def __post_init__(self) -> None: super().__post_init__() @@ -684,7 +679,7 @@ def __post_init__(self) -> None: @dataclass(eq=False, repr=False) -class MessageOptions(betterproto.Message): +class MessageOptions(Message): # Set true to use the old proto1 MessageSet wire format for extensions. This # is provided for backwards-compatibility with the MessageSet wire format. # You should not use this for any other reason: It's less efficient, has @@ -695,15 +690,15 @@ class MessageOptions(betterproto.Message): # type must be singular messages; e.g. they cannot be int32s, enums, or # repeated messages. Because this is an option, the above two restrictions # are not enforced by the protocol compiler. - message_set_wire_format: bool = betterproto.bool_field(1) + message_set_wire_format: bool = bool_field(1) # Disables the generation of the standard "descriptor()" accessor, which can # conflict with a field of the same name. This is meant to make migration # from proto1 easier; new code should avoid fields named "descriptor". - no_standard_descriptor_accessor: bool = betterproto.bool_field(2) + no_standard_descriptor_accessor: bool = bool_field(2) # Is this message deprecated? Depending on the target platform, this can emit # Deprecated annotations for the message, or it will be completely ignored; # in the very least, this is a formalization for deprecating messages. - deprecated: bool = betterproto.bool_field(3) + deprecated: bool = bool_field(3) # Whether the message is an automatically generated map entry type for the # maps field. For maps fields: map map_field = 1; The # parsed descriptor looks like: message MapFieldEntry { option @@ -715,24 +710,24 @@ class MessageOptions(betterproto.Message): # is a repeated message field. NOTE: Do not set the option in .proto files. # Always use the maps syntax instead. The option should only be implicitly # set by the proto compiler parser. - map_entry: bool = betterproto.bool_field(7) + map_entry: bool = bool_field(7) # The parser stores options it doesn't recognize here. See above. - uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999) + uninterpreted_option: List["UninterpretedOption"] = message_field(999) @dataclass(eq=False, repr=False) -class FieldOptions(betterproto.Message): +class FieldOptions(Message): # The ctype option instructs the C++ code generator to use a different # representation of the field than it normally would. See the specific # options below. This option is not yet implemented in the open source # release -- sorry, we'll try to include it in a future version! - ctype: "FieldOptionsCType" = betterproto.enum_field(1) + ctype: "FieldOptionsCType" = enum_field(1) # The packed option can be enabled for repeated primitive fields to enable a # more efficient representation on the wire. Rather than repeatedly writing # the tag and type for each element, the entire array is encoded as a single # length-delimited blob. In proto3, only explicit setting it to false will # avoid using packed encoding. - packed: bool = betterproto.bool_field(2) + packed: bool = bool_field(2) # The jstype option determines the JavaScript type used for values of the # field. The option is permitted only for 64 bit integral and fixed types # (int64, uint64, sint64, fixed64, sfixed64). A field with jstype JS_STRING @@ -742,7 +737,7 @@ class FieldOptions(betterproto.Message): # use the JavaScript "number" type. The behavior of the default option # JS_NORMAL is implementation dependent. This option is an enum to permit # additional types to be added, e.g. goog.math.Integer. - jstype: "FieldOptionsJsType" = betterproto.enum_field(6) + jstype: "FieldOptionsJsType" = enum_field(6) # Should this field be parsed lazily? Lazy applies only to message-type # fields. It means that when the outer message is initially parsed, the # inner message's contents will not be parsed but instead stored in encoded @@ -766,70 +761,70 @@ class FieldOptions(betterproto.Message): # implementation must either *always* check its required fields, or *never* # check its required fields, regardless of whether or not the message has # been parsed. - lazy: bool = betterproto.bool_field(5) + lazy: bool = bool_field(5) # Is this field deprecated? Depending on the target platform, this can emit # Deprecated annotations for accessors, or it will be completely ignored; in # the very least, this is a formalization for deprecating fields. - deprecated: bool = betterproto.bool_field(3) + deprecated: bool = bool_field(3) # For Google-internal migration only. Do not use. - weak: bool = betterproto.bool_field(10) + weak: bool = bool_field(10) # The parser stores options it doesn't recognize here. See above. - uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999) + uninterpreted_option: List["UninterpretedOption"] = message_field(999) @dataclass(eq=False, repr=False) -class OneofOptions(betterproto.Message): +class OneofOptions(Message): # The parser stores options it doesn't recognize here. See above. - uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999) + uninterpreted_option: List["UninterpretedOption"] = message_field(999) @dataclass(eq=False, repr=False) -class EnumOptions(betterproto.Message): +class EnumOptions(Message): # Set this option to true to allow mapping different tag names to the same # value. - allow_alias: bool = betterproto.bool_field(2) + allow_alias: bool = bool_field(2) # Is this enum deprecated? Depending on the target platform, this can emit # Deprecated annotations for the enum, or it will be completely ignored; in # the very least, this is a formalization for deprecating enums. - deprecated: bool = betterproto.bool_field(3) + deprecated: bool = bool_field(3) # The parser stores options it doesn't recognize here. See above. - uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999) + uninterpreted_option: List["UninterpretedOption"] = message_field(999) @dataclass(eq=False, repr=False) -class EnumValueOptions(betterproto.Message): +class EnumValueOptions(Message): # Is this enum value deprecated? Depending on the target platform, this can # emit Deprecated annotations for the enum value, or it will be completely # ignored; in the very least, this is a formalization for deprecating enum # values. - deprecated: bool = betterproto.bool_field(1) + deprecated: bool = bool_field(1) # The parser stores options it doesn't recognize here. See above. - uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999) + uninterpreted_option: List["UninterpretedOption"] = message_field(999) @dataclass(eq=False, repr=False) -class ServiceOptions(betterproto.Message): +class ServiceOptions(Message): # Is this service deprecated? Depending on the target platform, this can emit # Deprecated annotations for the service, or it will be completely ignored; # in the very least, this is a formalization for deprecating services. - deprecated: bool = betterproto.bool_field(33) + deprecated: bool = bool_field(33) # The parser stores options it doesn't recognize here. See above. - uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999) + uninterpreted_option: List["UninterpretedOption"] = message_field(999) @dataclass(eq=False, repr=False) -class MethodOptions(betterproto.Message): +class MethodOptions(Message): # Is this method deprecated? Depending on the target platform, this can emit # Deprecated annotations for the method, or it will be completely ignored; in # the very least, this is a formalization for deprecating methods. - deprecated: bool = betterproto.bool_field(33) - idempotency_level: "MethodOptionsIdempotencyLevel" = betterproto.enum_field(34) + deprecated: bool = bool_field(33) + idempotency_level: "MethodOptionsIdempotencyLevel" = enum_field(34) # The parser stores options it doesn't recognize here. See above. - uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999) + uninterpreted_option: List["UninterpretedOption"] = message_field(999) @dataclass(eq=False, repr=False) -class UninterpretedOption(betterproto.Message): +class UninterpretedOption(Message): """ A message representing a option the parser does not recognize. This only appears in options protos created by the compiler::Parser class. @@ -839,19 +834,19 @@ class UninterpretedOption(betterproto.Message): UninterpretedOptions in them. """ - name: List["UninterpretedOptionNamePart"] = betterproto.message_field(2) + name: List["UninterpretedOptionNamePart"] = message_field(2) # The value of the uninterpreted option, in whatever type the tokenizer # identified it as during parsing. Exactly one of these should be set. - identifier_value: str = betterproto.string_field(3) - positive_int_value: int = betterproto.uint64_field(4) - negative_int_value: int = betterproto.int64_field(5) - double_value: float = betterproto.double_field(6) - string_value: bytes = betterproto.bytes_field(7) - aggregate_value: str = betterproto.string_field(8) + identifier_value: str = string_field(3) + positive_int_value: int = uint64_field(4) + negative_int_value: int = int64_field(5) + double_value: float = double_field(6) + string_value: bytes = bytes_field(7) + aggregate_value: str = string_field(8) @dataclass(eq=False, repr=False) -class UninterpretedOptionNamePart(betterproto.Message): +class UninterpretedOptionNamePart(Message): """ The name of the uninterpreted option. Each string represents a segment in a dot-separated name. is_extension is true iff a segment represents an @@ -860,12 +855,12 @@ class UninterpretedOptionNamePart(betterproto.Message): "foo.(bar.baz).qux". """ - name_part: str = betterproto.string_field(1) - is_extension: bool = betterproto.bool_field(2) + name_part: str = string_field(1) + is_extension: bool = bool_field(2) @dataclass(eq=False, repr=False) -class SourceCodeInfo(betterproto.Message): +class SourceCodeInfo(Message): """ Encapsulates information about the original source file from which a FileDescriptorProto was generated. @@ -900,11 +895,11 @@ class SourceCodeInfo(betterproto.Message): # their components will overlap. - Code which tries to interpret locations # should probably be designed to ignore those that it doesn't understand, # as more types of locations could be recorded in the future. - location: List["SourceCodeInfoLocation"] = betterproto.message_field(1) + location: List["SourceCodeInfoLocation"] = message_field(1) @dataclass(eq=False, repr=False) -class SourceCodeInfoLocation(betterproto.Message): +class SourceCodeInfoLocation(Message): # Identifies which part of the FileDescriptorProto was defined at this # location. Each element is a field number or an index. They form a path # from the root FileDescriptorProto to the place where the definition. For @@ -917,13 +912,13 @@ class SourceCodeInfoLocation(betterproto.Message): # Thus, the above path gives the location of a field name. If we removed the # last element: [ 4, 3, 2, 7 ] this path refers to the whole field # declaration (from the beginning of the label to the terminating semicolon). - path: List[int] = betterproto.int32_field(1) + path: List[int] = int32_field(1) # Always has exactly three or four elements: start line, start column, end # line (optional, otherwise assumed same as start line), end column. These # are packed into a single field for efficiency. Note that line and column # numbers are zero-based -- typically you will want to add 1 to each before # displaying to a user. - span: List[int] = betterproto.int32_field(2) + span: List[int] = int32_field(2) # If this SourceCodeInfo represents a complete declaration, these are any # comments appearing before and after the declaration which appear to be # attached to the declaration. A series of line comments appearing on @@ -945,13 +940,13 @@ class SourceCodeInfoLocation(betterproto.Message): # Block comment attached * to corge. Leading asterisks * will be # removed. */ /* Block comment attached to * grault. */ optional int32 # grault = 6; // ignored detached comments. - leading_comments: str = betterproto.string_field(3) - trailing_comments: str = betterproto.string_field(4) - leading_detached_comments: List[str] = betterproto.string_field(6) + leading_comments: str = string_field(3) + trailing_comments: str = string_field(4) + leading_detached_comments: List[str] = string_field(6) @dataclass(eq=False, repr=False) -class GeneratedCodeInfo(betterproto.Message): +class GeneratedCodeInfo(Message): """ Describes the relationship between generated code and its original source file. A GeneratedCodeInfo message is associated with only one generated @@ -960,27 +955,27 @@ class GeneratedCodeInfo(betterproto.Message): # An Annotation connects some span of text in generated code to an element of # its generating .proto file. - annotation: List["GeneratedCodeInfoAnnotation"] = betterproto.message_field(1) + annotation: List["GeneratedCodeInfoAnnotation"] = message_field(1) @dataclass(eq=False, repr=False) -class GeneratedCodeInfoAnnotation(betterproto.Message): +class GeneratedCodeInfoAnnotation(Message): # Identifies the element in the original source .proto file. This field is # formatted the same as SourceCodeInfo.Location.path. - path: List[int] = betterproto.int32_field(1) + path: List[int] = int32_field(1) # Identifies the filesystem path to the original source .proto. - source_file: str = betterproto.string_field(2) + source_file: str = string_field(2) # Identifies the starting offset in bytes in the generated code that relates # to the identified object. - begin: int = betterproto.int32_field(3) + begin: int = int32_field(3) # Identifies the ending offset in bytes in the generated code that relates to # the identified offset. The end offset should be one past the last relevant # byte (so the length of the text = end - begin). - end: int = betterproto.int32_field(4) + end: int = int32_field(4) @dataclass(eq=False, repr=False) -class Duration(betterproto.Message): +class Duration(Message): """ A Duration represents a signed, fixed-length span of time represented as a count of seconds and fractions of seconds at nanosecond resolution. It is @@ -1015,17 +1010,17 @@ class Duration(betterproto.Message): # Signed seconds of the span of time. Must be from -315,576,000,000 to # +315,576,000,000 inclusive. Note: these bounds are computed from: 60 # sec/min * 60 min/hr * 24 hr/day * 365.25 days/year * 10000 years - seconds: int = betterproto.int64_field(1) + seconds: int = int64_field(1) # Signed fractions of a second at nanosecond resolution of the span of time. # Durations less than one second are represented with a 0 `seconds` field and # a positive or negative `nanos` field. For durations of one second or more, # a non-zero value for the `nanos` field must be of the same sign as the # `seconds` field. Must be from -999,999,999 to +999,999,999 inclusive. - nanos: int = betterproto.int32_field(2) + nanos: int = int32_field(2) @dataclass(eq=False, repr=False) -class Empty(betterproto.Message): +class Empty(Message): """ A generic empty message that you can re-use to avoid defining duplicated empty messages in your APIs. A typical example is to use it as the request @@ -1038,7 +1033,7 @@ class Empty(betterproto.Message): @dataclass(eq=False, repr=False) -class FieldMask(betterproto.Message): +class FieldMask(Message): """ `FieldMask` represents a set of symbolic field paths, for example: paths: "f.a" paths: "f.b.d" Here `f` represents a field in some root @@ -1118,11 +1113,11 @@ class FieldMask(betterproto.Message): """ # The set of field mask paths. - paths: List[str] = betterproto.string_field(1) + paths: List[str] = string_field(1) @dataclass(eq=False, repr=False) -class Struct(betterproto.Message): +class Struct(Message): """ `Struct` represents a structured data value, consisting of fields which map to dynamically typed values. In some languages, `Struct` might be supported @@ -1133,13 +1128,11 @@ class Struct(betterproto.Message): """ # Unordered map of dynamically typed values. - fields: Dict[str, "Value"] = betterproto.map_field( - 1, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE - ) + fields: Dict[str, "Value"] = map_field(1, TYPE_STRING, TYPE_MESSAGE) @dataclass(eq=False, repr=False) -class Value(betterproto.Message): +class Value(Message): """ `Value` represents a dynamically typed value which can be either null, a number, a string, a boolean, a recursive struct value, or a list of values. @@ -1149,32 +1142,32 @@ class Value(betterproto.Message): """ # Represents a null value. - null_value: "NullValue" = betterproto.enum_field(1, group="kind") + null_value: "NullValue" = enum_field(1, group="kind") # Represents a double value. - number_value: float = betterproto.double_field(2, group="kind") + number_value: float = double_field(2, group="kind") # Represents a string value. - string_value: str = betterproto.string_field(3, group="kind") + string_value: str = string_field(3, group="kind") # Represents a boolean value. - bool_value: bool = betterproto.bool_field(4, group="kind") + bool_value: bool = bool_field(4, group="kind") # Represents a structured value. - struct_value: "Struct" = betterproto.message_field(5, group="kind") + struct_value: "Struct" = message_field(5, group="kind") # Represents a repeated `Value`. - list_value: "ListValue" = betterproto.message_field(6, group="kind") + list_value: "ListValue" = message_field(6, group="kind") @dataclass(eq=False, repr=False) -class ListValue(betterproto.Message): +class ListValue(Message): """ `ListValue` is a wrapper around a repeated field of values. The JSON representation for `ListValue` is JSON array. """ # Repeated field of dynamically typed values. - values: List["Value"] = betterproto.message_field(1) + values: List["Value"] = message_field(1) @dataclass(eq=False, repr=False) -class Timestamp(betterproto.Message): +class Timestamp(Message): """ A Timestamp represents a point in time independent of any time zone or local calendar, encoded as a count of seconds and fractions of seconds at @@ -1233,107 +1226,107 @@ class Timestamp(betterproto.Message): # Represents seconds of UTC time since Unix epoch 1970-01-01T00:00:00Z. Must # be from 0001-01-01T00:00:00Z to 9999-12-31T23:59:59Z inclusive. - seconds: int = betterproto.int64_field(1) + seconds: int = int64_field(1) # Non-negative fractions of a second at nanosecond resolution. Negative # second values with fractions must still have non-negative nanos values that # count forward in time. Must be from 0 to 999,999,999 inclusive. - nanos: int = betterproto.int32_field(2) + nanos: int = int32_field(2) @dataclass(eq=False, repr=False) -class DoubleValue(betterproto.Message): +class DoubleValue(Message): """ Wrapper message for `double`. The JSON representation for `DoubleValue` is JSON number. """ # The double value. - value: float = betterproto.double_field(1) + value: float = double_field(1) @dataclass(eq=False, repr=False) -class FloatValue(betterproto.Message): +class FloatValue(Message): """ Wrapper message for `float`. The JSON representation for `FloatValue` is JSON number. """ # The float value. - value: float = betterproto.float_field(1) + value: float = float_field(1) @dataclass(eq=False, repr=False) -class Int64Value(betterproto.Message): +class Int64Value(Message): """ Wrapper message for `int64`. The JSON representation for `Int64Value` is JSON string. """ # The int64 value. - value: int = betterproto.int64_field(1) + value: int = int64_field(1) @dataclass(eq=False, repr=False) -class UInt64Value(betterproto.Message): +class UInt64Value(Message): """ Wrapper message for `uint64`. The JSON representation for `UInt64Value` is JSON string. """ # The uint64 value. - value: int = betterproto.uint64_field(1) + value: int = uint64_field(1) @dataclass(eq=False, repr=False) -class Int32Value(betterproto.Message): +class Int32Value(Message): """ Wrapper message for `int32`. The JSON representation for `Int32Value` is JSON number. """ # The int32 value. - value: int = betterproto.int32_field(1) + value: int = int32_field(1) @dataclass(eq=False, repr=False) -class UInt32Value(betterproto.Message): +class UInt32Value(Message): """ Wrapper message for `uint32`. The JSON representation for `UInt32Value` is JSON number. """ # The uint32 value. - value: int = betterproto.uint32_field(1) + value: int = uint32_field(1) @dataclass(eq=False, repr=False) -class BoolValue(betterproto.Message): +class BoolValue(Message): """ Wrapper message for `bool`. The JSON representation for `BoolValue` is JSON `true` and `false`. """ # The bool value. - value: bool = betterproto.bool_field(1) + value: bool = bool_field(1) @dataclass(eq=False, repr=False) -class StringValue(betterproto.Message): +class StringValue(Message): """ Wrapper message for `string`. The JSON representation for `StringValue` is JSON string. """ # The string value. - value: str = betterproto.string_field(1) + value: str = string_field(1) @dataclass(eq=False, repr=False) -class BytesValue(betterproto.Message): +class BytesValue(Message): """ Wrapper message for `bytes`. The JSON representation for `BytesValue` is JSON string. """ # The bytes value. - value: bytes = betterproto.bytes_field(1) + value: bytes = bytes_field(1) diff --git a/src/betterproto/message.py b/src/betterproto/message.py new file mode 100644 index 000000000..8f8ab2e34 --- /dev/null +++ b/src/betterproto/message.py @@ -0,0 +1,1057 @@ +import dataclasses +import json +import math +import struct +import sys +import typing +from abc import ABC +from base64 import b64decode, b64encode +from datetime import datetime, timedelta, timezone +from dateutil.parser import isoparse +from typing import ( + Any, + Callable, + Dict, + Iterable, + Optional, + Set, + Tuple, + Type, + Union, + get_type_hints, +) + +from .casing import Casing +from .enum import Enum + +from ._types import T +from .const import * + + +PLACEHOLDER: Any = object() + +__all__ = ( + "Message", + "enum_field", + "bool_field", + "int32_field", + "int64_field", + "uint32_field", + "uint64_field", + "sint32_field", + "sint64_field", + "float_field", + "double_field", + "fixed32_field", + "fixed64_field", + "sfixed32_field", + "sfixed64_field", + "string_field", + "bytes_field", + "message_field", + "map_field", +) + + +@dataclasses.dataclass(frozen=True) +class FieldMetadata: + """Stores internal metadata used for parsing & serialization.""" + + # Protobuf field number + number: int + # Protobuf type name + proto_type: str + # Map information if the proto_type is a map + map_types: Optional[Tuple[str, str]] = None + # Groups several "one-of" fields together + group: Optional[str] = None + # Describes the wrapped type (e.g. when using google.protobuf.BoolValue) + wraps: Optional[str] = None + # Is the field optional + optional: Optional[bool] = False + + @staticmethod + def get(field: dataclasses.Field) -> "FieldMetadata": + """Returns the field metadata for a dataclass field.""" + return field.metadata["betterproto"] + + +def dataclass_field( + number: int, + proto_type: str, + *, + map_types: Optional[Tuple[str, str]] = None, + group: Optional[str] = None, + wraps: Optional[str] = None, + optional: bool = False, +) -> dataclasses.Field: + """Creates a dataclass field with attached protobuf metadata.""" + return dataclasses.field( + default=None if optional else PLACEHOLDER, + metadata={ + "betterproto": FieldMetadata( + number, proto_type, map_types, group, wraps, optional + ) + }, + ) + + +# Note: the fields below return `Any` to prevent type errors in the generated +# data classes since the types won't match with `Field` and they get swapped +# out at runtime. The generated dataclass variables are still typed correctly. + + +def enum_field(number: int, group: Optional[str] = None, optional: bool = False) -> Any: + return dataclass_field(number, TYPE_ENUM, group=group, optional=optional) + + +def bool_field(number: int, group: Optional[str] = None, optional: bool = False) -> Any: + return dataclass_field(number, TYPE_BOOL, group=group, optional=optional) + + +def int32_field( + number: int, group: Optional[str] = None, optional: bool = False +) -> Any: + return dataclass_field(number, TYPE_INT32, group=group, optional=optional) + + +def int64_field( + number: int, group: Optional[str] = None, optional: bool = False +) -> Any: + return dataclass_field(number, TYPE_INT64, group=group, optional=optional) + + +def uint32_field( + number: int, group: Optional[str] = None, optional: bool = False +) -> Any: + return dataclass_field(number, TYPE_UINT32, group=group, optional=optional) + + +def uint64_field( + number: int, group: Optional[str] = None, optional: bool = False +) -> Any: + return dataclass_field(number, TYPE_UINT64, group=group, optional=optional) + + +def sint32_field( + number: int, group: Optional[str] = None, optional: bool = False +) -> Any: + return dataclass_field(number, TYPE_SINT32, group=group, optional=optional) + + +def sint64_field( + number: int, group: Optional[str] = None, optional: bool = False +) -> Any: + return dataclass_field(number, TYPE_SINT64, group=group, optional=optional) + + +def float_field( + number: int, group: Optional[str] = None, optional: bool = False +) -> Any: + return dataclass_field(number, TYPE_FLOAT, group=group, optional=optional) + + +def double_field( + number: int, group: Optional[str] = None, optional: bool = False +) -> Any: + return dataclass_field(number, TYPE_DOUBLE, group=group, optional=optional) + + +def fixed32_field( + number: int, group: Optional[str] = None, optional: bool = False +) -> Any: + return dataclass_field(number, TYPE_FIXED32, group=group, optional=optional) + + +def fixed64_field( + number: int, group: Optional[str] = None, optional: bool = False +) -> Any: + return dataclass_field(number, TYPE_FIXED64, group=group, optional=optional) + + +def sfixed32_field( + number: int, group: Optional[str] = None, optional: bool = False +) -> Any: + return dataclass_field(number, TYPE_SFIXED32, group=group, optional=optional) + + +def sfixed64_field( + number: int, group: Optional[str] = None, optional: bool = False +) -> Any: + return dataclass_field(number, TYPE_SFIXED64, group=group, optional=optional) + + +def string_field( + number: int, group: Optional[str] = None, optional: bool = False +) -> Any: + return dataclass_field(number, TYPE_STRING, group=group, optional=optional) + + +def bytes_field( + number: int, group: Optional[str] = None, optional: bool = False +) -> Any: + return dataclass_field(number, TYPE_BYTES, group=group, optional=optional) + + +def message_field( + number: int, + group: Optional[str] = None, + wraps: Optional[str] = None, + optional: bool = False, +) -> Any: + return dataclass_field( + number, TYPE_MESSAGE, group=group, wraps=wraps, optional=optional + ) + + +def map_field( + number: int, key_type: str, value_type: str, group: Optional[str] = None +) -> Any: + return dataclass_field( + number, TYPE_MAP, map_types=(key_type, value_type), group=group + ) + + +class ProtoClassMetadata: + __slots__ = ( + "oneof_group_by_field", + "oneof_field_by_group", + "default_gen", + "cls_by_field", + "field_name_by_number", + "meta_by_field_name", + "sorted_field_names", + ) + + oneof_group_by_field: Dict[str, str] + oneof_field_by_group: Dict[str, Set[dataclasses.Field]] + field_name_by_number: Dict[int, str] + meta_by_field_name: Dict[str, FieldMetadata] + sorted_field_names: Tuple[str, ...] + default_gen: Dict[str, Callable[[], Any]] + cls_by_field: Dict[str, Type] + + def __init__(self, cls: Type["Message"]): + by_field = {} + by_group: Dict[str, Set] = {} + by_field_name = {} + by_field_number = {} + + fields = dataclasses.fields(cls) + for field in fields: + meta = FieldMetadata.get(field) + + if meta.group: + # This is part of a one-of group. + by_field[field.name] = meta.group + + by_group.setdefault(meta.group, set()).add(field) + + by_field_name[field.name] = meta + by_field_number[meta.number] = field.name + + self.oneof_group_by_field = by_field + self.oneof_field_by_group = by_group + self.field_name_by_number = by_field_number + self.meta_by_field_name = by_field_name + self.sorted_field_names = tuple( + by_field_number[number] for number in sorted(by_field_number) + ) + self.default_gen = self._get_default_gen(cls, fields) + self.cls_by_field = self._get_cls_by_field(cls, fields) + + @staticmethod + def _get_default_gen( + cls: Type["Message"], fields: Iterable[dataclasses.Field] + ) -> Dict[str, Callable[[], Any]]: + return {field.name: cls._get_field_default_gen(field) for field in fields} + + @staticmethod + def _get_cls_by_field( + cls: Type["Message"], fields: Iterable[dataclasses.Field] + ) -> Dict[str, Type]: + field_cls = {} + + for field in fields: + meta = FieldMetadata.get(field) + if meta.proto_type == TYPE_MAP: + assert meta.map_types + kt = cls._cls_for(field, index=0) + vt = cls._cls_for(field, index=1) + field_cls[field.name] = dataclasses.make_dataclass( + "Entry", + [ + ("key", kt, dataclass_field(1, meta.map_types[0])), + ("value", vt, dataclass_field(2, meta.map_types[1])), + ], + bases=(Message,), + ) + field_cls[f"{field.name}.value"] = vt + else: + field_cls[field.name] = cls._cls_for(field) + + return field_cls + + +class Message(ABC): + """ + The base class for protobuf messages, all generated messages will inherit from + this. This class registers the message fields which are used by the serializers and + parsers to go between the Python, binary and JSON representations of the message. + + .. container:: operations + + .. describe:: bytes(x) + + Calls :meth:`__bytes__`. + + .. describe:: bool(x) + + Calls :meth:`__bool__`. + """ + + _serialized_on_wire: bool + _unknown_fields: bytes + _group_current: Dict[str, str] + + def __post_init__(self) -> None: + # Keep track of whether every field was default + all_sentinel = True + + # Set current field of each group after `__init__` has already been run. + group_current: Dict[str, Optional[str]] = {} + for field_name, meta in self._betterproto.meta_by_field_name.items(): + + if meta.group: + group_current.setdefault(meta.group) + + value = self.__raw_get(field_name) + if value != PLACEHOLDER and not (meta.optional and value is None): + # Found a non-sentinel value + all_sentinel = False + + if meta.group: + # This was set, so make it the selected value of the one-of. + group_current[meta.group] = field_name + + # Now that all the defaults are set, reset it! + self.__dict__["_serialized_on_wire"] = not all_sentinel + self.__dict__["_unknown_fields"] = b"" + self.__dict__["_group_current"] = group_current + + def __raw_get(self, name: str) -> Any: + return super().__getattribute__(name) + + def __eq__(self, other) -> bool: + if type(self) is not type(other): + return False + + for field_name in self._betterproto.meta_by_field_name: + self_val = self.__raw_get(field_name) + other_val = other.__raw_get(field_name) + if self_val is PLACEHOLDER: + if other_val is PLACEHOLDER: + continue + self_val = self._get_field_default(field_name) + elif other_val is PLACEHOLDER: + other_val = other._get_field_default(field_name) + + if self_val != other_val: + # We consider two nan values to be the same for the + # purposes of comparing messages (otherwise a message + # is not equal to itself) + if ( + isinstance(self_val, float) + and isinstance(other_val, float) + and math.isnan(self_val) + and math.isnan(other_val) + ): + continue + else: + return False + + return True + + def __repr__(self) -> str: + parts = [ + f"{field_name}={value!r}" + for field_name in self._betterproto.sorted_field_names + for value in (self.__raw_get(field_name),) + if value is not PLACEHOLDER + ] + return f"{self.__class__.__name__}({', '.join(parts)})" + + def __getattribute__(self, name: str) -> Any: + """ + Lazily initialize default values to avoid infinite recursion for recursive + message types + """ + value = super().__getattribute__(name) + if value is not PLACEHOLDER: + return value + + value = self._get_field_default(name) + super().__setattr__(name, value) + return value + + def __setattr__(self, attr: str, value: Any) -> None: + if attr != "_serialized_on_wire": + # Track when a field has been set. + self.__dict__["_serialized_on_wire"] = True + + if hasattr(self, "_group_current"): # __post_init__ had already run + if attr in self._betterproto.oneof_group_by_field: + group = self._betterproto.oneof_group_by_field[attr] + for field in self._betterproto.oneof_field_by_group[group]: + if field.name == attr: + self._group_current[group] = field.name + else: + super().__setattr__(field.name, PLACEHOLDER) + + super().__setattr__(attr, value) + + def __bool__(self) -> bool: + """True if the Message has any fields with non-default values.""" + return any( + self.__raw_get(field_name) + not in (PLACEHOLDER, self._get_field_default(field_name)) + for field_name in self._betterproto.meta_by_field_name + ) + + @property + def _betterproto(self) -> ProtoClassMetadata: + """ + Lazy initialize metadata for each protobuf class. + It may be initialized multiple times in a multi-threaded environment, + but that won't affect the correctness. + """ + meta = getattr(self.__class__, "_betterproto_meta", None) + if not meta: + meta = ProtoClassMetadata(self.__class__) + self.__class__._betterproto_meta = meta # type: ignore + return meta + + def __bytes__(self) -> bytes: + """ + Get the binary encoded Protobuf representation of this message instance. + """ + output = bytearray() + for field_name, meta in self._betterproto.meta_by_field_name.items(): + value = getattr(self, field_name) + + if value is None: + # Optional items should be skipped. This is used for the Google + # wrapper types and proto3 field presence/optional fields. + continue + + # Being selected in a a group means this field is the one that is + # currently set in a `oneof` group, so it must be serialized even + # if the value is the default zero value. + # + # Note that proto3 field presence/optional fields are put in a + # synthetic single-item oneof by protoc, which helps us ensure we + # send the value even if the value is the default zero value. + selected_in_group = ( + meta.group and self._group_current[meta.group] == field_name + ) + + # Empty messages can still be sent on the wire if they were + # set (or received empty). + serialize_empty = isinstance(value, Message) and value._serialized_on_wire + + include_default_value_for_oneof = self._include_default_value_for_oneof( + field_name=field_name, meta=meta + ) + + if value == self._get_field_default(field_name) and not ( + selected_in_group or serialize_empty or include_default_value_for_oneof + ): + # Default (zero) values are not serialized. Two exceptions are + # if this is the selected oneof item or if we know we have to + # serialize an empty message (i.e. zero value was explicitly + # set by the user). + continue + + if isinstance(value, list): + if meta.proto_type in PACKED_TYPES: + # Packed lists look like a length-delimited field. First, + # preprocess/encode each value into a buffer and then + # treat it like a field of raw bytes. + buf = bytearray() + for item in value: + buf += _preprocess_single(meta.proto_type, "", item) + output += _serialize_single(meta.number, TYPE_BYTES, buf) + else: + for item in value: + output += ( + _serialize_single( + meta.number, + meta.proto_type, + item, + wraps=meta.wraps or "", + ) + # if it's an empty message it still needs to be represented + # as an item in the repeated list + or b"\n\x00" + ) + + elif isinstance(value, dict): + for k, v in value.items(): + assert meta.map_types + sk = _serialize_single(1, meta.map_types[0], k) + sv = _serialize_single(2, meta.map_types[1], v) + output += _serialize_single(meta.number, meta.proto_type, sk + sv) + else: + # If we have an empty string and we're including the default value for + # a oneof, make sure we serialize it. This ensures that the byte string + # output isn't simply an empty string. This also ensures that round trip + # serialization will keep `which_one_of` calls consistent. + if ( + isinstance(value, str) + and value == "" + and include_default_value_for_oneof + ): + serialize_empty = True + + output += _serialize_single( + meta.number, + meta.proto_type, + value, + serialize_empty=serialize_empty or bool(selected_in_group), + wraps=meta.wraps or "", + ) + + output += self._unknown_fields + return bytes(output) + + # For compatibility with other libraries + def SerializeToString(self: T) -> bytes: + """ + Get the binary encoded Protobuf representation of this message instance. + + .. note:: + This is a method for compatibility with other libraries, + you should really use ``bytes(x)``. + + Returns + -------- + :class:`bytes` + The binary encoded Protobuf representation of this message instance + """ + return bytes(self) + + @classmethod + def _type_hint(cls, field_name: str) -> Type: + return cls._type_hints()[field_name] + + @classmethod + def _type_hints(cls) -> Dict[str, Type]: + module = sys.modules[cls.__module__] + return get_type_hints(cls, module.__dict__, {}) + + @classmethod + def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type: + """Get the message class for a field from the type hints.""" + field_cls = cls._type_hint(field.name) + if hasattr(field_cls, "__args__") and index >= 0: + if field_cls.__args__ is not None: + field_cls = field_cls.__args__[index] + return field_cls + + def _get_field_default(self, field_name: str) -> Any: + return self._betterproto.default_gen[field_name]() + + @classmethod + def _get_field_default_gen(cls, field: dataclasses.Field) -> Any: + t = cls._type_hint(field.name) + + if hasattr(t, "__origin__"): + if t.__origin__ in (dict, Dict): + # This is some kind of map (dict in Python). + return dict + elif t.__origin__ in (list, List): + # This is some kind of list (repeated) field. + return list + elif t.__origin__ is Union and t.__args__[1] is type(None): + # This is an optional field (either wrapped, or using proto3 + # field presence). For setting the default we really don't care + # what kind of field it is. + return type(None) + else: + return t + elif issubclass(t, Enum): + # Enums always default to zero. + return int + elif t is datetime: + # Offsets are relative to 1970-01-01T00:00:00Z + return datetime_default_gen + else: + # This is either a primitive scalar or another message type. Calling + # it should result in its zero value. + return t + + def _postprocess_single( + self, wire_type: int, meta: FieldMetadata, field_name: str, value: Any + ) -> Any: + """Adjusts values after parsing.""" + if wire_type == WIRE_VARINT: + if meta.proto_type in (TYPE_INT32, TYPE_INT64): + bits = int(meta.proto_type[3:]) + value = value & ((1 << bits) - 1) + signbit = 1 << (bits - 1) + value = int((value ^ signbit) - signbit) + elif meta.proto_type in (TYPE_SINT32, TYPE_SINT64): + # Undo zig-zag encoding + value = (value >> 1) ^ (-(value & 1)) + elif meta.proto_type == TYPE_BOOL: + # Booleans use a varint encoding, so convert it to true/false. + value = value > 0 + elif wire_type in (WIRE_FIXED_32, WIRE_FIXED_64): + fmt = _pack_fmt(meta.proto_type) + value = struct.unpack(fmt, value)[0] + elif wire_type == WIRE_LEN_DELIM: + if meta.proto_type == TYPE_STRING: + value = str(value, "utf-8") + elif meta.proto_type == TYPE_MESSAGE: + cls = self._betterproto.cls_by_field[field_name] + + if cls == datetime: + value = _Timestamp().parse(value).to_datetime() + elif cls == timedelta: + value = _Duration().parse(value).to_timedelta() + elif meta.wraps: + # This is a Google wrapper value message around a single + # scalar type. + value = _get_wrapper(meta.wraps)().parse(value).value + else: + value = cls().parse(value) + value._serialized_on_wire = True + elif meta.proto_type == TYPE_MAP: + value = self._betterproto.cls_by_field[field_name]().parse(value) + + return value + + def _include_default_value_for_oneof( + self, field_name: str, meta: FieldMetadata + ) -> bool: + return ( + meta.group is not None and self._group_current.get(meta.group) == field_name + ) + + def parse(self: T, data: bytes) -> T: + """ + Parse the binary encoded Protobuf into this message instance. This + returns the instance itself and is therefore assignable and chainable. + + Parameters + ----------- + data: :class:`bytes` + The data to parse the protobuf from. + + Returns + -------- + :class:`Message` + The initialized message. + """ + # Got some data over the wire + self._serialized_on_wire = True + proto_meta = self._betterproto + for parsed in parse_fields(data): + field_name = proto_meta.field_name_by_number.get(parsed.number) + if not field_name: + self._unknown_fields += parsed.raw + continue + + meta = proto_meta.meta_by_field_name[field_name] + + value: Any + if parsed.wire_type == WIRE_LEN_DELIM and meta.proto_type in PACKED_TYPES: + # This is a packed repeated field. + pos = 0 + value = [] + while pos < len(parsed.value): + if meta.proto_type in (TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32): + decoded, pos = parsed.value[pos : pos + 4], pos + 4 + wire_type = WIRE_FIXED_32 + elif meta.proto_type in (TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64): + decoded, pos = parsed.value[pos : pos + 8], pos + 8 + wire_type = WIRE_FIXED_64 + else: + decoded, pos = decode_varint(parsed.value, pos) + wire_type = WIRE_VARINT + decoded = self._postprocess_single( + wire_type, meta, field_name, decoded + ) + value.append(decoded) + else: + value = self._postprocess_single( + parsed.wire_type, meta, field_name, parsed.value + ) + + current = getattr(self, field_name) + if meta.proto_type == TYPE_MAP: + # Value represents a single key/value pair entry in the map. + current[value.key] = value.value + elif isinstance(current, list) and not isinstance(value, list): + current.append(value) + else: + setattr(self, field_name, value) + + return self + + # For compatibility with other libraries. + @classmethod + def FromString(cls: Type[T], data: bytes) -> T: + """ + Parse the binary encoded Protobuf into this message instance. This + returns the instance itself and is therefore assignable and chainable. + + .. note:: + This is a method for compatibility with other libraries, + you should really use :meth:`parse`. + + + Parameters + ----------- + data: :class:`bytes` + The data to parse the protobuf from. + + Returns + -------- + :class:`Message` + The initialized message. + """ + return cls().parse(data) + + def to_dict( + self, casing: Casing = Casing.CAMEL, include_default_values: bool = False + ) -> Dict[str, Any]: + """ + Returns a JSON serializable dict representation of this object. + + Parameters + ----------- + casing: :class:`Casing` + The casing to use for key values. Default is :attr:`Casing.CAMEL` for + compatibility purposes. + include_default_values: :class:`bool` + If ``True`` will include the default values of fields. Default is ``False``. + E.g. an ``int32`` field will be included with a value of ``0`` if this is + set to ``True``, otherwise this would be ignored. + + Returns + -------- + Dict[:class:`str`, Any] + The JSON serializable dict representation of this object. + """ + output: Dict[str, Any] = {} + field_types = self._type_hints() + defaults = self._betterproto.default_gen + for field_name, meta in self._betterproto.meta_by_field_name.items(): + field_is_repeated = defaults[field_name] is list + value = getattr(self, field_name) + cased_name = casing(field_name).rstrip("_") # type: ignore + if meta.proto_type == TYPE_MESSAGE: + if isinstance(value, datetime): + if ( + value != DATETIME_ZERO + or include_default_values + or self._include_default_value_for_oneof( + field_name=field_name, meta=meta + ) + ): + output[cased_name] = _Timestamp.timestamp_to_json(value) + elif isinstance(value, timedelta): + if ( + value != timedelta(0) + or include_default_values + or self._include_default_value_for_oneof( + field_name=field_name, meta=meta + ) + ): + output[cased_name] = _Duration.delta_to_json(value) + elif meta.wraps: + if value is not None or include_default_values: + output[cased_name] = value + elif field_is_repeated: + # Convert each item. + cls = self._betterproto.cls_by_field[field_name] + if cls == datetime: + value = [_Timestamp.timestamp_to_json(i) for i in value] + elif cls == timedelta: + value = [_Duration.delta_to_json(i) for i in value] + else: + value = [ + i.to_dict(casing, include_default_values) for i in value + ] + if value or include_default_values: + output[cased_name] = value + elif value is None: + if include_default_values: + output[cased_name] = value + elif ( + value._serialized_on_wire + or include_default_values + or self._include_default_value_for_oneof( + field_name=field_name, meta=meta + ) + ): + output[cased_name] = value.to_dict(casing, include_default_values) + elif meta.proto_type == TYPE_MAP: + for k in value: + if hasattr(value[k], "to_dict"): + value[k] = value[k].to_dict(casing, include_default_values) + + if value or include_default_values: + output[cased_name] = value + elif ( + value != self._get_field_default(field_name) + or include_default_values + or self._include_default_value_for_oneof( + field_name=field_name, meta=meta + ) + ): + if meta.proto_type in INT_64_TYPES: + if field_is_repeated: + output[cased_name] = [str(n) for n in value] + elif value is None: + if include_default_values: + output[cased_name] = value + else: + output[cased_name] = str(value) + elif meta.proto_type == TYPE_BYTES: + if field_is_repeated: + output[cased_name] = [ + b64encode(b).decode("utf8") for b in value + ] + elif value is None and include_default_values: + output[cased_name] = value + else: + output[cased_name] = b64encode(value).decode("utf8") + elif meta.proto_type == TYPE_ENUM: + if field_is_repeated: + enum_class = field_types[field_name].__args__[0] + if isinstance(value, typing.Iterable) and not isinstance( + value, str + ): + output[cased_name] = [enum_class(el).name for el in value] + else: + # transparently upgrade single value to repeated + output[cased_name] = [enum_class(value).name] + elif value is None: + if include_default_values: + output[cased_name] = value + elif meta.optional: + enum_class = field_types[field_name].__args__[0] + output[cased_name] = enum_class(value).name + else: + enum_class = field_types[field_name] # noqa + output[cased_name] = enum_class(value).name + elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE): + if field_is_repeated: + output[cased_name] = [_dump_float(n) for n in value] + else: + output[cased_name] = _dump_float(value) + else: + output[cased_name] = value + return output + + def from_dict(self: T, value: Dict[str, Any]) -> T: + """ + Parse the key/value pairs into the current message instance. This returns the + instance itself and is therefore assignable and chainable. + + Parameters + ----------- + value: Dict[:class:`str`, Any] + The dictionary to parse from. + + Returns + -------- + :class:`Message` + The initialized message. + """ + self._serialized_on_wire = True + for key in value: + field_name = safe_snake_case(key) + meta = self._betterproto.meta_by_field_name.get(field_name) + if not meta: + continue + + if value[key] is not None: + if meta.proto_type == TYPE_MESSAGE: + v = getattr(self, field_name) + if isinstance(v, list): + cls = self._betterproto.cls_by_field[field_name] + if cls == datetime: + v = [isoparse(item) for item in value[key]] + elif cls == timedelta: + v = [ + timedelta(seconds=float(item[:-1])) + for item in value[key] + ] + else: + v = [cls().from_dict(item) for item in value[key]] + elif isinstance(v, datetime): + v = isoparse(value[key]) + setattr(self, field_name, v) + elif isinstance(v, timedelta): + v = timedelta(seconds=float(value[key][:-1])) + setattr(self, field_name, v) + elif meta.wraps: + setattr(self, field_name, value[key]) + elif v is None: + cls = self._betterproto.cls_by_field[field_name] + setattr(self, field_name, cls().from_dict(value[key])) + else: + # NOTE: `from_dict` mutates the underlying message, so no + # assignment here is necessary. + v.from_dict(value[key]) + elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE: + v = getattr(self, field_name) + cls = self._betterproto.cls_by_field[f"{field_name}.value"] + for k in value[key]: + v[k] = cls().from_dict(value[key][k]) + else: + v = value[key] + if meta.proto_type in INT_64_TYPES: + if isinstance(value[key], list): + v = [int(n) for n in value[key]] + else: + v = int(value[key]) + elif meta.proto_type == TYPE_BYTES: + if isinstance(value[key], list): + v = [b64decode(n) for n in value[key]] + else: + v = b64decode(value[key]) + elif meta.proto_type == TYPE_ENUM: + enum_cls = self._betterproto.cls_by_field[field_name] + if isinstance(v, list): + v = [enum_cls.from_string(e) for e in v] + elif isinstance(v, str): + v = enum_cls.from_string(v) + elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE): + if isinstance(value[key], list): + v = [_parse_float(n) for n in value[key]] + else: + v = _parse_float(value[key]) + + if v is not None: + setattr(self, field_name, v) + return self + + def to_json(self, indent: Union[None, int, str] = None) -> str: + """A helper function to parse the message instance into its JSON + representation. + + This is equivalent to:: + + json.dumps(message.to_dict(), indent=indent) + + Parameters + ----------- + indent: Optional[Union[:class:`int`, :class:`str`]] + The indent to pass to :func:`json.dumps`. + + Returns + -------- + :class:`str` + The JSON representation of the message. + """ + return json.dumps(self.to_dict(), indent=indent) + + def from_json(self: T, value: Union[str, bytes]) -> T: + """A helper function to return the message instance from its JSON + representation. This returns the instance itself and is therefore assignable + and chainable. + + This is equivalent to:: + + return message.from_dict(json.loads(value)) + + Parameters + ----------- + value: Union[:class:`str`, :class:`bytes`] + The value to pass to :func:`json.loads`. + + Returns + -------- + :class:`Message` + The initialized message. + """ + return self.from_dict(json.loads(value)) + + +# Circular import workaround: google.protobuf depends on base classes defined above. +from .lib.google.protobuf import ( # noqa + BoolValue, + BytesValue, + DoubleValue, + Duration, + EnumValue, + FloatValue, + Int32Value, + Int64Value, + StringValue, + Timestamp, + UInt32Value, + UInt64Value, +) + + +class _Duration(Duration): + def to_timedelta(self) -> timedelta: + return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3) + + @staticmethod + def delta_to_json(delta: timedelta) -> str: + parts = str(delta.total_seconds()).split(".") + if len(parts) > 1: + while len(parts[1]) not in (3, 6, 9): + parts[1] = f"{parts[1]}0" + return f"{'.'.join(parts)}s" + + +class _Timestamp(Timestamp): + def to_datetime(self) -> datetime: + ts = self.seconds + (self.nanos / 1e9) + return datetime.fromtimestamp(ts, tz=timezone.utc) + + @staticmethod + def timestamp_to_json(dt: datetime) -> str: + nanos = dt.microsecond * 1e3 + copy = dt.replace(microsecond=0, tzinfo=None) + result = copy.isoformat() + if (nanos % 1e9) == 0: + # If there are 0 fractional digits, the fractional + # point '.' should be omitted when serializing. + return f"{result}Z" + if (nanos % 1e6) == 0: + # Serialize 3 fractional digits. + return f"{result}.{int(nanos // 1e6) :03d}Z" + if (nanos % 1e3) == 0: + # Serialize 6 fractional digits. + return f"{result}.{int(nanos // 1e3) :06d}Z" + # Serialize 9 fractional digits. + return f"{result}.{nanos:09d}" + + +def _get_wrapper(proto_type: str) -> Type: + """Get the wrapper message class for a wrapped type.""" + + # TODO: include ListValue and NullValue? + return { + TYPE_BOOL: BoolValue, + TYPE_BYTES: BytesValue, + TYPE_DOUBLE: DoubleValue, + TYPE_FLOAT: FloatValue, + TYPE_ENUM: EnumValue, + TYPE_INT32: Int32Value, + TYPE_INT64: Int64Value, + TYPE_STRING: StringValue, + TYPE_UINT32: UInt32Value, + TYPE_UINT64: UInt64Value, + }[proto_type] + + +from .io import * diff --git a/src/betterproto/utils.py b/src/betterproto/utils.py new file mode 100644 index 000000000..123c596c8 --- /dev/null +++ b/src/betterproto/utils.py @@ -0,0 +1,36 @@ +from typing import TYPE_CHECKING + + +from typing import TYPE_CHECKING, Tuple, Optional, Any + +if TYPE_CHECKING: + from .message import Message + + +def serialized_on_wire(message: "Message") -> bool: + """ + If this message was or should be serialized on the wire. This can be used to detect + presence (e.g. optional wrapper message) and is used internally during + parsing/serialization. + + Returns + -------- + :class:`bool` + Whether this message was or should be serialized on the wire. + """ + return message._serialized_on_wire + + +def which_one_of(message: "Message", group_name: str) -> Tuple[str, Optional[Any]]: + """ + Return the name and value of a message's one-of field group. + + Returns + -------- + Tuple[:class:`str`, Any] + The field name and the value for that field. + """ + field_name = message._group_current.get(group_name) + if not field_name: + return "", None + return field_name, getattr(message, field_name)