@@ -3201,8 +3201,8 @@ class ASTOptimiziationTests(unittest.TestCase):
3201
3201
def wrap_expr (self , expr ):
3202
3202
return ast .Module (body = [ast .Expr (value = expr )])
3203
3203
3204
- def wrap_for (self , for_statement ):
3205
- return ast .Module (body = [for_statement ])
3204
+ def wrap_statement (self , statement ):
3205
+ return ast .Module (body = [statement ])
3206
3206
3207
3207
def assert_ast (self , code , non_optimized_target , optimized_target ):
3208
3208
@@ -3230,16 +3230,16 @@ def assert_ast(self, code, non_optimized_target, optimized_target):
3230
3230
f"{ ast .dump (optimized_tree )} " ,
3231
3231
)
3232
3232
3233
+ def create_binop (self , operand , left = ast .Constant (1 ), right = ast .Constant (1 )):
3234
+ return ast .BinOp (left = left , op = self .binop [operand ], right = right )
3235
+
3233
3236
def test_folding_binop (self ):
3234
3237
code = "1 %s 1"
3235
3238
operators = self .binop .keys ()
3236
3239
3237
- def create_binop (operand , left = ast .Constant (1 ), right = ast .Constant (1 )):
3238
- return ast .BinOp (left = left , op = self .binop [operand ], right = right )
3239
-
3240
3240
for op in operators :
3241
3241
result_code = code % op
3242
- non_optimized_target = self .wrap_expr (create_binop (op ))
3242
+ non_optimized_target = self .wrap_expr (self . create_binop (op ))
3243
3243
optimized_target = self .wrap_expr (ast .Constant (value = eval (result_code )))
3244
3244
3245
3245
with self .subTest (
@@ -3251,7 +3251,7 @@ def create_binop(operand, left=ast.Constant(1), right=ast.Constant(1)):
3251
3251
3252
3252
# Multiplication of constant tuples must be folded
3253
3253
code = "(1,) * 3"
3254
- non_optimized_target = self .wrap_expr (create_binop ("*" , ast .Tuple (elts = [ast .Constant (value = 1 )]), ast .Constant (value = 3 )))
3254
+ non_optimized_target = self .wrap_expr (self . create_binop ("*" , ast .Tuple (elts = [ast .Constant (value = 1 )]), ast .Constant (value = 3 )))
3255
3255
optimized_target = self .wrap_expr (ast .Constant (eval (code )))
3256
3256
3257
3257
self .assert_ast (code , non_optimized_target , optimized_target )
@@ -3362,12 +3362,12 @@ def test_folding_iter(self):
3362
3362
]
3363
3363
3364
3364
for left , right , ast_cls , optimized_iter in braces :
3365
- non_optimized_target = self .wrap_for (ast .For (
3365
+ non_optimized_target = self .wrap_statement (ast .For (
3366
3366
target = ast .Name (id = "_" , ctx = ast .Store ()),
3367
3367
iter = ast_cls (elts = [ast .Constant (1 )]),
3368
3368
body = [ast .Pass ()]
3369
3369
))
3370
- optimized_target = self .wrap_for (ast .For (
3370
+ optimized_target = self .wrap_statement (ast .For (
3371
3371
target = ast .Name (id = "_" , ctx = ast .Store ()),
3372
3372
iter = ast .Constant (value = optimized_iter ),
3373
3373
body = [ast .Pass ()]
@@ -3385,6 +3385,92 @@ def test_folding_subscript(self):
3385
3385
3386
3386
self .assert_ast (code , non_optimized_target , optimized_target )
3387
3387
3388
+ def test_folding_type_param_in_function_def (self ):
3389
+ code = "def foo[%s = 1 + 1](): pass"
3390
+
3391
+ unoptimized_binop = self .create_binop ("+" )
3392
+ unoptimized_type_params = [
3393
+ ("T" , "T" , ast .TypeVar ),
3394
+ ("**P" , "P" , ast .ParamSpec ),
3395
+ ("*Ts" , "Ts" , ast .TypeVarTuple ),
3396
+ ]
3397
+
3398
+ for type , name , type_param in unoptimized_type_params :
3399
+ result_code = code % type
3400
+ optimized_target = self .wrap_statement (
3401
+ ast .FunctionDef (
3402
+ name = 'foo' ,
3403
+ args = ast .arguments (),
3404
+ body = [ast .Pass ()],
3405
+ type_params = [type_param (name = name , default_value = ast .Constant (2 ))]
3406
+ )
3407
+ )
3408
+ non_optimized_target = self .wrap_statement (
3409
+ ast .FunctionDef (
3410
+ name = 'foo' ,
3411
+ args = ast .arguments (),
3412
+ body = [ast .Pass ()],
3413
+ type_params = [type_param (name = name , default_value = unoptimized_binop )]
3414
+ )
3415
+ )
3416
+ self .assert_ast (result_code , non_optimized_target , optimized_target )
3417
+
3418
+ def test_folding_type_param_in_class_def (self ):
3419
+ code = "class foo[%s = 1 + 1]: pass"
3420
+
3421
+ unoptimized_binop = self .create_binop ("+" )
3422
+ unoptimized_type_params = [
3423
+ ("T" , "T" , ast .TypeVar ),
3424
+ ("**P" , "P" , ast .ParamSpec ),
3425
+ ("*Ts" , "Ts" , ast .TypeVarTuple ),
3426
+ ]
3427
+
3428
+ for type , name , type_param in unoptimized_type_params :
3429
+ result_code = code % type
3430
+ optimized_target = self .wrap_statement (
3431
+ ast .ClassDef (
3432
+ name = 'foo' ,
3433
+ body = [ast .Pass ()],
3434
+ type_params = [type_param (name = name , default_value = ast .Constant (2 ))]
3435
+ )
3436
+ )
3437
+ non_optimized_target = self .wrap_statement (
3438
+ ast .ClassDef (
3439
+ name = 'foo' ,
3440
+ body = [ast .Pass ()],
3441
+ type_params = [type_param (name = name , default_value = unoptimized_binop )]
3442
+ )
3443
+ )
3444
+ self .assert_ast (result_code , non_optimized_target , optimized_target )
3445
+
3446
+ def test_folding_type_param_in_type_alias (self ):
3447
+ code = "type foo[%s = 1 + 1] = 1"
3448
+
3449
+ unoptimized_binop = self .create_binop ("+" )
3450
+ unoptimized_type_params = [
3451
+ ("T" , "T" , ast .TypeVar ),
3452
+ ("**P" , "P" , ast .ParamSpec ),
3453
+ ("*Ts" , "Ts" , ast .TypeVarTuple ),
3454
+ ]
3455
+
3456
+ for type , name , type_param in unoptimized_type_params :
3457
+ result_code = code % type
3458
+ optimized_target = self .wrap_statement (
3459
+ ast .TypeAlias (
3460
+ name = ast .Name (id = 'foo' , ctx = ast .Store ()),
3461
+ type_params = [type_param (name = name , default_value = ast .Constant (2 ))],
3462
+ value = ast .Constant (value = 1 ),
3463
+ )
3464
+ )
3465
+ non_optimized_target = self .wrap_statement (
3466
+ ast .TypeAlias (
3467
+ name = ast .Name (id = 'foo' , ctx = ast .Store ()),
3468
+ type_params = [type_param (name = name , default_value = unoptimized_binop )],
3469
+ value = ast .Constant (value = 1 ),
3470
+ )
3471
+ )
3472
+ self .assert_ast (result_code , non_optimized_target , optimized_target )
3473
+
3388
3474
3389
3475
if __name__ == "__main__" :
3390
3476
unittest .main ()
0 commit comments