Skip to content

Commit c00c368

Browse files
authored
Variationally Informed Parameterization (#276)
1 parent f1ece1c commit c00c368

File tree

4 files changed

+484
-0
lines changed

4 files changed

+484
-0
lines changed

docs/api_reference.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,13 @@ Statespace Models
6060
statespace/core
6161
statespace/filters
6262
statespace/models
63+
64+
65+
Model Transforms
66+
================
67+
.. automodule:: pymc_experimental.model.transforms
68+
.. autosummary::
69+
:toctree: generated/
70+
71+
autoreparam.vip_reparametrize
72+
autoreparam.VIP

pymc_experimental/model/transforms/__init__.py

Whitespace-only changes.
Lines changed: 376 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,376 @@
1+
from dataclasses import dataclass
2+
from functools import singledispatch
3+
from typing import Dict, List, Optional, Sequence, Tuple, Union
4+
5+
import numpy as np
6+
import pymc as pm
7+
import pytensor
8+
import pytensor.tensor as pt
9+
import scipy.special
10+
from pymc.logprob.transforms import Transform
11+
from pymc.model.fgraph import (
12+
ModelDeterministic,
13+
ModelNamed,
14+
fgraph_from_model,
15+
model_deterministic,
16+
model_free_rv,
17+
model_from_fgraph,
18+
model_named,
19+
)
20+
from pymc.pytensorf import toposort_replace
21+
from pytensor.graph.basic import Apply, Variable
22+
from pytensor.tensor.random.op import RandomVariable
23+
24+
25+
@dataclass
26+
class VIP:
27+
r"""Helper to reparemetrize VIP model.
28+
29+
Manipulation of :math:`\lambda` in the below equation is done using this helper class.
30+
31+
.. math::
32+
33+
\begin{align*}
34+
\eta_{k} &\sim \text{normal}(\lambda_{k} \cdot \mu, \sigma^{\lambda_{k}})\\
35+
\theta_{k} &= \mu + \sigma^{1 - \lambda_{k}} ( \eta_{k} - \lambda_{k} \cdot \mu)
36+
\sim \text{normal}(\mu, \sigma).
37+
\end{align*}
38+
"""
39+
40+
_logit_lambda: Dict[str, pytensor.tensor.sharedvar.TensorSharedVariable]
41+
42+
@property
43+
def variational_parameters(self) -> List[pytensor.tensor.sharedvar.TensorSharedVariable]:
44+
r"""Return raw :math:`\operatorname{logit}(\lambda_k)` for custom optimization.
45+
46+
Examples
47+
--------
48+
with model:
49+
# set all parameterizations to mix of centered and non-centered
50+
vip.set_all_lambda(0.5)
51+
52+
pm.fit(more_obj_params=vip.variational_parameters, method="fullrank_advi")
53+
"""
54+
return list(self._logit_lambda.values())
55+
56+
def truncate_lambda(self, **kwargs: float):
57+
r"""Truncate :math:`\lambda_k` with :math:`\varepsilon`.
58+
59+
.. math::
60+
61+
\hat \lambda_k = \begin{cases}
62+
0, \quad &\lambda_k \le \varepsilon\\
63+
\lambda_k, \quad &\varepsilon \lt \lambda_k \lt 1-\varepsilon\\
64+
1, \quad &\lambda_k \ge 1-\varepsilon\\
65+
\end{cases}
66+
67+
Parameters
68+
----------
69+
kwargs : Dict[str, float]
70+
Variable to :math:`\varepsilon` mapping.
71+
If :math:`\lambda` (or :math:`1-\lambda`) is not passing
72+
the threshold of :math:`\varepsilon`, it will be clipped
73+
to 1 or zero if rounding is turned on.
74+
"""
75+
lambdas = self.get_lambda()
76+
update = dict()
77+
for var, eps in kwargs.items():
78+
lam = lambdas[var]
79+
update[var] = np.piecewise(
80+
lam,
81+
[lam < eps, lam > (1 - eps)],
82+
[0, 1, lambda x: x],
83+
)
84+
self.set_lambda(**update)
85+
86+
def truncate_all_lambda(self, value: float):
87+
r"""Truncate all :math:`\lambda_k` with :math:`\varepsilon`.
88+
89+
.. math::
90+
91+
\hat \lambda_k = \begin{cases}
92+
0, \quad &\lambda_k \le \varepsilon\\
93+
\lambda_k, \quad &\varepsilon \lt \lambda_k \lt 1-\varepsilon\\
94+
1, \quad &\lambda_k \ge 1-\varepsilon\\
95+
\end{cases}
96+
97+
98+
99+
Parameters
100+
----------
101+
value : float
102+
:math:`\varepsilon`
103+
"""
104+
truncate = dict.fromkeys(
105+
self._logit_lambda.keys(),
106+
value,
107+
)
108+
self.truncate_lambda(**truncate)
109+
110+
def get_lambda(self) -> Dict[str, np.ndarray]:
111+
r"""Get :math:`\lambda_k` that are currently used by the model.
112+
113+
Returns
114+
-------
115+
Dict[str, np.ndarray]
116+
Mapping from variable name to :math:`\lambda_k`.
117+
"""
118+
return {
119+
name: scipy.special.expit(shared.get_value())
120+
for name, shared in self._logit_lambda.items()
121+
}
122+
123+
def set_lambda(self, **kwargs: Dict[str, Union[np.ndarray, float]]):
124+
r"""Set :math:`\lambda_k` per variable."""
125+
for key, value in kwargs.items():
126+
logit_lam = scipy.special.logit(value)
127+
shared = self._logit_lambda[key]
128+
fill = np.broadcast_to(
129+
logit_lam,
130+
shared.type.shape,
131+
)
132+
shared.set_value(fill)
133+
134+
def set_all_lambda(self, value: Union[np.ndarray, float]):
135+
r"""Set :math:`\lambda_k` globally."""
136+
config = dict.fromkeys(
137+
self._logit_lambda.keys(),
138+
value,
139+
)
140+
self.set_lambda(**config)
141+
142+
def fit(self, *args, **kwargs) -> pm.Approximation:
143+
r"""Set :math:`\lambda_k` using Variational Inference.
144+
145+
Examples
146+
--------
147+
148+
.. code-block:: python
149+
150+
with model:
151+
# set all parameterizations to mix of centered and non-centered
152+
vip.set_all_lambda(0.5)
153+
154+
# fit using ADVI
155+
mf = vip.fit(random_seed=42)
156+
"""
157+
kwargs.setdefault("obj_optimizer", pm.adagrad_window(learning_rate=0.1))
158+
kwargs.setdefault("method", "advi")
159+
return pm.fit(
160+
*args,
161+
more_obj_params=self.variational_parameters,
162+
**kwargs,
163+
)
164+
165+
166+
def vip_reparam_node(
167+
op: RandomVariable,
168+
node: Apply,
169+
name: str,
170+
dims: List[Variable],
171+
transform: Optional[Transform],
172+
) -> Tuple[ModelDeterministic, ModelNamed]:
173+
if not isinstance(node.op, RandomVariable):
174+
raise TypeError("Op should be RandomVariable type")
175+
size = node.inputs[1]
176+
if not isinstance(size, pt.TensorConstant):
177+
raise ValueError("Size should be static for autoreparametrization.")
178+
logit_lam_ = pytensor.shared(
179+
np.zeros(size.data),
180+
shape=size.data,
181+
name=f"{name}::lam_logit__",
182+
)
183+
logit_lam = model_named(logit_lam_, *dims)
184+
lam = pt.sigmoid(logit_lam)
185+
return (
186+
_vip_reparam_node(
187+
op,
188+
node=node,
189+
name=name,
190+
dims=dims,
191+
transform=transform,
192+
lam=lam,
193+
),
194+
logit_lam,
195+
)
196+
197+
198+
@singledispatch
199+
def _vip_reparam_node(
200+
op: RandomVariable,
201+
node: Apply,
202+
name: str,
203+
dims: List[Variable],
204+
transform: Optional[Transform],
205+
lam: pt.TensorVariable,
206+
) -> ModelDeterministic:
207+
raise NotImplementedError
208+
209+
210+
@_vip_reparam_node.register
211+
def _(
212+
op: pm.Normal,
213+
node: Apply,
214+
name: str,
215+
dims: List[Variable],
216+
transform: Optional[Transform],
217+
lam: pt.TensorVariable,
218+
) -> ModelDeterministic:
219+
rng, size, _, loc, scale = node.inputs
220+
if transform is not None:
221+
raise NotImplementedError("Reparametrization of Normal with Transform is not implemented")
222+
vip_rv_ = pm.Normal.dist(
223+
lam * loc,
224+
scale**lam,
225+
size=size,
226+
rng=rng,
227+
)
228+
vip_rv_.name = f"{name}::tau_"
229+
230+
vip_rv = model_free_rv(
231+
vip_rv_,
232+
vip_rv_.clone(),
233+
None,
234+
*dims,
235+
)
236+
237+
vip_rep_ = loc + scale ** (1 - lam) * (vip_rv - lam * loc)
238+
239+
vip_rep_.name = name
240+
241+
vip_rep = model_deterministic(vip_rep_, *dims)
242+
return vip_rep
243+
244+
245+
def vip_reparametrize(
246+
model: pm.Model,
247+
var_names: Sequence[str],
248+
) -> Tuple[pm.Model, VIP]:
249+
r"""Repametrize Model using Variationally Informed Parametrization (VIP).
250+
251+
.. math::
252+
253+
\begin{align*}
254+
\eta_{k} &\sim \text{normal}(\lambda_{k} \cdot \mu, \sigma^{\lambda_{k}})\\
255+
\theta_{k} &= \mu + \sigma^{1 - \lambda_{k}} ( \eta_{k} - \lambda_{k} \cdot \mu)
256+
\sim \text{normal}(\mu, \sigma).
257+
\end{align*}
258+
259+
Parameters
260+
----------
261+
model : Model
262+
Model with centered parameterizations for variables.
263+
var_names : Sequence[str]
264+
Target variables to reparemetrize.
265+
266+
Returns
267+
-------
268+
Tuple[Model, VIP]
269+
Updated model and VIP helper to reparametrize or infer parametrization of the model.
270+
271+
Examples
272+
--------
273+
The traditional eight schools.
274+
275+
.. code-block:: python
276+
277+
import pymc as pm
278+
import numpy as np
279+
280+
J = 8
281+
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
282+
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
283+
284+
with pm.Model() as Centered_eight:
285+
mu = pm.Normal("mu", mu=0, sigma=5)
286+
tau = pm.HalfCauchy("tau", beta=5)
287+
theta = pm.Normal("theta", mu=mu, sigma=tau, shape=J)
288+
obs = pm.Normal("obs", mu=theta, sigma=sigma, observed=y)
289+
290+
The regular model definition with centered parametrization is sufficient to use VIP.
291+
To change the model parametrization use the following function.
292+
293+
.. code-block:: python
294+
295+
from pymc_experimental.model.transforms.autoreparam import vip_reparametrize
296+
Reparam_eight, vip = vip_reparametrize(Centered_eight, ["theta"])
297+
298+
with Reparam_eight:
299+
# set all parameterizations to cenered (not needed)
300+
vip.set_all_lambda(1)
301+
302+
# set all parameterizations to non-cenered (desired)
303+
vip.set_all_lambda(0)
304+
305+
# or per variable
306+
vip.set_lambda(theta=0)
307+
308+
# just set non-centered parameterization
309+
trace = pm.sample()
310+
311+
However, setting it manually is not always great experience, we can learn it.
312+
313+
.. code-block:: python
314+
315+
with Reparam_eight:
316+
# set all parameterizations to mix of centered and non-centered
317+
vip.set_all_lambda(0.5)
318+
319+
# fit using ADVI
320+
mf = vip.fit(random_seed=42)
321+
322+
# display lambdas
323+
print(vip.get_lambda())
324+
325+
# {'theta': array([0.01473405, 0.02221006, 0.03656685, 0.03798879, 0.04876761,
326+
# 0.0300203 , 0.02733082, 0.01817754])}
327+
328+
Now you can use sampling again:
329+
330+
.. code-block:: python
331+
332+
with Reparam_eight:
333+
trace = pm.sample()
334+
335+
Sometimes it makes sense to enable clipping (that is off by default).
336+
The idea is to round :math:`\varepsilon` to the closest extremum (:math:`0` or :math:`0`)
337+
338+
.. math::
339+
340+
\hat \lambda_k = \begin{cases}
341+
0, \quad &\lambda_k \le \varepsilon\\
342+
\lambda_k, \quad &\varepsilon \lt \lambda_k \lt 1-\varepsilon\\
343+
1, \quad &\lambda_k \ge 1-\varepsilon
344+
\end{cases}
345+
346+
.. code-block:: python
347+
348+
vip.truncate_all_lambda(0.1)
349+
350+
Sampling has to be performed again
351+
352+
.. code-block:: python
353+
354+
with Reparam_eight:
355+
trace = pm.sample()
356+
"""
357+
fmodel, memo = fgraph_from_model(model)
358+
lambda_names = []
359+
replacements = []
360+
for name in var_names:
361+
old = memo[model.named_vars[name]]
362+
rv, _, *dims = old.owner.inputs
363+
new, lam = vip_reparam_node(
364+
rv.owner.op,
365+
rv.owner,
366+
name=rv.name,
367+
dims=dims,
368+
transform=old.owner.op.transform,
369+
)
370+
replacements.append((old, new))
371+
lambda_names.append(lam.name)
372+
toposort_replace(fmodel, replacements, reverse=True)
373+
reparam_model = model_from_fgraph(fmodel)
374+
model_lambdas = {n: reparam_model[l] for l, n in zip(lambda_names, var_names)}
375+
vip = VIP(model_lambdas)
376+
return reparam_model, vip

0 commit comments

Comments
 (0)