Skip to content

Commit bc3cfc5

Browse files
authored
Fix default values for enum service args #298 (#299)
1 parent b0a36d1 commit bc3cfc5

File tree

3 files changed

+26
-4
lines changed

3 files changed

+26
-4
lines changed

src/betterproto/plugin/models.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,7 @@
5959
import re
6060
import textwrap
6161
from dataclasses import dataclass, field
62-
from typing import Dict, Iterable, Iterator, List, Optional, Set, Text, Type, Union
63-
import sys
62+
from typing import Dict, Iterable, Iterator, List, Optional, Set, Type, Union
6463

6564
from ..casing import sanitize_name
6665
from ..compile.importing import get_type_reference, parse_source_type_name
@@ -460,7 +459,7 @@ def field_type(self) -> str:
460459
)
461460

462461
@property
463-
def default_value_string(self) -> Union[Text, None, float, int]:
462+
def default_value_string(self) -> str:
464463
"""Python representation of the default proto value."""
465464
if self.repeated:
466465
return "[]"
@@ -474,6 +473,14 @@ def default_value_string(self) -> Union[Text, None, float, int]:
474473
return '""'
475474
elif self.py_type == "bytes":
476475
return 'b""'
476+
elif self.field_type == "enum":
477+
enum_proto_obj_name = self.proto_obj.type_name.split(".").pop()
478+
enum = next(
479+
e
480+
for e in self.output_file.enums
481+
if e.proto_obj.name == enum_proto_obj_name
482+
)
483+
return enum.default_value_string
477484
else:
478485
# Message type
479486
return "None"

tests/inputs/service/service.proto

+7
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,16 @@ syntax = "proto3";
22

33
package service;
44

5+
enum ThingType {
6+
UNKNOWN = 0;
7+
LIVING = 1;
8+
DEAD = 2;
9+
}
10+
511
message DoThingRequest {
612
string name = 1;
713
repeated string comments = 2;
14+
ThingType type = 3;
815
}
916

1017
message DoThingResponse {

tests/test_features.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import betterproto
22
from dataclasses import dataclass
33
from typing import Optional, List, Dict
4-
from datetime import datetime, timedelta
4+
from datetime import datetime
5+
from inspect import signature
56

67

78
def test_has_field():
@@ -476,3 +477,10 @@ class Envelope(betterproto.Message):
476477

477478
msg.from_dict({"timestamps": iso_candidates})
478479
assert all([isinstance(item, datetime) for item in msg.timestamps])
480+
481+
482+
def test_enum_service_argument__expected_default_value():
483+
from tests.output_betterproto.service.service import ThingType, TestStub
484+
485+
sig = signature(TestStub.do_thing)
486+
assert sig.parameters["type"].default == ThingType.UNKNOWN

0 commit comments

Comments
 (0)