1
1
#!/usr/bin/env python
2
2
3
3
import itertools
4
- import json
5
4
import os .path
6
- import re
7
5
import sys
8
6
import textwrap
9
- from typing import Any , List , Tuple
7
+ from collections import defaultdict
8
+ from typing import Dict , List , Optional , Type
10
9
11
10
try :
12
11
import black
24
23
DescriptorProto ,
25
24
EnumDescriptorProto ,
26
25
FieldDescriptorProto ,
27
- FileDescriptorProto ,
28
- ServiceDescriptorProto ,
29
26
)
30
27
31
28
from betterproto .casing import safe_snake_case
32
29
33
- import google .protobuf .wrappers_pb2
30
+ import google .protobuf .wrappers_pb2 as google_wrappers
34
31
35
-
36
- WRAPPER_TYPES = {
37
- google .protobuf .wrappers_pb2 .DoubleValue : "float" ,
38
- google .protobuf .wrappers_pb2 .FloatValue : "float" ,
39
- google .protobuf .wrappers_pb2 .Int64Value : "int" ,
40
- google .protobuf .wrappers_pb2 .UInt64Value : "int" ,
41
- google .protobuf .wrappers_pb2 .Int32Value : "int" ,
42
- google .protobuf .wrappers_pb2 .UInt32Value : "int" ,
43
- google .protobuf .wrappers_pb2 .BoolValue : "bool" ,
44
- google .protobuf .wrappers_pb2 .StringValue : "str" ,
45
- google .protobuf .wrappers_pb2 .BytesValue : "bytes" ,
46
- }
47
-
48
-
49
- def get_wrapper_type (type_name : str ) -> (Any , str ):
50
- for wrapper , wrapped_type in WRAPPER_TYPES .items ():
51
- if wrapper .DESCRIPTOR .full_name == type_name :
52
- return wrapper , wrapped_type
53
- return None , None
32
+ WRAPPER_TYPES : Dict [str , Optional [Type ]] = defaultdict (lambda : None , {
33
+ 'google.protobuf.DoubleValue' : google_wrappers .DoubleValue ,
34
+ 'google.protobuf.FloatValue' : google_wrappers .FloatValue ,
35
+ 'google.protobuf.Int64Value' : google_wrappers .Int64Value ,
36
+ 'google.protobuf.UInt64Value' : google_wrappers .UInt64Value ,
37
+ 'google.protobuf.Int32Value' : google_wrappers .Int32Value ,
38
+ 'google.protobuf.UInt32Value' : google_wrappers .UInt32Value ,
39
+ 'google.protobuf.BoolValue' : google_wrappers .BoolValue ,
40
+ 'google.protobuf.StringValue' : google_wrappers .StringValue ,
41
+ 'google.protobuf.BytesValue' : google_wrappers .BytesValue ,
42
+ })
54
43
55
44
56
45
def get_ref_type (package : str , imports : set , type_name : str , unwrap : bool = True ) -> str :
@@ -64,20 +53,21 @@ def get_ref_type(package: str, imports: set, type_name: str, unwrap: bool = True
64
53
type_name = type_name .lstrip ("." )
65
54
66
55
# Check if type is wrapper.
67
- wrapper , wrapped_type = get_wrapper_type ( type_name )
56
+ wrapper_class = WRAPPER_TYPES [ type_name ]
68
57
69
58
if unwrap :
70
- if wrapper :
71
- return f"Optional[{ wrapped_type } ]"
59
+ if wrapper_class :
60
+ wrapped_type = type (wrapper_class ().value )
61
+ return f"Optional[{ wrapped_type .__name__ } ]"
72
62
73
63
if type_name == "google.protobuf.Duration" :
74
64
return "timedelta"
75
65
76
66
if type_name == "google.protobuf.Timestamp" :
77
67
return "datetime"
78
- elif wrapper :
79
- imports .add (f"from { wrapper .__module__ } import { wrapper .__name__ } " )
80
- return f"{ wrapper .__name__ } "
68
+ elif wrapper_class :
69
+ imports .add (f"from { wrapper_class .__module__ } import { wrapper_class .__name__ } " )
70
+ return f"{ wrapper_class .__name__ } "
81
71
82
72
if type_name .startswith (package ):
83
73
parts = type_name .lstrip (package ).lstrip ("." ).split ("." )
0 commit comments