From 77f41861fea25371bffe4f91250396d5bde72aab Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Tue, 1 Apr 2025 01:58:39 +0200 Subject: [PATCH 1/5] Transform to remove Minibatch from model --- pymc/model/transform/basic.py | 17 ++++++++++++++++- tests/model/transform/test_basic.py | 17 ++++++++++++++++- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/pymc/model/transform/basic.py b/pymc/model/transform/basic.py index fcf42fdf8c..faf5d27bd9 100644 --- a/pymc/model/transform/basic.py +++ b/pymc/model/transform/basic.py @@ -14,8 +14,10 @@ from collections.abc import Sequence from pytensor import Variable -from pytensor.graph import ancestors +from pytensor.graph import ancestors, node_rewriter +from pytensor.graph.rewriting.basic import out2in +from pymc.data import MinibatchOp from pymc.model.core import Model from pymc.model.fgraph import ( ModelObservedRV, @@ -58,3 +60,16 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l else: vars_seq = (vars,) return [model[var] if isinstance(var, str) else var for var in vars_seq] + + +def remove_minibatched_nodes(model: Model): + """Remove all uses of pm.Minibatch in the Model""" + + @node_rewriter([MinibatchOp]) + def local_remove_minibatch(fgraph, node): + return node.inputs + + remove_minibatch = out2in(local_remove_minibatch) + fgraph, _ = fgraph_from_model(model) + remove_minibatch.apply(fgraph) + return model_from_fgraph(fgraph, mutate_fgraph=True) diff --git a/tests/model/transform/test_basic.py b/tests/model/transform/test_basic.py index 25bf2324ec..dc193f24b7 100644 --- a/tests/model/transform/test_basic.py +++ b/tests/model/transform/test_basic.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np import pymc as pm -from pymc.model.transform.basic import prune_vars_detached_from_observed +from pymc.model.transform.basic import prune_vars_detached_from_observed, remove_minibatched_nodes def test_prune_vars_detached_from_observed(): @@ -30,3 +31,17 @@ def test_prune_vars_detached_from_observed(): assert set(m.named_vars.keys()) == {"obs_data", "a0", "a1", "a2", "obs", "d0", "d1"} pruned_m = prune_vars_detached_from_observed(m) assert set(pruned_m.named_vars.keys()) == {"obs_data", "a0", "a1", "a2", "obs"} + + +def test_remove_minibatches(): + data_size = 100 + data = np.zeros((data_size,)) + batch_size = 10 + with pm.Model() as m1: + mb = pm.Minibatch(data, batch_size=batch_size) + x = pm.Normal("x") + y = pm.Normal("y", x, observed=mb, total_size=100) + + m2 = remove_minibatched_nodes(m1) + assert m1.y.shape[0].eval() == batch_size + assert m2.y.shape[0].eval() == data_size From 5866f027af65620b70022fc2ae64d527ae982de8 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Tue, 1 Apr 2025 02:26:45 +0200 Subject: [PATCH 2/5] Lint Fix --- pymc/model/transform/basic.py | 2 +- tests/model/transform/test_basic.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pymc/model/transform/basic.py b/pymc/model/transform/basic.py index faf5d27bd9..45306ff144 100644 --- a/pymc/model/transform/basic.py +++ b/pymc/model/transform/basic.py @@ -63,7 +63,7 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l def remove_minibatched_nodes(model: Model): - """Remove all uses of pm.Minibatch in the Model""" + """Remove all uses of pm.Minibatch in the Model.""" @node_rewriter([MinibatchOp]) def local_remove_minibatch(fgraph, node): diff --git a/tests/model/transform/test_basic.py b/tests/model/transform/test_basic.py index dc193f24b7..16e0c4ac1f 100644 --- a/tests/model/transform/test_basic.py +++ b/tests/model/transform/test_basic.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import numpy as np + import pymc as pm from pymc.model.transform.basic import prune_vars_detached_from_observed, remove_minibatched_nodes From b0d7f396db2d92c603f79db653c75b3d4c680b00 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Sat, 5 Apr 2025 23:40:47 +0200 Subject: [PATCH 3/5] Rework transform --- pymc/model/transform/basic.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/pymc/model/transform/basic.py b/pymc/model/transform/basic.py index 45306ff144..2b09d4e246 100644 --- a/pymc/model/transform/basic.py +++ b/pymc/model/transform/basic.py @@ -62,14 +62,24 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l return [model[var] if isinstance(var, str) else var for var in vars_seq] -def remove_minibatched_nodes(model: Model): +def remove_minibatched_nodes(model: pm.Model) -> pm.Model: """Remove all uses of pm.Minibatch in the Model.""" + fgraph, _ = fgraph_from_model(model) - @node_rewriter([MinibatchOp]) - def local_remove_minibatch(fgraph, node): - return node.inputs + replacements = {} + for var in fgraph.apply_nodes: + if isinstance(var.op, MinibatchOp): + for inp, out in zip(var.inputs, var.outputs): + replacements[out] = inp - remove_minibatch = out2in(local_remove_minibatch) - fgraph, _ = fgraph_from_model(model) - remove_minibatch.apply(fgraph) + old_outs, old_coords, old_dim_lengths = fgraph.outputs, fgraph._coords, fgraph._dim_lengths + # Using `rebuild_strict=False` means all coords, names, and dim information is lost + # So we need to restore it from the old fgraph + new_outs = pytensor.clone_replace(old_outs, replacements, rebuild_strict=False) + for old_out, new_out in zip(old_outs, new_outs): + new_out.name = old_out.name + fgraph = pytensor.graph.fg.FunctionGraph(outputs=new_outs, clone=False) + fgraph._coords = old_coords + fgraph._dim_lengths = old_dim_lengths return model_from_fgraph(fgraph, mutate_fgraph=True) + From 30dc4cc135aea73837f00de8280ad81488829ab0 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Sun, 6 Apr 2025 00:08:31 +0200 Subject: [PATCH 4/5] Tidy --- pymc/model/transform/basic.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/pymc/model/transform/basic.py b/pymc/model/transform/basic.py index 2b09d4e246..dfe3e864a6 100644 --- a/pymc/model/transform/basic.py +++ b/pymc/model/transform/basic.py @@ -13,9 +13,9 @@ # limitations under the License. from collections.abc import Sequence -from pytensor import Variable -from pytensor.graph import ancestors, node_rewriter -from pytensor.graph.rewriting.basic import out2in +from pytensor import Variable, clone_replace +from pytensor.graph import ancestors +from pytensor.graph.fg import FunctionGraph from pymc.data import MinibatchOp from pymc.model.core import Model @@ -62,7 +62,7 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l return [model[var] if isinstance(var, str) else var for var in vars_seq] -def remove_minibatched_nodes(model: pm.Model) -> pm.Model: +def remove_minibatched_nodes(model: Model) -> Model: """Remove all uses of pm.Minibatch in the Model.""" fgraph, _ = fgraph_from_model(model) @@ -75,11 +75,10 @@ def remove_minibatched_nodes(model: pm.Model) -> pm.Model: old_outs, old_coords, old_dim_lengths = fgraph.outputs, fgraph._coords, fgraph._dim_lengths # Using `rebuild_strict=False` means all coords, names, and dim information is lost # So we need to restore it from the old fgraph - new_outs = pytensor.clone_replace(old_outs, replacements, rebuild_strict=False) + new_outs = clone_replace(old_outs, replacements, rebuild_strict=False) for old_out, new_out in zip(old_outs, new_outs): new_out.name = old_out.name - fgraph = pytensor.graph.fg.FunctionGraph(outputs=new_outs, clone=False) + fgraph = FunctionGraph(outputs=new_outs, clone=False) fgraph._coords = old_coords fgraph._dim_lengths = old_dim_lengths return model_from_fgraph(fgraph, mutate_fgraph=True) - From ec3b3778f9bbdbf0dae885193b6469886655eaee Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Sun, 6 Apr 2025 00:13:16 +0200 Subject: [PATCH 5/5] Appease mypy --- pymc/model/transform/basic.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pymc/model/transform/basic.py b/pymc/model/transform/basic.py index dfe3e864a6..877814cd61 100644 --- a/pymc/model/transform/basic.py +++ b/pymc/model/transform/basic.py @@ -72,13 +72,13 @@ def remove_minibatched_nodes(model: Model) -> Model: for inp, out in zip(var.inputs, var.outputs): replacements[out] = inp - old_outs, old_coords, old_dim_lengths = fgraph.outputs, fgraph._coords, fgraph._dim_lengths + old_outs, old_coords, old_dim_lengths = fgraph.outputs, fgraph._coords, fgraph._dim_lengths # type: ignore[attr-defined] # Using `rebuild_strict=False` means all coords, names, and dim information is lost # So we need to restore it from the old fgraph - new_outs = clone_replace(old_outs, replacements, rebuild_strict=False) + new_outs = clone_replace(old_outs, replacements, rebuild_strict=False) # type: ignore[arg-type] for old_out, new_out in zip(old_outs, new_outs): new_out.name = old_out.name fgraph = FunctionGraph(outputs=new_outs, clone=False) - fgraph._coords = old_coords - fgraph._dim_lengths = old_dim_lengths + fgraph._coords = old_coords # type: ignore[attr-defined] + fgraph._dim_lengths = old_dim_lengths # type: ignore[attr-defined] return model_from_fgraph(fgraph, mutate_fgraph=True)