Skip to content

Commit 7fae212

Browse files
Add RandomVariable log-likelihoods
1 parent a15f356 commit 7fae212

File tree

7 files changed

+2040
-0
lines changed

7 files changed

+2040
-0
lines changed

aeppl/loglik.py

+209
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
from collections.abc import Mapping
2+
from functools import singledispatch
3+
from typing import Dict, Optional, Union
4+
5+
import aesara.tensor as at
6+
from aesara import config
7+
from aesara.gradient import disconnected_grad
8+
from aesara.graph.basic import Constant, clone, graph_inputs, io_toposort
9+
from aesara.graph.fg import FunctionGraph
10+
from aesara.graph.op import Op, compute_test_value
11+
from aesara.tensor.random.op import RandomVariable
12+
from aesara.tensor.random.opt import local_subtensor_rv_lift
13+
from aesara.tensor.subtensor import (
14+
AdvancedIncSubtensor,
15+
AdvancedIncSubtensor1,
16+
AdvancedSubtensor,
17+
AdvancedSubtensor1,
18+
IncSubtensor,
19+
Subtensor,
20+
)
21+
from aesara.tensor.var import TensorVariable
22+
23+
from aeppl.logpdf import _logpdf
24+
from aeppl.utils import (
25+
extract_rv_and_value_vars,
26+
indices_from_subtensor,
27+
rvs_to_value_vars,
28+
)
29+
30+
31+
def loglik(
32+
var: TensorVariable,
33+
rv_values: Optional[
34+
Union[TensorVariable, Dict[TensorVariable, TensorVariable]]
35+
] = None,
36+
**kwargs,
37+
) -> TensorVariable:
38+
"""Create a measure-space (i.e. log-likelihood) graph for a random variable at a given point.
39+
40+
The input `var` determines which log-likelihood graph is used and
41+
`rv_value` is that graph's input parameter. For example, if `var` is
42+
the output of a ``NormalRV`` ``Op``, then the output is a graph of the
43+
density function for `var` set to the value `rv_value`.
44+
45+
Parameters
46+
==========
47+
var
48+
The `RandomVariable` output that determines the log-likelihood graph.
49+
rv_values
50+
A variable, or ``dict`` of variables, that represents the value of
51+
`var` in its log-likelihood. If no `rv_value` is provided,
52+
``var.tag.value_var`` will be checked and, when available, used.
53+
54+
"""
55+
if not isinstance(rv_values, Mapping):
56+
rv_values = {var: rv_values} if rv_values is not None else {}
57+
58+
rv_var, rv_value_var = extract_rv_and_value_vars(var)
59+
60+
rv_value = rv_values.get(rv_var, rv_value_var)
61+
62+
if rv_var is not None and rv_value is None:
63+
raise ValueError(f"No value variable specified or associated with {rv_var}")
64+
65+
if rv_value is not None:
66+
rv_value = at.as_tensor(rv_value)
67+
68+
if rv_var is not None:
69+
# Make sure that the value is compatible with the random variable
70+
rv_value = rv_var.type.filter_variable(rv_value.astype(rv_var.dtype))
71+
72+
if rv_value_var is None:
73+
rv_value_var = rv_value
74+
75+
if rv_var is None:
76+
if var.owner is not None:
77+
return _loglik(
78+
var.owner.op,
79+
var,
80+
rv_values,
81+
*var.owner.inputs,
82+
)
83+
84+
return at.zeros_like(var)
85+
86+
rv_node = rv_var.owner
87+
88+
rng, size, dtype, *dist_params = rv_node.inputs
89+
90+
# Here, we plug the actual random variable into the log-likelihood graph,
91+
# because we want a log-likelihood graph that only contains
92+
# random variables. This is important, because a random variable's
93+
# parameters can contain random variables themselves.
94+
# Ultimately, with a graph containing only random variables and
95+
# "deterministics", we can simply replace all the random variables with
96+
# their value variables and be done.
97+
98+
# tmp_rv_values = rv_values.copy()
99+
# tmp_rv_values[rv_var] = rv_var
100+
101+
logpdf_var = _logpdf(rv_node.op, rv_value_var, *dist_params, **kwargs)
102+
103+
# Replace random variables with their value variables
104+
replacements = rv_values.copy()
105+
replacements.update({rv_var: rv_value, rv_value_var: rv_value})
106+
107+
(logpdf_var,), _ = rvs_to_value_vars(
108+
(logpdf_var,),
109+
initial_replacements=replacements,
110+
)
111+
112+
if sum:
113+
logpdf_var = at.sum(logpdf_var)
114+
115+
# Recompute test values for the changes introduced by the replacements
116+
# above.
117+
if config.compute_test_value != "off":
118+
for node in io_toposort(graph_inputs((logpdf_var,)), (logpdf_var,)):
119+
compute_test_value(node)
120+
121+
if rv_var.name is not None:
122+
logpdf_var.name = "__logp_%s" % rv_var.name
123+
124+
return logpdf_var
125+
126+
127+
@singledispatch
128+
def _loglik(
129+
op: Op,
130+
var: TensorVariable,
131+
rvs_to_values: Dict[TensorVariable, TensorVariable],
132+
*inputs: TensorVariable,
133+
**kwargs,
134+
):
135+
"""Create a graph for the log-likelihood of a ``Variable``.
136+
137+
This function dispatches on the type of ``op``. If you want to implement
138+
new graphs for an ``Op``, register a new function on this dispatcher.
139+
140+
The default returns a graph producing only zeros.
141+
142+
"""
143+
value_var = rvs_to_values.get(var, var)
144+
return at.zeros_like(value_var)
145+
146+
147+
@_loglik.register(IncSubtensor)
148+
@_loglik.register(AdvancedIncSubtensor)
149+
@_loglik.register(AdvancedIncSubtensor1)
150+
def incsubtensor_loglik(
151+
op, var, rvs_to_values, indexed_rv_var, rv_values, *indices, **kwargs
152+
):
153+
154+
index = indices_from_subtensor(getattr(op, "idx_list", None), indices)
155+
156+
_, (new_rv_var,) = clone(
157+
tuple(
158+
v for v in graph_inputs((indexed_rv_var,)) if not isinstance(v, Constant)
159+
),
160+
(indexed_rv_var,),
161+
copy_inputs=False,
162+
copy_orphans=False,
163+
)
164+
new_values = at.set_subtensor(disconnected_grad(new_rv_var)[index], rv_values)
165+
logp_var = loglik(indexed_rv_var, new_values, **kwargs)
166+
167+
return logp_var
168+
169+
170+
@_loglik.register(Subtensor)
171+
@_loglik.register(AdvancedSubtensor)
172+
@_loglik.register(AdvancedSubtensor1)
173+
def subtensor_loglik(op, var, rvs_to_values, indexed_rv_var, *indices, **kwargs):
174+
175+
index = indices_from_subtensor(getattr(op, "idx_list", None), indices)
176+
177+
rv_value = rvs_to_values.get(var, getattr(var.tag, "value_var", None))
178+
179+
if indexed_rv_var.owner and isinstance(indexed_rv_var.owner.op, RandomVariable):
180+
181+
# We need to lift the index operation through the random variable so
182+
# that we have a new random variable consisting of only the relevant
183+
# subset of variables per the index.
184+
var_copy = var.owner.clone().default_output()
185+
fgraph = FunctionGraph(
186+
[i for i in graph_inputs((indexed_rv_var,)) if not isinstance(i, Constant)],
187+
[var_copy],
188+
clone=False,
189+
)
190+
191+
(lifted_var,) = local_subtensor_rv_lift.transform(
192+
fgraph, fgraph.outputs[0].owner
193+
)
194+
195+
new_rvs_to_values = rvs_to_values.copy()
196+
new_rvs_to_values[lifted_var] = rv_value
197+
198+
logp_var = loglik(lifted_var, new_rvs_to_values, **kwargs)
199+
200+
for idx_var in index:
201+
logp_var += loglik(idx_var, rvs_to_values, **kwargs)
202+
203+
# TODO: We could add the constant case (i.e. `indexed_rv_var.owner is None`)
204+
else:
205+
raise NotImplementedError(
206+
f"`Subtensor` log-likelihood not implemented for {indexed_rv_var.owner}"
207+
)
208+
209+
return logp_var

0 commit comments

Comments
 (0)