120
120
121
121
122
122
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
123
- DATETIME_ZERO = datetime (1970 , 1 , 1 , tzinfo = timezone .utc )
123
+ def datetime_default_gen ():
124
+ return datetime (1970 , 1 , 1 , tzinfo = timezone .utc )
125
+
126
+
127
+ DATETIME_ZERO = datetime_default_gen ()
124
128
125
129
126
130
class Casing (enum .Enum ):
@@ -428,6 +432,63 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
428
432
T = TypeVar ("T" , bound = "Message" )
429
433
430
434
435
+ class ProtoClassMetadata :
436
+ cls : Type ["Message" ]
437
+
438
+ def __init__ (self , cls : Type ["Message" ]):
439
+ self .cls = cls
440
+ by_field = {}
441
+ by_group = {}
442
+
443
+ for field in dataclasses .fields (cls ):
444
+ meta = FieldMetadata .get (field )
445
+
446
+ if meta .group :
447
+ # This is part of a one-of group.
448
+ by_field [field .name ] = meta .group
449
+
450
+ by_group .setdefault (meta .group , set ()).add (field )
451
+
452
+ self .oneof_group_by_field = by_field
453
+ self .oneof_field_by_group = by_group
454
+
455
+ self .init_default_gen ()
456
+ self .init_cls_by_field ()
457
+
458
+ def init_default_gen (self ):
459
+ default_gen = {}
460
+
461
+ for field in dataclasses .fields (self .cls ):
462
+ meta = FieldMetadata .get (field )
463
+ default_gen [field .name ] = self .cls ._get_field_default_gen (field , meta )
464
+
465
+ self .default_gen = default_gen
466
+
467
+ def init_cls_by_field (self ):
468
+ field_cls = {}
469
+
470
+ for field in dataclasses .fields (self .cls ):
471
+ meta = FieldMetadata .get (field )
472
+ if meta .proto_type == TYPE_MAP :
473
+ assert meta .map_types
474
+ kt = self .cls ._cls_for (field , index = 0 )
475
+ vt = self .cls ._cls_for (field , index = 1 )
476
+ Entry = dataclasses .make_dataclass (
477
+ "Entry" ,
478
+ [
479
+ ("key" , kt , dataclass_field (1 , meta .map_types [0 ])),
480
+ ("value" , vt , dataclass_field (2 , meta .map_types [1 ])),
481
+ ],
482
+ bases = (Message ,),
483
+ )
484
+ field_cls [field .name ] = Entry
485
+ field_cls [field .name + ".value" ] = vt
486
+ else :
487
+ field_cls [field .name ] = self .cls ._cls_for (field )
488
+
489
+ self .cls_by_field = field_cls
490
+
491
+
431
492
class Message (ABC ):
432
493
"""
433
494
A protobuf message base class. Generated code will inherit from this and
@@ -445,25 +506,20 @@ def __post_init__(self) -> None:
445
506
446
507
# Set a default value for each field in the class after `__init__` has
447
508
# already been run.
448
- group_map : Dict [str , dict ] = {"fields" : {}, "groups" : {} }
509
+ group_map : Dict [str , dataclasses . Field ] = {}
449
510
for field in dataclasses .fields (self ):
450
511
meta = FieldMetadata .get (field )
451
512
452
513
if meta .group :
453
- # This is part of a one-of group.
454
- group_map ["fields" ][field .name ] = meta .group
455
-
456
- if meta .group not in group_map ["groups" ]:
457
- group_map ["groups" ][meta .group ] = {"current" : None , "fields" : set ()}
458
- group_map ["groups" ][meta .group ]["fields" ].add (field )
514
+ group_map .setdefault (meta .group )
459
515
460
516
if getattr (self , field .name ) != PLACEHOLDER :
461
517
# Skip anything not set to the sentinel value
462
518
all_sentinel = False
463
519
464
520
if meta .group :
465
521
# This was set, so make it the selected value of the one-of.
466
- group_map ["groups" ][ meta .group ][ "current" ] = field
522
+ group_map [meta .group ] = field
467
523
468
524
continue
469
525
@@ -479,19 +535,33 @@ def __setattr__(self, attr: str, value: Any) -> None:
479
535
# Track when a field has been set.
480
536
self .__dict__ ["_serialized_on_wire" ] = True
481
537
482
- if attr in getattr (self , "_group_map" , {}).get ("fields" , {}):
483
- group = self ._group_map ["fields" ][attr ]
484
- for field in self ._group_map ["groups" ][group ]["fields" ]:
485
- if field .name == attr :
486
- self ._group_map ["groups" ][group ]["current" ] = field
487
- else :
488
- super ().__setattr__ (
489
- field .name ,
490
- self ._get_field_default (field , FieldMetadata .get (field )),
491
- )
538
+ if hasattr (self , "_group_map" ): # __post_init__ had already run
539
+ if attr in self ._betterproto .oneof_group_by_field :
540
+ group = self ._betterproto .oneof_group_by_field [attr ]
541
+ for field in self ._betterproto .oneof_field_by_group [group ]:
542
+ if field .name == attr :
543
+ self ._group_map [group ] = field
544
+ else :
545
+ super ().__setattr__ (
546
+ field .name ,
547
+ self ._get_field_default (field , FieldMetadata .get (field )),
548
+ )
492
549
493
550
super ().__setattr__ (attr , value )
494
551
552
+ @property
553
+ def _betterproto (self ):
554
+ """
555
+ Lazy initialize metadata for each protobuf class.
556
+ It may be initialized multiple times in a multi-threaded environment,
557
+ but that won't affect the correctness.
558
+ """
559
+ meta = getattr (self .__class__ , "_betterproto_meta" , None )
560
+ if not meta :
561
+ meta = ProtoClassMetadata (self .__class__ )
562
+ self .__class__ ._betterproto_meta = meta
563
+ return meta
564
+
495
565
def __bytes__ (self ) -> bytes :
496
566
"""
497
567
Get the binary encoded Protobuf representation of this instance.
@@ -510,7 +580,7 @@ def __bytes__(self) -> bytes:
510
580
# currently set in a `oneof` group, so it must be serialized even
511
581
# if the value is the default zero value.
512
582
selected_in_group = False
513
- if meta .group and self ._group_map ["groups" ][ meta .group ][ "current" ] == field :
583
+ if meta .group and self ._group_map [meta .group ] == field :
514
584
selected_in_group = True
515
585
516
586
serialize_empty = False
@@ -562,47 +632,50 @@ def __bytes__(self) -> bytes:
562
632
# For compatibility with other libraries
563
633
SerializeToString = __bytes__
564
634
565
- def _type_hint (self , field_name : str ) -> Type :
566
- module = inspect .getmodule (self .__class__ )
567
- type_hints = get_type_hints (self .__class__ , vars (module ))
635
+ @classmethod
636
+ def _type_hint (cls , field_name : str ) -> Type :
637
+ module = inspect .getmodule (cls )
638
+ type_hints = get_type_hints (cls , vars (module ))
568
639
return type_hints [field_name ]
569
640
570
- def _cls_for (self , field : dataclasses .Field , index : int = 0 ) -> Type :
641
+ @classmethod
642
+ def _cls_for (cls , field : dataclasses .Field , index : int = 0 ) -> Type :
571
643
"""Get the message class for a field from the type hints."""
572
- cls = self ._type_hint (field .name )
573
- if hasattr (cls , "__args__" ) and index >= 0 :
574
- cls = cls .__args__ [index ]
575
- return cls
644
+ field_cls = cls ._type_hint (field .name )
645
+ if hasattr (field_cls , "__args__" ) and index >= 0 :
646
+ field_cls = field_cls .__args__ [index ]
647
+ return field_cls
576
648
577
649
def _get_field_default (self , field : dataclasses .Field , meta : FieldMetadata ) -> Any :
578
- t = self ._type_hint (field .name )
650
+ return self ._betterproto .default_gen [field .name ]()
651
+
652
+ @classmethod
653
+ def _get_field_default_gen (cls , field : dataclasses .Field , meta : FieldMetadata ) -> Any :
654
+ t = cls ._type_hint (field .name )
579
655
580
- value : Any = 0
581
656
if hasattr (t , "__origin__" ):
582
657
if t .__origin__ in (dict , Dict ):
583
658
# This is some kind of map (dict in Python).
584
- value = {}
659
+ return dict
585
660
elif t .__origin__ in (list , List ):
586
661
# This is some kind of list (repeated) field.
587
- value = []
662
+ return list
588
663
elif t .__origin__ == Union and t .__args__ [1 ] == type (None ):
589
664
# This is an optional (wrapped) field. For setting the default we
590
665
# really don't care what kind of field it is.
591
- value = None
666
+ return type ( None )
592
667
else :
593
- value = t ()
668
+ return t
594
669
elif issubclass (t , Enum ):
595
670
# Enums always default to zero.
596
- value = 0
671
+ return int
597
672
elif t == datetime :
598
673
# Offsets are relative to 1970-01-01T00:00:00Z
599
- value = DATETIME_ZERO
674
+ return datetime_default_gen
600
675
else :
601
676
# This is either a primitive scalar or another message type. Calling
602
677
# it should result in its zero value.
603
- value = t ()
604
-
605
- return value
678
+ return t
606
679
607
680
def _postprocess_single (
608
681
self , wire_type : int , meta : FieldMetadata , field : dataclasses .Field , value : Any
@@ -627,7 +700,7 @@ def _postprocess_single(
627
700
if meta .proto_type == TYPE_STRING :
628
701
value = value .decode ("utf-8" )
629
702
elif meta .proto_type == TYPE_MESSAGE :
630
- cls = self ._cls_for ( field )
703
+ cls = self ._betterproto . cls_by_field [ field . name ]
631
704
632
705
if cls == datetime :
633
706
value = _Timestamp ().parse (value ).to_datetime ()
@@ -641,20 +714,7 @@ def _postprocess_single(
641
714
value = cls ().parse (value )
642
715
value ._serialized_on_wire = True
643
716
elif meta .proto_type == TYPE_MAP :
644
- # TODO: This is slow, use a cache to make it faster since each
645
- # key/value pair will recreate the class.
646
- assert meta .map_types
647
- kt = self ._cls_for (field , index = 0 )
648
- vt = self ._cls_for (field , index = 1 )
649
- Entry = dataclasses .make_dataclass (
650
- "Entry" ,
651
- [
652
- ("key" , kt , dataclass_field (1 , meta .map_types [0 ])),
653
- ("value" , vt , dataclass_field (2 , meta .map_types [1 ])),
654
- ],
655
- bases = (Message ,),
656
- )
657
- value = Entry ().parse (value )
717
+ value = self ._betterproto .cls_by_field [field .name ]().parse (value )
658
718
659
719
return value
660
720
@@ -769,7 +829,7 @@ def to_dict(
769
829
else :
770
830
output [cased_name ] = b64encode (v ).decode ("utf8" )
771
831
elif meta .proto_type == TYPE_ENUM :
772
- enum_values = list (self ._cls_for ( field ) ) # type: ignore
832
+ enum_values = list (self ._betterproto . cls_by_field [ field . name ] ) # type: ignore
773
833
if isinstance (v , list ):
774
834
output [cased_name ] = [enum_values [e ].name for e in v ]
775
835
else :
@@ -795,7 +855,7 @@ def from_dict(self: T, value: dict) -> T:
795
855
if meta .proto_type == "message" :
796
856
v = getattr (self , field .name )
797
857
if isinstance (v , list ):
798
- cls = self ._cls_for ( field )
858
+ cls = self ._betterproto . cls_by_field [ field . name ]
799
859
for i in range (len (value [key ])):
800
860
v .append (cls ().from_dict (value [key ][i ]))
801
861
elif isinstance (v , datetime ):
@@ -812,7 +872,7 @@ def from_dict(self: T, value: dict) -> T:
812
872
v .from_dict (value [key ])
813
873
elif meta .map_types and meta .map_types [1 ] == TYPE_MESSAGE :
814
874
v = getattr (self , field .name )
815
- cls = self ._cls_for ( field , index = 1 )
875
+ cls = self ._betterproto . cls_by_field [ field . name + ".value" ]
816
876
for k in value [key ]:
817
877
v [k ] = cls ().from_dict (value [key ][k ])
818
878
else :
@@ -828,7 +888,7 @@ def from_dict(self: T, value: dict) -> T:
828
888
else :
829
889
v = b64decode (value [key ])
830
890
elif meta .proto_type == TYPE_ENUM :
831
- enum_cls = self ._cls_for ( field )
891
+ enum_cls = self ._betterproto . cls_by_field [ field . name ]
832
892
if isinstance (v , list ):
833
893
v = [enum_cls .from_string (e ) for e in v ]
834
894
elif isinstance (v , str ):
@@ -861,7 +921,7 @@ def serialized_on_wire(message: Message) -> bool:
861
921
862
922
def which_one_of (message : Message , group_name : str ) -> Tuple [str , Any ]:
863
923
"""Return the name and value of a message's one-of field group."""
864
- field = message ._group_map [ "groups" ] .get (group_name , {}). get ( "current" )
924
+ field = message ._group_map .get (group_name )
865
925
if not field :
866
926
return ("" , None )
867
927
return (field .name , getattr (message , field .name ))
0 commit comments