Skip to content

Commit 8b59234

Browse files
Handle typing collisions and add validation to a files module for overlaping declarations (#582)
* Fix 'typing' import collisions. * Fix formatting. * Fix self-test issues. * Validation for modules, different typing configurations * add readme * make warning * fix format --------- Co-authored-by: Scott Hendricks <[email protected]>
1 parent 7c6c627 commit 8b59234

13 files changed

+887
-157
lines changed

README.md

+43
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,50 @@ swap the dataclass implementation from the builtin python dataclass to the
391391
pydantic dataclass. You must have pydantic as a dependency in your project for
392392
this to work.
393393

394+
## Configuration typing imports
394395

396+
By default typing types will be imported directly from typing. This sometimes can lead to issues in generation if types that are being generated conflict with the name. In this case you can configure the way types are imported from 3 different options:
397+
398+
### Direct
399+
```
400+
protoc -I . --python_betterproto_opt=typing.direct --python_betterproto_out=lib example.proto
401+
```
402+
this configuration is the default, and will import types as follows:
403+
```
404+
from typing import (
405+
List,
406+
Optional,
407+
Union
408+
)
409+
...
410+
value: List[str] = []
411+
value2: Optional[str] = None
412+
value3: Union[str, int] = 1
413+
```
414+
### Root
415+
```
416+
protoc -I . --python_betterproto_opt=typing.root --python_betterproto_out=lib example.proto
417+
```
418+
this configuration loads the root typing module, and then access the types off of it directly:
419+
```
420+
import typing
421+
...
422+
value: typing.List[str] = []
423+
value2: typing.Optional[str] = None
424+
value3: typing.Union[str, int] = 1
425+
```
426+
427+
### 310
428+
```
429+
protoc -I . --python_betterproto_opt=typing.310 --python_betterproto_out=lib example.proto
430+
```
431+
this configuration avoid loading typing all together if possible and uses the python 3.10 pattern:
432+
```
433+
...
434+
value: list[str] = []
435+
value2: str | None = None
436+
value3: str | int = 1
437+
```
395438

396439
## Development
397440

src/betterproto/compile/importing.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def get_type_reference(
4747
package: str,
4848
imports: set,
4949
source_type: str,
50+
typing_compiler: "TypingCompiler",
5051
unwrap: bool = True,
5152
pydantic: bool = False,
5253
) -> str:
@@ -57,7 +58,7 @@ def get_type_reference(
5758
if unwrap:
5859
if source_type in WRAPPER_TYPES:
5960
wrapped_type = type(WRAPPER_TYPES[source_type]().value)
60-
return f"Optional[{wrapped_type.__name__}]"
61+
return typing_compiler.optional(wrapped_type.__name__)
6162

6263
if source_type == ".google.protobuf.Duration":
6364
return "timedelta"

src/betterproto/plugin/compiler.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import os.path
2+
import sys
3+
4+
from .module_validation import ModuleValidator
25

36

47
try:
@@ -30,9 +33,12 @@ def outputfile_compiler(output_file: OutputTemplate) -> str:
3033
lstrip_blocks=True,
3134
loader=jinja2.FileSystemLoader(templates_folder),
3235
)
33-
template = env.get_template("template.py.j2")
36+
# Load the body first so we have a compleate list of imports needed.
37+
body_template = env.get_template("template.py.j2")
38+
header_template = env.get_template("header.py.j2")
3439

35-
code = template.render(output_file=output_file)
40+
code = body_template.render(output_file=output_file)
41+
code = header_template.render(output_file=output_file) + code
3642
code = isort.api.sort_code_string(
3743
code=code,
3844
show_diff=False,
@@ -44,7 +50,18 @@ def outputfile_compiler(output_file: OutputTemplate) -> str:
4450
force_grid_wrap=2,
4551
known_third_party=["grpclib", "betterproto"],
4652
)
47-
return black.format_str(
53+
code = black.format_str(
4854
src_contents=code,
4955
mode=black.Mode(),
5056
)
57+
58+
# Validate the generated code.
59+
validator = ModuleValidator(iter(code.splitlines()))
60+
if not validator.validate():
61+
message_builder = ["[WARNING]: Generated code has collisions in the module:"]
62+
for collision, lines in validator.collisions.items():
63+
message_builder.append(f' "{collision}" on lines:')
64+
for num, line in lines:
65+
message_builder.append(f" {num}:{line}")
66+
print("\n".join(message_builder), file=sys.stderr)
67+
return code

src/betterproto/plugin/models.py

+20-44
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,8 @@
2929
reference to `A` to `B`'s `fields` attribute.
3030
"""
3131

32-
3332
import builtins
3433
import re
35-
import textwrap
3634
from dataclasses import (
3735
dataclass,
3836
field,
@@ -49,12 +47,6 @@
4947
)
5048

5149
import betterproto
52-
from betterproto import which_one_of
53-
from betterproto.casing import sanitize_name
54-
from betterproto.compile.importing import (
55-
get_type_reference,
56-
parse_source_type_name,
57-
)
5850
from betterproto.compile.naming import (
5951
pythonize_class_name,
6052
pythonize_field_name,
@@ -72,6 +64,7 @@
7264
)
7365
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest
7466

67+
from .. import which_one_of
7568
from ..compile.importing import (
7669
get_type_reference,
7770
parse_source_type_name,
@@ -82,6 +75,10 @@
8275
pythonize_field_name,
8376
pythonize_method_name,
8477
)
78+
from .typing_compiler import (
79+
DirectImportTypingCompiler,
80+
TypingCompiler,
81+
)
8582

8683

8784
# Create a unique placeholder to deal with
@@ -173,6 +170,7 @@ class ProtoContentBase:
173170
"""Methods common to MessageCompiler, ServiceCompiler and ServiceMethodCompiler."""
174171

175172
source_file: FileDescriptorProto
173+
typing_compiler: TypingCompiler
176174
path: List[int]
177175
comment_indent: int = 4
178176
parent: Union["betterproto.Message", "OutputTemplate"]
@@ -242,7 +240,6 @@ class OutputTemplate:
242240
input_files: List[str] = field(default_factory=list)
243241
imports: Set[str] = field(default_factory=set)
244242
datetime_imports: Set[str] = field(default_factory=set)
245-
typing_imports: Set[str] = field(default_factory=set)
246243
pydantic_imports: Set[str] = field(default_factory=set)
247244
builtins_import: bool = False
248245
messages: List["MessageCompiler"] = field(default_factory=list)
@@ -251,6 +248,7 @@ class OutputTemplate:
251248
imports_type_checking_only: Set[str] = field(default_factory=set)
252249
pydantic_dataclasses: bool = False
253250
output: bool = True
251+
typing_compiler: TypingCompiler = field(default_factory=DirectImportTypingCompiler)
254252

255253
@property
256254
def package(self) -> str:
@@ -289,6 +287,7 @@ class MessageCompiler(ProtoContentBase):
289287
"""Representation of a protobuf message."""
290288

291289
source_file: FileDescriptorProto
290+
typing_compiler: TypingCompiler
292291
parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER
293292
proto_obj: DescriptorProto = PLACEHOLDER
294293
path: List[int] = PLACEHOLDER
@@ -319,7 +318,7 @@ def py_name(self) -> str:
319318
@property
320319
def annotation(self) -> str:
321320
if self.repeated:
322-
return f"List[{self.py_name}]"
321+
return self.typing_compiler.list(self.py_name)
323322
return self.py_name
324323

325324
@property
@@ -434,18 +433,6 @@ def datetime_imports(self) -> Set[str]:
434433
imports.add("datetime")
435434
return imports
436435

437-
@property
438-
def typing_imports(self) -> Set[str]:
439-
imports = set()
440-
annotation = self.annotation
441-
if "Optional[" in annotation:
442-
imports.add("Optional")
443-
if "List[" in annotation:
444-
imports.add("List")
445-
if "Dict[" in annotation:
446-
imports.add("Dict")
447-
return imports
448-
449436
@property
450437
def pydantic_imports(self) -> Set[str]:
451438
return set()
@@ -458,7 +445,6 @@ def use_builtins(self) -> bool:
458445

459446
def add_imports_to(self, output_file: OutputTemplate) -> None:
460447
output_file.datetime_imports.update(self.datetime_imports)
461-
output_file.typing_imports.update(self.typing_imports)
462448
output_file.pydantic_imports.update(self.pydantic_imports)
463449
output_file.builtins_import = output_file.builtins_import or self.use_builtins
464450

@@ -488,7 +474,9 @@ def optional(self) -> bool:
488474
@property
489475
def mutable(self) -> bool:
490476
"""True if the field is a mutable type, otherwise False."""
491-
return self.annotation.startswith(("List[", "Dict["))
477+
return self.annotation.startswith(
478+
("typing.List[", "typing.Dict[", "dict[", "list[", "Dict[", "List[")
479+
)
492480

493481
@property
494482
def field_type(self) -> str:
@@ -562,6 +550,7 @@ def py_type(self) -> str:
562550
package=self.output_file.package,
563551
imports=self.output_file.imports,
564552
source_type=self.proto_obj.type_name,
553+
typing_compiler=self.typing_compiler,
565554
pydantic=self.output_file.pydantic_dataclasses,
566555
)
567556
else:
@@ -573,9 +562,9 @@ def annotation(self) -> str:
573562
if self.use_builtins:
574563
py_type = f"builtins.{py_type}"
575564
if self.repeated:
576-
return f"List[{py_type}]"
565+
return self.typing_compiler.list(py_type)
577566
if self.optional:
578-
return f"Optional[{py_type}]"
567+
return self.typing_compiler.optional(py_type)
579568
return py_type
580569

581570

@@ -623,11 +612,13 @@ def __post_init__(self) -> None:
623612
source_file=self.source_file,
624613
parent=self,
625614
proto_obj=nested.field[0], # key
615+
typing_compiler=self.typing_compiler,
626616
).py_type
627617
self.py_v_type = FieldCompiler(
628618
source_file=self.source_file,
629619
parent=self,
630620
proto_obj=nested.field[1], # value
621+
typing_compiler=self.typing_compiler,
631622
).py_type
632623

633624
# Get proto types
@@ -645,7 +636,7 @@ def field_type(self) -> str:
645636

646637
@property
647638
def annotation(self) -> str:
648-
return f"Dict[{self.py_k_type}, {self.py_v_type}]"
639+
return self.typing_compiler.dict(self.py_k_type, self.py_v_type)
649640

650641
@property
651642
def repeated(self) -> bool:
@@ -702,7 +693,6 @@ class ServiceCompiler(ProtoContentBase):
702693
def __post_init__(self) -> None:
703694
# Add service to output file
704695
self.output_file.services.append(self)
705-
self.output_file.typing_imports.add("Dict")
706696
super().__post_init__() # check for unset fields
707697

708698
@property
@@ -725,22 +715,6 @@ def __post_init__(self) -> None:
725715
# Add method to service
726716
self.parent.methods.append(self)
727717

728-
# Check for imports
729-
if "Optional" in self.py_output_message_type:
730-
self.output_file.typing_imports.add("Optional")
731-
732-
# Check for Async imports
733-
if self.client_streaming:
734-
self.output_file.typing_imports.add("AsyncIterable")
735-
self.output_file.typing_imports.add("Iterable")
736-
self.output_file.typing_imports.add("Union")
737-
738-
# Required by both client and server
739-
if self.client_streaming or self.server_streaming:
740-
self.output_file.typing_imports.add("AsyncIterator")
741-
742-
# add imports required for request arguments timeout, deadline and metadata
743-
self.output_file.typing_imports.add("Optional")
744718
self.output_file.imports_type_checking_only.add("import grpclib.server")
745719
self.output_file.imports_type_checking_only.add(
746720
"from betterproto.grpc.grpclib_client import MetadataLike"
@@ -806,6 +780,7 @@ def py_input_message_type(self) -> str:
806780
package=self.output_file.package,
807781
imports=self.output_file.imports,
808782
source_type=self.proto_obj.input_type,
783+
typing_compiler=self.output_file.typing_compiler,
809784
unwrap=False,
810785
pydantic=self.output_file.pydantic_dataclasses,
811786
).strip('"')
@@ -835,6 +810,7 @@ def py_output_message_type(self) -> str:
835810
package=self.output_file.package,
836811
imports=self.output_file.imports,
837812
source_type=self.proto_obj.output_type,
813+
typing_compiler=self.output_file.typing_compiler,
838814
unwrap=False,
839815
pydantic=self.output_file.pydantic_dataclasses,
840816
).strip('"')

0 commit comments

Comments
 (0)