Skip to content

Commit e96d07f

Browse files
committed
Support more kinds of marginalization via dim analysis
This commit lifts the restriction that only Elemwise operations may link marginalized to dependent RVs. We map input dims to output dims, to assess whether an operation mixes information from different dims or not. Graphs where information is not mixed can be efficiently marginalized.
1 parent d965959 commit e96d07f

File tree

9 files changed

+926
-216
lines changed

9 files changed

+926
-216
lines changed

conda-envs/environment-test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@ dependencies:
1010
- xhistogram
1111
- statsmodels
1212
- pip:
13-
- pymc>=5.16.1 # CI was failing to resolve
13+
- pymc>=5.17.0 # CI was failing to resolve
1414
- blackjax
1515
- scikit-learn

conda-envs/windows-environment-test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@ dependencies:
1010
- xhistogram
1111
- statsmodels
1212
- pip:
13-
- pymc>=5.16.1 # CI was failing to resolve
13+
- pymc>=5.17.0 # CI was failing to resolve
1414
- blackjax
1515
- scikit-learn

pymc_experimental/model/marginal/distributions.py

+132-47
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,55 @@
11
from collections.abc import Sequence
22

33
import numpy as np
4+
import pytensor.tensor as pt
45

5-
from pymc import Bernoulli, Categorical, DiscreteUniform, SymbolicRandomVariable, logp
6-
from pymc.logprob import conditional_logp
7-
from pymc.logprob.abstract import _logprob
6+
from pymc.distributions import Bernoulli, Categorical, DiscreteUniform
7+
from pymc.logprob.abstract import MeasurableOp, _logprob
8+
from pymc.logprob.basic import conditional_logp, logp
89
from pymc.pytensorf import constant_fold
9-
from pytensor import Mode, clone_replace, graph_replace, scan
10-
from pytensor import map as scan_map
11-
from pytensor import tensor as pt
12-
from pytensor.graph import vectorize_graph
13-
from pytensor.tensor import TensorType, TensorVariable
10+
from pytensor import Variable
11+
from pytensor.compile.builders import OpFromGraph
12+
from pytensor.compile.mode import Mode
13+
from pytensor.graph import Op, vectorize_graph
14+
from pytensor.graph.replace import clone_replace, graph_replace
15+
from pytensor.scan import map as scan_map
16+
from pytensor.scan import scan
17+
from pytensor.tensor import TensorVariable
1418

1519
from pymc_experimental.distributions import DiscreteMarkovChain
1620

1721

18-
class MarginalRV(SymbolicRandomVariable):
22+
class MarginalRV(OpFromGraph, MeasurableOp):
1923
"""Base class for Marginalized RVs"""
2024

25+
def __init__(self, *args, dims_connections: tuple[tuple[int | None]], **kwargs) -> None:
26+
self.dims_connections = dims_connections
27+
super().__init__(*args, **kwargs)
2128

22-
class FiniteDiscreteMarginalRV(MarginalRV):
23-
"""Base class for Finite Discrete Marginalized RVs"""
29+
@property
30+
def support_axes(self) -> tuple[tuple[int]]:
31+
"""Dimensions of dependent RVs that belong to the core (non-batched) marginalized variable."""
32+
marginalized_ndim_supp = self.inner_outputs[0].owner.op.ndim_supp
33+
support_axes_vars = []
34+
for dims_connection in self.dims_connections:
35+
ndim = len(dims_connection)
36+
marginalized_supp_axes = ndim - marginalized_ndim_supp
37+
support_axes_vars.append(
38+
tuple(
39+
-i
40+
for i, dim in enumerate(reversed(dims_connection), start=1)
41+
if (dim is None or dim > marginalized_supp_axes)
42+
)
43+
)
44+
return tuple(support_axes_vars)
2445

2546

26-
class DiscreteMarginalMarkovChainRV(MarginalRV):
27-
"""Base class for Discrete Marginal Markov Chain RVs"""
47+
class MarginalFiniteDiscreteRV(MarginalRV):
48+
"""Base class for Marginalized Finite Discrete RVs"""
49+
50+
51+
class MarginalDiscreteMarkovChainRV(MarginalRV):
52+
"""Base class for Marginalized Discrete Markov Chain RVs"""
2853

2954

3055
def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
@@ -34,7 +59,8 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
3459
return (0, 1)
3560
elif isinstance(op, Categorical):
3661
[p_param] = dist_params
37-
return tuple(range(pt.get_vector_length(p_param)))
62+
[p_param_length] = constant_fold([p_param.shape[-1]])
63+
return tuple(range(p_param_length))
3864
elif isinstance(op, DiscreteUniform):
3965
lower, upper = constant_fold(dist_params)
4066
return tuple(np.arange(lower, upper + 1))
@@ -45,31 +71,81 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
4571
raise NotImplementedError(f"Cannot compute domain for op {op}")
4672

4773

48-
def _add_reduce_batch_dependent_logps(
49-
marginalized_type: TensorType, dependent_logps: Sequence[TensorVariable]
50-
):
51-
"""Add the logps of dependent RVs while reducing extra batch dims relative to `marginalized_type`."""
74+
def reduce_batch_dependent_logps(
75+
dependent_dims_connections: Sequence[tuple[int | None, ...]],
76+
dependent_ops: Sequence[Op],
77+
dependent_logps: Sequence[TensorVariable],
78+
) -> TensorVariable:
79+
"""Combine the logps of dependent RVs and align them with the marginalized logp.
80+
81+
This requires reducing extra batch dims and transposing when they are not aligned.
82+
83+
idx = pm.Bernoulli(idx, shape=(3, 2)) # 0, 1
84+
pm.Normal("dep1", mu=idx.T[..., None] * 2, shape=(3, 2, 5))
85+
pm.Normal("dep2", mu=idx * 2, shape=(7, 2, 3))
86+
87+
marginalize(idx)
88+
89+
The marginalized op will have dims_connections = [(1, 0, None), (None, 0, 1)]
90+
which tells us we need to reduce the last axis of dep1 logp and the first of dep2 logp,
91+
as well as transpose the remaining axis of dep1 logp before adding the two element-wise.
92+
93+
"""
94+
from pymc_experimental.model.marginal.graph_analysis import get_support_axes
5295

53-
mbcast = marginalized_type.broadcastable
5496
reduced_logps = []
55-
for dependent_logp in dependent_logps:
56-
dbcast = dependent_logp.type.broadcastable
57-
dim_diff = len(dbcast) - len(mbcast)
58-
mbcast_aligned = (True,) * dim_diff + mbcast
59-
vbcast_axis = [i for i, (m, v) in enumerate(zip(mbcast_aligned, dbcast)) if m and not v]
60-
reduced_logps.append(dependent_logp.sum(vbcast_axis))
61-
return pt.add(*reduced_logps)
97+
for dependent_op, dependent_logp, dependent_dims_connection in zip(
98+
dependent_ops, dependent_logps, dependent_dims_connections
99+
):
100+
if dependent_logp.type.ndim > 0:
101+
# Find which support axis implied by the MarginalRV need to be reduced
102+
# Some may have already been reduced by the logp expression of the dependent RV (e.g., multivariate RVs)
103+
dep_supp_axes = get_support_axes(dependent_op)[0]
62104

105+
# Dependent RV support axes are already collapsed in the logp, so we ignore them
106+
supp_axes = [
107+
-i
108+
for i, dim in enumerate(reversed(dependent_dims_connection), start=1)
109+
if (dim is None and -i not in dep_supp_axes)
110+
]
111+
dependent_logp = dependent_logp.sum(supp_axes)
63112

64-
@_logprob.register(FiniteDiscreteMarginalRV)
65-
def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
66-
# Clone the inner RV graph of the Marginalized RV
67-
marginalized_rvs_node = op.make_node(*inputs)
68-
marginalized_rv, *inner_rvs = clone_replace(
113+
# Finally, we need to align the dependent logp batch dimensions with the marginalized logp
114+
dims_alignment = [dim for dim in dependent_dims_connection if dim is not None]
115+
dependent_logp = dependent_logp.transpose(*dims_alignment)
116+
117+
reduced_logps.append(dependent_logp)
118+
119+
reduced_logp = pt.add(*reduced_logps)
120+
return reduced_logp
121+
122+
123+
def align_logp_dims(dims: tuple[tuple[int, None]], logp: TensorVariable) -> TensorVariable:
124+
"""Align the logp with the order specified in dims."""
125+
dims_alignment = [dim for dim in dims if dim is not None]
126+
return logp.transpose(*dims_alignment)
127+
128+
129+
def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]:
130+
"""Inline the inner graph (outputs) of an OpFromGraph Op.
131+
132+
Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"
133+
the inner graph.
134+
"""
135+
return clone_replace(
69136
op.inner_outputs,
70-
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
137+
replace=tuple(zip(op.inner_inputs, inputs)),
71138
)
72139

140+
141+
DUMMY_ZERO = pt.constant(0, name="dummy_zero")
142+
143+
144+
@_logprob.register(MarginalFiniteDiscreteRV)
145+
def finite_discrete_marginal_rv_logp(op: MarginalFiniteDiscreteRV, values, *inputs, **kwargs):
146+
# Clone the inner RV graph of the Marginalized RV
147+
marginalized_rv, *inner_rvs = inline_ofg_outputs(op, inputs)
148+
73149
# Obtain the joint_logp graph of the inner RV graph
74150
inner_rv_values = dict(zip(inner_rvs, values))
75151
marginalized_vv = marginalized_rv.clone()
@@ -78,8 +154,10 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
78154

79155
# Reduce logp dimensions corresponding to broadcasted variables
80156
marginalized_logp = logps_dict.pop(marginalized_vv)
81-
joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps(
82-
marginalized_rv.type, logps_dict.values()
157+
joint_logp = marginalized_logp + reduce_batch_dependent_logps(
158+
dependent_dims_connections=op.dims_connections,
159+
dependent_ops=[inner_rv.owner.op for inner_rv in inner_rvs],
160+
dependent_logps=[logps_dict[value] for value in values],
83161
)
84162

85163
# Compute the joint_logp for all possible n values of the marginalized RV. We assume
@@ -116,21 +194,20 @@ def logp_fn(marginalized_rv_const, *non_sequences):
116194
mode=Mode().including("local_remove_check_parameter"),
117195
)
118196

119-
joint_logps = pt.logsumexp(joint_logps, axis=0)
197+
joint_logp = pt.logsumexp(joint_logps, axis=0)
198+
199+
# Align logp with non-collapsed batch dimensions of first RV
200+
joint_logp = align_logp_dims(dims=op.dims_connections[0], logp=joint_logp)
120201

121202
# We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
122-
return joint_logps, *(pt.constant(0),) * (len(values) - 1)
203+
dummy_logps = (DUMMY_ZERO,) * (len(values) - 1)
204+
return joint_logp, *dummy_logps
123205

124206

125-
@_logprob.register(DiscreteMarginalMarkovChainRV)
207+
@_logprob.register(MarginalDiscreteMarkovChainRV)
126208
def marginal_hmm_logp(op, values, *inputs, **kwargs):
127-
marginalized_rvs_node = op.make_node(*inputs)
128-
inner_rvs = clone_replace(
129-
op.inner_outputs,
130-
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
131-
)
209+
chain_rv, *dependent_rvs = inline_ofg_outputs(op, inputs)
132210

133-
chain_rv, *dependent_rvs = inner_rvs
134211
P, n_steps_, init_dist_, rng = chain_rv.owner.inputs
135212
domain = pt.arange(P.shape[-1], dtype="int32")
136213

@@ -145,8 +222,10 @@ def marginal_hmm_logp(op, values, *inputs, **kwargs):
145222
logp_emissions_dict = conditional_logp(dict(zip(dependent_rvs, values)))
146223

147224
# Reduce and add the batch dims beyond the chain dimension
148-
reduced_logp_emissions = _add_reduce_batch_dependent_logps(
149-
chain_rv.type, logp_emissions_dict.values()
225+
reduced_logp_emissions = reduce_batch_dependent_logps(
226+
dependent_dims_connections=op.dims_connections,
227+
dependent_ops=[dependent_rv.owner.op for dependent_rv in dependent_rvs],
228+
dependent_logps=[logp_emissions_dict[value] for value in values],
150229
)
151230

152231
# Add a batch dimension for the domain of the chain
@@ -185,7 +264,13 @@ def step_alpha(logp_emission, log_alpha, log_P):
185264
# Final logp is just the sum of the last scan state
186265
joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0)
187266

267+
# Align logp with non-collapsed batch dimensions of first RV
268+
remaining_dims_first_emission = list(op.dims_connections[0])
269+
# The last dim of chain_rv was removed when computing the logp
270+
remaining_dims_first_emission.remove(chain_rv.type.ndim - 1)
271+
joint_logp = align_logp_dims(remaining_dims_first_emission, joint_logp)
272+
188273
# If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
189-
# return is the joint probability of everything together, but PyMC still expects one logp for each one.
190-
dummy_logps = (pt.constant(0),) * (len(values) - 1)
274+
# return is the joint probability of everything together, but PyMC still expects one logp for each emission stream.
275+
dummy_logps = (DUMMY_ZERO,) * (len(values) - 1)
191276
return joint_logp, *dummy_logps

0 commit comments

Comments
 (0)