Skip to content

Commit 5fb56ba

Browse files
AdvH039ricardoV94
andauthored
Add overwrite_existing flag (#1119)
* Add 'overwrite_existing' flag to allow graph rewrites and include appropriate testing * Encapsulate test rewriters and use user-facing API --------- Co-authored-by: Ricardo Vieira <[email protected]>
1 parent b065112 commit 5fb56ba

File tree

2 files changed

+62
-13
lines changed

2 files changed

+62
-13
lines changed

pytensor/graph/rewriting/db.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def register(
3535
rewriter: Union["RewriteDatabase", RewritesType],
3636
*tags: str,
3737
use_db_name_as_tag=True,
38+
overwrite_existing=False,
3839
):
3940
"""Register a new rewriter to the database.
4041
@@ -56,7 +57,8 @@ def register(
5657
``local_remove_all_assert``. Setting `use_db_name_as_tag` to
5758
``False`` removes that behavior. This means that only the rewrite's name
5859
and/or its tags will enable it.
59-
60+
overwrite_existing:
61+
Overwrite the existing rewriter with a new one having the same name
6062
"""
6163
if not isinstance(
6264
rewriter,
@@ -66,22 +68,27 @@ def register(
6668
):
6769
raise TypeError(f"{rewriter} is not a valid rewrite type.")
6870

69-
if name in self.__db__:
70-
raise ValueError(f"The tag '{name}' is already present in the database.")
71-
7271
if use_db_name_as_tag:
7372
if self.name is not None:
7473
tags = (*tags, self.name)
7574

7675
rewriter.name = name
77-
# This restriction is there because in many place we suppose that
78-
# something in the RewriteDatabase is there only once.
79-
if rewriter.name in self.__db__:
80-
raise ValueError(
81-
f"Tried to register {rewriter.name} again under the new name {name}. "
82-
"The same rewrite cannot be registered multiple times in"
83-
" an `RewriteDatabase`; use `ProxyDB` instead."
84-
)
76+
77+
# if tag collides with name
78+
if name in self.__db__ and name not in self._names:
79+
raise ValueError(f"The tag '{name}' is already present in the database.")
80+
81+
if name in self.__db__ or rewriter.name in self.__db__:
82+
if overwrite_existing:
83+
self.remove_tags(name, *tags)
84+
old_rewriter = self.__db__[name].pop()
85+
self._names.remove(name)
86+
self.__db__[old_rewriter.__class__.__name__].remove(old_rewriter)
87+
else:
88+
raise ValueError(
89+
f"The tag '{name}' is already present in the database."
90+
)
91+
8592
self.__db__[name] = OrderedSet([rewriter])
8693
self._names.add(name)
8794
self.__db__[rewriter.__class__.__name__].add(rewriter)

tests/graph/rewriting/test_db.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22

3+
from pytensor.graph.fg import FunctionGraph
34
from pytensor.graph.rewriting.basic import GraphRewriter, SequentialGraphRewriter
45
from pytensor.graph.rewriting.db import (
56
EquilibriumDB,
@@ -17,6 +18,13 @@ def apply(self, fgraph):
1718
pass
1819

1920

21+
class NewTestRewriter(GraphRewriter):
22+
name = "bleh"
23+
24+
def apply(self, fgraph):
25+
pass
26+
27+
2028
class TestDB:
2129
def test_register(self):
2230
db = RewriteDatabase()
@@ -31,7 +39,7 @@ def test_register(self):
3139
assert "c" in db
3240

3341
with pytest.raises(ValueError, match=r"The tag.*"):
34-
db.register("c", TestRewriter()) # name taken
42+
db.register("c", NewTestRewriter()) # name taken
3543

3644
with pytest.raises(ValueError, match=r"The tag.*"):
3745
db.register("z", TestRewriter()) # name collides with tag
@@ -42,6 +50,40 @@ def test_register(self):
4250
with pytest.raises(TypeError, match=r".* is not a valid.*"):
4351
db.register("d", 1)
4452

53+
def test_overwrite_existing(self):
54+
class TestOverwrite1(GraphRewriter):
55+
def apply(self, fgraph):
56+
fgraph.counter[0] += 1
57+
58+
class TestOverwrite2(GraphRewriter):
59+
def apply(self, fgraph):
60+
fgraph.counter[1] += 1
61+
62+
db = SequenceDB()
63+
fg = FunctionGraph([], [])
64+
fg.counter = [0, 0]
65+
66+
db.register("a", TestRewriter(), "basic")
67+
rewriter = db.query("+basic")
68+
rewriter.rewrite(fg)
69+
assert fg.counter == [0, 0]
70+
71+
with pytest.raises(ValueError, match=r"The tag.*"):
72+
db.register("a", TestOverwrite1(), "basic")
73+
rewriter = db.query("+basic")
74+
rewriter.rewrite(fg)
75+
assert fg.counter == [0, 0]
76+
77+
db.register("a", TestOverwrite1(), "basic", overwrite_existing=True)
78+
rewriter = db.query("+basic")
79+
rewriter.rewrite(fg)
80+
assert fg.counter == [1, 0]
81+
82+
db.register("a", TestOverwrite2(), "basic", overwrite_existing=True)
83+
rewriter = db.query("+basic")
84+
rewriter.rewrite(fg)
85+
assert fg.counter == [1, 1]
86+
4587
def test_EquilibriumDB(self):
4688
eq_db = EquilibriumDB()
4789

0 commit comments

Comments
 (0)