diff --git a/pymc_extras/model/marginal/__init__.py b/pymc_extras/model/marginal/__init__.py index e69de29b..762faf2a 100644 --- a/pymc_extras/model/marginal/__init__.py +++ b/pymc_extras/model/marginal/__init__.py @@ -0,0 +1 @@ +import pymc_extras.model.marginal.rewrites # Need import to register rewrites \ No newline at end of file diff --git a/pymc_extras/model/marginal/distributions.py b/pymc_extras/model/marginal/distributions.py index 86aa5f02..872f0f65 100644 --- a/pymc_extras/model/marginal/distributions.py +++ b/pymc_extras/model/marginal/distributions.py @@ -24,37 +24,20 @@ from pymc_extras.distributions import DiscreteMarkovChain -class MarginalRV(OpFromGraph, MeasurableOp): +class MarginalRV(OpFromGraph): """Base class for Marginalized RVs""" def __init__( self, *args, - dims_connections: tuple[tuple[int | None], ...], dims: tuple[Variable, ...], + n_dependent_rvs: int, **kwargs, ) -> None: - self.dims_connections = dims_connections self.dims = dims + self.n_dependent_rvs = n_dependent_rvs super().__init__(*args, **kwargs) - @property - def support_axes(self) -> tuple[tuple[int]]: - """Dimensions of dependent RVs that belong to the core (non-batched) marginalized variable.""" - marginalized_ndim_supp = self.inner_outputs[0].owner.op.ndim_supp - support_axes_vars = [] - for dims_connection in self.dims_connections: - ndim = len(dims_connection) - marginalized_supp_axes = ndim - marginalized_ndim_supp - support_axes_vars.append( - tuple( - -i - for i, dim in enumerate(reversed(dims_connection), start=1) - if (dim is None or dim > marginalized_supp_axes) - ) - ) - return tuple(support_axes_vars) - def __eq__(self, other): # Just to allow easy testing of equivalent models, # This can be removed once https://github.com/pymc-devs/pytensor/issues/1114 is fixed @@ -124,11 +107,35 @@ def support_point_marginal_rv(op: MarginalRV, rv, *inputs): return rv_support_point -class MarginalFiniteDiscreteRV(MarginalRV): +class MarginalEnumerableRV(MarginalRV, MeasurableOp): + + def __init__(self, *args, dims_connections: tuple[tuple[int | None], ...], **kwargs): + super().__init__(*args, **kwargs) + self.dims_connections = dims_connections + + @property + def support_axes(self) -> tuple[tuple[int]]: + """Dimensions of dependent RVs that belong to the core (non-batched) marginalized variable.""" + marginalized_ndim_supp = self.inner_outputs[0].owner.op.ndim_supp + support_axes_vars = [] + for dims_connection in self.dims_connections: + ndim = len(dims_connection) + marginalized_supp_axes = ndim - marginalized_ndim_supp + support_axes_vars.append( + tuple( + -i + for i, dim in enumerate(reversed(dims_connection), start=1) + if (dim is None or dim > marginalized_supp_axes) + ) + ) + return tuple(support_axes_vars) + + +class MarginalFiniteDiscreteRV(MarginalEnumerableRV): """Base class for Marginalized Finite Discrete RVs""" -class MarginalDiscreteMarkovChainRV(MarginalRV): +class MarginalDiscreteMarkovChainRV(MarginalEnumerableRV): """Base class for Marginalized Discrete Markov Chain RVs""" @@ -239,7 +246,9 @@ def warn_non_separable_logp(values): @_logprob.register(MarginalFiniteDiscreteRV) def finite_discrete_marginal_rv_logp(op: MarginalFiniteDiscreteRV, values, *inputs, **kwargs): # Clone the inner RV graph of the Marginalized RV - marginalized_rv, *inner_rvs = inline_ofg_outputs(op, inputs) + marginalized_rv, *inner_rvs_and_rngs = inline_ofg_outputs(op, inputs) + inner_rvs = inner_rvs_and_rngs[:op.n_dependent_rvs] + assert len(values) == len(inner_rvs) # Obtain the joint_logp graph of the inner RV graph inner_rv_values = dict(zip(inner_rvs, values)) @@ -302,7 +311,9 @@ def logp_fn(marginalized_rv_const, *non_sequences): @_logprob.register(MarginalDiscreteMarkovChainRV) def marginal_hmm_logp(op, values, *inputs, **kwargs): - chain_rv, *dependent_rvs = inline_ofg_outputs(op, inputs) + chain_rv, *dependent_rvs_and_rngs = inline_ofg_outputs(op, inputs) + dependent_rvs = dependent_rvs_and_rngs[:op.n_dependent_rvs] + assert len(values) == len(dependent_rvs) P, n_steps_, init_dist_, rng = chain_rv.owner.inputs domain = pt.arange(P.shape[-1], dtype="int32") diff --git a/pymc_extras/model/marginal/marginal_model.py b/pymc_extras/model/marginal/marginal_model.py index b6ca25bf..490892ae 100644 --- a/pymc_extras/model/marginal/marginal_model.py +++ b/pymc_extras/model/marginal/marginal_model.py @@ -62,50 +62,6 @@ class MarginalModel(Model): - """Subclass of PyMC Model that implements functionality for automatic - marginalization of variables in the logp transformation - - After defining the full Model, the `marginalize` method can be used to indicate a - subset of variables that should be marginalized - - Notes - ----- - Marginalization functionality is still very restricted. Only finite discrete - variables can be marginalized. Deterministics and Potentials cannot be conditionally - dependent on the marginalized variables. - - Furthermore, not all instances of such variables can be marginalized. If a variable - has batched dimensions, it is required that any conditionally dependent variables - use information from an individual batched dimension. In other words, the graph - connecting the marginalized variable(s) to the dependent variable(s) must be - composed strictly of Elemwise Operations. This is necessary to ensure an efficient - logprob graph can be generated. If you want to bypass this restriction you can - separate each dimension of the marginalized variable into the scalar components - and then stack them together. Note that such graphs will grow exponentially in the - number of marginalized variables. - - For the same reason, it's not possible to marginalize RVs with multivariate - dependent RVs. - - Examples - -------- - Marginalize over a single variable - - .. code-block:: python - - import pymc as pm - from pymc_extras import MarginalModel - - with MarginalModel() as m: - p = pm.Beta("p", 1, 1) - x = pm.Bernoulli("x", p=p, shape=(3,)) - y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10]) - - m.marginalize([x]) - - idata = pm.sample() - - """ def __init__(self, *args, **kwargs): raise TypeError( @@ -147,10 +103,29 @@ def _unique(seq: Sequence) -> list: def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel: """Marginalize a subset of variables in a PyMC model. - This creates a class of `MarginalModel` from an existing `Model`, with the specified - variables marginalized. + Notes + ----- + Marginalization functionality is still very restricted. Only finite discrete + variables and some closed from graphs can be marginalized. + Deterministics and Potentials cannot be conditionally dependent on the marginalized variables. - See documentation for `MarginalModel` for more information. + + Examples + -------- + Marginalize over a single variable + + .. code-block:: python + + import pymc as pm + from pymc_extras import marginalize + + with pm.Model() as m: + p = pm.Beta("p", 1, 1) + x = pm.Bernoulli("x", p=p, shape=(3,)) + y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10]) + + with marginalize(m, [x]) as marginal_m: + idata = pm.sample() Parameters ---------- @@ -161,8 +136,8 @@ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel: Returns ------- - marginal_model: MarginalModel - Marginal model with the specified variables marginalized. + marginal_model: Model + PyMC model with the specified variables marginalized. """ if isinstance(rvs_to_marginalize, str | Variable): rvs_to_marginalize = (rvs_to_marginalize,) @@ -176,20 +151,20 @@ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel: if rv_to_marginalize not in model.free_RVs: raise ValueError(f"Marginalized RV {rv_to_marginalize} is not a free RV in the model") - rv_op = rv_to_marginalize.owner.op - if isinstance(rv_op, DiscreteMarkovChain): - if rv_op.n_lags > 1: - raise NotImplementedError( - "Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported" - ) - if rv_to_marginalize.owner.inputs[0].type.ndim > 2: - raise NotImplementedError( - "Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported" - ) - elif not isinstance(rv_op, Bernoulli | Categorical | DiscreteUniform): - raise NotImplementedError( - f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported" - ) + # rv_op = rv_to_marginalize.owner.op + # if isinstance(rv_op, DiscreteMarkovChain): + # if rv_op.n_lags > 1: + # raise NotImplementedError( + # "Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported" + # ) + # if rv_to_marginalize.owner.inputs[0].type.ndim > 2: + # raise NotImplementedError( + # "Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported" + # ) + # elif not isinstance(rv_op, Bernoulli | Categorical | DiscreteUniform): + # raise NotImplementedError( + # f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported" + # ) fg, memo = fgraph_from_model(model) rvs_to_marginalize = [memo[rv] for rv in rvs_to_marginalize] @@ -241,11 +216,52 @@ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel: ] input_rvs = _unique((*marginalized_rv_input_rvs, *other_direct_rv_ancestors)) - replace_finite_discrete_marginal_subgraph(fg, rv_to_marginalize, dependent_rvs, input_rvs) + marginalize_subgraph(fg, rv_to_marginalize, dependent_rvs, input_rvs) return model_from_fgraph(fg, mutate_fgraph=True) +def marginalize_subgraph( + fgraph, rv_to_marginalize, dependent_rvs, input_rvs +) -> None: + + output_rvs = [rv_to_marginalize, *dependent_rvs] + rng_updates = collect_default_updates(output_rvs, inputs=input_rvs, must_be_shared=False) + outputs = output_rvs + list(rng_updates.values()) + inputs = input_rvs + list(rng_updates.keys()) + # Add any other shared variable inputs + inputs += collect_shared_vars(output_rvs, blockers=inputs) + + inner_inputs = [inp.clone() for inp in inputs] + inner_outputs = clone_replace(outputs, replace=dict(zip(inputs, inner_inputs))) + inner_outputs = remove_model_vars(inner_outputs) + + _, _, *dims = rv_to_marginalize.owner.inputs + marginalization_op = MarginalRV( + inputs=inner_inputs, + outputs=inner_outputs, + dims=dims, + n_dependent_rvs=len(dependent_rvs) + ) + + new_outputs = marginalization_op(*inputs) + assert len(new_outputs) == len(outputs) + for old_output, new_output in zip(outputs, new_outputs): + new_output.name = old_output.name + + model_replacements = [] + for old_output, new_output in zip(outputs, new_outputs): + if old_output is rv_to_marginalize or not isinstance(old_output.owner.op, ModelValuedVar): + # Replace the marginalized ModelFreeRV (or non model-variables) themselves + var_to_replace = old_output + else: + # Replace the underlying RV, keeping the same value, transform and dims + var_to_replace = old_output.owner.inputs[0] + model_replacements.append((var_to_replace, new_output)) + + fgraph.replace_all(model_replacements) + + @node_rewriter(tracks=[MarginalRV]) def local_unmarginalize(fgraph, node): unmarginalized_rv, *dependent_rvs_and_rngs = inline_ofg_outputs(node.op, node.inputs) diff --git a/pymc_extras/model/marginal/rewrites.py b/pymc_extras/model/marginal/rewrites.py new file mode 100644 index 00000000..aeee2526 --- /dev/null +++ b/pymc_extras/model/marginal/rewrites.py @@ -0,0 +1,143 @@ +from logging import getLogger +from pymc import Normal, Bernoulli, Categorical, DiscreteUniform +from pymc.logprob.rewriting import measurable_ir_rewrites_db +from pymc.pytensorf import constant_fold +from pytensor import clone_replace, graph_replace +from pytensor.graph import node_rewriter, ancestors +import pytensor.tensor as pt + +from pymc_extras import DiscreteMarkovChain +from pymc_extras.model.marginal.distributions import MarginalRV, inline_ofg_outputs, MarginalFiniteDiscreteRV, \ + MarginalDiscreteMarkovChainRV +from pymc_extras.model.marginal.graph_analysis import subgraph_batch_dim_connection + + +logger = getLogger("pymc-logprob") + + +def register_marginal_rewrite(func): + measurable_ir_rewrites_db.register( + func.__name__, func, "basic", "marginal" + ) + +@register_marginal_rewrite +@node_rewriter(tracks=[MarginalRV]) +def finite_discrete_marginal(fgraph, node): + if type(node.op) is not MarginalRV: + # Already not a raw MarginalRV + return + + fgraph = node.op.fgraph + + marginalized_rv = fgraph.outputs[0] + marginalized_rv_op = marginalized_rv.owner.op + if not isinstance(marginalized_rv_op, Bernoulli | Categorical | DiscreteUniform | DiscreteMarkovChain): + return None + + if isinstance(marginalized_rv_op, DiscreteMarkovChain): + if marginalized_rv_op.n_lags > 1: + logger.error( + "Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported" + ) + return None + if marginalized_rv.owner.inputs[0].type.ndim > 2: + logger.error( + "Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported" + ) + return None + + dependent_rvs = fgraph.outputs[1: 1 + node.op.n_dependent_rvs] + try: + dependent_rvs_dim_connections = subgraph_batch_dim_connection( + marginalized_rv, dependent_rvs + ) + except (ValueError, NotImplementedError) as e: + raise logger.error( + "The graph between the marginalized and dependent RVs cannot be marginalized efficiently. " + "You can try splitting the marginalized RV into separate components and marginalizing them separately." + f"{e}" + ) + return None + + + if isinstance(marginalized_rv_op, DiscreteMarkovChain): + marginalize_constructor = MarginalDiscreteMarkovChainRV + else: + marginalize_constructor = MarginalFiniteDiscreteRV + + # _, _, *dims = rv_to_marginalize.owner.inputs + marginalization_op = marginalize_constructor( + inputs=fgraph.inputs, + outputs=fgraph.outputs, + dims_connections=dependent_rvs_dim_connections, + dims=node.op.dims, + n_dependent_rvs=node.op.n_dependent_rvs, + ) + + new_outputs = marginalization_op(*node.inputs) + return new_outputs + + +@register_marginal_rewrite +@node_rewriter(tracks=[MarginalRV]) +def normal_normal_marginal(fgraph, node): + if type(node.op) is not MarginalRV: + # Already not a raw MarginalRV + return + + if node.op.n_dependent_rvs != 1: + # More than two dependent variables + return + + marginalized_rv, dependent_rv, *_ = node.op.fgraph.outputs + if not ( + isinstance(marginalized_rv.owner.op, Normal) + and isinstance(marginalized_rv.owner.op, Normal) + ): + return + + mu_dependent_rv, sigma_dependent_rv = dependent_rv.owner.op.dist_params(dependent_rv.owner) + mu_marginalized_rv, sigma_marginalized_rv = marginalized_rv.owner.op.dist_params(marginalized_rv.owner) + + if marginalized_rv in ancestors([sigma_dependent_rv]): + return + + # Check that we have mu = marginalized_rv + offset + if not mu_dependent_rv is marginalized_rv: + add_node = mu_dependent_rv.owner + if not (add_node and add_node.op == pt.add and len(add_node.inputs) == 2): + return + a, b = add_node.inputs + if a is marginalized_rv: + if marginalized_rv in ancestors([b]): + # The marginalized_rv shows up in both branches of the addition + return + elif b is marginalized_rv: + if marginalized_rv in ancestors([a]): + # The marginalized_rv shows up in both branches of the addition + return + else: + # There's a more complicated function between the marginalized_rv and the mean of the dependent_rv + return + + + # Replace reference to marginalized RV by its mean (possibly broadcasted): + if marginalized_rv.type.broadcastable != mu_marginalized_rv.type.broadcastable: + mu_marginalized_rv = pt.broadcast_to( + mu_marginalized_rv, + constant_fold(marginalized_rv.shape, raise_not_constant=False) + ) + rng_dependent_rv = dependent_rv.owner.op.rng_param(dependent_rv.owner) + size_dependent_rv = dependent_rv.owner.op.size_param(dependent_rv.owner) + + new_mu = clone_replace(mu_dependent_rv, {marginalized_rv: mu_marginalized_rv}) + new_sigma = pt.sqrt(sigma_dependent_rv ** 2 + sigma_marginalized_rv ** 2) + new_rv = Normal.dist(mu=new_mu, sigma=new_sigma, size=size_dependent_rv, rng=rng_dependent_rv) + + # Replace inner inputs by outer inputs + new_rv = graph_replace( + new_rv, + replace=tuple(zip(node.op.inner_inputs, node.inputs)), + strict=False, + ) + return {node.outputs[1]: new_rv} diff --git a/tests/model/marginal/test_closed_form.py b/tests/model/marginal/test_closed_form.py new file mode 100644 index 00000000..fc84d27c --- /dev/null +++ b/tests/model/marginal/test_closed_form.py @@ -0,0 +1,41 @@ +import numpy as np +import pytest +import scipy +import pymc as pm + +from pymc_extras import marginalize + + + +def test_normal_normal(): + with pm.Model() as m: + x = pm.Normal("x", mu=0, sigma=1) + y = pm.Normal("y", mu=x + np.pi - 1, sigma=1.0) + z = pm.Normal("z", mu=y + 2 * np.pi, sigma=np.sqrt(np.e)) + + marginal_m = marginalize(m, m["y"]) + + test_point = {"x": 1, "z": -1} + + np.testing.assert_allclose( + marginal_m.compile_logp([marginal_m["z"]])(test_point), + scipy.stats.norm.logpdf(test_point["z"], np.pi * 3, np.sqrt(1 + np.e)) + ) + +def test_normal_normal_does_not_apply(): + # If these cases become supported, the test should be repurposed + + with pm.Model() as m1: + y = pm.Normal("y", mu=1) + z = pm.Normal("z", mu=y * 2) + + with pytest.raises(RuntimeError, match="could not be derived"): + marginalize(m1, y).logp() + + with pm.Model() as m2: + y = pm.Normal("y", mu=1) + z = pm.Normal("z", mu=y) + w = pm.Normal("w", mu=y) + + with pytest.raises(RuntimeError, match="could not be derived"): + marginalize(m2, y).logp() diff --git a/tests/model/marginal/test_distributions.py b/tests/model/marginal/test_distributions.py index 434ec271..5b8e4108 100644 --- a/tests/model/marginal/test_distributions.py +++ b/tests/model/marginal/test_distributions.py @@ -22,6 +22,7 @@ def test_marginalized_bernoulli_logp(): [idx, y], dims_connections=(((),),), dims=(), + n_dependent_rvs=1, )(mu)[0].owner y_vv = y.clone() diff --git a/tests/model/marginal/test_marginal_model.py b/tests/model/marginal/test_marginal_model.py index 0ab7991b..eb4b38b1 100644 --- a/tests/model/marginal/test_marginal_model.py +++ b/tests/model/marginal/test_marginal_model.py @@ -400,8 +400,8 @@ def test_mixed_dims_via_transposed_dot(self): idx = pm.Bernoulli("idx", p=0.7, shape=2) y = pm.Normal("y", mu=idx @ idx.T) - with pytest.raises(NotImplementedError): - marginalize(m, idx) + with pytest.raises(RuntimeError): + marginalize(m, idx).logp() def test_mixed_dims_via_indexing(self): mean = pt.as_tensor([[0.1, 0.9], [0.6, 0.4]]) @@ -409,14 +409,14 @@ def test_mixed_dims_via_indexing(self): with Model() as m: idx = pm.Bernoulli("idx", p=0.7, shape=2) y = pm.Normal("y", mu=mean[idx, :] + mean[:, idx]) - with pytest.raises(NotImplementedError): - marginalize(m, idx) + with pytest.raises(RuntimeError): + marginalize(m, idx).logp() with Model() as m: idx = pm.Bernoulli("idx", p=0.7, shape=2) y = pm.Normal("y", mu=mean[idx, None] + mean[None, idx]) - with pytest.raises(NotImplementedError): - marginalize(m, idx) + with pytest.raises(RuntimeError): + marginalize(m, idx).logp() with Model() as m: idx = pm.Bernoulli("idx", p=0.7, shape=2) @@ -424,34 +424,34 @@ def test_mixed_dims_via_indexing(self): mean[None, :][:, idx], 0 ) y = pm.Normal("y", mu=mu) - with pytest.raises(NotImplementedError): - marginalize(m, idx) + with pytest.raises(RuntimeError): + marginalize(m, idx).logp() with Model() as m: idx = pm.Bernoulli("idx", p=0.7, shape=2) y = pm.Normal("y", mu=idx[0] + idx[1]) - with pytest.raises(NotImplementedError): - marginalize(m, idx) + with pytest.raises(RuntimeError): + marginalize(m, idx).logp() def test_mixed_dims_via_vector_indexing(self): with Model() as m: idx = pm.Bernoulli("idx", p=0.7, shape=2) y = pm.Normal("y", mu=idx[[0, 1, 0, 0]]) - with pytest.raises(NotImplementedError): - marginalize(m, idx) + with pytest.raises(RuntimeError): + marginalize(m, idx).logp() with Model() as m: idx = pm.Categorical("key", p=[0.1, 0.3, 0.6], shape=(2, 2)) y = pm.Normal("y", pt.as_tensor([[0, 1], [2, 3]])[idx.astype(bool)]) - with pytest.raises(NotImplementedError): - marginalize(m, idx) + with pytest.raises(RuntimeError): + marginalize(m, idx).logp() def test_mixed_dims_via_support_dimension(self): with Model() as m: x = pm.Bernoulli("x", p=0.7, shape=3) y = pm.Dirichlet("y", a=x * 10 + 1) - with pytest.raises(NotImplementedError): - marginalize(m, x) + with pytest.raises(RuntimeError): + marginalize(m, x).logp() def test_mixed_dims_via_nested_marginalization(self): with Model() as m: @@ -459,8 +459,8 @@ def test_mixed_dims_via_nested_marginalization(self): y = pm.Bernoulli("y", p=0.7, shape=(2,)) z = pm.Normal("z", mu=pt.add.outer(x, y), shape=(3, 2)) - with pytest.raises(NotImplementedError): - marginalize(m, [x, y]) + with pytest.raises(RuntimeError): + marginalize(m, [x, y]).logp() def test_marginalized_deterministic_and_potential():