Skip to content

Allow closed form marginalization #441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pymc_extras/model/marginal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import pymc_extras.model.marginal.rewrites # Need import to register rewrites
59 changes: 35 additions & 24 deletions pymc_extras/model/marginal/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,37 +24,20 @@
from pymc_extras.distributions import DiscreteMarkovChain


class MarginalRV(OpFromGraph, MeasurableOp):
class MarginalRV(OpFromGraph):
"""Base class for Marginalized RVs"""

def __init__(
self,
*args,
dims_connections: tuple[tuple[int | None], ...],
dims: tuple[Variable, ...],
n_dependent_rvs: int,
**kwargs,
) -> None:
self.dims_connections = dims_connections
self.dims = dims
self.n_dependent_rvs = n_dependent_rvs
super().__init__(*args, **kwargs)

@property
def support_axes(self) -> tuple[tuple[int]]:
"""Dimensions of dependent RVs that belong to the core (non-batched) marginalized variable."""
marginalized_ndim_supp = self.inner_outputs[0].owner.op.ndim_supp
support_axes_vars = []
for dims_connection in self.dims_connections:
ndim = len(dims_connection)
marginalized_supp_axes = ndim - marginalized_ndim_supp
support_axes_vars.append(
tuple(
-i
for i, dim in enumerate(reversed(dims_connection), start=1)
if (dim is None or dim > marginalized_supp_axes)
)
)
return tuple(support_axes_vars)

def __eq__(self, other):
# Just to allow easy testing of equivalent models,
# This can be removed once https://github.com/pymc-devs/pytensor/issues/1114 is fixed
Expand Down Expand Up @@ -124,11 +107,35 @@ def support_point_marginal_rv(op: MarginalRV, rv, *inputs):
return rv_support_point


class MarginalFiniteDiscreteRV(MarginalRV):
class MarginalEnumerableRV(MarginalRV, MeasurableOp):

def __init__(self, *args, dims_connections: tuple[tuple[int | None], ...], **kwargs):
super().__init__(*args, **kwargs)
self.dims_connections = dims_connections

@property
def support_axes(self) -> tuple[tuple[int]]:
"""Dimensions of dependent RVs that belong to the core (non-batched) marginalized variable."""
marginalized_ndim_supp = self.inner_outputs[0].owner.op.ndim_supp
support_axes_vars = []
for dims_connection in self.dims_connections:
ndim = len(dims_connection)
marginalized_supp_axes = ndim - marginalized_ndim_supp
support_axes_vars.append(
tuple(
-i
for i, dim in enumerate(reversed(dims_connection), start=1)
if (dim is None or dim > marginalized_supp_axes)
)
)
return tuple(support_axes_vars)


class MarginalFiniteDiscreteRV(MarginalEnumerableRV):
"""Base class for Marginalized Finite Discrete RVs"""


class MarginalDiscreteMarkovChainRV(MarginalRV):
class MarginalDiscreteMarkovChainRV(MarginalEnumerableRV):
"""Base class for Marginalized Discrete Markov Chain RVs"""


Expand Down Expand Up @@ -239,7 +246,9 @@ def warn_non_separable_logp(values):
@_logprob.register(MarginalFiniteDiscreteRV)
def finite_discrete_marginal_rv_logp(op: MarginalFiniteDiscreteRV, values, *inputs, **kwargs):
# Clone the inner RV graph of the Marginalized RV
marginalized_rv, *inner_rvs = inline_ofg_outputs(op, inputs)
marginalized_rv, *inner_rvs_and_rngs = inline_ofg_outputs(op, inputs)
inner_rvs = inner_rvs_and_rngs[:op.n_dependent_rvs]
assert len(values) == len(inner_rvs)

# Obtain the joint_logp graph of the inner RV graph
inner_rv_values = dict(zip(inner_rvs, values))
Expand Down Expand Up @@ -302,7 +311,9 @@ def logp_fn(marginalized_rv_const, *non_sequences):

@_logprob.register(MarginalDiscreteMarkovChainRV)
def marginal_hmm_logp(op, values, *inputs, **kwargs):
chain_rv, *dependent_rvs = inline_ofg_outputs(op, inputs)
chain_rv, *dependent_rvs_and_rngs = inline_ofg_outputs(op, inputs)
dependent_rvs = dependent_rvs_and_rngs[:op.n_dependent_rvs]
assert len(values) == len(dependent_rvs)

P, n_steps_, init_dist_, rng = chain_rv.owner.inputs
domain = pt.arange(P.shape[-1], dtype="int32")
Expand Down
144 changes: 80 additions & 64 deletions pymc_extras/model/marginal/marginal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,50 +62,6 @@


class MarginalModel(Model):
"""Subclass of PyMC Model that implements functionality for automatic
marginalization of variables in the logp transformation

After defining the full Model, the `marginalize` method can be used to indicate a
subset of variables that should be marginalized

Notes
-----
Marginalization functionality is still very restricted. Only finite discrete
variables can be marginalized. Deterministics and Potentials cannot be conditionally
dependent on the marginalized variables.

Furthermore, not all instances of such variables can be marginalized. If a variable
has batched dimensions, it is required that any conditionally dependent variables
use information from an individual batched dimension. In other words, the graph
connecting the marginalized variable(s) to the dependent variable(s) must be
composed strictly of Elemwise Operations. This is necessary to ensure an efficient
logprob graph can be generated. If you want to bypass this restriction you can
separate each dimension of the marginalized variable into the scalar components
and then stack them together. Note that such graphs will grow exponentially in the
number of marginalized variables.

For the same reason, it's not possible to marginalize RVs with multivariate
dependent RVs.

Examples
--------
Marginalize over a single variable

.. code-block:: python

import pymc as pm
from pymc_extras import MarginalModel

with MarginalModel() as m:
p = pm.Beta("p", 1, 1)
x = pm.Bernoulli("x", p=p, shape=(3,))
y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10])

m.marginalize([x])

idata = pm.sample()

"""

def __init__(self, *args, **kwargs):
raise TypeError(
Expand Down Expand Up @@ -147,10 +103,29 @@ def _unique(seq: Sequence) -> list:
def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
"""Marginalize a subset of variables in a PyMC model.

This creates a class of `MarginalModel` from an existing `Model`, with the specified
variables marginalized.
Notes
-----
Marginalization functionality is still very restricted. Only finite discrete
variables and some closed from graphs can be marginalized.
Deterministics and Potentials cannot be conditionally dependent on the marginalized variables.

See documentation for `MarginalModel` for more information.

Examples
--------
Marginalize over a single variable

.. code-block:: python

import pymc as pm
from pymc_extras import marginalize

with pm.Model() as m:
p = pm.Beta("p", 1, 1)
x = pm.Bernoulli("x", p=p, shape=(3,))
y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10])

with marginalize(m, [x]) as marginal_m:
idata = pm.sample()

Parameters
----------
Expand All @@ -161,8 +136,8 @@ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:

Returns
-------
marginal_model: MarginalModel
Marginal model with the specified variables marginalized.
marginal_model: Model
PyMC model with the specified variables marginalized.
"""
if isinstance(rvs_to_marginalize, str | Variable):
rvs_to_marginalize = (rvs_to_marginalize,)
Expand All @@ -176,20 +151,20 @@ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
if rv_to_marginalize not in model.free_RVs:
raise ValueError(f"Marginalized RV {rv_to_marginalize} is not a free RV in the model")

rv_op = rv_to_marginalize.owner.op
if isinstance(rv_op, DiscreteMarkovChain):
if rv_op.n_lags > 1:
raise NotImplementedError(
"Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported"
)
if rv_to_marginalize.owner.inputs[0].type.ndim > 2:
raise NotImplementedError(
"Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
)
elif not isinstance(rv_op, Bernoulli | Categorical | DiscreteUniform):
raise NotImplementedError(
f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported"
)
# rv_op = rv_to_marginalize.owner.op
# if isinstance(rv_op, DiscreteMarkovChain):
# if rv_op.n_lags > 1:
# raise NotImplementedError(
# "Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported"
# )
# if rv_to_marginalize.owner.inputs[0].type.ndim > 2:
# raise NotImplementedError(
# "Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
# )
# elif not isinstance(rv_op, Bernoulli | Categorical | DiscreteUniform):
# raise NotImplementedError(
# f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported"
# )

fg, memo = fgraph_from_model(model)
rvs_to_marginalize = [memo[rv] for rv in rvs_to_marginalize]
Expand Down Expand Up @@ -241,11 +216,52 @@ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
]
input_rvs = _unique((*marginalized_rv_input_rvs, *other_direct_rv_ancestors))

replace_finite_discrete_marginal_subgraph(fg, rv_to_marginalize, dependent_rvs, input_rvs)
marginalize_subgraph(fg, rv_to_marginalize, dependent_rvs, input_rvs)

return model_from_fgraph(fg, mutate_fgraph=True)


def marginalize_subgraph(
fgraph, rv_to_marginalize, dependent_rvs, input_rvs
) -> None:

output_rvs = [rv_to_marginalize, *dependent_rvs]
rng_updates = collect_default_updates(output_rvs, inputs=input_rvs, must_be_shared=False)
outputs = output_rvs + list(rng_updates.values())
inputs = input_rvs + list(rng_updates.keys())
# Add any other shared variable inputs
inputs += collect_shared_vars(output_rvs, blockers=inputs)

inner_inputs = [inp.clone() for inp in inputs]
inner_outputs = clone_replace(outputs, replace=dict(zip(inputs, inner_inputs)))
inner_outputs = remove_model_vars(inner_outputs)

_, _, *dims = rv_to_marginalize.owner.inputs
marginalization_op = MarginalRV(
inputs=inner_inputs,
outputs=inner_outputs,
dims=dims,
n_dependent_rvs=len(dependent_rvs)
)

new_outputs = marginalization_op(*inputs)
assert len(new_outputs) == len(outputs)
for old_output, new_output in zip(outputs, new_outputs):
new_output.name = old_output.name

model_replacements = []
for old_output, new_output in zip(outputs, new_outputs):
if old_output is rv_to_marginalize or not isinstance(old_output.owner.op, ModelValuedVar):
# Replace the marginalized ModelFreeRV (or non model-variables) themselves
var_to_replace = old_output
else:
# Replace the underlying RV, keeping the same value, transform and dims
var_to_replace = old_output.owner.inputs[0]
model_replacements.append((var_to_replace, new_output))

fgraph.replace_all(model_replacements)


@node_rewriter(tracks=[MarginalRV])
def local_unmarginalize(fgraph, node):
unmarginalized_rv, *dependent_rvs_and_rngs = inline_ofg_outputs(node.op, node.inputs)
Expand Down
Loading
Loading