Skip to content

Commit f96f516

Browse files
authored
fix: 3.10 style imports not resolving correctly (#594)
1 parent 5fdd0bb commit f96f516

File tree

4 files changed

+54
-37
lines changed

4 files changed

+54
-37
lines changed

src/betterproto/__init__.py

+23-16
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,13 @@
6262
SupportsWrite,
6363
)
6464

65+
if sys.version_info >= (3, 10):
66+
from types import UnionType as _types_UnionType
67+
else:
68+
69+
class _types_UnionType:
70+
...
71+
6572

6673
# Proto 3 data types
6774
TYPE_ENUM = "enum"
@@ -148,6 +155,7 @@ def datetime_default_gen() -> datetime:
148155

149156
DATETIME_ZERO = datetime_default_gen()
150157

158+
151159
# Special protobuf json doubles
152160
INFINITY = "Infinity"
153161
NEG_INFINITY = "-Infinity"
@@ -1166,30 +1174,29 @@ def _get_field_default(self, field_name: str) -> Any:
11661174
def _get_field_default_gen(cls, field: dataclasses.Field) -> Any:
11671175
t = cls._type_hint(field.name)
11681176

1169-
if hasattr(t, "__origin__"):
1170-
if t.__origin__ is dict:
1171-
# This is some kind of map (dict in Python).
1172-
return dict
1173-
elif t.__origin__ is list:
1174-
# This is some kind of list (repeated) field.
1175-
return list
1176-
elif t.__origin__ is Union and t.__args__[1] is type(None):
1177+
is_310_union = isinstance(t, _types_UnionType)
1178+
if hasattr(t, "__origin__") or is_310_union:
1179+
if is_310_union or t.__origin__ is Union:
11771180
# This is an optional field (either wrapped, or using proto3
11781181
# field presence). For setting the default we really don't care
11791182
# what kind of field it is.
11801183
return type(None)
1181-
else:
1182-
return t
1183-
elif issubclass(t, Enum):
1184+
if t.__origin__ is list:
1185+
# This is some kind of list (repeated) field.
1186+
return list
1187+
if t.__origin__ is dict:
1188+
# This is some kind of map (dict in Python).
1189+
return dict
1190+
return t
1191+
if issubclass(t, Enum):
11841192
# Enums always default to zero.
11851193
return t.try_value
1186-
elif t is datetime:
1194+
if t is datetime:
11871195
# Offsets are relative to 1970-01-01T00:00:00Z
11881196
return datetime_default_gen
1189-
else:
1190-
# This is either a primitive scalar or another message type. Calling
1191-
# it should result in its zero value.
1192-
return t
1197+
# This is either a primitive scalar or another message type. Calling
1198+
# it should result in its zero value.
1199+
return t
11931200

11941201
def _postprocess_single(
11951202
self, wire_type: int, meta: FieldMetadata, field_name: str, value: Any

src/betterproto/compile/importing.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
from __future__ import annotations
2+
13
import os
24
import re
35
from typing import (
6+
TYPE_CHECKING,
47
Dict,
58
List,
69
Set,
@@ -13,6 +16,9 @@
1316
from .naming import pythonize_class_name
1417

1518

19+
if TYPE_CHECKING:
20+
from ..plugin.typing_compiler import TypingCompiler
21+
1622
WRAPPER_TYPES: Dict[str, Type] = {
1723
".google.protobuf.DoubleValue": google_protobuf.DoubleValue,
1824
".google.protobuf.FloatValue": google_protobuf.FloatValue,
@@ -47,7 +53,7 @@ def get_type_reference(
4753
package: str,
4854
imports: set,
4955
source_type: str,
50-
typing_compiler: "TypingCompiler",
56+
typing_compiler: TypingCompiler,
5157
unwrap: bool = True,
5258
pydantic: bool = False,
5359
) -> str:

src/betterproto/plugin/typing_compiler.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -139,29 +139,35 @@ def imports(self) -> Dict[str, Optional[Set[str]]]:
139139
class NoTyping310TypingCompiler(TypingCompiler):
140140
_imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set))
141141

142+
@staticmethod
143+
def _fmt(type: str) -> str: # for now this is necessary till 3.14
144+
if type.startswith('"'):
145+
return type[1:-1]
146+
return type
147+
142148
def optional(self, type: str) -> str:
143-
return f"{type} | None"
149+
return f'"{self._fmt(type)} | None"'
144150

145151
def list(self, type: str) -> str:
146-
return f"list[{type}]"
152+
return f'"list[{self._fmt(type)}]"'
147153

148154
def dict(self, key: str, value: str) -> str:
149-
return f"dict[{key}, {value}]"
155+
return f'"dict[{key}, {self._fmt(value)}]"'
150156

151157
def union(self, *types: str) -> str:
152-
return " | ".join(types)
158+
return f'"{" | ".join(map(self._fmt, types))}"'
153159

154160
def iterable(self, type: str) -> str:
155-
self._imports["typing"].add("Iterable")
156-
return f"Iterable[{type}]"
161+
self._imports["collections.abc"].add("Iterable")
162+
return f'"Iterable[{type}]"'
157163

158164
def async_iterable(self, type: str) -> str:
159-
self._imports["typing"].add("AsyncIterable")
160-
return f"AsyncIterable[{type}]"
165+
self._imports["collections.abc"].add("AsyncIterable")
166+
return f'"AsyncIterable[{type}]"'
161167

162168
def async_iterator(self, type: str) -> str:
163-
self._imports["typing"].add("AsyncIterator")
164-
return f"AsyncIterator[{type}]"
169+
self._imports["collections.abc"].add("AsyncIterator")
170+
return f'"AsyncIterator[{type}]"'
165171

166172
def imports(self) -> Dict[str, Optional[Set[str]]]:
167173
return {k: v if v else None for k, v in self._imports.items()}

tests/test_typing_compiler.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -62,19 +62,17 @@ def test_typing_import_typing_compiler():
6262
def test_no_typing_311_typing_compiler():
6363
compiler = NoTyping310TypingCompiler()
6464
assert compiler.imports() == {}
65-
assert compiler.optional("str") == "str | None"
65+
assert compiler.optional("str") == '"str | None"'
6666
assert compiler.imports() == {}
67-
assert compiler.list("str") == "list[str]"
67+
assert compiler.list("str") == '"list[str]"'
6868
assert compiler.imports() == {}
69-
assert compiler.dict("str", "int") == "dict[str, int]"
69+
assert compiler.dict("str", "int") == '"dict[str, int]"'
7070
assert compiler.imports() == {}
71-
assert compiler.union("str", "int") == "str | int"
71+
assert compiler.union("str", "int") == '"str | int"'
7272
assert compiler.imports() == {}
73-
assert compiler.iterable("str") == "Iterable[str]"
74-
assert compiler.imports() == {"typing": {"Iterable"}}
75-
assert compiler.async_iterable("str") == "AsyncIterable[str]"
76-
assert compiler.imports() == {"typing": {"Iterable", "AsyncIterable"}}
77-
assert compiler.async_iterator("str") == "AsyncIterator[str]"
73+
assert compiler.iterable("str") == '"Iterable[str]"'
74+
assert compiler.async_iterable("str") == '"AsyncIterable[str]"'
75+
assert compiler.async_iterator("str") == '"AsyncIterator[str]"'
7876
assert compiler.imports() == {
79-
"typing": {"Iterable", "AsyncIterable", "AsyncIterator"}
77+
"collections.abc": {"Iterable", "AsyncIterable", "AsyncIterator"}
8078
}

0 commit comments

Comments
 (0)