@@ -4305,7 +4305,7 @@ def visit_try_without_finally(self, s: TryStmt, try_frame: bool) -> None:
4305
4305
with self .binder .frame_context (can_skip = True , fall_through = 4 ):
4306
4306
typ = s .types [i ]
4307
4307
if typ :
4308
- t = self .check_except_handler_test (typ )
4308
+ t = self .check_except_handler_test (typ , s . is_star )
4309
4309
var = s .vars [i ]
4310
4310
if var :
4311
4311
# To support local variables, we make this a definition line,
@@ -4325,7 +4325,7 @@ def visit_try_without_finally(self, s: TryStmt, try_frame: bool) -> None:
4325
4325
if s .else_body :
4326
4326
self .accept (s .else_body )
4327
4327
4328
- def check_except_handler_test (self , n : Expression ) -> Type :
4328
+ def check_except_handler_test (self , n : Expression , is_star : bool ) -> Type :
4329
4329
"""Type check an exception handler test clause."""
4330
4330
typ = self .expr_checker .accept (n )
4331
4331
@@ -4341,22 +4341,47 @@ def check_except_handler_test(self, n: Expression) -> Type:
4341
4341
item = ttype .items [0 ]
4342
4342
if not item .is_type_obj ():
4343
4343
self .fail (message_registry .INVALID_EXCEPTION_TYPE , n )
4344
- return AnyType ( TypeOfAny . from_error )
4345
- exc_type = item .ret_type
4344
+ return self . default_exception_type ( is_star )
4345
+ exc_type = erase_typevars ( item .ret_type )
4346
4346
elif isinstance (ttype , TypeType ):
4347
4347
exc_type = ttype .item
4348
4348
else :
4349
4349
self .fail (message_registry .INVALID_EXCEPTION_TYPE , n )
4350
- return AnyType ( TypeOfAny . from_error )
4350
+ return self . default_exception_type ( is_star )
4351
4351
4352
4352
if not is_subtype (exc_type , self .named_type ("builtins.BaseException" )):
4353
4353
self .fail (message_registry .INVALID_EXCEPTION_TYPE , n )
4354
- return AnyType ( TypeOfAny . from_error )
4354
+ return self . default_exception_type ( is_star )
4355
4355
4356
4356
all_types .append (exc_type )
4357
4357
4358
+ if is_star :
4359
+ new_all_types : list [Type ] = []
4360
+ for typ in all_types :
4361
+ if is_proper_subtype (typ , self .named_type ("builtins.BaseExceptionGroup" )):
4362
+ self .fail (message_registry .INVALID_EXCEPTION_GROUP , n )
4363
+ new_all_types .append (AnyType (TypeOfAny .from_error ))
4364
+ else :
4365
+ new_all_types .append (typ )
4366
+ return self .wrap_exception_group (new_all_types )
4358
4367
return make_simplified_union (all_types )
4359
4368
4369
+ def default_exception_type (self , is_star : bool ) -> Type :
4370
+ """Exception type to return in case of a previous type error."""
4371
+ any_type = AnyType (TypeOfAny .from_error )
4372
+ if is_star :
4373
+ return self .named_generic_type ("builtins.ExceptionGroup" , [any_type ])
4374
+ return any_type
4375
+
4376
+ def wrap_exception_group (self , types : Sequence [Type ]) -> Type :
4377
+ """Transform except* variable type into an appropriate exception group."""
4378
+ arg = make_simplified_union (types )
4379
+ if is_subtype (arg , self .named_type ("builtins.Exception" )):
4380
+ base = "builtins.ExceptionGroup"
4381
+ else :
4382
+ base = "builtins.BaseExceptionGroup"
4383
+ return self .named_generic_type (base , [arg ])
4384
+
4360
4385
def get_types_from_except_handler (self , typ : Type , n : Expression ) -> list [Type ]:
4361
4386
"""Helper for check_except_handler_test to retrieve handler types."""
4362
4387
typ = get_proper_type (typ )
0 commit comments