Skip to content

Commit 13d6565

Browse files
authored
Add support for pydantic dataclasses (#406)
1 parent 6df8cef commit 13d6565

File tree

11 files changed

+283
-19
lines changed

11 files changed

+283
-19
lines changed

README.md

+20
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ This project aims to provide an improved experience when using Protobuf / gRPC i
1414
- Timezone-aware `datetime` and `timedelta` objects
1515
- Relative imports
1616
- Mypy type checking
17+
- [Pydantic Models](https://docs.pydantic.dev/) generation (see #generating-pydantic-models)
1718

1819
This project is heavily inspired by, and borrows functionality from:
1920

@@ -364,6 +365,25 @@ datetime.datetime(2019, 1, 1, 11, 59, 58, 800000, tzinfo=datetime.timezone.utc)
364365
{'ts': '2019-01-01T12:00:00Z', 'duration': '1.200s'}
365366
```
366367

368+
## Generating Pydantic Models
369+
370+
You can use python-betterproto to generate pydantic based models, using
371+
pydantic dataclasses. This means the results of the protobuf unmarshalling will
372+
be typed checked. The usage is the same, but you need to add a custom option
373+
when calling the protobuf compiler:
374+
375+
376+
```
377+
protoc -I . --custom_opt=pydantic_dataclasses --python_betterproto_out=lib example.proto
378+
```
379+
380+
With the important change being `--custom_opt=pydantic_dataclasses`. This will
381+
swap the dataclass implementation from the builtin python dataclass to the
382+
pydantic dataclass. You must have pydantic as a dependency in your project for
383+
this to work.
384+
385+
386+
367387
## Development
368388

369389
- _Join us on [Slack](https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ)!_

poetry.lock

+61-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ sphinx-rtd-theme = "0.5.0"
3636
tomlkit = "^0.7.0"
3737
tox = "^3.15.1"
3838
pre-commit = "^2.17.0"
39+
pydantic = ">=1.8.0"
3940

4041

4142
[tool.poetry.scripts]

src/betterproto/__init__.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,6 @@ def __post_init__(self) -> None:
628628
# Set current field of each group after `__init__` has already been run.
629629
group_current: Dict[str, Optional[str]] = {}
630630
for field_name, meta in self._betterproto.meta_by_field_name.items():
631-
632631
if meta.group:
633632
group_current.setdefault(meta.group)
634633

@@ -1470,6 +1469,24 @@ def is_set(self, name: str) -> bool:
14701469
)
14711470
return self.__raw_get(name) is not default
14721471

1472+
@classmethod
1473+
def _validate_field_groups(cls, values):
1474+
meta = cls._betterproto_meta.oneof_field_by_group # type: ignore
1475+
1476+
for group, field_set in meta.items():
1477+
set_fields = [
1478+
field.name for field in field_set if values[field.name] is not None
1479+
]
1480+
if not set_fields:
1481+
raise ValueError(f"Group {group} has no value; all fields are None")
1482+
elif len(set_fields) > 1:
1483+
set_fields_str = ", ".join(set_fields)
1484+
raise ValueError(
1485+
f"Group {group} has more than one value; fields {set_fields_str} are not None"
1486+
)
1487+
1488+
return values
1489+
14731490

14741491
def serialized_on_wire(message: Message) -> bool:
14751492
"""

src/betterproto/plugin/models.py

+35-2
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,6 @@ def comment(self) -> str:
214214

215215
@dataclass
216216
class PluginRequestCompiler:
217-
218217
plugin_request_obj: CodeGeneratorRequest
219218
output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict)
220219

@@ -247,11 +246,13 @@ class OutputTemplate:
247246
imports: Set[str] = field(default_factory=set)
248247
datetime_imports: Set[str] = field(default_factory=set)
249248
typing_imports: Set[str] = field(default_factory=set)
249+
pydantic_imports: Set[str] = field(default_factory=set)
250250
builtins_import: bool = False
251251
messages: List["MessageCompiler"] = field(default_factory=list)
252252
enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
253253
services: List["ServiceCompiler"] = field(default_factory=list)
254254
imports_type_checking_only: Set[str] = field(default_factory=set)
255+
pydantic_dataclasses: bool = False
255256
output: bool = True
256257

257258
@property
@@ -334,6 +335,20 @@ def deprecated_fields(self) -> Iterator[str]:
334335
def has_deprecated_fields(self) -> bool:
335336
return any(self.deprecated_fields)
336337

338+
@property
339+
def has_oneof_fields(self) -> bool:
340+
return any(isinstance(field, OneOfFieldCompiler) for field in self.fields)
341+
342+
@property
343+
def has_message_field(self) -> bool:
344+
return any(
345+
(
346+
field.proto_obj.type in PROTO_MESSAGE_TYPES
347+
for field in self.fields
348+
if isinstance(field.proto_obj, FieldDescriptorProto)
349+
)
350+
)
351+
337352

338353
def is_map(
339354
proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto
@@ -431,6 +446,10 @@ def typing_imports(self) -> Set[str]:
431446
imports.add("Dict")
432447
return imports
433448

449+
@property
450+
def pydantic_imports(self) -> Set[str]:
451+
return set()
452+
434453
@property
435454
def use_builtins(self) -> bool:
436455
return self.py_type in self.parent.builtins_types or (
@@ -440,6 +459,7 @@ def use_builtins(self) -> bool:
440459
def add_imports_to(self, output_file: OutputTemplate) -> None:
441460
output_file.datetime_imports.update(self.datetime_imports)
442461
output_file.typing_imports.update(self.typing_imports)
462+
output_file.pydantic_imports.update(self.pydantic_imports)
443463
output_file.builtins_import = output_file.builtins_import or self.use_builtins
444464

445465
@property
@@ -568,6 +588,20 @@ def betterproto_field_args(self) -> List[str]:
568588
return args
569589

570590

591+
@dataclass
592+
class PydanticOneOfFieldCompiler(OneOfFieldCompiler):
593+
@property
594+
def optional(self) -> bool:
595+
# Force the optional to be True. This will allow the pydantic dataclass
596+
# to validate the object correctly by allowing the field to be let empty.
597+
# We add a pydantic validator later to ensure exactly one field is defined.
598+
return True
599+
600+
@property
601+
def pydantic_imports(self) -> Set[str]:
602+
return {"root_validator"}
603+
604+
571605
@dataclass
572606
class MapEntryCompiler(FieldCompiler):
573607
py_k_type: Type = PLACEHOLDER
@@ -679,7 +713,6 @@ def py_name(self) -> str:
679713

680714
@dataclass
681715
class ServiceMethodCompiler(ProtoContentBase):
682-
683716
parent: ServiceCompiler
684717
proto_obj: MethodDescriptorProto
685718
path: List[int] = PLACEHOLDER

src/betterproto/plugin/parser.py

+27-5
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from betterproto.lib.google.protobuf import (
1212
DescriptorProto,
1313
EnumDescriptorProto,
14+
FieldDescriptorProto,
1415
FileDescriptorProto,
1516
ServiceDescriptorProto,
1617
)
@@ -30,6 +31,7 @@
3031
OneOfFieldCompiler,
3132
OutputTemplate,
3233
PluginRequestCompiler,
34+
PydanticOneOfFieldCompiler,
3335
ServiceCompiler,
3436
ServiceMethodCompiler,
3537
is_map,
@@ -91,6 +93,11 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
9193
# skip outputting Google's well-known types
9294
request_data.output_packages[output_package_name].output = False
9395

96+
if "pydantic_dataclasses" in plugin_options:
97+
request_data.output_packages[
98+
output_package_name
99+
].pydantic_dataclasses = True
100+
94101
# Read Messages and Enums
95102
# We need to read Messages before Services in so that we can
96103
# get the references to input/output messages for each service
@@ -145,6 +152,24 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
145152
return response
146153

147154

155+
def _make_one_of_field_compiler(
156+
output_package: OutputTemplate,
157+
source_file: "FileDescriptorProto",
158+
parent: MessageCompiler,
159+
proto_obj: "FieldDescriptorProto",
160+
path: List[int],
161+
) -> FieldCompiler:
162+
163+
pydantic = output_package.pydantic_dataclasses
164+
Cls = PydanticOneOfFieldCompiler if pydantic else OneOfFieldCompiler
165+
return Cls(
166+
source_file=source_file,
167+
parent=parent,
168+
proto_obj=proto_obj,
169+
path=path,
170+
)
171+
172+
148173
def read_protobuf_type(
149174
item: DescriptorProto,
150175
path: List[int],
@@ -168,11 +193,8 @@ def read_protobuf_type(
168193
path=path + [2, index],
169194
)
170195
elif is_oneof(field):
171-
OneOfFieldCompiler(
172-
source_file=source_file,
173-
parent=message_data,
174-
proto_obj=field,
175-
path=path + [2, index],
196+
_make_one_of_field_compiler(
197+
output_package, source_file, message_data, field, path + [2, index]
176198
)
177199
else:
178200
FieldCompiler(

src/betterproto/templates/template.py.j2

+24
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,13 @@
55
{% for i in output_file.python_module_imports|sort %}
66
import {{ i }}
77
{% endfor %}
8+
9+
{% if output_file.pydantic_dataclasses %}
10+
from pydantic.dataclasses import dataclass
11+
{%- else -%}
812
from dataclasses import dataclass
13+
{% endif %}
14+
915
{% if output_file.datetime_imports %}
1016
from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
1117

@@ -15,6 +21,11 @@ from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if no
1521

1622
{% endif %}
1723

24+
{% if output_file.pydantic_imports %}
25+
from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
26+
27+
{% endif %}
28+
1829
import betterproto
1930
{% if output_file.services %}
2031
from betterproto.grpc.grpclib_server import ServiceBase
@@ -80,6 +91,11 @@ class {{ message.py_name }}(betterproto.Message):
8091
{% endfor %}
8192
{% endif %}
8293

94+
{% if output_file.pydantic_dataclasses and message.has_oneof_fields %}
95+
@root_validator()
96+
def check_oneof(cls, values):
97+
return cls._validate_field_groups(values)
98+
{% endif %}
8399

84100
{% endfor %}
85101
{% for service in output_file.services %}
@@ -226,3 +242,11 @@ class {{ service.py_name }}Base(ServiceBase):
226242
}
227243

228244
{% endfor %}
245+
246+
{% if output_file.pydantic_dataclasses %}
247+
{% for message in output_file.messages %}
248+
{% if message.has_message_field %}
249+
{{ message.py_name }}.__pydantic_model__.update_forward_refs() # type: ignore
250+
{% endif %}
251+
{% endfor %}
252+
{% endif %}

0 commit comments

Comments
 (0)