Skip to content

Commit 6784767

Browse files
committed
Implement proto3 field presence
1 parent 29d8947 commit 6784767

File tree

6 files changed

+80
-43
lines changed

6 files changed

+80
-43
lines changed

src/betterproto/__init__.py

+41-36
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,11 @@ def dataclass_field(
159159
map_types: Optional[Tuple[str, str]] = None,
160160
group: Optional[str] = None,
161161
wraps: Optional[str] = None,
162+
optional: bool = False,
162163
) -> dataclasses.Field:
163164
"""Creates a dataclass field with attached protobuf metadata."""
164165
return dataclasses.field(
165-
default=PLACEHOLDER,
166+
default=None if optional else PLACEHOLDER,
166167
metadata={
167168
"betterproto": FieldMetadata(number, proto_type, map_types, group, wraps)
168169
},
@@ -174,74 +175,74 @@ def dataclass_field(
174175
# out at runtime. The generated dataclass variables are still typed correctly.
175176

176177

177-
def enum_field(number: int, group: Optional[str] = None) -> Any:
178-
return dataclass_field(number, TYPE_ENUM, group=group)
178+
def enum_field(number: int, group: Optional[str] = None, optional: bool = False) -> Any:
179+
return dataclass_field(number, TYPE_ENUM, group=group, optional=optional)
179180

180181

181-
def bool_field(number: int, group: Optional[str] = None) -> Any:
182-
return dataclass_field(number, TYPE_BOOL, group=group)
182+
def bool_field(number: int, group: Optional[str] = None, optional: bool = False) -> Any:
183+
return dataclass_field(number, TYPE_BOOL, group=group, optional=optional)
183184

184185

185-
def int32_field(number: int, group: Optional[str] = None) -> Any:
186-
return dataclass_field(number, TYPE_INT32, group=group)
186+
def int32_field(number: int, group: Optional[str] = None, optional: bool = False) -> Any:
187+
return dataclass_field(number, TYPE_INT32, group=group, optional=optional)
187188

188189

189-
def int64_field(number: int, group: Optional[str] = None) -> Any:
190-
return dataclass_field(number, TYPE_INT64, group=group)
190+
def int64_field(number: int, group: Optional[str] = None, optional: bool = False) -> Any:
191+
return dataclass_field(number, TYPE_INT64, group=group, optional=optional)
191192

192193

193-
def uint32_field(number: int, group: Optional[str] = None) -> Any:
194-
return dataclass_field(number, TYPE_UINT32, group=group)
194+
def uint32_field(number: int, group: Optional[str] = None, optional: bool = False) -> Any:
195+
return dataclass_field(number, TYPE_UINT32, group=group, optional=optional)
195196

196197

197-
def uint64_field(number: int, group: Optional[str] = None) -> Any:
198-
return dataclass_field(number, TYPE_UINT64, group=group)
198+
def uint64_field(number: int, group: Optional[str] = None, optional: bool = False) -> Any:
199+
return dataclass_field(number, TYPE_UINT64, group=group, optional=optional)
199200

200201

201-
def sint32_field(number: int, group: Optional[str] = None) -> Any:
202-
return dataclass_field(number, TYPE_SINT32, group=group)
202+
def sint32_field(number: int, group: Optional[str] = None, optional: bool = False) -> Any:
203+
return dataclass_field(number, TYPE_SINT32, group=group, optional=optional)
203204

204205

205-
def sint64_field(number: int, group: Optional[str] = None) -> Any:
206-
return dataclass_field(number, TYPE_SINT64, group=group)
206+
def sint64_field(number: int, group: Optional[str] = None, optional: bool = False) -> Any:
207+
return dataclass_field(number, TYPE_SINT64, group=group, optional=optional)
207208

208209

209-
def float_field(number: int, group: Optional[str] = None) -> Any:
210-
return dataclass_field(number, TYPE_FLOAT, group=group)
210+
def float_field(number: int, group: Optional[str] = None, optional: bool = False) -> Any:
211+
return dataclass_field(number, TYPE_FLOAT, group=group, optional=optional)
211212

212213

213-
def double_field(number: int, group: Optional[str] = None) -> Any:
214-
return dataclass_field(number, TYPE_DOUBLE, group=group)
214+
def double_field(number: int, group: Optional[str] = None, optional: bool = False) -> Any:
215+
return dataclass_field(number, TYPE_DOUBLE, group=group, optional=optional)
215216

216217

217-
def fixed32_field(number: int, group: Optional[str] = None) -> Any:
218-
return dataclass_field(number, TYPE_FIXED32, group=group)
218+
def fixed32_field(number: int, group: Optional[str] = None, optional: bool = False) -> Any:
219+
return dataclass_field(number, TYPE_FIXED32, group=group, optional=optional)
219220

220221

221-
def fixed64_field(number: int, group: Optional[str] = None) -> Any:
222-
return dataclass_field(number, TYPE_FIXED64, group=group)
222+
def fixed64_field(number: int, group: Optional[str] = None, optional: bool = False) -> Any:
223+
return dataclass_field(number, TYPE_FIXED64, group=group, optional=optional)
223224

224225

225-
def sfixed32_field(number: int, group: Optional[str] = None) -> Any:
226-
return dataclass_field(number, TYPE_SFIXED32, group=group)
226+
def sfixed32_field(number: int, group: Optional[str] = None, optional: bool = False) -> Any:
227+
return dataclass_field(number, TYPE_SFIXED32, group=group, optional=optional)
227228

228229

229-
def sfixed64_field(number: int, group: Optional[str] = None) -> Any:
230-
return dataclass_field(number, TYPE_SFIXED64, group=group)
230+
def sfixed64_field(number: int, group: Optional[str] = None, optional: bool = False) -> Any:
231+
return dataclass_field(number, TYPE_SFIXED64, group=group, optional=optional)
231232

232233

233-
def string_field(number: int, group: Optional[str] = None) -> Any:
234-
return dataclass_field(number, TYPE_STRING, group=group)
234+
def string_field(number: int, group: Optional[str] = None, optional: bool = False) -> Any:
235+
return dataclass_field(number, TYPE_STRING, group=group, optional=optional)
235236

236237

237-
def bytes_field(number: int, group: Optional[str] = None) -> Any:
238-
return dataclass_field(number, TYPE_BYTES, group=group)
238+
def bytes_field(number: int, group: Optional[str] = None, optional: bool = False) -> Any:
239+
return dataclass_field(number, TYPE_BYTES, group=group, optional=optional)
239240

240241

241242
def message_field(
242-
number: int, group: Optional[str] = None, wraps: Optional[str] = None
243+
number: int, group: Optional[str] = None, wraps: Optional[str] = None, optional: bool = False
243244
) -> Any:
244-
return dataclass_field(number, TYPE_MESSAGE, group=group, wraps=wraps)
245+
return dataclass_field(number, TYPE_MESSAGE, group=group, wraps=wraps, optional=optional)
245246

246247

247248
def map_field(
@@ -701,12 +702,16 @@ def __bytes__(self) -> bytes:
701702

702703
if value is None:
703704
# Optional items should be skipped. This is used for the Google
704-
# wrapper types.
705+
# wrapper types and proto3 field presence/optional fields.
705706
continue
706707

707708
# Being selected in a a group means this field is the one that is
708709
# currently set in a `oneof` group, so it must be serialized even
709710
# if the value is the default zero value.
711+
#
712+
# Note that proto3 field presence/optional fields are put in a
713+
# synthetic single-item oneof by protoc, which helps us ensure we
714+
# send the value even if the value is the default zero value.
710715
selected_in_group = (
711716
meta.group and self._group_current[meta.group] == field_name
712717
)

src/betterproto/plugin/main.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,8 @@ def main() -> None:
2828
if dump_file:
2929
dump_request(dump_file, request)
3030

31-
# Create response
32-
response = CodeGeneratorResponse()
33-
3431
# Generate code
35-
generate_code(request, response)
32+
response = generate_code(request)
3633

3734
# Serialise response message
3835
output = response.SerializeToString()

src/betterproto/plugin/models.py

+12
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,8 @@ def betterproto_field_args(self) -> List[str]:
383383
args = []
384384
if self.field_wraps:
385385
args.append(f"wraps={self.field_wraps}")
386+
if self.optional:
387+
args.append(f"optional=True")
386388
return args
387389

388390
@property
@@ -431,6 +433,12 @@ def repeated(self) -> bool:
431433
and not is_map(self.proto_obj, self.parent)
432434
)
433435

436+
@property
437+
def optional(self) -> bool:
438+
# TODO: Should proto2 optional fields with kind=Message also be
439+
# considered Optional.
440+
return self.proto_obj.proto3_optional
441+
434442
@property
435443
def mutable(self) -> bool:
436444
"""True if the field is a mutable type, otherwise False."""
@@ -450,6 +458,8 @@ def default_value_string(self) -> Union[Text, None, float, int]:
450458
"""Python representation of the default proto value."""
451459
if self.repeated:
452460
return "[]"
461+
if self.optional:
462+
return "None"
453463
if self.py_type == "int":
454464
return "0"
455465
if self.py_type == "float":
@@ -506,6 +516,8 @@ def py_type(self) -> str:
506516
def annotation(self) -> str:
507517
if self.repeated:
508518
return f"List[{self.py_type}]"
519+
if self.optional:
520+
return f"Optional[{self.py_type}]"
509521
return self.py_type
510522

511523

src/betterproto/plugin/parser.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from betterproto.lib.google.protobuf.compiler import (
99
CodeGeneratorRequest,
1010
CodeGeneratorResponse,
11+
CodeGeneratorResponseFeature,
1112
CodeGeneratorResponseFile,
1213
)
1314
import itertools
@@ -60,10 +61,11 @@ def _traverse(
6061
)
6162

6263

63-
def generate_code(
64-
request: CodeGeneratorRequest, response: CodeGeneratorResponse
65-
) -> None:
64+
def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
65+
response = CodeGeneratorResponse()
66+
6667
plugin_options = request.parameter.split(",") if request.parameter else []
68+
response.supported_features = CodeGeneratorResponseFeature.FEATURE_PROTO3_OPTIONAL
6769

6870
request_data = PluginRequestCompiler(plugin_request_obj=request)
6971
# Gather output packages
@@ -133,6 +135,7 @@ def generate_code(
133135
for output_package_name in sorted(output_paths.union(init_files)):
134136
print(f"Writing {output_package_name}", file=sys.stderr)
135137

138+
return response
136139

137140
def read_protobuf_type(
138141
item: DescriptorProto,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"test1": null,
3+
"test2": null,
4+
"test3": null,
5+
"test4": null,
6+
"test5": null
7+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
syntax = "proto3";
2+
3+
message InnerTest {
4+
string test = 1;
5+
}
6+
7+
message Test {
8+
optional uint32 test1 = 1;
9+
optional bool test2 = 2;
10+
optional string test3 = 3;
11+
optional bytes test4 = 4;
12+
optional InnerTest test5 = 5;
13+
}

0 commit comments

Comments
 (0)