Skip to content

Commit b0a36d1

Browse files
guyszguysz-nvidia
andauthored
Fix compilation of fields with name identical to their type (#294)
* Revert "Fix compilation of fields named 'bytes' or 'str' (#226)" This reverts commit deb623e. * Fix compilation of fileds with name identical to their type * Added test for field-name identical to python type Co-authored-by: Guy Szweigman <[email protected]>
1 parent a4d2d39 commit b0a36d1

File tree

4 files changed

+38
-13
lines changed

4 files changed

+38
-13
lines changed

src/betterproto/casing.py

+1-11
Original file line numberDiff line numberDiff line change
@@ -133,16 +133,6 @@ def lowercase_first(value: str) -> str:
133133
return value[0:1].lower() + value[1:]
134134

135135

136-
def is_reserved_name(value: str) -> bool:
137-
if keyword.iskeyword(value):
138-
return True
139-
140-
if value in ("bytes", "str"):
141-
return True
142-
143-
return False
144-
145-
146136
def sanitize_name(value: str) -> str:
147137
# https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles
148-
return f"{value}_" if is_reserved_name(value) else value
138+
return f"{value}_" if keyword.iskeyword(value) else value

src/betterproto/plugin/models.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
"""
3131

3232

33+
import builtins
3334
import betterproto
3435
from betterproto import which_one_of
3536
from betterproto.casing import sanitize_name
@@ -237,6 +238,7 @@ class OutputTemplate:
237238
imports: Set[str] = field(default_factory=set)
238239
datetime_imports: Set[str] = field(default_factory=set)
239240
typing_imports: Set[str] = field(default_factory=set)
241+
builtins_import: bool = False
240242
messages: List["MessageCompiler"] = field(default_factory=list)
241243
enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
242244
services: List["ServiceCompiler"] = field(default_factory=list)
@@ -268,6 +270,8 @@ def python_module_imports(self) -> Set[str]:
268270
imports = set()
269271
if any(x for x in self.messages if any(x.deprecated_fields)):
270272
imports.add("warnings")
273+
if self.builtins_import:
274+
imports.add("builtins")
271275
return imports
272276

273277

@@ -283,6 +287,7 @@ class MessageCompiler(ProtoContentBase):
283287
default_factory=list
284288
)
285289
deprecated: bool = field(default=False, init=False)
290+
builtins_types: Set[str] = field(default_factory=set)
286291

287292
def __post_init__(self) -> None:
288293
# Add message to output file
@@ -376,6 +381,8 @@ def get_field_string(self, indent: int = 4) -> str:
376381
betterproto_field_type = (
377382
f"betterproto.{self.field_type}_field({self.proto_obj.number}{field_args})"
378383
)
384+
if self.py_name in dir(builtins):
385+
self.parent.builtins_types.add(self.py_name)
379386
return f"{name}{annotations} = {betterproto_field_type}"
380387

381388
@property
@@ -408,9 +415,16 @@ def typing_imports(self) -> Set[str]:
408415
imports.add("Dict")
409416
return imports
410417

418+
@property
419+
def use_builtins(self) -> bool:
420+
return self.py_type in self.parent.builtins_types or (
421+
self.py_type == self.py_name and self.py_name in dir(builtins)
422+
)
423+
411424
def add_imports_to(self, output_file: OutputTemplate) -> None:
412425
output_file.datetime_imports.update(self.datetime_imports)
413426
output_file.typing_imports.update(self.typing_imports)
427+
output_file.builtins_import = output_file.builtins_import or self.use_builtins
414428

415429
@property
416430
def field_wraps(self) -> Optional[str]:
@@ -504,9 +518,12 @@ def py_type(self) -> str:
504518

505519
@property
506520
def annotation(self) -> str:
521+
py_type = self.py_type
522+
if self.use_builtins:
523+
py_type = f"builtins.{py_type}"
507524
if self.repeated:
508-
return f"List[{self.py_type}]"
509-
return self.py_type
525+
return f"List[{py_type}]"
526+
return py_type
510527

511528

512529
@dataclass
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"int": 26,
3+
"float": 26.0,
4+
"str": "value-for-str",
5+
"bytes": "001a",
6+
"bool": true
7+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
syntax = "proto3";
2+
3+
// Tests that messages may contain fields with names that are identical to their python types (PR #294)
4+
5+
message Test {
6+
int32 int = 1;
7+
float float = 2;
8+
string str = 3;
9+
bytes bytes = 4;
10+
bool bool = 5;
11+
}

0 commit comments

Comments
 (0)