Skip to content

Commit 047c682

Browse files
committed
.WIP Remove MarginalModel in favor of model transform
1 parent 427ef18 commit 047c682

File tree

11 files changed

+881
-650
lines changed

11 files changed

+881
-650
lines changed

Diff for: docs/api_reference.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ methods in the current release of PyMC experimental.
1212
:toctree: generated/
1313

1414
as_model
15-
MarginalModel
1615
marginalize
16+
recover_marginals
1717
model_builder.ModelBuilder
1818

1919
Inference
@@ -53,6 +53,7 @@ Utils
5353

5454
spline.bspline_interpolation
5555
prior.prior_from_idata
56+
model_equivalence.equivalent_models
5657

5758

5859
Statespace Models

Diff for: pymc_experimental/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
from pymc_experimental import gp, statespace, utils
1717
from pymc_experimental.distributions import *
1818
from pymc_experimental.inference.fit import fit
19-
from pymc_experimental.model.marginal.marginal_model import MarginalModel, marginalize
19+
from pymc_experimental.model.marginal.marginal_model import (
20+
MarginalModel,
21+
marginalize,
22+
recover_marginals,
23+
)
2024
from pymc_experimental.model.model_api import as_model
2125
from pymc_experimental.version import __version__
2226

Diff for: pymc_experimental/distributions/timeseries.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,8 @@ def transition(*args):
214214
discrete_mc_op = DiscreteMarkovChainRV(
215215
inputs=[P_, steps_, init_dist_, state_rng],
216216
outputs=[state_next_rng, discrete_mc_],
217-
ndim_supp=1,
218217
n_lags=n_lags,
218+
extended_signature="(p,p),(),(p),[rng]->[rng],(t)",
219219
)
220220

221221
discrete_mc = discrete_mc_op(P, steps, init_dist, state_rng)

Diff for: pymc_experimental/model/marginal/distributions.py

+87-2
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,25 @@
1+
import warnings
2+
13
from collections.abc import Sequence
24

35
import numpy as np
46
import pytensor.tensor as pt
57

68
from pymc.distributions import Bernoulli, Categorical, DiscreteUniform
9+
from pymc.distributions.distribution import _support_point, support_point
710
from pymc.logprob.abstract import MeasurableOp, _logprob
811
from pymc.logprob.basic import conditional_logp, logp
912
from pymc.pytensorf import constant_fold
1013
from pytensor import Variable
1114
from pytensor.compile.builders import OpFromGraph
1215
from pytensor.compile.mode import Mode
13-
from pytensor.graph import Op, vectorize_graph
16+
from pytensor.graph import FunctionGraph, Op, vectorize_graph
17+
from pytensor.graph.basic import equal_computations
1418
from pytensor.graph.replace import clone_replace, graph_replace
1519
from pytensor.scan import map as scan_map
1620
from pytensor.scan import scan
1721
from pytensor.tensor import TensorVariable
22+
from pytensor.tensor.random.type import RandomType
1823

1924
from pymc_experimental.distributions import DiscreteMarkovChain
2025

@@ -43,6 +48,74 @@ def support_axes(self) -> tuple[tuple[int]]:
4348
)
4449
return tuple(support_axes_vars)
4550

51+
def __eq__(self, other):
52+
# Just to allow easy testing of equivalent models,
53+
# This can be removed once https://github.com/pymc-devs/pytensor/issues/1114 is fixed
54+
if type(self) is not type(other):
55+
return False
56+
57+
return equal_computations(
58+
self.inner_outputs,
59+
other.inner_outputs,
60+
self.inner_inputs,
61+
other.inner_inputs,
62+
)
63+
64+
def __hash__(self):
65+
# Just to allow easy testing of equivalent models,
66+
# This can be removed once https://github.com/pymc-devs/pytensor/issues/1114 is fixed
67+
return hash((type(self), len(self.inner_inputs), len(self.inner_outputs)))
68+
69+
70+
@_support_point.register
71+
def support_point_marginal_rv(op: MarginalRV, rv, *inputs):
72+
"""Support point for a marginalized RV.
73+
74+
The support point of a marginalized RV is the support point of the inner RV,
75+
conditioned on the marginalized RV taking its support point.
76+
"""
77+
outputs = rv.owner.outputs
78+
79+
inner_rv = op.inner_outputs[outputs.index(rv)]
80+
marginalized_inner_rv, *other_dependent_inner_rvs = (
81+
out
82+
for out in op.inner_outputs
83+
if out is not inner_rv and not isinstance(out.type, RandomType)
84+
)
85+
86+
# Replace references to inner rvs by the dummy variables (including the marginalized RV)
87+
# This is necessary because the inner RVs may depend on each other
88+
marginalized_inner_rv_dummy = marginalized_inner_rv.clone()
89+
other_dependent_inner_rv_to_dummies = {
90+
inner_rv: inner_rv.clone() for inner_rv in other_dependent_inner_rvs
91+
}
92+
inner_rv = clone_replace(
93+
inner_rv,
94+
replace={marginalized_inner_rv: marginalized_inner_rv_dummy}
95+
| other_dependent_inner_rv_to_dummies,
96+
)
97+
98+
# Get support point of inner RV and marginalized RV
99+
inner_rv_support_point = support_point(inner_rv)
100+
marginalized_inner_rv_support_point = support_point(marginalized_inner_rv)
101+
102+
replacements = [
103+
# Replace the marginalized RV dummy by its support point
104+
(marginalized_inner_rv_dummy, marginalized_inner_rv_support_point),
105+
# Replace other dependent RVs dummies by the respective outer outputs.
106+
# PyMC will replace them by their support points later
107+
*(
108+
(v, outputs[op.inner_outputs.index(k)])
109+
for k, v in other_dependent_inner_rv_to_dummies.items()
110+
),
111+
# Replace outer input RVs
112+
*zip(op.inner_inputs, inputs),
113+
]
114+
fgraph = FunctionGraph(outputs=[inner_rv_support_point], clone=False)
115+
fgraph.replace_all(replacements, import_missing=True)
116+
[rv_support_point] = fgraph.outputs
117+
return rv_support_point
118+
46119

47120
class MarginalFiniteDiscreteRV(MarginalRV):
48121
"""Base class for Marginalized Finite Discrete RVs"""
@@ -132,12 +205,22 @@ def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Var
132205
Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"
133206
the inner graph.
134207
"""
135-
return clone_replace(
208+
return graph_replace(
136209
op.inner_outputs,
137210
replace=tuple(zip(op.inner_inputs, inputs)),
138211
)
139212

140213

214+
def warn_logp_non_separable(values):
215+
if len(values) > 1:
216+
warnings.warn(
217+
"There are multiple dependent variables in a FiniteDiscreteMarginalRV. "
218+
f"Their joint logp terms will be assigned to the first value: {values[0]}.",
219+
UserWarning,
220+
stacklevel=2,
221+
)
222+
223+
141224
DUMMY_ZERO = pt.constant(0, name="dummy_zero")
142225

143226

@@ -200,6 +283,7 @@ def logp_fn(marginalized_rv_const, *non_sequences):
200283
joint_logp = align_logp_dims(dims=op.dims_connections[0], logp=joint_logp)
201284

202285
# We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
286+
warn_logp_non_separable(values)
203287
dummy_logps = (DUMMY_ZERO,) * (len(values) - 1)
204288
return joint_logp, *dummy_logps
205289

@@ -272,5 +356,6 @@ def step_alpha(logp_emission, log_alpha, log_P):
272356

273357
# If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
274358
# return is the joint probability of everything together, but PyMC still expects one logp for each emission stream.
359+
warn_logp_non_separable(values)
275360
dummy_logps = (DUMMY_ZERO,) * (len(values) - 1)
276361
return joint_logp, *dummy_logps

Diff for: pymc_experimental/model/marginal/graph_analysis.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from itertools import zip_longest
55

66
from pymc import SymbolicRandomVariable
7+
from pymc.model.fgraph import ModelVar
78
from pytensor.compile import SharedVariable
89
from pytensor.graph import Constant, Variable, ancestors
910
from pytensor.graph.basic import io_toposort
@@ -35,12 +36,12 @@ def static_shape_ancestors(vars):
3536

3637
def find_conditional_input_rvs(output_rvs, all_rvs):
3738
"""Find conditionally indepedent input RVs."""
38-
blockers = [other_rv for other_rv in all_rvs if other_rv not in output_rvs]
39-
blockers += static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs))
39+
other_rvs = [other_rv for other_rv in all_rvs if other_rv not in output_rvs]
40+
blockers = other_rvs + static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs))
4041
return [
4142
var
4243
for var in ancestors(output_rvs, blockers=blockers)
43-
if var in blockers or (var.owner is None and not isinstance(var, Constant | SharedVariable))
44+
if var in other_rvs
4445
]
4546

4647

@@ -141,6 +142,9 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars)
141142
# None of the inputs are related to the batch_axes of the input_vars
142143
continue
143144

145+
elif isinstance(node.op, ModelVar):
146+
var_dims[node.outputs[0]] = inputs_dims[0]
147+
144148
elif isinstance(node.op, DimShuffle):
145149
[input_dims] = inputs_dims
146150
output_dims = tuple(None if i == "x" else input_dims[i] for i in node.op.new_order)

0 commit comments

Comments
 (0)