Skip to content

Commit dfe3fe0

Browse files
committed
Allow creating MarginalModel from existing Model
1 parent 6d12203 commit dfe3fe0

File tree

4 files changed

+134
-23
lines changed

4 files changed

+134
-23
lines changed

Diff for: docs/api_reference.rst

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ methods in the current release of PyMC experimental.
1010

1111
as_model
1212
MarginalModel
13+
marginalize
1314
model_builder.ModelBuilder
1415

1516
Inference

Diff for: pymc_experimental/model/marginal_model.py

+69-23
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import warnings
2-
from typing import Sequence
2+
from typing import Sequence, Union
33

44
import numpy as np
55
import pymc
@@ -25,10 +25,12 @@
2525
from pytensor.tensor.shape import Shape
2626
from pytensor.tensor.special import log_softmax
2727

28-
__all__ = ["MarginalModel"]
28+
__all__ = ["MarginalModel", "marginalize"]
2929

3030
from pymc_experimental.distributions import DiscreteMarkovChain
3131

32+
ModelRVs = TensorVariable | Sequence[TensorVariable] | str | Sequence[str]
33+
3234

3335
class MarginalModel(Model):
3436
"""Subclass of PyMC Model that implements functionality for automatic
@@ -207,35 +209,50 @@ def logp(self, vars=None, **kwargs):
207209
vars = [m[var.name] for var in vars]
208210
return m._logp(vars=vars, **kwargs)
209211

210-
def clone(self):
211-
m = MarginalModel(coords=self.coords)
212-
model_vars = self.basic_RVs + self.potentials + self.deterministics + self.marginalized_rvs
213-
data_vars = [var for name, var in self.named_vars.items() if var not in model_vars]
212+
@staticmethod
213+
def from_model(model: Union[Model, "MarginalModel"]) -> "MarginalModel":
214+
new_model = MarginalModel(coords=model.coords)
215+
if isinstance(model, MarginalModel):
216+
marginalized_rvs = model.marginalized_rvs
217+
marginalized_named_vars_to_dims = model._marginalized_named_vars_to_dims
218+
else:
219+
marginalized_rvs = []
220+
marginalized_named_vars_to_dims = {}
221+
222+
model_vars = model.basic_RVs + model.potentials + model.deterministics + marginalized_rvs
223+
data_vars = [var for name, var in model.named_vars.items() if var not in model_vars]
214224
vars = model_vars + data_vars
215225
cloned_vars = clone_replace(vars)
216226
vars_to_clone = {var: cloned_var for var, cloned_var in zip(vars, cloned_vars)}
217-
m.vars_to_clone = vars_to_clone
218-
219-
m.named_vars = treedict({name: vars_to_clone[var] for name, var in self.named_vars.items()})
220-
m.named_vars_to_dims = self.named_vars_to_dims
221-
m.values_to_rvs = {i: vars_to_clone[rv] for i, rv in self.values_to_rvs.items()}
222-
m.rvs_to_values = {vars_to_clone[rv]: i for rv, i in self.rvs_to_values.items()}
223-
m.rvs_to_transforms = {vars_to_clone[rv]: i for rv, i in self.rvs_to_transforms.items()}
224-
m.rvs_to_initial_values = {
225-
vars_to_clone[rv]: i for rv, i in self.rvs_to_initial_values.items()
227+
new_model.vars_to_clone = vars_to_clone
228+
229+
new_model.named_vars = treedict(
230+
{name: vars_to_clone[var] for name, var in model.named_vars.items()}
231+
)
232+
new_model.named_vars_to_dims = model.named_vars_to_dims
233+
new_model.values_to_rvs = {vv: vars_to_clone[rv] for vv, rv in model.values_to_rvs.items()}
234+
new_model.rvs_to_values = {vars_to_clone[rv]: vv for rv, vv in model.rvs_to_values.items()}
235+
new_model.rvs_to_transforms = {
236+
vars_to_clone[rv]: tr for rv, tr in model.rvs_to_transforms.items()
237+
}
238+
new_model.rvs_to_initial_values = {
239+
vars_to_clone[rv]: iv for rv, iv in model.rvs_to_initial_values.items()
226240
}
227-
m.free_RVs = [vars_to_clone[rv] for rv in self.free_RVs]
228-
m.observed_RVs = [vars_to_clone[rv] for rv in self.observed_RVs]
229-
m.potentials = [vars_to_clone[pot] for pot in self.potentials]
230-
m.deterministics = [vars_to_clone[det] for det in self.deterministics]
241+
new_model.free_RVs = [vars_to_clone[rv] for rv in model.free_RVs]
242+
new_model.observed_RVs = [vars_to_clone[rv] for rv in model.observed_RVs]
243+
new_model.potentials = [vars_to_clone[pot] for pot in model.potentials]
244+
new_model.deterministics = [vars_to_clone[det] for det in model.deterministics]
231245

232-
m.marginalized_rvs = [vars_to_clone[rv] for rv in self.marginalized_rvs]
233-
m._marginalized_named_vars_to_dims = self._marginalized_named_vars_to_dims
234-
return m
246+
new_model.marginalized_rvs = [vars_to_clone[rv] for rv in marginalized_rvs]
247+
new_model._marginalized_named_vars_to_dims = marginalized_named_vars_to_dims
248+
return new_model
249+
250+
def clone(self):
251+
return self.from_model(self)
235252

236253
def marginalize(
237254
self,
238-
rvs_to_marginalize: TensorVariable | Sequence[TensorVariable] | str | Sequence[str],
255+
rvs_to_marginalize: ModelRVs,
239256
):
240257
if not isinstance(rvs_to_marginalize, Sequence):
241258
rvs_to_marginalize = (rvs_to_marginalize,)
@@ -491,6 +508,35 @@ def transform_input(inputs):
491508
return rv_dataset
492509

493510

511+
def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
512+
"""Marginalize a subset of variables in a PyMC model.
513+
514+
This creates a class of `MarginalModel` from an existing `Model`, with the specified
515+
variables marginalized.
516+
517+
See documentation for `MarginalModel` for more information.
518+
519+
Parameters
520+
----------
521+
model : Model
522+
PyMC model to marginalize. Original variables well be cloned.
523+
rvs_to_marginalize : Sequence[TensorVariable]
524+
Variables to marginalize in the returned model.
525+
526+
Returns
527+
-------
528+
marginal_model: MarginalModel
529+
Marginal model with the specified variables marginalized.
530+
"""
531+
if not isinstance(rvs_to_marginalize, tuple | list):
532+
rvs_to_marginalize = (rvs_to_marginalize,)
533+
rvs_to_marginalize = [rv if isinstance(rv, str) else rv.name for rv in rvs_to_marginalize]
534+
535+
marginal_model = MarginalModel.from_model(model)
536+
marginal_model.marginalize(rvs_to_marginalize)
537+
return marginal_model
538+
539+
494540
class MarginalRV(SymbolicRandomVariable):
495541
"""Base class for Marginalized RVs"""
496542

Diff for: pymc_experimental/tests/model/test_marginal_model.py

+33
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pymc import ImputationWarning, inputvars
1111
from pymc.distributions import transforms
1212
from pymc.logprob.abstract import _logprob
13+
from pymc.model.fgraph import fgraph_from_model
1314
from pymc.util import UNSET
1415
from scipy.special import log_softmax, logsumexp
1516
from scipy.stats import halfnorm, norm
@@ -19,7 +20,9 @@
1920
FiniteDiscreteMarginalRV,
2021
MarginalModel,
2122
is_conditional_dependent,
23+
marginalize,
2224
)
25+
from pymc_experimental.tests.utils import equal_computations_up_to_root
2326

2427

2528
@pytest.fixture
@@ -776,3 +779,33 @@ def test_mutable_indexing_jax_backend():
776779
pm.LogNormal("y", mu=cat_effect[cat_effect_idx], sigma=1 + is_outlier, observed=data)
777780
model.marginalize(["is_outlier"])
778781
get_jaxified_logp(model)
782+
783+
784+
def test_marginal_model_func():
785+
def create_model(model_class):
786+
with model_class(coords={"trial": range(10)}) as m:
787+
idx = pm.Bernoulli("idx", p=0.5, dims="trial")
788+
mu = pt.where(idx, 1, -1)
789+
sigma = pm.HalfNormal("sigma")
790+
y = pm.Normal("y", mu=mu, sigma=sigma, dims="trial", observed=[1] * 10)
791+
return m
792+
793+
marginal_m = marginalize(create_model(pm.Model), ["idx"])
794+
assert isinstance(marginal_m, MarginalModel)
795+
796+
reference_m = create_model(MarginalModel)
797+
reference_m.marginalize(["idx"])
798+
799+
# Check forward graph representation is the same
800+
marginal_fgraph, _ = fgraph_from_model(marginal_m)
801+
reference_fgraph, _ = fgraph_from_model(reference_m)
802+
assert equal_computations_up_to_root(marginal_fgraph.outputs, reference_fgraph.outputs)
803+
804+
# Check logp graph is the same
805+
# This fails because OpFromGraphs comparison is broken
806+
# assert equal_computations_up_to_root([marginal_m.logp()], [reference_m.logp()])
807+
ip = marginal_m.initial_point()
808+
np.testing.assert_allclose(
809+
marginal_m.compile_logp()(ip),
810+
reference_m.compile_logp()(ip),
811+
)

Diff for: pymc_experimental/tests/utils.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from typing import Sequence
2+
3+
from pytensor.compile import SharedVariable
4+
from pytensor.graph import Constant, graph_inputs
5+
from pytensor.graph.basic import Variable, equal_computations
6+
from pytensor.tensor.random.type import RandomType
7+
8+
9+
def equal_computations_up_to_root(
10+
xs: Sequence[Variable], ys: Sequence[Variable], ignore_rng_values=True
11+
) -> bool:
12+
# Check if graphs are equivalent even if root variables have distinct identities
13+
14+
x_graph_inputs = [var for var in graph_inputs(xs) if not isinstance(var, Constant)]
15+
y_graph_inputs = [var for var in graph_inputs(ys) if not isinstance(var, Constant)]
16+
if len(x_graph_inputs) != len(y_graph_inputs):
17+
return False
18+
for x, y in zip(x_graph_inputs, y_graph_inputs):
19+
if x.type != y.type:
20+
return False
21+
if x.name != y.name:
22+
return False
23+
if isinstance(x, SharedVariable):
24+
if not isinstance(y, SharedVariable):
25+
return False
26+
if isinstance(x.type, RandomType) and ignore_rng_values:
27+
continue
28+
if not x.type.values_eq(x.get_value(), y.get_value()):
29+
return False
30+
31+
return equal_computations(xs, ys, in_xs=x_graph_inputs, in_ys=y_graph_inputs)

0 commit comments

Comments
 (0)