Skip to content

Commit b6008af

Browse files
committed
Figuring out nested enum collissions - #212
1 parent 02aa4e8 commit b6008af

File tree

4 files changed

+129
-27
lines changed

4 files changed

+129
-27
lines changed

src/betterproto/plugin/models.py

+46-4
Original file line numberDiff line numberDiff line change
@@ -248,13 +248,55 @@ class OutputTemplate:
248248
typing_imports: Set[str] = field(default_factory=set)
249249
pydantic_imports: Set[str] = field(default_factory=set)
250250
builtins_import: bool = False
251-
messages: List["MessageCompiler"] = field(default_factory=list)
252-
enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
251+
messages: Dict[str, "MessageCompiler"] = field(default_factory=dict)
252+
enums: Dict[str, "EnumDefinitionCompiler"] = field(default_factory=dict)
253253
services: List["ServiceCompiler"] = field(default_factory=list)
254254
imports_type_checking_only: Set[str] = field(default_factory=set)
255255
pydantic_dataclasses: bool = False
256256
output: bool = True
257257

258+
def structure(self):
259+
def recursive_structure(descriptor_proto):
260+
branch = {}
261+
for msg in descriptor_proto.nested_type:
262+
branch[msg.name] = {"children": recursive_structure(msg), "kind": "msg"}
263+
for enum_ in descriptor_proto.enum_type:
264+
branch[enum_.name] = {"kind": "enum"}
265+
return branch
266+
267+
tree = {}
268+
for msg in self.package_proto_obj.message_type:
269+
tree[msg.name] = {"children": recursive_structure(msg), "kind": "msg"}
270+
for enum_ in self.package_proto_obj.enum_type:
271+
tree[enum_.name] = {"kind": "enum"}
272+
273+
return {"root": tree}
274+
275+
def structure_with_obj(self):
276+
def recursive_structure(descriptor_proto):
277+
branch = {}
278+
for msg in descriptor_proto.nested_type:
279+
branch[msg.name] = {
280+
"children": recursive_structure(msg),
281+
"kind": "msg",
282+
"obj": self.messages[msg.name],
283+
}
284+
for enum_ in descriptor_proto.enum_type:
285+
branch[enum_.name] = {"kind": "enum", "obj": self.enums[enum_.name]}
286+
return branch
287+
288+
tree = {}
289+
for msg in self.package_proto_obj.message_type:
290+
tree[msg.name] = {
291+
"children": recursive_structure(msg),
292+
"kind": "msg",
293+
"obj": self.messages[msg.name],
294+
}
295+
for enum_ in self.package_proto_obj.enum_type:
296+
tree[enum_.name] = {"kind": "enum", "obj": self.enums[enum_.name]}
297+
298+
return {"root": tree}
299+
258300
@property
259301
def package(self) -> str:
260302
"""Name of input package.
@@ -305,9 +347,9 @@ def __post_init__(self) -> None:
305347
# Add message to output file
306348
if isinstance(self.parent, OutputTemplate):
307349
if isinstance(self, EnumDefinitionCompiler):
308-
self.output_file.enums.append(self)
350+
self.output_file.enums[self.proto_name] = self
309351
else:
310-
self.output_file.messages.append(self)
352+
self.output_file.messages[self.proto_name] = self
311353
self.deprecated = self.proto_obj.options.deprecated
312354
super().__post_init__()
313355

src/betterproto/templates/template.py.j2

+35-23
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,23 @@ if TYPE_CHECKING:
4848
{% endfor %}
4949
{% endif %}
5050

51-
{% if output_file.enums %}{% for enum in output_file.enums %}
51+
{#
52+
53+
54+
original markdown is missing until I finish debugging this
55+
56+
57+
#}
58+
59+
{% set tree = output_file.structure() %}
60+
"""
61+
{{ tree|tojson(indent=4) }}
62+
"""
63+
64+
{% macro render_class(tree_) %}
65+
{% for k, v in tree_.items() %}
66+
{% if v.kind == "enum" %}
67+
{% set enum = v.obj %}
5268
class {{ enum.py_name }}(betterproto.Enum):
5369
{% if enum.comment %}
5470
{{ enum.comment }}
@@ -61,17 +77,17 @@ class {{ enum.py_name }}(betterproto.Enum):
6177

6278
{% endif %}
6379
{% endfor %}
64-
65-
66-
{% endfor %}
67-
{% endif %}
68-
{% for message in output_file.messages %}
80+
{% else %}
81+
{% set message = v.obj %}
6982
@dataclass(eq=False, repr=False)
7083
class {{ message.py_name }}(betterproto.Message):
7184
{% if message.comment %}
7285
{{ message.comment }}
7386

7487
{% endif %}
88+
89+
{{ render_class(v.children)|indent }}
90+
7591
{% for field in message.fields %}
7692
{{ field.get_field_string() }}
7793
{% if field.comment %}
@@ -83,25 +99,21 @@ class {{ message.py_name }}(betterproto.Message):
8399
pass
84100
{% endif %}
85101

86-
{% if message.deprecated or message.has_deprecated_fields %}
87-
def __post_init__(self) -> None:
88-
{% if message.deprecated %}
89-
warnings.warn("{{ message.py_name }} is deprecated", DeprecationWarning)
90-
{% endif %}
91-
super().__post_init__()
92-
{% for field in message.deprecated_fields %}
93-
if self.is_set("{{ field }}"):
94-
warnings.warn("{{ message.py_name }}.{{ field }} is deprecated", DeprecationWarning)
95-
{% endfor %}
96-
{% endif %}
102+
{% endif %}
103+
{% endfor %}
104+
{% endmacro %}
97105

98-
{% if output_file.pydantic_dataclasses and message.has_oneof_fields %}
99-
@root_validator()
100-
def check_oneof(cls, values):
101-
return cls._validate_field_groups(values)
102-
{% endif %}
106+
{% set tree_with_objs = output_file.structure_with_obj() %}
107+
{{ render_class(tree_with_objs["root"]) }}
108+
109+
{#
110+
111+
112+
original markdown is missing until I finish debugging this
113+
114+
115+
#}
103116

104-
{% endfor %}
105117
{% for service in output_file.services %}
106118
class {{ service.py_name }}Stub(betterproto.ServiceStub):
107119
{% if service.comment %}
+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
syntax = "proto3";
2+
3+
package nested_enum;
4+
5+
message Test {
6+
enum Inner {
7+
NONE = 0;
8+
THIS = 1;
9+
}
10+
Inner status = 1;
11+
12+
message Doubly {
13+
enum Inner {
14+
NONE = 0;
15+
THIS = 1;
16+
}
17+
Inner status = 1;
18+
}
19+
}
20+
21+
22+
message TestInner {
23+
int32 foo = 1;
24+
}
25+
26+
message TestDoublyInner {
27+
int32 foo = 1;
28+
string bar = 2;
29+
}
30+
31+
enum Outer {
32+
foo = 0;
33+
bar = 1;
34+
}
35+
36+
message Content {
37+
message Status {
38+
string code = 1;
39+
}
40+
Status status = 1;
41+
}
42+
43+
message ContentStatus {
44+
int32 id = 1;
45+
}

tests/test_nested_enums.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import betterproto
2+
from dataclasses import dataclass
3+

0 commit comments

Comments
 (0)