@@ -3062,8 +3062,8 @@ class ASTOptimiziationTests(unittest.TestCase):
3062
3062
def wrap_expr (self , expr ):
3063
3063
return ast .Module (body = [ast .Expr (value = expr )])
3064
3064
3065
- def wrap_for (self , for_statement ):
3066
- return ast .Module (body = [for_statement ])
3065
+ def wrap_statement (self , statement ):
3066
+ return ast .Module (body = [statement ])
3067
3067
3068
3068
def assert_ast (self , code , non_optimized_target , optimized_target ):
3069
3069
non_optimized_tree = ast .parse (code , optimize = - 1 )
@@ -3090,16 +3090,16 @@ def assert_ast(self, code, non_optimized_target, optimized_target):
3090
3090
f"{ ast .dump (optimized_tree )} " ,
3091
3091
)
3092
3092
3093
+ def create_binop (self , operand , left = ast .Constant (1 ), right = ast .Constant (1 )):
3094
+ return ast .BinOp (left = left , op = self .binop [operand ], right = right )
3095
+
3093
3096
def test_folding_binop (self ):
3094
3097
code = "1 %s 1"
3095
3098
operators = self .binop .keys ()
3096
3099
3097
- def create_binop (operand , left = ast .Constant (1 ), right = ast .Constant (1 )):
3098
- return ast .BinOp (left = left , op = self .binop [operand ], right = right )
3099
-
3100
3100
for op in operators :
3101
3101
result_code = code % op
3102
- non_optimized_target = self .wrap_expr (create_binop (op ))
3102
+ non_optimized_target = self .wrap_expr (self . create_binop (op ))
3103
3103
optimized_target = self .wrap_expr (ast .Constant (value = eval (result_code )))
3104
3104
3105
3105
with self .subTest (
@@ -3111,7 +3111,7 @@ def create_binop(operand, left=ast.Constant(1), right=ast.Constant(1)):
3111
3111
3112
3112
# Multiplication of constant tuples must be folded
3113
3113
code = "(1,) * 3"
3114
- non_optimized_target = self .wrap_expr (create_binop ("*" , ast .Tuple (elts = [ast .Constant (value = 1 )]), ast .Constant (value = 3 )))
3114
+ non_optimized_target = self .wrap_expr (self . create_binop ("*" , ast .Tuple (elts = [ast .Constant (value = 1 )]), ast .Constant (value = 3 )))
3115
3115
optimized_target = self .wrap_expr (ast .Constant (eval (code )))
3116
3116
3117
3117
self .assert_ast (code , non_optimized_target , optimized_target )
@@ -3222,12 +3222,12 @@ def test_folding_iter(self):
3222
3222
]
3223
3223
3224
3224
for left , right , ast_cls , optimized_iter in braces :
3225
- non_optimized_target = self .wrap_for (ast .For (
3225
+ non_optimized_target = self .wrap_statement (ast .For (
3226
3226
target = ast .Name (id = "_" , ctx = ast .Store ()),
3227
3227
iter = ast_cls (elts = [ast .Constant (1 )]),
3228
3228
body = [ast .Pass ()]
3229
3229
))
3230
- optimized_target = self .wrap_for (ast .For (
3230
+ optimized_target = self .wrap_statement (ast .For (
3231
3231
target = ast .Name (id = "_" , ctx = ast .Store ()),
3232
3232
iter = ast .Constant (value = optimized_iter ),
3233
3233
body = [ast .Pass ()]
@@ -3245,6 +3245,92 @@ def test_folding_subscript(self):
3245
3245
3246
3246
self .assert_ast (code , non_optimized_target , optimized_target )
3247
3247
3248
+ def test_folding_type_param_in_function_def (self ):
3249
+ code = "def foo[%s = 1 + 1](): pass"
3250
+
3251
+ unoptimized_binop = self .create_binop ("+" )
3252
+ unoptimized_type_params = [
3253
+ ("T" , "T" , ast .TypeVar ),
3254
+ ("**P" , "P" , ast .ParamSpec ),
3255
+ ("*Ts" , "Ts" , ast .TypeVarTuple ),
3256
+ ]
3257
+
3258
+ for type , name , type_param in unoptimized_type_params :
3259
+ result_code = code % type
3260
+ optimized_target = self .wrap_statement (
3261
+ ast .FunctionDef (
3262
+ name = 'foo' ,
3263
+ args = ast .arguments (),
3264
+ body = [ast .Pass ()],
3265
+ type_params = [type_param (name = name , default_value = ast .Constant (2 ))]
3266
+ )
3267
+ )
3268
+ non_optimized_target = self .wrap_statement (
3269
+ ast .FunctionDef (
3270
+ name = 'foo' ,
3271
+ args = ast .arguments (),
3272
+ body = [ast .Pass ()],
3273
+ type_params = [type_param (name = name , default_value = unoptimized_binop )]
3274
+ )
3275
+ )
3276
+ self .assert_ast (result_code , non_optimized_target , optimized_target )
3277
+
3278
+ def test_folding_type_param_in_class_def (self ):
3279
+ code = "class foo[%s = 1 + 1]: pass"
3280
+
3281
+ unoptimized_binop = self .create_binop ("+" )
3282
+ unoptimized_type_params = [
3283
+ ("T" , "T" , ast .TypeVar ),
3284
+ ("**P" , "P" , ast .ParamSpec ),
3285
+ ("*Ts" , "Ts" , ast .TypeVarTuple ),
3286
+ ]
3287
+
3288
+ for type , name , type_param in unoptimized_type_params :
3289
+ result_code = code % type
3290
+ optimized_target = self .wrap_statement (
3291
+ ast .ClassDef (
3292
+ name = 'foo' ,
3293
+ body = [ast .Pass ()],
3294
+ type_params = [type_param (name = name , default_value = ast .Constant (2 ))]
3295
+ )
3296
+ )
3297
+ non_optimized_target = self .wrap_statement (
3298
+ ast .ClassDef (
3299
+ name = 'foo' ,
3300
+ body = [ast .Pass ()],
3301
+ type_params = [type_param (name = name , default_value = unoptimized_binop )]
3302
+ )
3303
+ )
3304
+ self .assert_ast (result_code , non_optimized_target , optimized_target )
3305
+
3306
+ def test_folding_type_param_in_type_alias (self ):
3307
+ code = "type foo[%s = 1 + 1] = 1"
3308
+
3309
+ unoptimized_binop = self .create_binop ("+" )
3310
+ unoptimized_type_params = [
3311
+ ("T" , "T" , ast .TypeVar ),
3312
+ ("**P" , "P" , ast .ParamSpec ),
3313
+ ("*Ts" , "Ts" , ast .TypeVarTuple ),
3314
+ ]
3315
+
3316
+ for type , name , type_param in unoptimized_type_params :
3317
+ result_code = code % type
3318
+ optimized_target = self .wrap_statement (
3319
+ ast .TypeAlias (
3320
+ name = ast .Name (id = 'foo' , ctx = ast .Store ()),
3321
+ type_params = [type_param (name = name , default_value = ast .Constant (2 ))],
3322
+ value = ast .Constant (value = 1 ),
3323
+ )
3324
+ )
3325
+ non_optimized_target = self .wrap_statement (
3326
+ ast .TypeAlias (
3327
+ name = ast .Name (id = 'foo' , ctx = ast .Store ()),
3328
+ type_params = [type_param (name = name , default_value = unoptimized_binop )],
3329
+ value = ast .Constant (value = 1 ),
3330
+ )
3331
+ )
3332
+ self .assert_ast (result_code , non_optimized_target , optimized_target )
3333
+
3248
3334
3249
3335
if __name__ == "__main__" :
3250
3336
unittest .main ()
0 commit comments