29
29
reference to `A` to `B`'s `fields` attribute.
30
30
"""
31
31
32
-
33
32
import builtins
34
33
import re
35
- import textwrap
36
34
from dataclasses import (
37
35
dataclass ,
38
36
field ,
49
47
)
50
48
51
49
import betterproto
52
- from betterproto import which_one_of
53
- from betterproto .casing import sanitize_name
54
- from betterproto .compile .importing import (
55
- get_type_reference ,
56
- parse_source_type_name ,
57
- )
58
50
from betterproto .compile .naming import (
59
51
pythonize_class_name ,
60
52
pythonize_field_name ,
72
64
)
73
65
from betterproto .lib .google .protobuf .compiler import CodeGeneratorRequest
74
66
67
+ from .. import which_one_of
75
68
from ..compile .importing import (
76
69
get_type_reference ,
77
70
parse_source_type_name ,
82
75
pythonize_field_name ,
83
76
pythonize_method_name ,
84
77
)
78
+ from .typing_compiler import (
79
+ DirectImportTypingCompiler ,
80
+ TypingCompiler ,
81
+ )
85
82
86
83
87
84
# Create a unique placeholder to deal with
@@ -173,6 +170,7 @@ class ProtoContentBase:
173
170
"""Methods common to MessageCompiler, ServiceCompiler and ServiceMethodCompiler."""
174
171
175
172
source_file : FileDescriptorProto
173
+ typing_compiler : TypingCompiler
176
174
path : List [int ]
177
175
comment_indent : int = 4
178
176
parent : Union ["betterproto.Message" , "OutputTemplate" ]
@@ -242,7 +240,6 @@ class OutputTemplate:
242
240
input_files : List [str ] = field (default_factory = list )
243
241
imports : Set [str ] = field (default_factory = set )
244
242
datetime_imports : Set [str ] = field (default_factory = set )
245
- typing_imports : Set [str ] = field (default_factory = set )
246
243
pydantic_imports : Set [str ] = field (default_factory = set )
247
244
builtins_import : bool = False
248
245
messages : List ["MessageCompiler" ] = field (default_factory = list )
@@ -251,6 +248,7 @@ class OutputTemplate:
251
248
imports_type_checking_only : Set [str ] = field (default_factory = set )
252
249
pydantic_dataclasses : bool = False
253
250
output : bool = True
251
+ typing_compiler : TypingCompiler = field (default_factory = DirectImportTypingCompiler )
254
252
255
253
@property
256
254
def package (self ) -> str :
@@ -289,6 +287,7 @@ class MessageCompiler(ProtoContentBase):
289
287
"""Representation of a protobuf message."""
290
288
291
289
source_file : FileDescriptorProto
290
+ typing_compiler : TypingCompiler
292
291
parent : Union ["MessageCompiler" , OutputTemplate ] = PLACEHOLDER
293
292
proto_obj : DescriptorProto = PLACEHOLDER
294
293
path : List [int ] = PLACEHOLDER
@@ -319,7 +318,7 @@ def py_name(self) -> str:
319
318
@property
320
319
def annotation (self ) -> str :
321
320
if self .repeated :
322
- return f"List[ { self .py_name } ]"
321
+ return self .typing_compiler . list ( self . py_name )
323
322
return self .py_name
324
323
325
324
@property
@@ -434,18 +433,6 @@ def datetime_imports(self) -> Set[str]:
434
433
imports .add ("datetime" )
435
434
return imports
436
435
437
- @property
438
- def typing_imports (self ) -> Set [str ]:
439
- imports = set ()
440
- annotation = self .annotation
441
- if "Optional[" in annotation :
442
- imports .add ("Optional" )
443
- if "List[" in annotation :
444
- imports .add ("List" )
445
- if "Dict[" in annotation :
446
- imports .add ("Dict" )
447
- return imports
448
-
449
436
@property
450
437
def pydantic_imports (self ) -> Set [str ]:
451
438
return set ()
@@ -458,7 +445,6 @@ def use_builtins(self) -> bool:
458
445
459
446
def add_imports_to (self , output_file : OutputTemplate ) -> None :
460
447
output_file .datetime_imports .update (self .datetime_imports )
461
- output_file .typing_imports .update (self .typing_imports )
462
448
output_file .pydantic_imports .update (self .pydantic_imports )
463
449
output_file .builtins_import = output_file .builtins_import or self .use_builtins
464
450
@@ -488,7 +474,9 @@ def optional(self) -> bool:
488
474
@property
489
475
def mutable (self ) -> bool :
490
476
"""True if the field is a mutable type, otherwise False."""
491
- return self .annotation .startswith (("List[" , "Dict[" ))
477
+ return self .annotation .startswith (
478
+ ("typing.List[" , "typing.Dict[" , "dict[" , "list[" , "Dict[" , "List[" )
479
+ )
492
480
493
481
@property
494
482
def field_type (self ) -> str :
@@ -562,6 +550,7 @@ def py_type(self) -> str:
562
550
package = self .output_file .package ,
563
551
imports = self .output_file .imports ,
564
552
source_type = self .proto_obj .type_name ,
553
+ typing_compiler = self .typing_compiler ,
565
554
pydantic = self .output_file .pydantic_dataclasses ,
566
555
)
567
556
else :
@@ -573,9 +562,9 @@ def annotation(self) -> str:
573
562
if self .use_builtins :
574
563
py_type = f"builtins.{ py_type } "
575
564
if self .repeated :
576
- return f"List[ { py_type } ]"
565
+ return self . typing_compiler . list ( py_type )
577
566
if self .optional :
578
- return f"Optional[ { py_type } ]"
567
+ return self . typing_compiler . optional ( py_type )
579
568
return py_type
580
569
581
570
@@ -623,11 +612,13 @@ def __post_init__(self) -> None:
623
612
source_file = self .source_file ,
624
613
parent = self ,
625
614
proto_obj = nested .field [0 ], # key
615
+ typing_compiler = self .typing_compiler ,
626
616
).py_type
627
617
self .py_v_type = FieldCompiler (
628
618
source_file = self .source_file ,
629
619
parent = self ,
630
620
proto_obj = nested .field [1 ], # value
621
+ typing_compiler = self .typing_compiler ,
631
622
).py_type
632
623
633
624
# Get proto types
@@ -645,7 +636,7 @@ def field_type(self) -> str:
645
636
646
637
@property
647
638
def annotation (self ) -> str :
648
- return f"Dict[ { self .py_k_type } , { self .py_v_type } ]"
639
+ return self .typing_compiler . dict ( self . py_k_type , self .py_v_type )
649
640
650
641
@property
651
642
def repeated (self ) -> bool :
@@ -702,7 +693,6 @@ class ServiceCompiler(ProtoContentBase):
702
693
def __post_init__ (self ) -> None :
703
694
# Add service to output file
704
695
self .output_file .services .append (self )
705
- self .output_file .typing_imports .add ("Dict" )
706
696
super ().__post_init__ () # check for unset fields
707
697
708
698
@property
@@ -725,22 +715,6 @@ def __post_init__(self) -> None:
725
715
# Add method to service
726
716
self .parent .methods .append (self )
727
717
728
- # Check for imports
729
- if "Optional" in self .py_output_message_type :
730
- self .output_file .typing_imports .add ("Optional" )
731
-
732
- # Check for Async imports
733
- if self .client_streaming :
734
- self .output_file .typing_imports .add ("AsyncIterable" )
735
- self .output_file .typing_imports .add ("Iterable" )
736
- self .output_file .typing_imports .add ("Union" )
737
-
738
- # Required by both client and server
739
- if self .client_streaming or self .server_streaming :
740
- self .output_file .typing_imports .add ("AsyncIterator" )
741
-
742
- # add imports required for request arguments timeout, deadline and metadata
743
- self .output_file .typing_imports .add ("Optional" )
744
718
self .output_file .imports_type_checking_only .add ("import grpclib.server" )
745
719
self .output_file .imports_type_checking_only .add (
746
720
"from betterproto.grpc.grpclib_client import MetadataLike"
@@ -806,6 +780,7 @@ def py_input_message_type(self) -> str:
806
780
package = self .output_file .package ,
807
781
imports = self .output_file .imports ,
808
782
source_type = self .proto_obj .input_type ,
783
+ typing_compiler = self .output_file .typing_compiler ,
809
784
unwrap = False ,
810
785
pydantic = self .output_file .pydantic_dataclasses ,
811
786
).strip ('"' )
@@ -835,6 +810,7 @@ def py_output_message_type(self) -> str:
835
810
package = self .output_file .package ,
836
811
imports = self .output_file .imports ,
837
812
source_type = self .proto_obj .output_type ,
813
+ typing_compiler = self .output_file .typing_compiler ,
838
814
unwrap = False ,
839
815
pydantic = self .output_file .pydantic_dataclasses ,
840
816
).strip ('"' )
0 commit comments