Skip to content

Commit d837e26

Browse files
committed
.WIP Remove MarginalModel in favor of model transform
1 parent d447e0e commit d837e26

File tree

5 files changed

+441
-616
lines changed

5 files changed

+441
-616
lines changed

Diff for: pymc_experimental/model/marginal/distributions.py

+53-1
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,19 @@
44
import pytensor.tensor as pt
55

66
from pymc.distributions import Bernoulli, Categorical, DiscreteUniform
7+
from pymc.distributions.distribution import _support_point, support_point
78
from pymc.logprob.abstract import MeasurableOp, _logprob
89
from pymc.logprob.basic import conditional_logp, logp
910
from pymc.pytensorf import constant_fold
1011
from pytensor import Variable
1112
from pytensor.compile.builders import OpFromGraph
1213
from pytensor.compile.mode import Mode
13-
from pytensor.graph import Op, vectorize_graph
14+
from pytensor.graph import FunctionGraph, Op, vectorize_graph
1415
from pytensor.graph.replace import clone_replace, graph_replace
1516
from pytensor.scan import map as scan_map
1617
from pytensor.scan import scan
1718
from pytensor.tensor import TensorVariable
19+
from pytensor.tensor.random.type import RandomType
1820

1921
from pymc_experimental.distributions import DiscreteMarkovChain
2022

@@ -44,6 +46,56 @@ def support_axes(self) -> tuple[tuple[int]]:
4446
return tuple(support_axes_vars)
4547

4648

49+
@_support_point.register
50+
def support_point_marginal_rv(op: MarginalRV, rv, *inputs):
51+
"""Support point for a marginalized RV.
52+
53+
The support point of a marginalized RV is the support point of the inner RV,
54+
conditioned on the marginalized RV taking its support point.
55+
"""
56+
outputs = rv.owner.outputs
57+
58+
inner_rv = op.inner_outputs[outputs.index(rv)]
59+
marginalized_inner_rv, *other_dependent_inner_rvs = (
60+
out
61+
for out in op.inner_outputs
62+
if out is not inner_rv and not isinstance(out.type, RandomType)
63+
)
64+
65+
# Replace references to inner rvs by the dummy variables (including the marginalized RV)
66+
# This is necessary because the inner RVs may depend on each other
67+
marginalized_inner_rv_dummy = marginalized_inner_rv.clone()
68+
other_dependent_inner_rv_to_dummies = {
69+
inner_rv: inner_rv.clone() for inner_rv in other_dependent_inner_rvs
70+
}
71+
inner_rv = clone_replace(
72+
inner_rv,
73+
replace={marginalized_inner_rv: marginalized_inner_rv_dummy}
74+
| other_dependent_inner_rv_to_dummies,
75+
)
76+
77+
# Get support point of inner RV and marginalized RV
78+
inner_rv_support_point = support_point(inner_rv)
79+
marginalized_inner_rv_support_point = support_point(marginalized_inner_rv)
80+
81+
replacements = [
82+
# Replace the marginalized RV dummy by its support point
83+
(marginalized_inner_rv_dummy, marginalized_inner_rv_support_point),
84+
# Replace other dependent RVs dummies by the respective outer outputs.
85+
# PyMC will replace them by their support points later
86+
*(
87+
(v, outputs[op.inner_outputs.index(k)])
88+
for k, v in other_dependent_inner_rv_to_dummies.items()
89+
),
90+
# Replace outer input RVs
91+
*zip(op.inner_inputs, inputs),
92+
]
93+
fgraph = FunctionGraph(outputs=[inner_rv_support_point], clone=False)
94+
fgraph.replace_all(replacements, import_missing=True)
95+
[rv_support_point] = fgraph.outputs
96+
return rv_support_point
97+
98+
4799
class MarginalFiniteDiscreteRV(MarginalRV):
48100
"""Base class for Marginalized Finite Discrete RVs"""
49101

Diff for: pymc_experimental/model/marginal/graph_analysis.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from itertools import zip_longest
55

66
from pymc import SymbolicRandomVariable
7+
from pymc.model.fgraph import ModelVar
78
from pytensor.compile import SharedVariable
89
from pytensor.graph import Constant, Variable, ancestors
910
from pytensor.graph.basic import io_toposort
@@ -35,12 +36,12 @@ def static_shape_ancestors(vars):
3536

3637
def find_conditional_input_rvs(output_rvs, all_rvs):
3738
"""Find conditionally indepedent input RVs."""
38-
blockers = [other_rv for other_rv in all_rvs if other_rv not in output_rvs]
39-
blockers += static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs))
39+
other_rvs = [other_rv for other_rv in all_rvs if other_rv not in output_rvs]
40+
blockers = other_rvs + static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs))
4041
return [
4142
var
4243
for var in ancestors(output_rvs, blockers=blockers)
43-
if var in blockers or (var.owner is None and not isinstance(var, Constant | SharedVariable))
44+
if var in other_rvs
4445
]
4546

4647

@@ -141,6 +142,9 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars)
141142
# None of the inputs are related to the batch_axes of the input_vars
142143
continue
143144

145+
elif isinstance(node.op, ModelVar):
146+
var_dims[node.outputs[0]] = inputs_dims[0]
147+
144148
elif isinstance(node.op, DimShuffle):
145149
[input_dims] = inputs_dims
146150
output_dims = tuple(None if i == "x" else input_dims[i] for i in node.op.new_order)

0 commit comments

Comments
 (0)