Skip to content

Commit 6eb8af1

Browse files
committed
.WIP refactor MarginalModel
1 parent d447e0e commit 6eb8af1

File tree

4 files changed

+444
-614
lines changed

4 files changed

+444
-614
lines changed

Diff for: pymc_experimental/model/marginal/distributions.py

+56
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytensor.tensor as pt
55

66
from pymc.distributions import Bernoulli, Categorical, DiscreteUniform
7+
from pymc.distributions.distribution import _support_point, support_point
78
from pymc.logprob.abstract import MeasurableOp, _logprob
89
from pymc.logprob.basic import conditional_logp, logp
910
from pymc.pytensorf import constant_fold
@@ -15,6 +16,7 @@
1516
from pytensor.scan import map as scan_map
1617
from pytensor.scan import scan
1718
from pytensor.tensor import TensorVariable
19+
from pytensor.tensor.random.type import RandomType
1820

1921
from pymc_experimental.distributions import DiscreteMarkovChain
2022

@@ -44,6 +46,60 @@ def support_axes(self) -> tuple[tuple[int]]:
4446
return tuple(support_axes_vars)
4547

4648

49+
@_support_point.register
50+
def support_point_marginal_rv(op: MarginalRV, rv, *inputs):
51+
"""Support point for a marginalized RV.
52+
53+
The support point of a marginalized RV is the support point of the inner RV,
54+
conditioned on the marginalized RV taking its support point.
55+
"""
56+
outputs = rv.owner.outputs
57+
rv_idx = outputs.index(rv)
58+
inner_rv = op.inner_outputs[rv_idx]
59+
other_inner_rvs = [
60+
out
61+
for out in op.inner_outputs
62+
if not isinstance(out.type, RandomType) and out is not inner_rv
63+
]
64+
65+
# Replace references to inner rvs by the dummy variables (including the marginalized RV)
66+
# This is necessary because the inner RVs may depend on each other
67+
other_inner_rv_to_dummies = {
68+
other_inner_rv: other_inner_rv.clone() for other_inner_rv in other_inner_rvs
69+
}
70+
inner_rv = clone_replace(inner_rv, other_inner_rv_to_dummies)
71+
inner_rv_support_point = support_point(inner_rv)
72+
73+
# Replace the dummy marginalized RV by the support point of the marginalized RV
74+
marginalized_rv = other_inner_rvs[0]
75+
marginalized_rv_support_point = support_point(marginalized_rv)
76+
dummy_marginalized_rv = other_inner_rv_to_dummies[marginalized_rv]
77+
inner_rv_support_point = clone_replace(
78+
inner_rv_support_point,
79+
{dummy_marginalized_rv: marginalized_rv_support_point},
80+
)
81+
82+
# Replace the remaining dummy variables by outer RVs
83+
rv_support_point = graph_replace(
84+
inner_rv_support_point,
85+
replace={
86+
v: outputs[op.inner_outputs.index(k)]
87+
for k, v in other_inner_rv_to_dummies.items()
88+
if k is not marginalized_rv
89+
},
90+
strict=False,
91+
)
92+
93+
# Make it a function of any remaining outer inputs
94+
rv_support_point = graph_replace(
95+
rv_support_point,
96+
replace=tuple(zip(op.inner_inputs, inputs)),
97+
strict=False,
98+
)
99+
100+
return rv_support_point
101+
102+
47103
class MarginalFiniteDiscreteRV(MarginalRV):
48104
"""Base class for Marginalized Finite Discrete RVs"""
49105

Diff for: pymc_experimental/model/marginal/graph_analysis.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from itertools import zip_longest
55

66
from pymc import SymbolicRandomVariable
7+
from pymc.model.fgraph import ModelVar
78
from pytensor.compile import SharedVariable
89
from pytensor.graph import Constant, Variable, ancestors
910
from pytensor.graph.basic import io_toposort
@@ -35,12 +36,12 @@ def static_shape_ancestors(vars):
3536

3637
def find_conditional_input_rvs(output_rvs, all_rvs):
3738
"""Find conditionally indepedent input RVs."""
38-
blockers = [other_rv for other_rv in all_rvs if other_rv not in output_rvs]
39-
blockers += static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs))
39+
other_rvs = [other_rv for other_rv in all_rvs if other_rv not in output_rvs]
40+
blockers = other_rvs + static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs))
4041
return [
4142
var
4243
for var in ancestors(output_rvs, blockers=blockers)
43-
if var in blockers or (var.owner is None and not isinstance(var, Constant | SharedVariable))
44+
if var in other_rvs
4445
]
4546

4647

@@ -141,6 +142,9 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars)
141142
# None of the inputs are related to the batch_axes of the input_vars
142143
continue
143144

145+
elif isinstance(node.op, ModelVar):
146+
var_dims[node.outputs[0]] = inputs_dims[0]
147+
144148
elif isinstance(node.op, DimShuffle):
145149
[input_dims] = inputs_dims
146150
output_dims = tuple(None if i == "x" else input_dims[i] for i in node.op.new_order)

0 commit comments

Comments
 (0)