Skip to content

Commit aefb963

Browse files
committed
Add 'overwrite_existing' flag to allow graph rewrites
and include appropriate testing
1 parent 2f1d25a commit aefb963

File tree

2 files changed

+65
-13
lines changed

2 files changed

+65
-13
lines changed

pytensor/graph/rewriting/db.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def register(
3535
rewriter: Union["RewriteDatabase", RewritesType],
3636
*tags: str,
3737
use_db_name_as_tag=True,
38+
overwrite_existing=False,
3839
):
3940
"""Register a new rewriter to the database.
4041
@@ -56,7 +57,8 @@ def register(
5657
``local_remove_all_assert``. Setting `use_db_name_as_tag` to
5758
``False`` removes that behavior. This means that only the rewrite's name
5859
and/or its tags will enable it.
59-
60+
overwrite_existing:
61+
Overwrite the existing rewriter with a new one having the same name
6062
"""
6163
if not isinstance(
6264
rewriter,
@@ -66,22 +68,27 @@ def register(
6668
):
6769
raise TypeError(f"{rewriter} is not a valid rewrite type.")
6870

69-
if name in self.__db__:
70-
raise ValueError(f"The tag '{name}' is already present in the database.")
71-
7271
if use_db_name_as_tag:
7372
if self.name is not None:
7473
tags = (*tags, self.name)
7574

7675
rewriter.name = name
77-
# This restriction is there because in many place we suppose that
78-
# something in the RewriteDatabase is there only once.
79-
if rewriter.name in self.__db__:
80-
raise ValueError(
81-
f"Tried to register {rewriter.name} again under the new name {name}. "
82-
"The same rewrite cannot be registered multiple times in"
83-
" an `RewriteDatabase`; use `ProxyDB` instead."
84-
)
76+
77+
# if tag collides with name
78+
if name in self.__db__ and name not in self._names:
79+
raise ValueError(f"The tag '{name}' is already present in the database.")
80+
81+
if name in self.__db__ or rewriter.name in self.__db__:
82+
if overwrite_existing:
83+
self.remove_tags(name, *tags)
84+
old_rewriter = self.__db__[name].pop()
85+
self._names.remove(name)
86+
self.__db__[old_rewriter.__class__.__name__].remove(old_rewriter)
87+
else:
88+
raise ValueError(
89+
f"The tag '{name}' is already present in the database."
90+
)
91+
8592
self.__db__[name] = OrderedSet([rewriter])
8693
self._names.add(name)
8794
self.__db__[rewriter.__class__.__name__].add(rewriter)

tests/graph/rewriting/test_db.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22

3+
from pytensor.graph.fg import FunctionGraph
34
from pytensor.graph.rewriting.basic import GraphRewriter, SequentialGraphRewriter
45
from pytensor.graph.rewriting.db import (
56
EquilibriumDB,
@@ -17,6 +18,30 @@ def apply(self, fgraph):
1718
pass
1819

1920

21+
class NewTestRewriter(GraphRewriter):
22+
name = "bleh"
23+
24+
def apply(self, fgraph):
25+
pass
26+
27+
28+
counter1 = 0
29+
30+
counter2 = 0
31+
32+
33+
class TestOverwrite1(GraphRewriter):
34+
def apply(self, fgraph):
35+
global counter1
36+
counter1 += 1
37+
38+
39+
class TestOverwrite2(GraphRewriter):
40+
def apply(self, fgraph):
41+
global counter2
42+
counter2 += 1
43+
44+
2045
class TestDB:
2146
def test_register(self):
2247
db = RewriteDatabase()
@@ -31,7 +56,9 @@ def test_register(self):
3156
assert "c" in db
3257

3358
with pytest.raises(ValueError, match=r"The tag.*"):
34-
db.register("c", TestRewriter()) # name taken
59+
db.register("c", NewTestRewriter()) # name taken
60+
61+
db.register("c", NewTestRewriter(), overwrite_existing=True)
3562

3663
with pytest.raises(ValueError, match=r"The tag.*"):
3764
db.register("z", TestRewriter()) # name collides with tag
@@ -42,6 +69,24 @@ def test_register(self):
4269
with pytest.raises(TypeError, match=r".* is not a valid.*"):
4370
db.register("d", 1)
4471

72+
def test_overwrite(self):
73+
db = RewriteDatabase()
74+
fg = FunctionGraph([], [])
75+
76+
db.register("a", TestRewriter())
77+
Rewriter = db.__getitem__("a")
78+
Rewriter.rewrite(fg)
79+
80+
db.register("a", TestOverwrite1(), overwrite_existing=True)
81+
Rewriter = db.__getitem__("a")
82+
Rewriter.rewrite(fg)
83+
assert counter1 == 1 and counter2 == 0
84+
85+
db.register("a", TestOverwrite2(), overwrite_existing=True)
86+
Rewriter = db.__getitem__("a")
87+
Rewriter.rewrite(fg)
88+
assert counter1 == 1 and counter2 == 1
89+
4590
def test_EquilibriumDB(self):
4691
eq_db = EquilibriumDB()
4792

0 commit comments

Comments
 (0)