Skip to content

Commit 156db06

Browse files
Add support for binary union types - Python 3.10 (#1977)
Co-authored-by: Pierre Sassoulas <[email protected]>
1 parent 0545192 commit 156db06

File tree

5 files changed

+203
-4
lines changed

5 files changed

+203
-4
lines changed

ChangeLog

+3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ What's New in astroid 2.14.0?
66
=============================
77
Release date: TBA
88

9+
* Add support for inferring binary union types added in Python 3.10.
10+
11+
Refs PyCQA/pylint#8119
912

1013

1114
What's New in astroid 2.13.4?

astroid/bases.py

+41-2
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,12 @@ def __init__(
121121
if proxied is None:
122122
# This is a hack to allow calling this __init__ during bootstrapping of
123123
# builtin classes and their docstrings.
124-
# For Const and Generator nodes the _proxied attribute is set during bootstrapping
124+
# For Const, Generator, and UnionType nodes the _proxied attribute
125+
# is set during bootstrapping
125126
# as we first need to build the ClassDef that they can proxy.
126127
# Thus, if proxied is None self should be a Const or Generator
127128
# as that is the only way _proxied will be correctly set as a ClassDef.
128-
assert isinstance(self, (nodes.Const, Generator))
129+
assert isinstance(self, (nodes.Const, Generator, UnionType))
129130
else:
130131
self._proxied = proxied
131132

@@ -669,3 +670,41 @@ def __repr__(self) -> str:
669670

670671
def __str__(self) -> str:
671672
return f"AsyncGenerator({self._proxied.name})"
673+
674+
675+
class UnionType(BaseInstance):
676+
"""Special node representing new style typing unions.
677+
678+
Proxied class is set once for all in raw_building.
679+
"""
680+
681+
_proxied: nodes.ClassDef
682+
683+
def __init__(
684+
self,
685+
left: UnionType | nodes.ClassDef | nodes.Const,
686+
right: UnionType | nodes.ClassDef | nodes.Const,
687+
parent: nodes.NodeNG | None = None,
688+
) -> None:
689+
super().__init__()
690+
self.parent = parent
691+
self.left = left
692+
self.right = right
693+
694+
def callable(self) -> Literal[False]:
695+
return False
696+
697+
def bool_value(self, context: InferenceContext | None = None) -> Literal[True]:
698+
return True
699+
700+
def pytype(self) -> Literal["types.UnionType"]:
701+
return "types.UnionType"
702+
703+
def display_type(self) -> str:
704+
return "UnionType"
705+
706+
def __repr__(self) -> str:
707+
return f"<UnionType({self._proxied.name}) l.{self.lineno} at 0x{id(self)}>"
708+
709+
def __str__(self) -> str:
710+
return f"UnionType({self._proxied.name})"

astroid/inference.py

+25
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
1616

1717
from astroid import bases, constraint, decorators, helpers, nodes, protocols, util
18+
from astroid.const import PY310_PLUS
1819
from astroid.context import (
1920
CallContext,
2021
InferenceContext,
@@ -758,6 +759,14 @@ def _bin_op(
758759
)
759760

760761

762+
def _bin_op_or_union_type(
763+
left: bases.UnionType | nodes.ClassDef | nodes.Const,
764+
right: bases.UnionType | nodes.ClassDef | nodes.Const,
765+
) -> Generator[InferenceResult, None, None]:
766+
"""Create a new UnionType instance for binary or, e.g. int | str."""
767+
yield bases.UnionType(left, right)
768+
769+
761770
def _get_binop_contexts(context, left, right):
762771
"""Get contexts for binary operations.
763772
@@ -817,6 +826,22 @@ def _get_binop_flow(
817826
_bin_op(left, binary_opnode, op, right, context),
818827
_bin_op(right, binary_opnode, op, left, reverse_context, reverse=True),
819828
]
829+
830+
if (
831+
PY310_PLUS
832+
and op == "|"
833+
and (
834+
isinstance(left, (bases.UnionType, nodes.ClassDef))
835+
or isinstance(left, nodes.Const)
836+
and left.value is None
837+
)
838+
and (
839+
isinstance(right, (bases.UnionType, nodes.ClassDef))
840+
or isinstance(right, nodes.Const)
841+
and right.value is None
842+
)
843+
):
844+
methods.extend([functools.partial(_bin_op_or_union_type, left, right)])
820845
return methods
821846

822847

astroid/raw_building.py

+18
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,24 @@ def _astroid_bootstrapping() -> None:
575575
)
576576
bases.AsyncGenerator._proxied = _AsyncGeneratorType
577577
builder.object_build(bases.AsyncGenerator._proxied, types.AsyncGeneratorType)
578+
579+
if hasattr(types, "UnionType"):
580+
_UnionTypeType = nodes.ClassDef(types.UnionType.__name__)
581+
_UnionTypeType.parent = astroid_builtin
582+
union_type_doc_node = (
583+
nodes.Const(value=types.UnionType.__doc__)
584+
if types.UnionType.__doc__
585+
else None
586+
)
587+
_UnionTypeType.postinit(
588+
bases=[],
589+
body=[],
590+
decorators=None,
591+
doc_node=union_type_doc_node,
592+
)
593+
bases.UnionType._proxied = _UnionTypeType
594+
builder.object_build(bases.UnionType._proxied, types.UnionType)
595+
578596
builtin_types = (
579597
types.GetSetDescriptorType,
580598
types.GeneratorType,

tests/unittest_inference.py

+116-2
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
from astroid import decorators as decoratorsmod
2222
from astroid import helpers, nodes, objects, test_utils, util
2323
from astroid.arguments import CallSite
24-
from astroid.bases import BoundMethod, Instance, UnboundMethod
24+
from astroid.bases import BoundMethod, Instance, UnboundMethod, UnionType
2525
from astroid.builder import AstroidBuilder, _extract_single_node, extract_node, parse
26-
from astroid.const import IS_PYPY, PY38_PLUS, PY39_PLUS
26+
from astroid.const import IS_PYPY, PY38_PLUS, PY39_PLUS, PY310_PLUS
2727
from astroid.context import InferenceContext
2828
from astroid.exceptions import (
2929
AstroidTypeError,
@@ -1209,6 +1209,120 @@ def randint(maximum):
12091209
],
12101210
)
12111211

1212+
def test_binary_op_or_union_type(self) -> None:
1213+
"""Binary or union is only defined for Python 3.10+."""
1214+
code = """
1215+
class A: ...
1216+
1217+
int | 2 #@
1218+
int | "Hello" #@
1219+
int | ... #@
1220+
int | A() #@
1221+
int | None | 2 #@
1222+
"""
1223+
ast_nodes = extract_node(code)
1224+
for n in ast_nodes:
1225+
assert n.inferred() == [util.Uninferable]
1226+
1227+
code = """
1228+
from typing import List
1229+
1230+
class A: ...
1231+
class B: ...
1232+
1233+
int | None #@
1234+
int | str #@
1235+
int | str | None #@
1236+
A | B #@
1237+
A | None #@
1238+
List[int] | int #@
1239+
tuple | int #@
1240+
"""
1241+
ast_nodes = extract_node(code)
1242+
if not PY310_PLUS:
1243+
for n in ast_nodes:
1244+
assert n.inferred() == [util.Uninferable]
1245+
else:
1246+
i0 = ast_nodes[0].inferred()[0]
1247+
assert isinstance(i0, UnionType)
1248+
assert isinstance(i0.left, nodes.ClassDef)
1249+
assert i0.left.name == "int"
1250+
assert isinstance(i0.right, nodes.Const)
1251+
assert i0.right.value is None
1252+
1253+
# Assert basic UnionType properties and methods
1254+
assert i0.callable() is False
1255+
assert i0.bool_value() is True
1256+
assert i0.pytype() == "types.UnionType"
1257+
assert i0.display_type() == "UnionType"
1258+
assert str(i0) == "UnionType(UnionType)"
1259+
assert repr(i0) == f"<UnionType(UnionType) l.None at 0x{id(i0)}>"
1260+
1261+
i1 = ast_nodes[1].inferred()[0]
1262+
assert isinstance(i1, UnionType)
1263+
1264+
i2 = ast_nodes[2].inferred()[0]
1265+
assert isinstance(i2, UnionType)
1266+
assert isinstance(i2.left, UnionType)
1267+
assert isinstance(i2.left.left, nodes.ClassDef)
1268+
assert i2.left.left.name == "int"
1269+
assert isinstance(i2.left.right, nodes.ClassDef)
1270+
assert i2.left.right.name == "str"
1271+
assert isinstance(i2.right, nodes.Const)
1272+
assert i2.right.value is None
1273+
1274+
i3 = ast_nodes[3].inferred()[0]
1275+
assert isinstance(i3, UnionType)
1276+
assert isinstance(i3.left, nodes.ClassDef)
1277+
assert i3.left.name == "A"
1278+
assert isinstance(i3.right, nodes.ClassDef)
1279+
assert i3.right.name == "B"
1280+
1281+
i4 = ast_nodes[4].inferred()[0]
1282+
assert isinstance(i4, UnionType)
1283+
1284+
i5 = ast_nodes[5].inferred()[0]
1285+
assert isinstance(i5, UnionType)
1286+
assert isinstance(i5.left, nodes.ClassDef)
1287+
assert i5.left.name == "List"
1288+
1289+
i6 = ast_nodes[6].inferred()[0]
1290+
assert isinstance(i6, UnionType)
1291+
assert isinstance(i6.left, nodes.ClassDef)
1292+
assert i6.left.name == "tuple"
1293+
1294+
code = """
1295+
from typing import List
1296+
1297+
Alias1 = List[int]
1298+
Alias2 = str | int
1299+
1300+
Alias1 | int #@
1301+
Alias2 | int #@
1302+
Alias1 | Alias2 #@
1303+
"""
1304+
ast_nodes = extract_node(code)
1305+
if not PY310_PLUS:
1306+
for n in ast_nodes:
1307+
assert n.inferred() == [util.Uninferable]
1308+
else:
1309+
i0 = ast_nodes[0].inferred()[0]
1310+
assert isinstance(i0, UnionType)
1311+
assert isinstance(i0.left, nodes.ClassDef)
1312+
assert i0.left.name == "List"
1313+
1314+
i1 = ast_nodes[1].inferred()[0]
1315+
assert isinstance(i1, UnionType)
1316+
assert isinstance(i1.left, UnionType)
1317+
assert isinstance(i1.left.left, nodes.ClassDef)
1318+
assert i1.left.left.name == "str"
1319+
1320+
i2 = ast_nodes[2].inferred()[0]
1321+
assert isinstance(i2, UnionType)
1322+
assert isinstance(i2.left, nodes.ClassDef)
1323+
assert i2.left.name == "List"
1324+
assert isinstance(i2.right, UnionType)
1325+
12121326
def test_nonregr_lambda_arg(self) -> None:
12131327
code = """
12141328
def f(g = lambda: None):

0 commit comments

Comments
 (0)