Skip to content

Commit 86cf372

Browse files
Fix Union[..., NoneType] injection by get_type_hints if a None default value is used. (#482)
Co-authored-by: Jelle Zijlstra <[email protected]>
1 parent 8184ac6 commit 86cf372

File tree

3 files changed

+173
-0
lines changed

3 files changed

+173
-0
lines changed

Diff for: CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ aliases that have a `Concatenate` special form as their argument.
2121
Patch by [Daraan](https://github.com/Daraan).
2222
- Extended the `Concatenate` backport for Python 3.8-3.10 to now accept
2323
`Ellipsis` as an argument. Patch by [Daraan](https://github.com/Daraan).
24+
- Fix backport of `get_type_hints` to reflect Python 3.11+ behavior which does not add
25+
`Union[..., NoneType]` to annotations that have a `None` default value anymore.
26+
This fixes wrapping of `Annotated` in an unwanted `Optional` in such cases.
27+
Patch by [Daraan](https://github.com/Daraan).
2428
- Fix error in subscription of `Unpack` aliases causing nested Unpacks
2529
to not be resolved correctly. Patch by [Daraan](https://github.com/Daraan).
2630
- Backport CPython PR [#124795](https://github.com/python/cpython/pull/124795):

Diff for: src/test_typing_extensions.py

+89
Original file line numberDiff line numberDiff line change
@@ -1647,6 +1647,95 @@ def test_final_forward_ref(self):
16471647
self.assertNotEqual(gth(Loop, globals())['attr'], Final[int])
16481648
self.assertNotEqual(gth(Loop, globals())['attr'], Final)
16491649

1650+
def test_annotation_and_optional_default(self):
1651+
annotation = Annotated[Union[int, None], "data"]
1652+
NoneAlias = None
1653+
StrAlias = str
1654+
T_default = TypeVar("T_default", default=None)
1655+
Ts = TypeVarTuple("Ts")
1656+
1657+
cases = {
1658+
# annotation: expected_type_hints
1659+
Annotated[None, "none"] : Annotated[None, "none"],
1660+
annotation : annotation,
1661+
Optional[int] : Optional[int],
1662+
Optional[List[str]] : Optional[List[str]],
1663+
Optional[annotation] : Optional[annotation],
1664+
Union[str, None, str] : Optional[str],
1665+
Unpack[Tuple[int, None]]: Unpack[Tuple[int, None]],
1666+
# Note: A starred *Ts will use typing.Unpack in 3.11+ see Issue #485
1667+
Unpack[Ts] : Unpack[Ts],
1668+
}
1669+
# contains a ForwardRef, TypeVar(~prefix) or no expression
1670+
do_not_stringify_cases = {
1671+
() : {}, # Special-cased below to create an unannotated parameter
1672+
int : int,
1673+
"int" : int,
1674+
None : type(None),
1675+
"NoneAlias" : type(None),
1676+
List["str"] : List[str],
1677+
Union[str, "str"] : str,
1678+
Union[str, None, "str"] : Optional[str],
1679+
Union[str, "NoneAlias", "StrAlias"]: Optional[str],
1680+
Union[str, "Union[None, StrAlias]"]: Optional[str],
1681+
Union["annotation", T_default] : Union[annotation, T_default],
1682+
Annotated["annotation", "nested"] : Annotated[Union[int, None], "data", "nested"],
1683+
}
1684+
if TYPING_3_10_0: # cannot construct UnionTypes before 3.10
1685+
do_not_stringify_cases["str | NoneAlias | StrAlias"] = str | None
1686+
cases[str | None] = Optional[str]
1687+
cases.update(do_not_stringify_cases)
1688+
for (annot, expected), none_default, as_str, wrap_optional in itertools.product(
1689+
cases.items(), (False, True), (False, True), (False, True)
1690+
):
1691+
# Special case:
1692+
skip_reason = None
1693+
annot_unchanged = annot
1694+
if sys.version_info[:2] == (3, 10) and annot == "str | NoneAlias | StrAlias" and none_default:
1695+
# In 3.10 converts Optional[str | None] to Optional[str] which has a different repr
1696+
skip_reason = "UnionType not preserved in 3.10"
1697+
if wrap_optional:
1698+
if annot_unchanged == ():
1699+
continue
1700+
annot = Optional[annot]
1701+
expected = {"x": Optional[expected]}
1702+
else:
1703+
expected = {"x": expected} if annot_unchanged != () else {}
1704+
if as_str:
1705+
if annot_unchanged in do_not_stringify_cases or annot_unchanged == ():
1706+
continue
1707+
annot = str(annot)
1708+
with self.subTest(
1709+
annotation=annot,
1710+
as_str=as_str,
1711+
wrap_optional=wrap_optional,
1712+
none_default=none_default,
1713+
expected_type_hints=expected,
1714+
):
1715+
# Create function to check
1716+
if annot_unchanged == ():
1717+
if none_default:
1718+
def func(x=None): pass
1719+
else:
1720+
def func(x): pass
1721+
elif none_default:
1722+
def func(x: annot = None): pass
1723+
else:
1724+
def func(x: annot): pass
1725+
type_hints = get_type_hints(func, globals(), locals(), include_extras=True)
1726+
# Equality
1727+
self.assertEqual(type_hints, expected)
1728+
# Hash
1729+
for k in type_hints.keys():
1730+
self.assertEqual(hash(type_hints[k]), hash(expected[k]))
1731+
# Test if UnionTypes are preserved
1732+
self.assertIs(type(type_hints[k]), type(expected[k]))
1733+
# Repr
1734+
with self.subTest("Check str and repr"):
1735+
if skip_reason == "UnionType not preserved in 3.10":
1736+
self.skipTest(skip_reason)
1737+
self.assertEqual(repr(type_hints), repr(expected))
1738+
16501739

16511740
class GetUtilitiesTestCase(TestCase):
16521741
def test_get_origin(self):

Diff for: src/typing_extensions.py

+80
Original file line numberDiff line numberDiff line change
@@ -1242,10 +1242,90 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
12421242
)
12431243
else: # 3.8
12441244
hint = typing.get_type_hints(obj, globalns=globalns, localns=localns)
1245+
if sys.version_info < (3, 11):
1246+
_clean_optional(obj, hint, globalns, localns)
1247+
if sys.version_info < (3, 9):
1248+
# In 3.8 eval_type does not flatten Optional[ForwardRef] correctly
1249+
# This will recreate and and cache Unions.
1250+
hint = {
1251+
k: (t
1252+
if get_origin(t) != Union
1253+
else Union[t.__args__])
1254+
for k, t in hint.items()
1255+
}
12451256
if include_extras:
12461257
return hint
12471258
return {k: _strip_extras(t) for k, t in hint.items()}
12481259

1260+
_NoneType = type(None)
1261+
1262+
def _could_be_inserted_optional(t):
1263+
"""detects Union[..., None] pattern"""
1264+
# 3.8+ compatible checking before _UnionGenericAlias
1265+
if get_origin(t) is not Union:
1266+
return False
1267+
# Assume if last argument is not None they are user defined
1268+
if t.__args__[-1] is not _NoneType:
1269+
return False
1270+
return True
1271+
1272+
# < 3.11
1273+
def _clean_optional(obj, hints, globalns=None, localns=None):
1274+
# reverts injected Union[..., None] cases from typing.get_type_hints
1275+
# when a None default value is used.
1276+
# see https://github.com/python/typing_extensions/issues/310
1277+
if not hints or isinstance(obj, type):
1278+
return
1279+
defaults = typing._get_defaults(obj) # avoid accessing __annotations___
1280+
if not defaults:
1281+
return
1282+
original_hints = obj.__annotations__
1283+
for name, value in hints.items():
1284+
# Not a Union[..., None] or replacement conditions not fullfilled
1285+
if (not _could_be_inserted_optional(value)
1286+
or name not in defaults
1287+
or defaults[name] is not None
1288+
):
1289+
continue
1290+
original_value = original_hints[name]
1291+
# value=NoneType should have caused a skip above but check for safety
1292+
if original_value is None:
1293+
original_value = _NoneType
1294+
# Forward reference
1295+
if isinstance(original_value, str):
1296+
if globalns is None:
1297+
if isinstance(obj, _types.ModuleType):
1298+
globalns = obj.__dict__
1299+
else:
1300+
nsobj = obj
1301+
# Find globalns for the unwrapped object.
1302+
while hasattr(nsobj, '__wrapped__'):
1303+
nsobj = nsobj.__wrapped__
1304+
globalns = getattr(nsobj, '__globals__', {})
1305+
if localns is None:
1306+
localns = globalns
1307+
elif localns is None:
1308+
localns = globalns
1309+
if sys.version_info < (3, 9):
1310+
original_value = ForwardRef(original_value)
1311+
else:
1312+
original_value = ForwardRef(
1313+
original_value,
1314+
is_argument=not isinstance(obj, _types.ModuleType)
1315+
)
1316+
original_evaluated = typing._eval_type(original_value, globalns, localns)
1317+
if sys.version_info < (3, 9) and get_origin(original_evaluated) is Union:
1318+
# Union[str, None, "str"] is not reduced to Union[str, None]
1319+
original_evaluated = Union[original_evaluated.__args__]
1320+
# Compare if values differ. Note that even if equal
1321+
# value might be cached by typing._tp_cache contrary to original_evaluated
1322+
if original_evaluated != value or (
1323+
# 3.10: ForwardRefs of UnionType might be turned into _UnionGenericAlias
1324+
hasattr(_types, "UnionType")
1325+
and isinstance(original_evaluated, _types.UnionType)
1326+
and not isinstance(value, _types.UnionType)
1327+
):
1328+
hints[name] = original_evaluated
12491329

12501330
# Python 3.9+ has PEP 593 (Annotated)
12511331
if hasattr(typing, 'Annotated'):

0 commit comments

Comments
 (0)