|
| 1 | +import warnings |
| 2 | + |
1 | 3 | from collections.abc import Sequence
|
2 | 4 |
|
3 | 5 | import numpy as np
|
4 | 6 | import pytensor.tensor as pt
|
5 | 7 |
|
6 | 8 | from pymc.distributions import Bernoulli, Categorical, DiscreteUniform
|
| 9 | +from pymc.distributions.distribution import _support_point, support_point |
7 | 10 | from pymc.logprob.abstract import MeasurableOp, _logprob
|
8 | 11 | from pymc.logprob.basic import conditional_logp, logp
|
9 | 12 | from pymc.pytensorf import constant_fold
|
10 | 13 | from pytensor import Variable
|
11 | 14 | from pytensor.compile.builders import OpFromGraph
|
12 | 15 | from pytensor.compile.mode import Mode
|
13 |
| -from pytensor.graph import Op, vectorize_graph |
| 16 | +from pytensor.graph import FunctionGraph, Op, vectorize_graph |
| 17 | +from pytensor.graph.basic import equal_computations |
14 | 18 | from pytensor.graph.replace import clone_replace, graph_replace
|
15 | 19 | from pytensor.scan import map as scan_map
|
16 | 20 | from pytensor.scan import scan
|
17 | 21 | from pytensor.tensor import TensorVariable
|
| 22 | +from pytensor.tensor.random.type import RandomType |
18 | 23 |
|
19 | 24 | from pymc_experimental.distributions import DiscreteMarkovChain
|
20 | 25 |
|
@@ -43,6 +48,74 @@ def support_axes(self) -> tuple[tuple[int]]:
|
43 | 48 | )
|
44 | 49 | return tuple(support_axes_vars)
|
45 | 50 |
|
| 51 | + def __eq__(self, other): |
| 52 | + # Just to allow easy testing of equivalent models, |
| 53 | + # This can be removed once https://github.com/pymc-devs/pytensor/issues/1114 is fixed |
| 54 | + if type(self) is not type(other): |
| 55 | + return False |
| 56 | + |
| 57 | + return equal_computations( |
| 58 | + self.inner_outputs, |
| 59 | + other.inner_outputs, |
| 60 | + self.inner_inputs, |
| 61 | + other.inner_inputs, |
| 62 | + ) |
| 63 | + |
| 64 | + def __hash__(self): |
| 65 | + # Just to allow easy testing of equivalent models, |
| 66 | + # This can be removed once https://github.com/pymc-devs/pytensor/issues/1114 is fixed |
| 67 | + return hash((type(self), len(self.inner_inputs), len(self.inner_outputs))) |
| 68 | + |
| 69 | + |
| 70 | +@_support_point.register |
| 71 | +def support_point_marginal_rv(op: MarginalRV, rv, *inputs): |
| 72 | + """Support point for a marginalized RV. |
| 73 | +
|
| 74 | + The support point of a marginalized RV is the support point of the inner RV, |
| 75 | + conditioned on the marginalized RV taking its support point. |
| 76 | + """ |
| 77 | + outputs = rv.owner.outputs |
| 78 | + |
| 79 | + inner_rv = op.inner_outputs[outputs.index(rv)] |
| 80 | + marginalized_inner_rv, *other_dependent_inner_rvs = ( |
| 81 | + out |
| 82 | + for out in op.inner_outputs |
| 83 | + if out is not inner_rv and not isinstance(out.type, RandomType) |
| 84 | + ) |
| 85 | + |
| 86 | + # Replace references to inner rvs by the dummy variables (including the marginalized RV) |
| 87 | + # This is necessary because the inner RVs may depend on each other |
| 88 | + marginalized_inner_rv_dummy = marginalized_inner_rv.clone() |
| 89 | + other_dependent_inner_rv_to_dummies = { |
| 90 | + inner_rv: inner_rv.clone() for inner_rv in other_dependent_inner_rvs |
| 91 | + } |
| 92 | + inner_rv = clone_replace( |
| 93 | + inner_rv, |
| 94 | + replace={marginalized_inner_rv: marginalized_inner_rv_dummy} |
| 95 | + | other_dependent_inner_rv_to_dummies, |
| 96 | + ) |
| 97 | + |
| 98 | + # Get support point of inner RV and marginalized RV |
| 99 | + inner_rv_support_point = support_point(inner_rv) |
| 100 | + marginalized_inner_rv_support_point = support_point(marginalized_inner_rv) |
| 101 | + |
| 102 | + replacements = [ |
| 103 | + # Replace the marginalized RV dummy by its support point |
| 104 | + (marginalized_inner_rv_dummy, marginalized_inner_rv_support_point), |
| 105 | + # Replace other dependent RVs dummies by the respective outer outputs. |
| 106 | + # PyMC will replace them by their support points later |
| 107 | + *( |
| 108 | + (v, outputs[op.inner_outputs.index(k)]) |
| 109 | + for k, v in other_dependent_inner_rv_to_dummies.items() |
| 110 | + ), |
| 111 | + # Replace outer input RVs |
| 112 | + *zip(op.inner_inputs, inputs), |
| 113 | + ] |
| 114 | + fgraph = FunctionGraph(outputs=[inner_rv_support_point], clone=False) |
| 115 | + fgraph.replace_all(replacements, import_missing=True) |
| 116 | + [rv_support_point] = fgraph.outputs |
| 117 | + return rv_support_point |
| 118 | + |
46 | 119 |
|
47 | 120 | class MarginalFiniteDiscreteRV(MarginalRV):
|
48 | 121 | """Base class for Marginalized Finite Discrete RVs"""
|
@@ -132,12 +205,22 @@ def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Var
|
132 | 205 | Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"
|
133 | 206 | the inner graph.
|
134 | 207 | """
|
135 |
| - return clone_replace( |
| 208 | + return graph_replace( |
136 | 209 | op.inner_outputs,
|
137 | 210 | replace=tuple(zip(op.inner_inputs, inputs)),
|
138 | 211 | )
|
139 | 212 |
|
140 | 213 |
|
| 214 | +def warn_logp_non_separable(values): |
| 215 | + if len(values) > 1: |
| 216 | + warnings.warn( |
| 217 | + "There are multiple dependent variables in a FiniteDiscreteMarginalRV. " |
| 218 | + f"Their joint logp terms will be assigned to the first value: {values[0]}.", |
| 219 | + UserWarning, |
| 220 | + stacklevel=2, |
| 221 | + ) |
| 222 | + |
| 223 | + |
141 | 224 | DUMMY_ZERO = pt.constant(0, name="dummy_zero")
|
142 | 225 |
|
143 | 226 |
|
@@ -200,6 +283,7 @@ def logp_fn(marginalized_rv_const, *non_sequences):
|
200 | 283 | joint_logp = align_logp_dims(dims=op.dims_connections[0], logp=joint_logp)
|
201 | 284 |
|
202 | 285 | # We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
|
| 286 | + warn_logp_non_separable(values) |
203 | 287 | dummy_logps = (DUMMY_ZERO,) * (len(values) - 1)
|
204 | 288 | return joint_logp, *dummy_logps
|
205 | 289 |
|
@@ -272,5 +356,6 @@ def step_alpha(logp_emission, log_alpha, log_P):
|
272 | 356 |
|
273 | 357 | # If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
|
274 | 358 | # return is the joint probability of everything together, but PyMC still expects one logp for each emission stream.
|
| 359 | + warn_logp_non_separable(values) |
275 | 360 | dummy_logps = (DUMMY_ZERO,) * (len(values) - 1)
|
276 | 361 | return joint_logp, *dummy_logps
|
0 commit comments