diff --git a/pymc/initial_point.py b/pymc/initial_point.py index 15f4f887c0..0cb1a26ddc 100644 --- a/pymc/initial_point.py +++ b/pymc/initial_point.py @@ -25,7 +25,13 @@ from pytensor.tensor.variable import TensorVariable from pymc.logprob.transforms import Transform -from pymc.pytensorf import compile_pymc, find_rng_nodes, replace_rng_nodes, reseed_rngs +from pymc.pytensorf import ( + compile_pymc, + find_rng_nodes, + replace_rng_nodes, + reseed_rngs, + toposort_replace, +) from pymc.util import get_transformed_name, get_untransformed_name, is_transformed_name StartDict = dict[Variable | str, np.ndarray | Variable | str] @@ -288,8 +294,7 @@ def make_initial_point_expression( # order, so that later nodes do not reintroduce expressions with earlier # rvs that would need to once again be replaced by their initial_points graph = FunctionGraph(outputs=free_rvs_clone, clone=False) - replacements = reversed(list(zip(free_rvs_clone, initial_values_clone))) - graph.replace_all(replacements, import_missing=True) + toposort_replace(graph, tuple(zip(free_rvs_clone, initial_values_clone)), reverse=True) if not return_transformed: return graph.outputs diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index 213831c9f1..527a1c8e0c 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -1122,10 +1122,40 @@ def toposort_replace( fgraph: FunctionGraph, replacements: Sequence[tuple[Variable, Variable]], reverse: bool = False ) -> None: """Replace multiple variables in place in topological order.""" - toposort = fgraph.toposort() + fgraph_toposort = {node: i for i, node in enumerate(fgraph.toposort())} + _inner_fgraph_toposorts = {} # Cache inner toposorts + + def _nested_toposort_index(var, fgraph_toposort) -> tuple[int]: + """Compute position of variable in fgraph toposort. + + When a variable is an OpFromGraph output, extend output with the toposort index of the inner graph(s). + + This allows ordering variables that come from the same OpFromGraph. + """ + if not var.owner: + return (-1,) + + index = fgraph_toposort[var.owner] + + # Recurse into OpFromGraphs + # TODO: Could also recurse into Scans + if isinstance(var.owner.op, OpFromGraph): + inner_fgraph = var.owner.op.fgraph + + if inner_fgraph not in _inner_fgraph_toposorts: + _inner_fgraph_toposorts[inner_fgraph] = { + node: i for i, node in enumerate(inner_fgraph.toposort()) + } + + inner_fgraph_toposort = _inner_fgraph_toposorts[inner_fgraph] + inner_var = inner_fgraph.outputs[var.owner.outputs.index(var)] + return (index, *_nested_toposort_index(inner_var, inner_fgraph_toposort)) + else: + return (index,) + sorted_replacements = sorted( replacements, - key=lambda pair: toposort.index(pair[0].owner) if pair[0].owner else -1, + key=lambda pair: _nested_toposort_index(pair[0], fgraph_toposort), reverse=reverse, ) fgraph.replace_all(sorted_replacements, import_missing=True) diff --git a/tests/test_initial_point.py b/tests/test_initial_point.py index 6d61b4b3fc..9138f37b3e 100644 --- a/tests/test_initial_point.py +++ b/tests/test_initial_point.py @@ -17,11 +17,12 @@ import pytensor.tensor as pt import pytest +from pytensor.compile.builders import OpFromGraph from pytensor.tensor.random.op import RandomVariable import pymc as pm -from pymc.distributions.distribution import support_point +from pymc.distributions.distribution import _support_point, support_point from pymc.initial_point import make_initial_point_fn, make_initial_point_fns_per_chain @@ -47,12 +48,17 @@ def test_make_initial_point_fns_per_chain_checks_kwargs(self): ) pass - def test_dependent_initvals(self): + @pytest.mark.parametrize("reverse_rvs", [False, True]) + def test_dependent_initvals(self, reverse_rvs): with pm.Model() as pmodel: L = pm.Uniform("L", 0, 1, initval=0.5) U = pm.Uniform("U", lower=9, upper=10, initval=9.5) B1 = pm.Uniform("B1", lower=L, upper=U, initval=5) B2 = pm.Uniform("B2", lower=L, upper=U, initval=(L + U) / 2) + + if reverse_rvs: + pmodel.free_RVs = pmodel.free_RVs[::-1] + ip = pmodel.initial_point(random_seed=0) assert ip["L_interval__"] == 0 assert ip["U_interval__"] == 0 @@ -187,6 +193,40 @@ def test_string_overrides_work(self): assert np.isclose(iv["B_log__"], 0) assert iv["C_log__"] == 0 + @pytest.mark.parametrize("reverse_rvs", [False, True]) + def test_dependent_initval_from_OFG(self, reverse_rvs): + class MyTestOp(OpFromGraph): + pass + + @_support_point.register(MyTestOp) + def my_test_op_support_point(op, out): + out1, out2 = out.owner.outputs + if out is out1: + return out1 + else: + return out1 * 4 + + out1 = pt.zeros(()) + out2 = out1 * 2 + rv_op = MyTestOp([], [out1, out2]) + + with pm.Model() as model: + A, B = rv_op() + if reverse_rvs: + model.register_rv(B, "B") + model.register_rv(A, "A") + else: + model.register_rv(A, "A") + model.register_rv(B, "B") + + assert model.initial_point() == {"A": 0, "B": 0} + + model.set_initval(A, 1) + assert model.initial_point() == {"A": 1, "B": 4} + + model.set_initval(B, 3) + assert model.initial_point() == {"A": 1, "B": 3} + class TestSupportPoint: def test_basic(self):