Skip to content

Commit 87d4aea

Browse files
carsten-jpre-commit-ci[bot]ricardoV94jessegrabowski
authored
Implement Laplace (quadratic) approximation (#345)
* First draft of quadratic approximation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Review comments incorporated * License and copyright information added * Only add additional data to inferencedata when chains!=0 * Raise error if Hessian is singular * Replace for loop with call to remove_value_transforms * Pass model directly when finding MAP and the Hessian * Update pymc_experimental/inference/laplace.py Co-authored-by: Ricardo Vieira <[email protected]> * Remove chains from public parameters for Laplace approx method * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Parameter draws is not optional with default value 1000 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add warning if numbers of variables in vars does not equal number of model variables * Update version.txt * `shock_size` should never be scalar * Blackjax API change * Handle latest PyMC/PyTensor breaking changes * Temporarily mark two tests as xfail * More bugfixes for statespace (#346) * Allow forward sampling of statespace models in JAX mode Explicitly set data shape to avoid broadcasting error Better handling of measurement error dims in `SARIMAX` models Freeze auxiliary models before forward sampling Bugfixes for posterior predictive sampling helpers Allow specification of time dimension name when registering data Save info about exogenous data for post-estimation tasks Restore `_exog_data_info` member variable Be more consistent with the names of filter outputs * Adjust test suite to reflect API changes Modify structural tests to accommodate deterministic models Save kalman filter outputs to idata for statespace tests Remove test related to `add_exogenous` Adjust structural module tests * Add JAX test suite * Bug-fixes and changes to statespace distributions Remove tests related to the `add_exogenous` method Add dummy `MvNormalSVDRV` for forward jax sampling with `method="SVD"` Dynamically generate `LinearGaussianStateSpaceRV` signature from inputs Add signature and simple test for `SequenceMvNormal` * Re-run example notebooks * Add helper function to sample prior/posterior statespace matrices * fix tests * Wrap jax MvNormal rewrite in try/except block * Don't use `action` keyword in `catch_warnings` * Skip JAX test if `numpyro` is not installed * Handle batch dims on `SequenceMvNormal` * Remove unused batch_dim logic in SequenceMvNormal * Restore `get_support_shape_1d` import * Fix failing test case for laplace --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ricardo Vieira <[email protected]> Co-authored-by: Jesse Grabowski <[email protected]> Co-authored-by: Ricardo Vieira <[email protected]> Co-authored-by: Jesse Grabowski <[email protected]>
1 parent e85677b commit 87d4aea

File tree

3 files changed

+334
-1
lines changed

3 files changed

+334
-1
lines changed

Diff for: pymc_experimental/inference/fit.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def fit(method, **kwargs):
2121
----------
2222
method : str
2323
Which inference method to run.
24-
Supported: pathfinder
24+
Supported: pathfinder or laplace
2525
2626
kwargs are passed on.
2727
@@ -38,3 +38,9 @@ def fit(method, **kwargs):
3838
from pymc_experimental.inference.pathfinder import fit_pathfinder
3939

4040
return fit_pathfinder(**kwargs)
41+
42+
if method == "laplace":
43+
44+
from pymc_experimental.inference.laplace import laplace
45+
46+
return laplace(**kwargs)

Diff for: pymc_experimental/inference/laplace.py

+190
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
# Copyright 2024 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import warnings
16+
from collections.abc import Sequence
17+
from typing import Optional
18+
19+
import arviz as az
20+
import numpy as np
21+
import pymc as pm
22+
import xarray as xr
23+
from arviz import dict_to_dataset
24+
from pymc.backends.arviz import (
25+
coords_and_dims_for_inferencedata,
26+
find_constants,
27+
find_observations,
28+
)
29+
from pymc.model.transform.conditioning import remove_value_transforms
30+
from pymc.util import RandomSeed
31+
from pytensor import Variable
32+
33+
34+
def laplace(
35+
vars: Sequence[Variable],
36+
draws: Optional[int] = 1000,
37+
model=None,
38+
random_seed: Optional[RandomSeed] = None,
39+
progressbar=True,
40+
):
41+
"""
42+
Create a Laplace (quadratic) approximation for a posterior distribution.
43+
44+
This function generates a Laplace approximation for a given posterior distribution using a specified
45+
number of draws. This is useful for obtaining a parametric approximation to the posterior distribution
46+
that can be used for further analysis.
47+
48+
Parameters
49+
----------
50+
vars : Sequence[Variable]
51+
A sequence of variables for which the Laplace approximation of the posterior distribution
52+
is to be created.
53+
draws : Optional[int] with default=1_000
54+
The number of draws to sample from the posterior distribution for creating the approximation.
55+
For draws=None only the fit of the Laplace approximation is returned
56+
model : object, optional, default=None
57+
The model object that defines the posterior distribution. If None, the default model will be used.
58+
random_seed : Optional[RandomSeed], optional, default=None
59+
An optional random seed to ensure reproducibility of the draws. If None, the draws will be
60+
generated using the current random state.
61+
progressbar: bool, optional defaults to True
62+
Whether to display a progress bar in the command line.
63+
64+
Returns
65+
-------
66+
arviz.InferenceData
67+
An `InferenceData` object from the `arviz` library containing the Laplace
68+
approximation of the posterior distribution. The inferenceData object also
69+
contains constant and observed data as well as deterministic variables.
70+
InferenceData also contains a group 'fit' with the mean and covariance
71+
for the Laplace approximation.
72+
73+
Examples
74+
--------
75+
76+
>>> import numpy as np
77+
>>> import pymc as pm
78+
>>> import arviz as az
79+
>>> from pymc_experimental.inference.laplace import laplace
80+
>>> y = np.array([2642, 3503, 4358]*10)
81+
>>> with pm.Model() as m:
82+
>>> logsigma = pm.Uniform("logsigma", 1, 100)
83+
>>> mu = pm.Uniform("mu", -10000, 10000)
84+
>>> yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y)
85+
>>> idata = laplace([mu, logsigma], model=m)
86+
87+
Notes
88+
-----
89+
This method of approximation may not be suitable for all types of posterior distributions,
90+
especially those with significant skewness or multimodality.
91+
92+
See Also
93+
--------
94+
fit : Calling the inference function 'fit' like pmx.fit(method="laplace", vars=[mu, logsigma], model=m)
95+
will forward the call to 'laplace'.
96+
97+
"""
98+
99+
rng = np.random.default_rng(seed=random_seed)
100+
101+
transformed_m = pm.modelcontext(model)
102+
103+
if len(vars) != len(transformed_m.free_RVs):
104+
warnings.warn(
105+
"Number of variables in vars does not eqaul the number of variables in the model.",
106+
UserWarning,
107+
)
108+
109+
map = pm.find_MAP(vars=vars, progressbar=progressbar, model=transformed_m)
110+
111+
# See https://www.pymc.io/projects/docs/en/stable/api/model/generated/pymc.model.transform.conditioning.remove_value_transforms.html
112+
untransformed_m = remove_value_transforms(transformed_m)
113+
untransformed_vars = [untransformed_m[v.name] for v in vars]
114+
hessian = pm.find_hessian(point=map, vars=untransformed_vars, model=untransformed_m)
115+
116+
if np.linalg.det(hessian) == 0:
117+
raise np.linalg.LinAlgError("Hessian is singular.")
118+
119+
cov = np.linalg.inv(hessian)
120+
mean = np.concatenate([np.atleast_1d(map[v.name]) for v in vars])
121+
122+
chains = 1
123+
124+
if draws is not None:
125+
samples = rng.multivariate_normal(mean, cov, size=(chains, draws))
126+
127+
data_vars = {}
128+
for i, var in enumerate(vars):
129+
data_vars[str(var)] = xr.DataArray(samples[:, :, i], dims=("chain", "draw"))
130+
131+
coords = {"chain": np.arange(chains), "draw": np.arange(draws)}
132+
ds = xr.Dataset(data_vars, coords=coords)
133+
134+
idata = az.convert_to_inference_data(ds)
135+
idata = addDataToInferenceData(model, idata, progressbar)
136+
else:
137+
idata = az.InferenceData()
138+
139+
idata = addFitToInferenceData(vars, idata, mean, cov)
140+
141+
return idata
142+
143+
144+
def addFitToInferenceData(vars, idata, mean, covariance):
145+
coord_names = [v.name for v in vars]
146+
# Convert to xarray DataArray
147+
mean_dataarray = xr.DataArray(mean, dims=["rows"], coords={"rows": coord_names})
148+
cov_dataarray = xr.DataArray(
149+
covariance, dims=["rows", "columns"], coords={"rows": coord_names, "columns": coord_names}
150+
)
151+
152+
# Create xarray dataset
153+
dataset = xr.Dataset({"mean_vector": mean_dataarray, "covariance_matrix": cov_dataarray})
154+
155+
idata.add_groups(fit=dataset)
156+
157+
return idata
158+
159+
160+
def addDataToInferenceData(model, trace, progressbar):
161+
# Add deterministic variables to inference data
162+
trace.posterior = pm.compute_deterministics(
163+
trace.posterior, model=model, merge_dataset=True, progressbar=progressbar
164+
)
165+
166+
coords, dims = coords_and_dims_for_inferencedata(model)
167+
168+
observed_data = dict_to_dataset(
169+
find_observations(model),
170+
library=pm,
171+
coords=coords,
172+
dims=dims,
173+
default_dims=[],
174+
)
175+
176+
constant_data = dict_to_dataset(
177+
find_constants(model),
178+
library=pm,
179+
coords=coords,
180+
dims=dims,
181+
default_dims=[],
182+
)
183+
184+
trace.add_groups(
185+
{"observed_data": observed_data, "constant_data": constant_data},
186+
coords=coords,
187+
dims=dims,
188+
)
189+
190+
return trace

Diff for: pymc_experimental/tests/test_laplace.py

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright 2024 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import numpy as np
17+
import pymc as pm
18+
import pytest
19+
20+
import pymc_experimental as pmx
21+
22+
23+
@pytest.mark.filterwarnings(
24+
"ignore:hessian will stop negating the output in a future version of PyMC.\n"
25+
+ "To suppress this warning set `negate_output=False`:FutureWarning",
26+
)
27+
def test_laplace():
28+
29+
# Example originates from Bayesian Data Analyses, 3rd Edition
30+
# By Andrew Gelman, John Carlin, Hal Stern, David Dunson,
31+
# Aki Vehtari, and Donald Rubin.
32+
# See section. 4.1
33+
34+
y = np.array([2642, 3503, 4358], dtype=np.float64)
35+
n = y.size
36+
draws = 100000
37+
38+
with pm.Model() as m:
39+
logsigma = pm.Uniform("logsigma", 1, 100)
40+
mu = pm.Uniform("mu", -10000, 10000)
41+
yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y)
42+
vars = [mu, logsigma]
43+
44+
idata = pmx.fit(
45+
method="laplace",
46+
vars=vars,
47+
model=m,
48+
draws=draws,
49+
random_seed=173300,
50+
)
51+
52+
assert idata.posterior["mu"].shape == (1, draws)
53+
assert idata.posterior["logsigma"].shape == (1, draws)
54+
assert idata.observed_data["y"].shape == (n,)
55+
assert idata.fit["mean_vector"].shape == (len(vars),)
56+
assert idata.fit["covariance_matrix"].shape == (len(vars), len(vars))
57+
58+
bda_map = [y.mean(), np.log(y.std())]
59+
bda_cov = np.array([[y.var() / n, 0], [0, 1 / (2 * n)]])
60+
61+
assert np.allclose(idata.fit["mean_vector"].values, bda_map)
62+
assert np.allclose(idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4)
63+
64+
65+
@pytest.mark.filterwarnings(
66+
"ignore:hessian will stop negating the output in a future version of PyMC.\n"
67+
+ "To suppress this warning set `negate_output=False`:FutureWarning",
68+
)
69+
def test_laplace_only_fit():
70+
71+
# Example originates from Bayesian Data Analyses, 3rd Edition
72+
# By Andrew Gelman, John Carlin, Hal Stern, David Dunson,
73+
# Aki Vehtari, and Donald Rubin.
74+
# See section. 4.1
75+
76+
y = np.array([2642, 3503, 4358], dtype=np.float64)
77+
n = y.size
78+
79+
with pm.Model() as m:
80+
logsigma = pm.Uniform("logsigma", 1, 100)
81+
mu = pm.Uniform("mu", -10000, 10000)
82+
yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y)
83+
vars = [mu, logsigma]
84+
85+
idata = pmx.fit(
86+
method="laplace",
87+
vars=vars,
88+
draws=None,
89+
model=m,
90+
random_seed=173300,
91+
)
92+
93+
assert idata.fit["mean_vector"].shape == (len(vars),)
94+
assert idata.fit["covariance_matrix"].shape == (len(vars), len(vars))
95+
96+
bda_map = [y.mean(), np.log(y.std())]
97+
bda_cov = np.array([[y.var() / n, 0], [0, 1 / (2 * n)]])
98+
99+
assert np.allclose(idata.fit["mean_vector"].values, bda_map)
100+
assert np.allclose(idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4)
101+
102+
103+
@pytest.mark.filterwarnings(
104+
"ignore:hessian will stop negating the output in a future version of PyMC.\n"
105+
+ "To suppress this warning set `negate_output=False`:FutureWarning",
106+
)
107+
def test_laplace_subset_of_rv(recwarn):
108+
109+
# Example originates from Bayesian Data Analyses, 3rd Edition
110+
# By Andrew Gelman, John Carlin, Hal Stern, David Dunson,
111+
# Aki Vehtari, and Donald Rubin.
112+
# See section. 4.1
113+
114+
y = np.array([2642, 3503, 4358], dtype=np.float64)
115+
n = y.size
116+
117+
with pm.Model() as m:
118+
logsigma = pm.Uniform("logsigma", 1, 100)
119+
mu = pm.Uniform("mu", -10000, 10000)
120+
yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y)
121+
vars = [mu]
122+
123+
idata = pmx.fit(
124+
method="laplace",
125+
vars=vars,
126+
draws=None,
127+
model=m,
128+
random_seed=173300,
129+
)
130+
131+
assert len(recwarn) == 3
132+
w = recwarn.pop(UserWarning)
133+
assert issubclass(w.category, UserWarning)
134+
assert (
135+
str(w.message)
136+
== "Number of variables in vars does not eqaul the number of variables in the model."
137+
)

0 commit comments

Comments
 (0)