Skip to content

Commit e8de6d1

Browse files
authored
Add support for exception groups and except* (#14020)
Ref #12840 It looks like from the point of view of type checking support is quite easy. Mypyc support however requires some actual work, so I don't include it in this PR.
1 parent 807da26 commit e8de6d1

File tree

9 files changed

+104
-15
lines changed

9 files changed

+104
-15
lines changed

mypy/checker.py

+31-6
Original file line numberDiff line numberDiff line change
@@ -4305,7 +4305,7 @@ def visit_try_without_finally(self, s: TryStmt, try_frame: bool) -> None:
43054305
with self.binder.frame_context(can_skip=True, fall_through=4):
43064306
typ = s.types[i]
43074307
if typ:
4308-
t = self.check_except_handler_test(typ)
4308+
t = self.check_except_handler_test(typ, s.is_star)
43094309
var = s.vars[i]
43104310
if var:
43114311
# To support local variables, we make this a definition line,
@@ -4325,7 +4325,7 @@ def visit_try_without_finally(self, s: TryStmt, try_frame: bool) -> None:
43254325
if s.else_body:
43264326
self.accept(s.else_body)
43274327

4328-
def check_except_handler_test(self, n: Expression) -> Type:
4328+
def check_except_handler_test(self, n: Expression, is_star: bool) -> Type:
43294329
"""Type check an exception handler test clause."""
43304330
typ = self.expr_checker.accept(n)
43314331

@@ -4341,22 +4341,47 @@ def check_except_handler_test(self, n: Expression) -> Type:
43414341
item = ttype.items[0]
43424342
if not item.is_type_obj():
43434343
self.fail(message_registry.INVALID_EXCEPTION_TYPE, n)
4344-
return AnyType(TypeOfAny.from_error)
4345-
exc_type = item.ret_type
4344+
return self.default_exception_type(is_star)
4345+
exc_type = erase_typevars(item.ret_type)
43464346
elif isinstance(ttype, TypeType):
43474347
exc_type = ttype.item
43484348
else:
43494349
self.fail(message_registry.INVALID_EXCEPTION_TYPE, n)
4350-
return AnyType(TypeOfAny.from_error)
4350+
return self.default_exception_type(is_star)
43514351

43524352
if not is_subtype(exc_type, self.named_type("builtins.BaseException")):
43534353
self.fail(message_registry.INVALID_EXCEPTION_TYPE, n)
4354-
return AnyType(TypeOfAny.from_error)
4354+
return self.default_exception_type(is_star)
43554355

43564356
all_types.append(exc_type)
43574357

4358+
if is_star:
4359+
new_all_types: list[Type] = []
4360+
for typ in all_types:
4361+
if is_proper_subtype(typ, self.named_type("builtins.BaseExceptionGroup")):
4362+
self.fail(message_registry.INVALID_EXCEPTION_GROUP, n)
4363+
new_all_types.append(AnyType(TypeOfAny.from_error))
4364+
else:
4365+
new_all_types.append(typ)
4366+
return self.wrap_exception_group(new_all_types)
43584367
return make_simplified_union(all_types)
43594368

4369+
def default_exception_type(self, is_star: bool) -> Type:
4370+
"""Exception type to return in case of a previous type error."""
4371+
any_type = AnyType(TypeOfAny.from_error)
4372+
if is_star:
4373+
return self.named_generic_type("builtins.ExceptionGroup", [any_type])
4374+
return any_type
4375+
4376+
def wrap_exception_group(self, types: Sequence[Type]) -> Type:
4377+
"""Transform except* variable type into an appropriate exception group."""
4378+
arg = make_simplified_union(types)
4379+
if is_subtype(arg, self.named_type("builtins.Exception")):
4380+
base = "builtins.ExceptionGroup"
4381+
else:
4382+
base = "builtins.BaseExceptionGroup"
4383+
return self.named_generic_type(base, [arg])
4384+
43604385
def get_types_from_except_handler(self, typ: Type, n: Expression) -> list[Type]:
43614386
"""Helper for check_except_handler_test to retrieve handler types."""
43624387
typ = get_proper_type(typ)

mypy/fastparse.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1254,7 +1254,6 @@ def visit_Try(self, n: ast3.Try) -> TryStmt:
12541254
return self.set_line(node, n)
12551255

12561256
def visit_TryStar(self, n: TryStar) -> TryStmt:
1257-
# TODO: we treat TryStar exactly like Try, which makes mypy not crash. See #12840
12581257
vs = [
12591258
self.set_line(NameExpr(h.name), h) if h.name is not None else None for h in n.handlers
12601259
]
@@ -1269,6 +1268,7 @@ def visit_TryStar(self, n: TryStar) -> TryStmt:
12691268
self.as_block(n.orelse, n.lineno),
12701269
self.as_block(n.finalbody, n.lineno),
12711270
)
1271+
node.is_star = True
12721272
return self.set_line(node, n)
12731273

12741274
# Assert(expr test, expr? msg)

mypy/message_registry.py

+3
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
4444
NO_RETURN_EXPECTED: Final = ErrorMessage("Return statement in function which does not return")
4545
INVALID_EXCEPTION: Final = ErrorMessage("Exception must be derived from BaseException")
4646
INVALID_EXCEPTION_TYPE: Final = ErrorMessage("Exception type must be derived from BaseException")
47+
INVALID_EXCEPTION_GROUP: Final = ErrorMessage(
48+
"Exception type in except* cannot derive from BaseExceptionGroup"
49+
)
4750
RETURN_IN_ASYNC_GENERATOR: Final = ErrorMessage(
4851
'"return" with value in async generator is not allowed'
4952
)

mypy/nodes.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1529,9 +1529,9 @@ def accept(self, visitor: StatementVisitor[T]) -> T:
15291529

15301530

15311531
class TryStmt(Statement):
1532-
__slots__ = ("body", "types", "vars", "handlers", "else_body", "finally_body")
1532+
__slots__ = ("body", "types", "vars", "handlers", "else_body", "finally_body", "is_star")
15331533

1534-
__match_args__ = ("body", "types", "vars", "handlers", "else_body", "finally_body")
1534+
__match_args__ = ("body", "types", "vars", "handlers", "else_body", "finally_body", "is_star")
15351535

15361536
body: Block # Try body
15371537
# Plain 'except:' also possible
@@ -1540,6 +1540,8 @@ class TryStmt(Statement):
15401540
handlers: list[Block] # Except bodies
15411541
else_body: Block | None
15421542
finally_body: Block | None
1543+
# Whether this is try ... except* (added in Python 3.11)
1544+
is_star: bool
15431545

15441546
def __init__(
15451547
self,
@@ -1557,6 +1559,7 @@ def __init__(
15571559
self.handlers = handlers
15581560
self.else_body = else_body
15591561
self.finally_body = finally_body
1562+
self.is_star = False
15601563

15611564
def accept(self, visitor: StatementVisitor[T]) -> T:
15621565
return visitor.visit_try_stmt(self)

mypy/strconv.py

+2
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,8 @@ def visit_del_stmt(self, o: mypy.nodes.DelStmt) -> str:
276276

277277
def visit_try_stmt(self, o: mypy.nodes.TryStmt) -> str:
278278
a: list[Any] = [o.body]
279+
if o.is_star:
280+
a.append("*")
279281

280282
for i in range(len(o.vars)):
281283
a.append(o.types[i])

mypy/treetransform.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -373,14 +373,16 @@ def visit_raise_stmt(self, node: RaiseStmt) -> RaiseStmt:
373373
return RaiseStmt(self.optional_expr(node.expr), self.optional_expr(node.from_expr))
374374

375375
def visit_try_stmt(self, node: TryStmt) -> TryStmt:
376-
return TryStmt(
376+
new = TryStmt(
377377
self.block(node.body),
378378
self.optional_names(node.vars),
379379
self.optional_expressions(node.types),
380380
self.blocks(node.handlers),
381381
self.optional_block(node.else_body),
382382
self.optional_block(node.finally_body),
383383
)
384+
new.is_star = node.is_star
385+
return new
384386

385387
def visit_with_stmt(self, node: WithStmt) -> WithStmt:
386388
new = WithStmt(

mypyc/irbuild/statement.py

+2
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,8 @@ def transform_try_stmt(builder: IRBuilder, t: TryStmt) -> None:
616616
# constructs that we compile separately. When we have a
617617
# try/except/else/finally, we treat the try/except/else as the
618618
# body of a try/finally block.
619+
if t.is_star:
620+
builder.error("Exception groups and except* cannot be compiled yet", t.line)
619621
if t.finally_body:
620622

621623
def transform_try_body() -> None:

test-data/unit/check-python311.test

+49-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,53 @@
1-
[case testTryStarDoesNotCrash]
1+
[case testTryStarSimple]
22
try:
33
pass
44
except* Exception as e:
5-
reveal_type(e) # N: Revealed type is "builtins.Exception"
5+
reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[builtins.Exception]"
6+
[builtins fixtures/exception.pyi]
7+
8+
[case testTryStarMultiple]
9+
try:
10+
pass
11+
except* Exception as e:
12+
reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[builtins.Exception]"
13+
except* RuntimeError as e:
14+
reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[builtins.RuntimeError]"
15+
[builtins fixtures/exception.pyi]
16+
17+
[case testTryStarBase]
18+
try:
19+
pass
20+
except* BaseException as e:
21+
reveal_type(e) # N: Revealed type is "builtins.BaseExceptionGroup[builtins.BaseException]"
22+
[builtins fixtures/exception.pyi]
23+
24+
[case testTryStarTuple]
25+
class Custom(Exception): ...
26+
27+
try:
28+
pass
29+
except* (RuntimeError, Custom) as e:
30+
reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[Union[builtins.RuntimeError, __main__.Custom]]"
31+
[builtins fixtures/exception.pyi]
32+
33+
[case testTryStarInvalidType]
34+
class Bad: ...
35+
try:
36+
pass
37+
except* (RuntimeError, Bad) as e: # E: Exception type must be derived from BaseException
38+
reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[Any]"
39+
[builtins fixtures/exception.pyi]
40+
41+
[case testTryStarGroupInvalid]
42+
try:
43+
pass
44+
except* ExceptionGroup as e: # E: Exception type in except* cannot derive from BaseExceptionGroup
45+
reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[Any]"
46+
[builtins fixtures/exception.pyi]
47+
48+
[case testTryStarGroupInvalidTuple]
49+
try:
50+
pass
51+
except* (RuntimeError, ExceptionGroup) as e: # E: Exception type in except* cannot derive from BaseExceptionGroup
52+
reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[Union[builtins.RuntimeError, Any]]"
653
[builtins fixtures/exception.pyi]

test-data/unit/fixtures/exception.pyi

+8-3
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,28 @@
1+
import sys
12
from typing import Generic, TypeVar
23
T = TypeVar('T')
34

45
class object:
56
def __init__(self): pass
67

78
class type: pass
8-
class tuple(Generic[T]): pass
9+
class tuple(Generic[T]):
10+
def __ge__(self, other: object) -> bool: ...
911
class function: pass
1012
class int: pass
1113
class str: pass
1214
class unicode: pass
1315
class bool: pass
1416
class ellipsis: pass
1517

16-
# Note: this is a slight simplification. In Python 2, the inheritance hierarchy
17-
# is actually Exception -> StandardError -> RuntimeError -> ...
1818
class BaseException:
1919
def __init__(self, *args: object) -> None: ...
2020
class Exception(BaseException): pass
2121
class RuntimeError(Exception): pass
2222
class NotImplementedError(RuntimeError): pass
2323

24+
if sys.version_info >= (3, 11):
25+
_BT_co = TypeVar("_BT_co", bound=BaseException, covariant=True)
26+
_T_co = TypeVar("_T_co", bound=Exception, covariant=True)
27+
class BaseExceptionGroup(BaseException, Generic[_BT_co]): ...
28+
class ExceptionGroup(BaseExceptionGroup[_T_co], Exception): ...

0 commit comments

Comments
 (0)