Skip to content

Commit be083ce

Browse files
wrongnullEclips4JelleZijlstra
authored
gh-123344: Add missing ast optimizations for PEP 696 (#123377)
Co-authored-by: Kirill Podoprigora <[email protected]> Co-authored-by: Jelle Zijlstra <[email protected]>
1 parent 9e108b8 commit be083ce

File tree

3 files changed

+99
-9
lines changed

3 files changed

+99
-9
lines changed

Lib/test/test_ast/test_ast.py

Lines changed: 95 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3062,8 +3062,8 @@ class ASTOptimiziationTests(unittest.TestCase):
30623062
def wrap_expr(self, expr):
30633063
return ast.Module(body=[ast.Expr(value=expr)])
30643064

3065-
def wrap_for(self, for_statement):
3066-
return ast.Module(body=[for_statement])
3065+
def wrap_statement(self, statement):
3066+
return ast.Module(body=[statement])
30673067

30683068
def assert_ast(self, code, non_optimized_target, optimized_target):
30693069
non_optimized_tree = ast.parse(code, optimize=-1)
@@ -3090,16 +3090,16 @@ def assert_ast(self, code, non_optimized_target, optimized_target):
30903090
f"{ast.dump(optimized_tree)}",
30913091
)
30923092

3093+
def create_binop(self, operand, left=ast.Constant(1), right=ast.Constant(1)):
3094+
return ast.BinOp(left=left, op=self.binop[operand], right=right)
3095+
30933096
def test_folding_binop(self):
30943097
code = "1 %s 1"
30953098
operators = self.binop.keys()
30963099

3097-
def create_binop(operand, left=ast.Constant(1), right=ast.Constant(1)):
3098-
return ast.BinOp(left=left, op=self.binop[operand], right=right)
3099-
31003100
for op in operators:
31013101
result_code = code % op
3102-
non_optimized_target = self.wrap_expr(create_binop(op))
3102+
non_optimized_target = self.wrap_expr(self.create_binop(op))
31033103
optimized_target = self.wrap_expr(ast.Constant(value=eval(result_code)))
31043104

31053105
with self.subTest(
@@ -3111,7 +3111,7 @@ def create_binop(operand, left=ast.Constant(1), right=ast.Constant(1)):
31113111

31123112
# Multiplication of constant tuples must be folded
31133113
code = "(1,) * 3"
3114-
non_optimized_target = self.wrap_expr(create_binop("*", ast.Tuple(elts=[ast.Constant(value=1)]), ast.Constant(value=3)))
3114+
non_optimized_target = self.wrap_expr(self.create_binop("*", ast.Tuple(elts=[ast.Constant(value=1)]), ast.Constant(value=3)))
31153115
optimized_target = self.wrap_expr(ast.Constant(eval(code)))
31163116

31173117
self.assert_ast(code, non_optimized_target, optimized_target)
@@ -3222,12 +3222,12 @@ def test_folding_iter(self):
32223222
]
32233223

32243224
for left, right, ast_cls, optimized_iter in braces:
3225-
non_optimized_target = self.wrap_for(ast.For(
3225+
non_optimized_target = self.wrap_statement(ast.For(
32263226
target=ast.Name(id="_", ctx=ast.Store()),
32273227
iter=ast_cls(elts=[ast.Constant(1)]),
32283228
body=[ast.Pass()]
32293229
))
3230-
optimized_target = self.wrap_for(ast.For(
3230+
optimized_target = self.wrap_statement(ast.For(
32313231
target=ast.Name(id="_", ctx=ast.Store()),
32323232
iter=ast.Constant(value=optimized_iter),
32333233
body=[ast.Pass()]
@@ -3245,6 +3245,92 @@ def test_folding_subscript(self):
32453245

32463246
self.assert_ast(code, non_optimized_target, optimized_target)
32473247

3248+
def test_folding_type_param_in_function_def(self):
3249+
code = "def foo[%s = 1 + 1](): pass"
3250+
3251+
unoptimized_binop = self.create_binop("+")
3252+
unoptimized_type_params = [
3253+
("T", "T", ast.TypeVar),
3254+
("**P", "P", ast.ParamSpec),
3255+
("*Ts", "Ts", ast.TypeVarTuple),
3256+
]
3257+
3258+
for type, name, type_param in unoptimized_type_params:
3259+
result_code = code % type
3260+
optimized_target = self.wrap_statement(
3261+
ast.FunctionDef(
3262+
name='foo',
3263+
args=ast.arguments(),
3264+
body=[ast.Pass()],
3265+
type_params=[type_param(name=name, default_value=ast.Constant(2))]
3266+
)
3267+
)
3268+
non_optimized_target = self.wrap_statement(
3269+
ast.FunctionDef(
3270+
name='foo',
3271+
args=ast.arguments(),
3272+
body=[ast.Pass()],
3273+
type_params=[type_param(name=name, default_value=unoptimized_binop)]
3274+
)
3275+
)
3276+
self.assert_ast(result_code, non_optimized_target, optimized_target)
3277+
3278+
def test_folding_type_param_in_class_def(self):
3279+
code = "class foo[%s = 1 + 1]: pass"
3280+
3281+
unoptimized_binop = self.create_binop("+")
3282+
unoptimized_type_params = [
3283+
("T", "T", ast.TypeVar),
3284+
("**P", "P", ast.ParamSpec),
3285+
("*Ts", "Ts", ast.TypeVarTuple),
3286+
]
3287+
3288+
for type, name, type_param in unoptimized_type_params:
3289+
result_code = code % type
3290+
optimized_target = self.wrap_statement(
3291+
ast.ClassDef(
3292+
name='foo',
3293+
body=[ast.Pass()],
3294+
type_params=[type_param(name=name, default_value=ast.Constant(2))]
3295+
)
3296+
)
3297+
non_optimized_target = self.wrap_statement(
3298+
ast.ClassDef(
3299+
name='foo',
3300+
body=[ast.Pass()],
3301+
type_params=[type_param(name=name, default_value=unoptimized_binop)]
3302+
)
3303+
)
3304+
self.assert_ast(result_code, non_optimized_target, optimized_target)
3305+
3306+
def test_folding_type_param_in_type_alias(self):
3307+
code = "type foo[%s = 1 + 1] = 1"
3308+
3309+
unoptimized_binop = self.create_binop("+")
3310+
unoptimized_type_params = [
3311+
("T", "T", ast.TypeVar),
3312+
("**P", "P", ast.ParamSpec),
3313+
("*Ts", "Ts", ast.TypeVarTuple),
3314+
]
3315+
3316+
for type, name, type_param in unoptimized_type_params:
3317+
result_code = code % type
3318+
optimized_target = self.wrap_statement(
3319+
ast.TypeAlias(
3320+
name=ast.Name(id='foo', ctx=ast.Store()),
3321+
type_params=[type_param(name=name, default_value=ast.Constant(2))],
3322+
value=ast.Constant(value=1),
3323+
)
3324+
)
3325+
non_optimized_target = self.wrap_statement(
3326+
ast.TypeAlias(
3327+
name=ast.Name(id='foo', ctx=ast.Store()),
3328+
type_params=[type_param(name=name, default_value=unoptimized_binop)],
3329+
value=ast.Constant(value=1),
3330+
)
3331+
)
3332+
self.assert_ast(result_code, non_optimized_target, optimized_target)
3333+
32483334

32493335
if __name__ == "__main__":
32503336
unittest.main()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add AST optimizations for type parameter defaults.

Python/ast_opt.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,10 +1087,13 @@ astfold_type_param(type_param_ty node_, PyArena *ctx_, _PyASTOptimizeState *stat
10871087
switch (node_->kind) {
10881088
case TypeVar_kind:
10891089
CALL_OPT(astfold_expr, expr_ty, node_->v.TypeVar.bound);
1090+
CALL_OPT(astfold_expr, expr_ty, node_->v.TypeVar.default_value);
10901091
break;
10911092
case ParamSpec_kind:
1093+
CALL_OPT(astfold_expr, expr_ty, node_->v.ParamSpec.default_value);
10921094
break;
10931095
case TypeVarTuple_kind:
1096+
CALL_OPT(astfold_expr, expr_ty, node_->v.TypeVarTuple.default_value);
10941097
break;
10951098
}
10961099
return 1;

0 commit comments

Comments
 (0)