Skip to content

Commit 4dd2333

Browse files
committed
[IMP] util/misc: dynamic literal_replace
Allow to match in a glob-like fashion and produce the replacement node dynamically.
1 parent f6f772e commit 4dd2333

File tree

2 files changed

+58
-10
lines changed

2 files changed

+58
-10
lines changed

src/base/tests/test_util.py

+35
Original file line numberDiff line numberDiff line change
@@ -1674,6 +1674,41 @@ def test_literal_replace(self, orig, expected):
16741674
)
16751675
self.assertEqual(repl, expected)
16761676

1677+
@unittest.skipUnless(util.ast_unparse is not None, "`ast.unparse` available from Python3.9")
1678+
def test_literal_replace_callable(self):
1679+
def adapter(node):
1680+
return ast.parse("this.get('{}')".format(node.attr), mode="eval").body
1681+
1682+
repl = util.literal_replace(
1683+
"result = this. x if that . y == 2 else this .z",
1684+
{ast.Attribute(ast.Name("this"), util.literal_replace.WILDCARD): adapter},
1685+
)
1686+
self.assertEqual(repl, "result = this.get('x') if that.y == 2 else this.get('z')")
1687+
1688+
@unittest.skipUnless(util.ast_unparse is not None, "`ast.unparse` available from Python3.9")
1689+
def test_literal_replace_wildcards(self):
1690+
repl = util.literal_replace(
1691+
"x+1 - z* 18",
1692+
{ast.Name(util.literal_replace.WILDCARD): "y", ast.Constant(util.literal_replace.WILDCARD): "2"},
1693+
)
1694+
self.assertEqual(repl, "y + 2 - y * 2")
1695+
1696+
@parametrize(
1697+
[
1698+
(ast.Attribute(ast.Name(util.literal_replace.WILDCARD), util.literal_replace.WILDCARD), "*.*"),
1699+
(ast.Constant(util.literal_replace.WILDCARD), "*"),
1700+
(
1701+
ast.BinOp(
1702+
ast.Constant(util.literal_replace.WILDCARD), ast.Add(), ast.Constant(util.literal_replace.WILDCARD)
1703+
),
1704+
"* + *",
1705+
),
1706+
]
1707+
)
1708+
@unittest.skipUnless(util.ast_unparse is not None, "`ast.unparse` available from Python3.9")
1709+
def test_literal_replace_wildcard_unparse(self, orig, expected):
1710+
self.assertEqual(ast.unparse(orig), expected)
1711+
16771712

16781713
def not_doing_anything_converter(el):
16791714
return True

src/util/misc.py

+23-10
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ def visit_UnaryOp(self, node):
510510
return self._replace_node("_upg_UnaryOp_Not", node) if isinstance(node.op, ast.Not) else node
511511

512512
replacer = RewriteName()
513-
root = ast.parse(expr.strip(), mode="eval").body
513+
root = ast.parse(expr.strip(), mode="exec")
514514
visited = replacer.visit(root)
515515
return (ast_unparse(visited).strip(), SelfPrintEvalContext(replacer.replaces))
516516

@@ -521,36 +521,40 @@ class _Replacer(ast.NodeTransformer):
521521
def __init__(self, mapping):
522522
self.mapping = collections.defaultdict(list)
523523
for key, value in mapping.items():
524-
key_ast = ast.parse(key, mode="eval").body
524+
key_ast = key if isinstance(key, ast.AST) else ast.parse(key, mode="eval").body
525525
self.mapping[key_ast.__class__.__name__].append((key_ast, value))
526526

527527
def _no_match(self, left, right):
528528
return False
529529

530530
def _match(self, left, right):
531-
same_ctx = getattr(left, "ctx", None).__class__ is getattr(right, "ctx", None).__class__
531+
same_ctx = not hasattr(right, "ctx") or getattr(left, "ctx", None).__class__ is right.ctx.__class__
532532
cname = left.__class__.__name__
533533
same_type = right.__class__.__name__ == cname
534534
matcher = getattr(self, "_match_" + cname, self._no_match)
535535
return matcher(left, right) if same_type and same_ctx else False
536536

537537
def _match_Constant(self, left, right):
538538
# we don't care about kind for u-strings
539-
return type(left.value) is type(right.value) and left.value == right.value
539+
return right.value is literal_replace.WILDCARD or (
540+
type(left.value) is type(right.value) and left.value == right.value
541+
)
540542

541543
def _match_Num(self, left, right):
542544
# Dreprecated, for Python <3.8
543-
return type(left.n) is type(right.n) and left.n == right.n
545+
return right.n is literal_replace.WILDCARD or (type(left.n) is type(right.n) and left.n == right.n)
544546

545547
def _match_Str(self, left, right):
546548
# Deprecated, for Python <3.8
547-
return left.s == right.s
549+
return right.s is literal_replace.WILDCARD or left.s == right.s
548550

549551
def _match_Name(self, left, right):
550-
return left.id == right.id
552+
return right.id is literal_replace.WILDCARD or left.id == right.id
551553

552554
def _match_Attribute(self, left, right):
553-
return left.attr == right.attr and self._match(left.value, right.value)
555+
return (right.attr is literal_replace.WILDCARD or left.attr == right.attr) and self._match(
556+
left.value, right.value
557+
)
554558

555559
def _match_List(self, left, right):
556560
return len(left.elts) == len(right.elts) and all(
@@ -581,14 +585,23 @@ def visit(self, node):
581585
if node.__class__.__name__ in self.mapping:
582586
for target_ast, new in self.mapping[node.__class__.__name__]:
583587
if self._match(node, target_ast):
584-
return ast.parse(new, mode="eval").body
588+
# since we use eval below, non-callable must be an expression (no assignment)
589+
return new(node) if callable(new) else ast.parse(new, mode="eval").body
585590
return super(_Replacer, self).visit(node)
586591

587592

588593
def literal_replace(expr, mapping):
589594
if ast_unparse is None:
590595
_logger.critical("AST unparse unavailable")
591596
return expr
592-
root = ast.parse(expr.strip(), mode="eval").body
597+
root = ast.parse(expr.strip(), mode="exec")
593598
visited = _Replacer(mapping).visit(root)
594599
return ast_unparse(visited).strip()
600+
601+
602+
class _Symbol(str):
603+
def __repr__(self):
604+
return str(self)
605+
606+
607+
literal_replace.WILDCARD = _Symbol("*")

0 commit comments

Comments
 (0)