Skip to content

Commit 6cdae5f

Browse files
committed
[IMP] util/misc: more coverage
More nodes + more tests
1 parent 4dd2333 commit 6cdae5f

File tree

2 files changed

+43
-8
lines changed

2 files changed

+43
-8
lines changed

src/base/tests/test_util.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -1686,6 +1686,25 @@ def adapter(node):
16861686
self.assertEqual(repl, "result = this.get('x') if that.y == 2 else this.get('z')")
16871687

16881688
@unittest.skipUnless(util.ast_unparse is not None, "`ast.unparse` available from Python3.9")
1689+
@unittest.skipUnless(hasattr(ast, "Constant"), "`ast.Constant` available from Python3.6")
1690+
def test_literal_replace_callable2(self):
1691+
def adapter2(node):
1692+
return ast.parse("{} / 42".format(node.left.value), mode="eval").body
1693+
1694+
def adapter3(node):
1695+
return ast.parse("y == 42", mode="eval").body
1696+
1697+
repl = util.literal_replace(
1698+
"16 * w or y == 2",
1699+
{
1700+
ast.BinOp(ast.Constant(16), ast.Mult(), ast.Name(util.literal_replace.WILDCARD)): adapter2,
1701+
ast.Compare(ast.Name("y"), [ast.Eq()], [ast.Constant(util.literal_replace.WILDCARD)]): adapter3,
1702+
},
1703+
)
1704+
self.assertEqual(repl, "16 / 42 or y == 42")
1705+
1706+
@unittest.skipUnless(util.ast_unparse is not None, "`ast.unparse` available from Python3.9")
1707+
@unittest.skipUnless(hasattr(ast, "Constant"), "`ast.Constant` available from Python3.6")
16891708
def test_literal_replace_wildcards(self):
16901709
repl = util.literal_replace(
16911710
"x+1 - z* 18",
@@ -1695,7 +1714,7 @@ def test_literal_replace_wildcards(self):
16951714

16961715
@parametrize(
16971716
[
1698-
(ast.Attribute(ast.Name(util.literal_replace.WILDCARD), util.literal_replace.WILDCARD), "*.*"),
1717+
(ast.Attribute(ast.Name(util.literal_replace.WILDCARD, None), util.literal_replace.WILDCARD, None), "*.*"),
16991718
(ast.Constant(util.literal_replace.WILDCARD), "*"),
17001719
(
17011720
ast.BinOp(

src/util/misc.py

+23-7
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,9 @@ def __init__(self, mapping):
527527
def _no_match(self, left, right):
528528
return False
529529

530+
def _all_match(self, lefts, rights):
531+
return len(lefts) == len(rights) and all(self._match(left_, right_) for left_, right_ in zip(lefts, rights))
532+
530533
def _match(self, left, right):
531534
same_ctx = not hasattr(right, "ctx") or getattr(left, "ctx", None).__class__ is right.ctx.__class__
532535
cname = left.__class__.__name__
@@ -565,20 +568,33 @@ def _match_Tuple(self, left, right):
565568
return self._match_List(left, right)
566569

567570
def _match_ListComp(self, left, right):
568-
return (
569-
self._match(left.elt, right.elt)
570-
and len(left.generators) == len(right.generators)
571-
and all(self._match(left_, right_) for left_, right_ in zip(left.generators, right.generators))
572-
)
571+
return self._match(left.elt, right.elt) and self._all_match(left.generators, right.generators)
573572

574573
def _match_comprehension(self, left, right):
575574
return (
576575
# async is not expected in our use cases, just for completeness
577576
getattr(left, "is_async", 0) == getattr(right, "is_async", 0)
578-
and len(left.ifs) == len(right.ifs)
579577
and self._match(left.target, right.target)
580578
and self._match(left.iter, right.iter)
581-
and all(self._match(left_, right_) for left_, right_ in zip(left.ifs, right.ifs))
579+
and self._all_match(left.ifs, right.ifs)
580+
)
581+
582+
def _match_BinOp(self, left, right):
583+
return (
584+
left.op.__class__ is right.op.__class__
585+
and self._match(left.left, right.left)
586+
and self._match(left.right, right.right)
587+
)
588+
589+
def _match_BoolOp(self, left, right):
590+
return left.op.__class__ is right.op.__class__ and self._all_match(left.values, right.values)
591+
592+
def _match_Compare(self, left, right):
593+
return (
594+
len(left.ops) == len(right.ops)
595+
and all(lop.__class__ is rop.__class__ for lop, rop in zip(left.ops, right.ops))
596+
and self._match(left.left, right.left)
597+
and self._all_match(left.comparators, right.comparators)
582598
)
583599

584600
def visit(self, node):

0 commit comments

Comments
 (0)