Skip to content

Commit 4a2baf3

Browse files
authored
Merge pull request #46 from jameslan/perf/class-cache
Improve performance of serialize/deserialize by caching type information of fields in class
2 parents 5e2d9fe + 917de09 commit 4a2baf3

File tree

3 files changed

+186
-116
lines changed

3 files changed

+186
-116
lines changed

betterproto/__init__.py

+119-59
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,11 @@
120120

121121

122122
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
123-
DATETIME_ZERO = datetime(1970, 1, 1, tzinfo=timezone.utc)
123+
def datetime_default_gen():
124+
return datetime(1970, 1, 1, tzinfo=timezone.utc)
125+
126+
127+
DATETIME_ZERO = datetime_default_gen()
124128

125129

126130
class Casing(enum.Enum):
@@ -428,6 +432,63 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
428432
T = TypeVar("T", bound="Message")
429433

430434

435+
class ProtoClassMetadata:
436+
cls: Type["Message"]
437+
438+
def __init__(self, cls: Type["Message"]):
439+
self.cls = cls
440+
by_field = {}
441+
by_group = {}
442+
443+
for field in dataclasses.fields(cls):
444+
meta = FieldMetadata.get(field)
445+
446+
if meta.group:
447+
# This is part of a one-of group.
448+
by_field[field.name] = meta.group
449+
450+
by_group.setdefault(meta.group, set()).add(field)
451+
452+
self.oneof_group_by_field = by_field
453+
self.oneof_field_by_group = by_group
454+
455+
self.init_default_gen()
456+
self.init_cls_by_field()
457+
458+
def init_default_gen(self):
459+
default_gen = {}
460+
461+
for field in dataclasses.fields(self.cls):
462+
meta = FieldMetadata.get(field)
463+
default_gen[field.name] = self.cls._get_field_default_gen(field, meta)
464+
465+
self.default_gen = default_gen
466+
467+
def init_cls_by_field(self):
468+
field_cls = {}
469+
470+
for field in dataclasses.fields(self.cls):
471+
meta = FieldMetadata.get(field)
472+
if meta.proto_type == TYPE_MAP:
473+
assert meta.map_types
474+
kt = self.cls._cls_for(field, index=0)
475+
vt = self.cls._cls_for(field, index=1)
476+
Entry = dataclasses.make_dataclass(
477+
"Entry",
478+
[
479+
("key", kt, dataclass_field(1, meta.map_types[0])),
480+
("value", vt, dataclass_field(2, meta.map_types[1])),
481+
],
482+
bases=(Message,),
483+
)
484+
field_cls[field.name] = Entry
485+
field_cls[field.name + ".value"] = vt
486+
else:
487+
field_cls[field.name] = self.cls._cls_for(field)
488+
489+
self.cls_by_field = field_cls
490+
491+
431492
class Message(ABC):
432493
"""
433494
A protobuf message base class. Generated code will inherit from this and
@@ -445,25 +506,20 @@ def __post_init__(self) -> None:
445506

446507
# Set a default value for each field in the class after `__init__` has
447508
# already been run.
448-
group_map: Dict[str, dict] = {"fields": {}, "groups": {}}
509+
group_map: Dict[str, dataclasses.Field] = {}
449510
for field in dataclasses.fields(self):
450511
meta = FieldMetadata.get(field)
451512

452513
if meta.group:
453-
# This is part of a one-of group.
454-
group_map["fields"][field.name] = meta.group
455-
456-
if meta.group not in group_map["groups"]:
457-
group_map["groups"][meta.group] = {"current": None, "fields": set()}
458-
group_map["groups"][meta.group]["fields"].add(field)
514+
group_map.setdefault(meta.group)
459515

460516
if getattr(self, field.name) != PLACEHOLDER:
461517
# Skip anything not set to the sentinel value
462518
all_sentinel = False
463519

464520
if meta.group:
465521
# This was set, so make it the selected value of the one-of.
466-
group_map["groups"][meta.group]["current"] = field
522+
group_map[meta.group] = field
467523

468524
continue
469525

@@ -479,19 +535,33 @@ def __setattr__(self, attr: str, value: Any) -> None:
479535
# Track when a field has been set.
480536
self.__dict__["_serialized_on_wire"] = True
481537

482-
if attr in getattr(self, "_group_map", {}).get("fields", {}):
483-
group = self._group_map["fields"][attr]
484-
for field in self._group_map["groups"][group]["fields"]:
485-
if field.name == attr:
486-
self._group_map["groups"][group]["current"] = field
487-
else:
488-
super().__setattr__(
489-
field.name,
490-
self._get_field_default(field, FieldMetadata.get(field)),
491-
)
538+
if hasattr(self, "_group_map"): # __post_init__ had already run
539+
if attr in self._betterproto.oneof_group_by_field:
540+
group = self._betterproto.oneof_group_by_field[attr]
541+
for field in self._betterproto.oneof_field_by_group[group]:
542+
if field.name == attr:
543+
self._group_map[group] = field
544+
else:
545+
super().__setattr__(
546+
field.name,
547+
self._get_field_default(field, FieldMetadata.get(field)),
548+
)
492549

493550
super().__setattr__(attr, value)
494551

552+
@property
553+
def _betterproto(self):
554+
"""
555+
Lazy initialize metadata for each protobuf class.
556+
It may be initialized multiple times in a multi-threaded environment,
557+
but that won't affect the correctness.
558+
"""
559+
meta = getattr(self.__class__, "_betterproto_meta", None)
560+
if not meta:
561+
meta = ProtoClassMetadata(self.__class__)
562+
self.__class__._betterproto_meta = meta
563+
return meta
564+
495565
def __bytes__(self) -> bytes:
496566
"""
497567
Get the binary encoded Protobuf representation of this instance.
@@ -510,7 +580,7 @@ def __bytes__(self) -> bytes:
510580
# currently set in a `oneof` group, so it must be serialized even
511581
# if the value is the default zero value.
512582
selected_in_group = False
513-
if meta.group and self._group_map["groups"][meta.group]["current"] == field:
583+
if meta.group and self._group_map[meta.group] == field:
514584
selected_in_group = True
515585

516586
serialize_empty = False
@@ -562,47 +632,50 @@ def __bytes__(self) -> bytes:
562632
# For compatibility with other libraries
563633
SerializeToString = __bytes__
564634

565-
def _type_hint(self, field_name: str) -> Type:
566-
module = inspect.getmodule(self.__class__)
567-
type_hints = get_type_hints(self.__class__, vars(module))
635+
@classmethod
636+
def _type_hint(cls, field_name: str) -> Type:
637+
module = inspect.getmodule(cls)
638+
type_hints = get_type_hints(cls, vars(module))
568639
return type_hints[field_name]
569640

570-
def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type:
641+
@classmethod
642+
def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
571643
"""Get the message class for a field from the type hints."""
572-
cls = self._type_hint(field.name)
573-
if hasattr(cls, "__args__") and index >= 0:
574-
cls = cls.__args__[index]
575-
return cls
644+
field_cls = cls._type_hint(field.name)
645+
if hasattr(field_cls, "__args__") and index >= 0:
646+
field_cls = field_cls.__args__[index]
647+
return field_cls
576648

577649
def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any:
578-
t = self._type_hint(field.name)
650+
return self._betterproto.default_gen[field.name]()
651+
652+
@classmethod
653+
def _get_field_default_gen(cls, field: dataclasses.Field, meta: FieldMetadata) -> Any:
654+
t = cls._type_hint(field.name)
579655

580-
value: Any = 0
581656
if hasattr(t, "__origin__"):
582657
if t.__origin__ in (dict, Dict):
583658
# This is some kind of map (dict in Python).
584-
value = {}
659+
return dict
585660
elif t.__origin__ in (list, List):
586661
# This is some kind of list (repeated) field.
587-
value = []
662+
return list
588663
elif t.__origin__ == Union and t.__args__[1] == type(None):
589664
# This is an optional (wrapped) field. For setting the default we
590665
# really don't care what kind of field it is.
591-
value = None
666+
return type(None)
592667
else:
593-
value = t()
668+
return t
594669
elif issubclass(t, Enum):
595670
# Enums always default to zero.
596-
value = 0
671+
return int
597672
elif t == datetime:
598673
# Offsets are relative to 1970-01-01T00:00:00Z
599-
value = DATETIME_ZERO
674+
return datetime_default_gen
600675
else:
601676
# This is either a primitive scalar or another message type. Calling
602677
# it should result in its zero value.
603-
value = t()
604-
605-
return value
678+
return t
606679

607680
def _postprocess_single(
608681
self, wire_type: int, meta: FieldMetadata, field: dataclasses.Field, value: Any
@@ -627,7 +700,7 @@ def _postprocess_single(
627700
if meta.proto_type == TYPE_STRING:
628701
value = value.decode("utf-8")
629702
elif meta.proto_type == TYPE_MESSAGE:
630-
cls = self._cls_for(field)
703+
cls = self._betterproto.cls_by_field[field.name]
631704

632705
if cls == datetime:
633706
value = _Timestamp().parse(value).to_datetime()
@@ -641,20 +714,7 @@ def _postprocess_single(
641714
value = cls().parse(value)
642715
value._serialized_on_wire = True
643716
elif meta.proto_type == TYPE_MAP:
644-
# TODO: This is slow, use a cache to make it faster since each
645-
# key/value pair will recreate the class.
646-
assert meta.map_types
647-
kt = self._cls_for(field, index=0)
648-
vt = self._cls_for(field, index=1)
649-
Entry = dataclasses.make_dataclass(
650-
"Entry",
651-
[
652-
("key", kt, dataclass_field(1, meta.map_types[0])),
653-
("value", vt, dataclass_field(2, meta.map_types[1])),
654-
],
655-
bases=(Message,),
656-
)
657-
value = Entry().parse(value)
717+
value = self._betterproto.cls_by_field[field.name]().parse(value)
658718

659719
return value
660720

@@ -769,7 +829,7 @@ def to_dict(
769829
else:
770830
output[cased_name] = b64encode(v).decode("utf8")
771831
elif meta.proto_type == TYPE_ENUM:
772-
enum_values = list(self._cls_for(field)) # type: ignore
832+
enum_values = list(self._betterproto.cls_by_field[field.name]) # type: ignore
773833
if isinstance(v, list):
774834
output[cased_name] = [enum_values[e].name for e in v]
775835
else:
@@ -795,7 +855,7 @@ def from_dict(self: T, value: dict) -> T:
795855
if meta.proto_type == "message":
796856
v = getattr(self, field.name)
797857
if isinstance(v, list):
798-
cls = self._cls_for(field)
858+
cls = self._betterproto.cls_by_field[field.name]
799859
for i in range(len(value[key])):
800860
v.append(cls().from_dict(value[key][i]))
801861
elif isinstance(v, datetime):
@@ -812,7 +872,7 @@ def from_dict(self: T, value: dict) -> T:
812872
v.from_dict(value[key])
813873
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
814874
v = getattr(self, field.name)
815-
cls = self._cls_for(field, index=1)
875+
cls = self._betterproto.cls_by_field[field.name + ".value"]
816876
for k in value[key]:
817877
v[k] = cls().from_dict(value[key][k])
818878
else:
@@ -828,7 +888,7 @@ def from_dict(self: T, value: dict) -> T:
828888
else:
829889
v = b64decode(value[key])
830890
elif meta.proto_type == TYPE_ENUM:
831-
enum_cls = self._cls_for(field)
891+
enum_cls = self._betterproto.cls_by_field[field.name]
832892
if isinstance(v, list):
833893
v = [enum_cls.from_string(e) for e in v]
834894
elif isinstance(v, str):
@@ -861,7 +921,7 @@ def serialized_on_wire(message: Message) -> bool:
861921

862922
def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
863923
"""Return the name and value of a message's one-of field group."""
864-
field = message._group_map["groups"].get(group_name, {}).get("current")
924+
field = message._group_map.get(group_name)
865925
if not field:
866926
return ("", None)
867927
return (field.name, getattr(message, field.name))

0 commit comments

Comments
 (0)