Skip to content

Commit cf25319

Browse files
Michael0x2ailevkivskyi
authored andcommitted
Add tests for literals and generics (#6035)
This pull request adds some tests verifying that basic interactions between literals and generics work as expected. It also tweaks `checkexpr` so it adds another exception for type variable return vs a Literal[...] context.
1 parent acc7740 commit cf25319

File tree

2 files changed

+249
-6
lines changed

2 files changed

+249
-6
lines changed

mypy/checkexpr.py

+22-6
Original file line numberDiff line numberDiff line change
@@ -900,17 +900,30 @@ def infer_function_type_arguments_using_context(
900900
# variables in an expression are inferred at the same time.
901901
# (And this is hard, also we need to be careful with lambdas that require
902902
# two passes.)
903-
if isinstance(ret_type, TypeVarType) and not is_generic_instance(ctx):
903+
if isinstance(ret_type, TypeVarType):
904904
# Another special case: the return type is a type variable. If it's unrestricted,
905905
# we could infer a too general type for the type variable if we use context,
906906
# and this could result in confusing and spurious type errors elsewhere.
907907
#
908-
# Give up and just use function arguments for type inference. As an exception,
909-
# if the context is a generic instance type, actually use it as context, as
910-
# this *seems* to usually be the reasonable thing to do.
908+
# So we give up and just use function arguments for type inference, with just two
909+
# exceptions:
911910
#
912-
# See also github issues #462 and #360.
913-
return callable.copy_modified()
911+
# 1. If the context is a generic instance type, actually use it as context, as
912+
# this *seems* to usually be the reasonable thing to do.
913+
#
914+
# See also github issues #462 and #360.
915+
#
916+
# 2. If the context is some literal type, we want to "propagate" that information
917+
# down so that we infer a more precise type for literal expressions. For example,
918+
# the expression `3` normally has an inferred type of `builtins.int`: but if it's
919+
# in a literal context like below, we want it to infer `Literal[3]` instead.
920+
#
921+
# def expects_literal(x: Literal[3]) -> None: pass
922+
# def identity(x: T) -> T: return x
923+
#
924+
# expects_literal(identity(3)) # Should type-check
925+
if not is_generic_instance(ctx) and not is_literal_type_like(ctx):
926+
return callable.copy_modified()
914927
args = infer_type_arguments(callable.type_var_ids(), ret_type, erased_ctx)
915928
# Only substitute non-Uninhabited and non-erased types.
916929
new_args = [] # type: List[Optional[Type]]
@@ -3638,6 +3651,9 @@ def is_literal_type_like(t: Optional[Type]) -> bool:
36383651
return True
36393652
elif isinstance(t, UnionType):
36403653
return any(is_literal_type_like(item) for item in t.items)
3654+
elif isinstance(t, TypeVarType):
3655+
return (is_literal_type_like(t.upper_bound)
3656+
or any(is_literal_type_like(item) for item in t.values))
36413657
else:
36423658
return False
36433659

test-data/unit/check-literal.test

+227
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,18 @@ b: bt # E: Invalid type "__main__.bt"
489489
[builtins fixtures/set.pyi]
490490
[out]
491491

492+
[case testLiteralDisallowTypeVar]
493+
from typing import TypeVar
494+
from typing_extensions import Literal
495+
496+
T = TypeVar('T')
497+
498+
at = Literal[T] # E: Parameter 1 of Literal[...] is invalid
499+
a: at
500+
501+
def foo(b: Literal[T]) -> T: pass # E: Parameter 1 of Literal[...] is invalid
502+
[out]
503+
492504

493505
--
494506
-- Test mixing and matching literals with other types
@@ -1348,6 +1360,221 @@ indirect.Literal()
13481360
[out]
13491361

13501362

1363+
--
1364+
-- Test to make sure literals interact with generics as expected
1365+
--
1366+
1367+
[case testLiteralAndGenericsWithSimpleFunctions]
1368+
from typing import TypeVar
1369+
from typing_extensions import Literal
1370+
1371+
T = TypeVar('T')
1372+
def foo(x: T) -> T: pass
1373+
def expects_literal(x: Literal[3]) -> None: pass
1374+
def expects_int(x: int) -> None: pass
1375+
1376+
a: Literal[3]
1377+
reveal_type(foo(3)) # E: Revealed type is 'builtins.int*'
1378+
reveal_type(foo(a)) # E: Revealed type is 'Literal[3]'
1379+
1380+
expects_literal(3)
1381+
expects_literal(foo(3))
1382+
expects_literal(foo(foo(3)))
1383+
1384+
expects_literal(a)
1385+
expects_literal(foo(a))
1386+
expects_literal(foo(foo(a)))
1387+
1388+
expects_literal(5) # E: Argument 1 to "expects_literal" has incompatible type "Literal[5]"; expected "Literal[3]"
1389+
expects_literal(foo(5)) # E: Argument 1 to "foo" has incompatible type "Literal[5]"; expected "Literal[3]"
1390+
expects_literal(foo(foo(5))) # E: Argument 1 to "foo" has incompatible type "Literal[5]"; expected "Literal[3]"
1391+
1392+
expects_int(a)
1393+
expects_int(foo(a))
1394+
expects_int(foo(foo(a)))
1395+
[out]
1396+
1397+
[case testLiteralAndGenericWithUnion]
1398+
from typing import TypeVar, Union
1399+
from typing_extensions import Literal
1400+
1401+
T = TypeVar('T')
1402+
def identity(x: T) -> T: return x
1403+
1404+
a: Union[int, Literal['foo']] = identity('foo')
1405+
b: Union[int, Literal['foo']] = identity('bar') # E: Argument 1 to "identity" has incompatible type "Literal['bar']"; expected "Union[int, Literal['foo']]"
1406+
[out]
1407+
1408+
[case testLiteralAndGenericsNoMatch]
1409+
from typing import TypeVar, Union, List
1410+
from typing_extensions import Literal
1411+
1412+
def identity(x: T) -> T:
1413+
return x
1414+
1415+
Ok1 = Union[List[int], Literal['bad']]
1416+
Ok2 = Union[List[Literal[42]], Literal['bad']]
1417+
Bad = Union[List[Literal[43]], Literal['bad']]
1418+
1419+
x: Ok1 = identity([42])
1420+
y: Ok2 = identity([42])
1421+
z: Bad = identity([42]) # E: List item 0 has incompatible type "Literal[42]"; expected "Literal[43]"
1422+
[builtins fixtures/list.pyi]
1423+
[out]
1424+
1425+
[case testLiteralAndGenericsWithSimpleClasses]
1426+
from typing import TypeVar, Generic
1427+
from typing_extensions import Literal
1428+
1429+
T = TypeVar('T')
1430+
class Wrapper(Generic[T]):
1431+
def __init__(self, val: T) -> None:
1432+
self.val = val
1433+
def inner(self) -> T:
1434+
return self.val
1435+
1436+
def expects_literal(a: Literal[3]) -> None: pass
1437+
def expects_literal_wrapper(x: Wrapper[Literal[3]]) -> None: pass
1438+
1439+
a: Literal[3]
1440+
reveal_type(Wrapper(3)) # E: Revealed type is '__main__.Wrapper[builtins.int*]'
1441+
reveal_type(Wrapper[Literal[3]](3)) # E: Revealed type is '__main__.Wrapper[Literal[3]]'
1442+
reveal_type(Wrapper(a)) # E: Revealed type is '__main__.Wrapper[Literal[3]]'
1443+
1444+
expects_literal(Wrapper(a).inner())
1445+
1446+
# Note: the following probably ought to type-check: it's reasonable to infer
1447+
# Wrapper[Literal[3]] here.
1448+
# TODO: Consider finding a way to handle this edge case better
1449+
expects_literal(Wrapper(3).inner()) # E: Argument 1 to "expects_literal" has incompatible type "int"; expected "Literal[3]"
1450+
1451+
# Note: if we handle the edge case above, we should make sure this error
1452+
# message switches to warning about an incompatible type 'Literal[5]' rather
1453+
# then an incompatible type 'int'
1454+
expects_literal(Wrapper(5).inner()) # E: Argument 1 to "expects_literal" has incompatible type "int"; expected "Literal[3]"
1455+
1456+
expects_literal_wrapper(Wrapper(a))
1457+
expects_literal_wrapper(Wrapper(3))
1458+
expects_literal_wrapper(Wrapper(5)) # E: Argument 1 to "Wrapper" has incompatible type "Literal[5]"; expected "Literal[3]"
1459+
[out]
1460+
1461+
[case testLiteralAndGenericsRespectsUpperBound]
1462+
from typing import TypeVar
1463+
from typing_extensions import Literal
1464+
1465+
TLiteral = TypeVar('TLiteral', bound=Literal[3])
1466+
TInt = TypeVar('TInt', bound=int)
1467+
1468+
def func1(x: TLiteral) -> TLiteral: pass
1469+
def func2(x: TInt) -> TInt: pass
1470+
1471+
def func3(x: TLiteral) -> TLiteral:
1472+
y = func2(x)
1473+
return y
1474+
def func4(x: TInt) -> TInt:
1475+
y = func1(x) # E: Value of type variable "TLiteral" of "func1" cannot be "TInt"
1476+
return y
1477+
1478+
a: Literal[3]
1479+
b: Literal[4]
1480+
c: int
1481+
1482+
reveal_type(func1) # E: Revealed type is 'def [TLiteral <: Literal[3]] (x: TLiteral`-1) -> TLiteral`-1'
1483+
1484+
reveal_type(func1(3)) # E: Revealed type is 'Literal[3]'
1485+
reveal_type(func1(a)) # E: Revealed type is 'Literal[3]'
1486+
reveal_type(func1(4)) # E: Revealed type is 'Literal[4]' \
1487+
# E: Value of type variable "TLiteral" of "func1" cannot be "Literal[4]"
1488+
reveal_type(func1(b)) # E: Revealed type is 'Literal[4]' \
1489+
# E: Value of type variable "TLiteral" of "func1" cannot be "Literal[4]"
1490+
reveal_type(func1(c)) # E: Revealed type is 'builtins.int*' \
1491+
# E: Value of type variable "TLiteral" of "func1" cannot be "int"
1492+
1493+
reveal_type(func2(3)) # E: Revealed type is 'builtins.int*'
1494+
reveal_type(func2(a)) # E: Revealed type is 'Literal[3]'
1495+
reveal_type(func2(4)) # E: Revealed type is 'builtins.int*'
1496+
reveal_type(func2(b)) # E: Revealed type is 'Literal[4]'
1497+
reveal_type(func2(c)) # E: Revealed type is 'builtins.int*'
1498+
[out]
1499+
1500+
[case testLiteralAndGenericsRespectsValueRestriction]
1501+
from typing import TypeVar
1502+
from typing_extensions import Literal
1503+
1504+
TLiteral = TypeVar('TLiteral', Literal[3], Literal['foo'])
1505+
TNormal = TypeVar('TNormal', int, str)
1506+
1507+
def func1(x: TLiteral) -> TLiteral: pass
1508+
def func2(x: TNormal) -> TNormal: pass
1509+
1510+
def func3(x: TLiteral) -> TLiteral:
1511+
y = func2(x)
1512+
return y # E: Incompatible return value type (got "int", expected "Literal[3]") \
1513+
# E: Incompatible return value type (got "str", expected "Literal['foo']")
1514+
def func4(x: TNormal) -> TNormal:
1515+
y = func1(x) # E: Value of type variable "TLiteral" of "func1" cannot be "int" \
1516+
# E: Value of type variable "TLiteral" of "func1" cannot be "str"
1517+
return y
1518+
1519+
i1: Literal[3]
1520+
i2: Literal[4]
1521+
i: int
1522+
1523+
s1: Literal['foo']
1524+
s2: Literal['bar']
1525+
s: str
1526+
1527+
reveal_type(func1) # E: Revealed type is 'def [TLiteral in (Literal[3], Literal['foo'])] (x: TLiteral`-1) -> TLiteral`-1'
1528+
1529+
reveal_type(func1(3)) # E: Revealed type is 'Literal[3]'
1530+
reveal_type(func1(i1)) # E: Revealed type is 'Literal[3]'
1531+
reveal_type(func1(4)) # E: Revealed type is 'Literal[4]' \
1532+
# E: Value of type variable "TLiteral" of "func1" cannot be "Literal[4]"
1533+
reveal_type(func1(i2)) # E: Revealed type is 'Literal[4]' \
1534+
# E: Value of type variable "TLiteral" of "func1" cannot be "Literal[4]"
1535+
reveal_type(func1(i)) # E: Revealed type is 'builtins.int*' \
1536+
# E: Value of type variable "TLiteral" of "func1" cannot be "int"
1537+
1538+
reveal_type(func1("foo")) # E: Revealed type is 'Literal['foo']'
1539+
reveal_type(func1(s1)) # E: Revealed type is 'Literal['foo']'
1540+
reveal_type(func1("bar")) # E: Revealed type is 'Literal['bar']' \
1541+
# E: Value of type variable "TLiteral" of "func1" cannot be "Literal['bar']"
1542+
reveal_type(func1(s2)) # E: Revealed type is 'Literal['bar']' \
1543+
# E: Value of type variable "TLiteral" of "func1" cannot be "Literal['bar']"
1544+
reveal_type(func1(s)) # E: Revealed type is 'builtins.str*' \
1545+
# E: Value of type variable "TLiteral" of "func1" cannot be "str"
1546+
1547+
reveal_type(func2(3)) # E: Revealed type is 'builtins.int*'
1548+
reveal_type(func2(i1)) # E: Revealed type is 'builtins.int*'
1549+
reveal_type(func2(4)) # E: Revealed type is 'builtins.int*'
1550+
reveal_type(func2(i2)) # E: Revealed type is 'builtins.int*'
1551+
reveal_type(func2("foo")) # E: Revealed type is 'builtins.str*'
1552+
reveal_type(func2(s1)) # E: Revealed type is 'builtins.str*'
1553+
reveal_type(func2("bar")) # E: Revealed type is 'builtins.str*'
1554+
reveal_type(func2(s2)) # E: Revealed type is 'builtins.str*'
1555+
[out]
1556+
1557+
[case testLiteralAndGenericsWithOverloads]
1558+
from typing import TypeVar, overload, Union
1559+
from typing_extensions import Literal
1560+
1561+
@overload
1562+
def func1(x: Literal[4]) -> Literal[19]: ...
1563+
@overload
1564+
def func1(x: int) -> int: ...
1565+
def func1(x: int) -> int: pass
1566+
1567+
T = TypeVar('T')
1568+
def identity(x: T) -> T: pass
1569+
1570+
a: Literal[4]
1571+
b: Literal[5]
1572+
1573+
reveal_type(func1(identity(4))) # E: Revealed type is 'Literal[19]'
1574+
reveal_type(func1(identity(5))) # E: Revealed type is 'builtins.int'
1575+
reveal_type(func1(identity(a))) # E: Revealed type is 'Literal[19]'
1576+
reveal_type(func1(identity(b))) # E: Revealed type is 'builtins.int'
1577+
13511578
--
13521579
-- Other misc interactions
13531580
--

0 commit comments

Comments
 (0)