1
1
from collections .abc import Sequence
2
2
3
3
import numpy as np
4
+ import pytensor .tensor as pt
4
5
5
- from pymc import Bernoulli , Categorical , DiscreteUniform , SymbolicRandomVariable , logp
6
- from pymc .logprob import conditional_logp
7
- from pymc .logprob .abstract import _logprob
6
+ from pymc . distributions import Bernoulli , Categorical , DiscreteUniform
7
+ from pymc .logprob . abstract import MeasurableOp , _logprob
8
+ from pymc .logprob .basic import conditional_logp , logp
8
9
from pymc .pytensorf import constant_fold
9
- from pytensor import Mode , clone_replace , graph_replace , scan
10
- from pytensor import map as scan_map
11
- from pytensor import tensor as pt
12
- from pytensor .graph import vectorize_graph
13
- from pytensor .tensor import TensorType , TensorVariable
10
+ from pytensor import Variable
11
+ from pytensor .compile .builders import OpFromGraph
12
+ from pytensor .compile .mode import Mode
13
+ from pytensor .graph import Op , vectorize_graph
14
+ from pytensor .graph .replace import clone_replace , graph_replace
15
+ from pytensor .scan import map as scan_map
16
+ from pytensor .scan import scan
17
+ from pytensor .tensor import TensorVariable
14
18
15
19
from pymc_experimental .distributions import DiscreteMarkovChain
16
20
17
21
18
- class MarginalRV (SymbolicRandomVariable ):
22
+ class MarginalRV (OpFromGraph , MeasurableOp ):
19
23
"""Base class for Marginalized RVs"""
20
24
25
+ def __init__ (self , * args , dims_connections : tuple [tuple [int | None ]], ** kwargs ) -> None :
26
+ self .dims_connections = dims_connections
27
+ super ().__init__ (* args , ** kwargs )
21
28
22
- class FiniteDiscreteMarginalRV (MarginalRV ):
23
- """Base class for Finite Discrete Marginalized RVs"""
29
+ @property
30
+ def support_axes (self ) -> tuple [tuple [int ]]:
31
+ """Dimensions of dependent RVs that belong to the core (non-batched) marginalized variable."""
32
+ marginalized_ndim_supp = self .inner_outputs [0 ].owner .op .ndim_supp
33
+ support_axes_vars = []
34
+ for dims_connection in self .dims_connections :
35
+ ndim = len (dims_connection )
36
+ marginalized_supp_axes = ndim - marginalized_ndim_supp
37
+ support_axes_vars .append (
38
+ tuple (
39
+ - i
40
+ for i , dim in enumerate (reversed (dims_connection ), start = 1 )
41
+ if (dim is None or dim > marginalized_supp_axes )
42
+ )
43
+ )
44
+ return tuple (support_axes_vars )
24
45
25
46
26
- class DiscreteMarginalMarkovChainRV (MarginalRV ):
27
- """Base class for Discrete Marginal Markov Chain RVs"""
47
+ class MarginalFiniteDiscreteRV (MarginalRV ):
48
+ """Base class for Marginalized Finite Discrete RVs"""
49
+
50
+
51
+ class MarginalDiscreteMarkovChainRV (MarginalRV ):
52
+ """Base class for Marginalized Discrete Markov Chain RVs"""
28
53
29
54
30
55
def get_domain_of_finite_discrete_rv (rv : TensorVariable ) -> tuple [int , ...]:
@@ -34,7 +59,8 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
34
59
return (0 , 1 )
35
60
elif isinstance (op , Categorical ):
36
61
[p_param ] = dist_params
37
- return tuple (range (pt .get_vector_length (p_param )))
62
+ [p_param_length ] = constant_fold ([p_param .shape [- 1 ]])
63
+ return tuple (range (p_param_length ))
38
64
elif isinstance (op , DiscreteUniform ):
39
65
lower , upper = constant_fold (dist_params )
40
66
return tuple (np .arange (lower , upper + 1 ))
@@ -45,31 +71,81 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
45
71
raise NotImplementedError (f"Cannot compute domain for op { op } " )
46
72
47
73
48
- def _add_reduce_batch_dependent_logps (
49
- marginalized_type : TensorType , dependent_logps : Sequence [TensorVariable ]
50
- ):
51
- """Add the logps of dependent RVs while reducing extra batch dims relative to `marginalized_type`."""
74
+ def reduce_batch_dependent_logps (
75
+ dependent_dims_connections : Sequence [tuple [int | None , ...]],
76
+ dependent_ops : Sequence [Op ],
77
+ dependent_logps : Sequence [TensorVariable ],
78
+ ) -> TensorVariable :
79
+ """Combine the logps of dependent RVs and align them with the marginalized logp.
80
+
81
+ This requires reducing extra batch dims and transposing when they are not aligned.
82
+
83
+ idx = pm.Bernoulli(idx, shape=(3, 2)) # 0, 1
84
+ pm.Normal("dep1", mu=idx.T[..., None] * 2, shape=(3, 2, 5))
85
+ pm.Normal("dep2", mu=idx * 2, shape=(7, 2, 3))
86
+
87
+ marginalize(idx)
88
+
89
+ The marginalized op will have dims_connections = [(1, 0, None), (None, 0, 1)]
90
+ which tells us we need to reduce the last axis of dep1 logp and the first of dep2 logp,
91
+ as well as transpose the remaining axis of dep1 logp before adding the two element-wise.
92
+
93
+ """
94
+ from pymc_experimental .model .marginal .graph_analysis import get_support_axes
52
95
53
- mbcast = marginalized_type .broadcastable
54
96
reduced_logps = []
55
- for dependent_logp in dependent_logps :
56
- dbcast = dependent_logp . type . broadcastable
57
- dim_diff = len ( dbcast ) - len ( mbcast )
58
- mbcast_aligned = ( True ,) * dim_diff + mbcast
59
- vbcast_axis = [ i for i , ( m , v ) in enumerate ( zip ( mbcast_aligned , dbcast )) if m and not v ]
60
- reduced_logps . append ( dependent_logp . sum ( vbcast_axis ) )
61
- return pt . add ( * reduced_logps )
97
+ for dependent_op , dependent_logp , dependent_dims_connection in zip (
98
+ dependent_ops , dependent_logps , dependent_dims_connections
99
+ ):
100
+ if dependent_logp . type . ndim > 0 :
101
+ # Find which support axis implied by the MarginalRV need to be reduced
102
+ # Some may have already been reduced by the logp expression of the dependent RV (e.g., multivariate RVs )
103
+ dep_supp_axes = get_support_axes ( dependent_op )[ 0 ]
62
104
105
+ # Dependent RV support axes are already collapsed in the logp, so we ignore them
106
+ supp_axes = [
107
+ - i
108
+ for i , dim in enumerate (reversed (dependent_dims_connection ), start = 1 )
109
+ if (dim is None and - i not in dep_supp_axes )
110
+ ]
111
+ dependent_logp = dependent_logp .sum (supp_axes )
63
112
64
- @_logprob .register (FiniteDiscreteMarginalRV )
65
- def finite_discrete_marginal_rv_logp (op , values , * inputs , ** kwargs ):
66
- # Clone the inner RV graph of the Marginalized RV
67
- marginalized_rvs_node = op .make_node (* inputs )
68
- marginalized_rv , * inner_rvs = clone_replace (
113
+ # Finally, we need to align the dependent logp batch dimensions with the marginalized logp
114
+ dims_alignment = [dim for dim in dependent_dims_connection if dim is not None ]
115
+ dependent_logp = dependent_logp .transpose (* dims_alignment )
116
+
117
+ reduced_logps .append (dependent_logp )
118
+
119
+ reduced_logp = pt .add (* reduced_logps )
120
+ return reduced_logp
121
+
122
+
123
+ def align_logp_dims (dims : tuple [tuple [int , None ]], logp : TensorVariable ) -> TensorVariable :
124
+ """Align the logp with the order specified in dims."""
125
+ dims_alignment = [dim for dim in dims if dim is not None ]
126
+ return logp .transpose (* dims_alignment )
127
+
128
+
129
+ def inline_ofg_outputs (op : OpFromGraph , inputs : Sequence [Variable ]) -> tuple [Variable ]:
130
+ """Inline the inner graph (outputs) of an OpFromGraph Op.
131
+
132
+ Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"
133
+ the inner graph.
134
+ """
135
+ return clone_replace (
69
136
op .inner_outputs ,
70
- replace = { u : v for u , v in zip (op .inner_inputs , marginalized_rvs_node . inputs )} ,
137
+ replace = tuple ( zip (op .inner_inputs , inputs )) ,
71
138
)
72
139
140
+
141
+ DUMMY_ZERO = pt .constant (0 , name = "dummy_zero" )
142
+
143
+
144
+ @_logprob .register (MarginalFiniteDiscreteRV )
145
+ def finite_discrete_marginal_rv_logp (op : MarginalFiniteDiscreteRV , values , * inputs , ** kwargs ):
146
+ # Clone the inner RV graph of the Marginalized RV
147
+ marginalized_rv , * inner_rvs = inline_ofg_outputs (op , inputs )
148
+
73
149
# Obtain the joint_logp graph of the inner RV graph
74
150
inner_rv_values = dict (zip (inner_rvs , values ))
75
151
marginalized_vv = marginalized_rv .clone ()
@@ -78,8 +154,10 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
78
154
79
155
# Reduce logp dimensions corresponding to broadcasted variables
80
156
marginalized_logp = logps_dict .pop (marginalized_vv )
81
- joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps (
82
- marginalized_rv .type , logps_dict .values ()
157
+ joint_logp = marginalized_logp + reduce_batch_dependent_logps (
158
+ dependent_dims_connections = op .dims_connections ,
159
+ dependent_ops = [inner_rv .owner .op for inner_rv in inner_rvs ],
160
+ dependent_logps = [logps_dict [value ] for value in values ],
83
161
)
84
162
85
163
# Compute the joint_logp for all possible n values of the marginalized RV. We assume
@@ -116,21 +194,20 @@ def logp_fn(marginalized_rv_const, *non_sequences):
116
194
mode = Mode ().including ("local_remove_check_parameter" ),
117
195
)
118
196
119
- joint_logps = pt .logsumexp (joint_logps , axis = 0 )
197
+ joint_logp = pt .logsumexp (joint_logps , axis = 0 )
198
+
199
+ # Align logp with non-collapsed batch dimensions of first RV
200
+ joint_logp = align_logp_dims (dims = op .dims_connections [0 ], logp = joint_logp )
120
201
121
202
# We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
122
- return joint_logps , * (pt .constant (0 ),) * (len (values ) - 1 )
203
+ dummy_logps = (DUMMY_ZERO ,) * (len (values ) - 1 )
204
+ return joint_logp , * dummy_logps
123
205
124
206
125
- @_logprob .register (DiscreteMarginalMarkovChainRV )
207
+ @_logprob .register (MarginalDiscreteMarkovChainRV )
126
208
def marginal_hmm_logp (op , values , * inputs , ** kwargs ):
127
- marginalized_rvs_node = op .make_node (* inputs )
128
- inner_rvs = clone_replace (
129
- op .inner_outputs ,
130
- replace = {u : v for u , v in zip (op .inner_inputs , marginalized_rvs_node .inputs )},
131
- )
209
+ chain_rv , * dependent_rvs = inline_ofg_outputs (op , inputs )
132
210
133
- chain_rv , * dependent_rvs = inner_rvs
134
211
P , n_steps_ , init_dist_ , rng = chain_rv .owner .inputs
135
212
domain = pt .arange (P .shape [- 1 ], dtype = "int32" )
136
213
@@ -145,8 +222,10 @@ def marginal_hmm_logp(op, values, *inputs, **kwargs):
145
222
logp_emissions_dict = conditional_logp (dict (zip (dependent_rvs , values )))
146
223
147
224
# Reduce and add the batch dims beyond the chain dimension
148
- reduced_logp_emissions = _add_reduce_batch_dependent_logps (
149
- chain_rv .type , logp_emissions_dict .values ()
225
+ reduced_logp_emissions = reduce_batch_dependent_logps (
226
+ dependent_dims_connections = op .dims_connections ,
227
+ dependent_ops = [dependent_rv .owner .op for dependent_rv in dependent_rvs ],
228
+ dependent_logps = [logp_emissions_dict [value ] for value in values ],
150
229
)
151
230
152
231
# Add a batch dimension for the domain of the chain
@@ -185,7 +264,13 @@ def step_alpha(logp_emission, log_alpha, log_P):
185
264
# Final logp is just the sum of the last scan state
186
265
joint_logp = pt .logsumexp (log_alpha_seq [- 1 ], axis = 0 )
187
266
267
+ # Align logp with non-collapsed batch dimensions of first RV
268
+ remaining_dims_first_emission = list (op .dims_connections [0 ])
269
+ # The last dim of chain_rv was removed when computing the logp
270
+ remaining_dims_first_emission .remove (chain_rv .type .ndim - 1 )
271
+ joint_logp = align_logp_dims (remaining_dims_first_emission , joint_logp )
272
+
188
273
# If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
189
- # return is the joint probability of everything together, but PyMC still expects one logp for each one .
190
- dummy_logps = (pt . constant ( 0 ) ,) * (len (values ) - 1 )
274
+ # return is the joint probability of everything together, but PyMC still expects one logp for each emission stream .
275
+ dummy_logps = (DUMMY_ZERO ,) * (len (values ) - 1 )
191
276
return joint_logp , * dummy_logps
0 commit comments