Skip to content

Commit 8f0caf1

Browse files
Read desired wrapper type directly from wrapper definition
1 parent c50d9e2 commit 8f0caf1

File tree

1 file changed

+21
-31
lines changed

1 file changed

+21
-31
lines changed

betterproto/plugin.py

+21-31
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
#!/usr/bin/env python
22

33
import itertools
4-
import json
54
import os.path
6-
import re
75
import sys
86
import textwrap
9-
from typing import Any, List, Tuple
7+
from collections import defaultdict
8+
from typing import Dict, List, Optional, Type
109

1110
try:
1211
import black
@@ -24,33 +23,23 @@
2423
DescriptorProto,
2524
EnumDescriptorProto,
2625
FieldDescriptorProto,
27-
FileDescriptorProto,
28-
ServiceDescriptorProto,
2926
)
3027

3128
from betterproto.casing import safe_snake_case
3229

33-
import google.protobuf.wrappers_pb2
30+
import google.protobuf.wrappers_pb2 as google_wrappers
3431

35-
36-
WRAPPER_TYPES = {
37-
google.protobuf.wrappers_pb2.DoubleValue: "float",
38-
google.protobuf.wrappers_pb2.FloatValue: "float",
39-
google.protobuf.wrappers_pb2.Int64Value: "int",
40-
google.protobuf.wrappers_pb2.UInt64Value: "int",
41-
google.protobuf.wrappers_pb2.Int32Value: "int",
42-
google.protobuf.wrappers_pb2.UInt32Value: "int",
43-
google.protobuf.wrappers_pb2.BoolValue: "bool",
44-
google.protobuf.wrappers_pb2.StringValue: "str",
45-
google.protobuf.wrappers_pb2.BytesValue: "bytes",
46-
}
47-
48-
49-
def get_wrapper_type(type_name: str) -> (Any, str):
50-
for wrapper, wrapped_type in WRAPPER_TYPES.items():
51-
if wrapper.DESCRIPTOR.full_name == type_name:
52-
return wrapper, wrapped_type
53-
return None, None
32+
WRAPPER_TYPES: Dict[str, Optional[Type]] = defaultdict(lambda: None, {
33+
'google.protobuf.DoubleValue': google_wrappers.DoubleValue,
34+
'google.protobuf.FloatValue': google_wrappers.FloatValue,
35+
'google.protobuf.Int64Value': google_wrappers.Int64Value,
36+
'google.protobuf.UInt64Value': google_wrappers.UInt64Value,
37+
'google.protobuf.Int32Value': google_wrappers.Int32Value,
38+
'google.protobuf.UInt32Value': google_wrappers.UInt32Value,
39+
'google.protobuf.BoolValue': google_wrappers.BoolValue,
40+
'google.protobuf.StringValue': google_wrappers.StringValue,
41+
'google.protobuf.BytesValue': google_wrappers.BytesValue,
42+
})
5443

5544

5645
def get_ref_type(package: str, imports: set, type_name: str, unwrap: bool = True) -> str:
@@ -64,20 +53,21 @@ def get_ref_type(package: str, imports: set, type_name: str, unwrap: bool = True
6453
type_name = type_name.lstrip(".")
6554

6655
# Check if type is wrapper.
67-
wrapper, wrapped_type = get_wrapper_type(type_name)
56+
wrapper_class = WRAPPER_TYPES[type_name]
6857

6958
if unwrap:
70-
if wrapper:
71-
return f"Optional[{wrapped_type}]"
59+
if wrapper_class:
60+
wrapped_type = type(wrapper_class().value)
61+
return f"Optional[{wrapped_type.__name__}]"
7262

7363
if type_name == "google.protobuf.Duration":
7464
return "timedelta"
7565

7666
if type_name == "google.protobuf.Timestamp":
7767
return "datetime"
78-
elif wrapper:
79-
imports.add(f"from {wrapper.__module__} import {wrapper.__name__}")
80-
return f"{wrapper.__name__}"
68+
elif wrapper_class:
69+
imports.add(f"from {wrapper_class.__module__} import {wrapper_class.__name__}")
70+
return f"{wrapper_class.__name__}"
8171

8272
if type_name.startswith(package):
8373
parts = type_name.lstrip(package).lstrip(".").split(".")

0 commit comments

Comments
 (0)