Skip to content

Commit ba77d85

Browse files
authored
Change SMC metropolis kernel to independent metropolis kernel (#4115)
* change metropolis kernel for independent metropolis kernel * fix pyupgrade * fix float32 * clean code * add inline comments and update docstring * add inline comments and update docstring
1 parent 66f078c commit ba77d85

File tree

3 files changed

+40
-39
lines changed

3 files changed

+40
-39
lines changed

RELEASE-NOTES.md

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
### New features
1313
- `sample_posterior_predictive_w` can now feed on `xarray.Dataset` - e.g. from `InferenceData.posterior`. (see [#4042](https://github.com/pymc-devs/pymc3/pull/4042))
1414
- Add MLDA, a new stepper for multilevel sampling. MLDA can be used when a hierarchy of approximate posteriors of varying accuracy is available, offering improved sampling efficiency especially in high-dimensional problems and/or where gradients are not available (see [#3926](https://github.com/pymc-devs/pymc3/pull/3926))
15+
- Change SMC metropolis kernel to independent metropolis kernel [#4115](https://github.com/pymc-devs/pymc3/pull/3926))
1516

1617

1718
## PyMC3 3.9.3 (11 August 2020)

pymc3/smc/sample_smc.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def sample_smc(
3535
n_steps=25,
3636
start=None,
3737
tune_steps=True,
38-
p_acc_rate=0.99,
38+
p_acc_rate=0.85,
3939
threshold=0.5,
4040
save_sim_data=False,
4141
model=None,
@@ -61,7 +61,7 @@ def sample_smc(
6161
acceptance rate and `p_acc_rate`, the max number of steps is ``n_steps``.
6262
start: dict, or array of dict
6363
Starting point in parameter space. It should be a list of dict with length `chains`.
64-
When None (default) the starting point is sampled from the prior distribution.
64+
When None (default) the starting point is sampled from the prior distribution.
6565
tune_steps: bool
6666
Whether to compute the number of steps automatically or not. Defaults to True
6767
p_acc_rate: float
@@ -105,17 +105,18 @@ def sample_smc(
105105
106106
1. Initialize :math:`\beta` at zero and stage at zero.
107107
2. Generate N samples :math:`S_{\beta}` from the prior (because when :math `\beta = 0` the
108-
tempered posterior is the prior).
108+
tempered posterior is the prior).
109109
3. Increase :math:`\beta` in order to make the effective sample size equals some predefined
110110
value (we use :math:`Nt`, where :math:`t` is 0.5 by default).
111111
4. Compute a set of N importance weights W. The weights are computed as the ratio of the
112112
likelihoods of a sample at stage i+1 and stage i.
113113
5. Obtain :math:`S_{w}` by re-sampling according to W.
114-
6. Use W to compute the covariance for the proposal distribution.
115-
7. For stages other than 0 use the acceptance rate from the previous stage to estimate the
116-
scaling of the proposal distribution and `n_steps`.
117-
8. Run N Metropolis chains (each one of length `n_steps`), starting each one from a different
118-
sample in :math:`S_{w}`.
114+
6. Use W to compute the mean and covariance for the proposal distribution, a MVNormal.
115+
7. For stages other than 0 use the acceptance rate from the previous stage to estimate
116+
`n_steps`.
117+
8. Run N independent Metropolis-Hastings (IMH) chains (each one of length `n_steps`),
118+
starting each one from a different sample in :math:`S_{w}`. Samples are IMH as the proposal
119+
mean is the of the previous posterior stage and not the current point in parameter space.
119120
9. Repeat from step 3 until :math:`\beta \ge 1`.
120121
10. The final result is a collection of N samples from the posterior.
121122
@@ -147,8 +148,8 @@ def sample_smc(
147148
cores = 1
148149

149150
_log.info(
150-
f"Multiprocess sampling ({chains} chain{'s' if chains > 1 else ''} "
151-
f"in {cores} job{'s' if cores > 1 else ''})"
151+
f"Sampling {chains} chain{'s' if chains > 1 else ''} "
152+
f"in {cores} job{'s' if cores > 1 else ''}"
152153
)
153154

154155
if random_seed == -1:
@@ -200,11 +201,11 @@ def sample_smc(
200201
trace = MultiTrace(traces)
201202
trace.report._n_draws = draws
202203
trace.report._n_tune = 0
203-
trace.report._t_sampling = time.time() - t1
204204
trace.report.log_marginal_likelihood = np.array(log_marginal_likelihoods)
205205
trace.report.betas = betas
206206
trace.report.accept_ratios = accept_ratios
207207
trace.report.nsteps = nsteps
208+
trace.report._t_sampling = time.time() - t1
208209

209210
if save_sim_data:
210211
return trace, {modelcontext(model).observed_RVs[0].name: np.array(sim_data)}

pymc3/smc/smc.py

+27-28
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import numpy as np
1818
from scipy.special import logsumexp
19+
from scipy.stats import multivariate_normal
1920
from theano import function as theano_function
2021
import theano.tensor as tt
2122

@@ -33,7 +34,7 @@ def __init__(
3334
n_steps=25,
3435
start=None,
3536
tune_steps=True,
36-
p_acc_rate=0.99,
37+
p_acc_rate=0.85,
3738
threshold=0.5,
3839
save_sim_data=False,
3940
model=None,
@@ -42,7 +43,7 @@ def __init__(
4243
):
4344

4445
self.draws = draws
45-
self.kernel = kernel
46+
self.kernel = kernel.lower()
4647
self.n_steps = n_steps
4748
self.start = start
4849
self.tune_steps = tune_steps
@@ -62,10 +63,7 @@ def __init__(
6263
self.max_steps = n_steps
6364
self.proposed = draws * n_steps
6465
self.acc_rate = 1
65-
self.acc_per_chain = np.ones(self.draws)
6666
self.variables = inputvars(self.model.vars)
67-
self.dimension = sum(v.dsize for v in self.variables)
68-
self.scalings = np.ones(self.draws) * 2.38 / (self.dimension) ** 0.5
6967
self.weights = np.ones(self.draws) / self.draws
7068
self.log_marginal_likelihood = 0
7169
self.sim_data = []
@@ -78,7 +76,9 @@ def initialize_population(self):
7876
var_info = OrderedDict()
7977
if self.start is None:
8078
init_rnd = sample_prior_predictive(
81-
self.draws, var_names=[v.name for v in self.model.unobserved_RVs], model=self.model,
79+
self.draws,
80+
var_names=[v.name for v in self.model.unobserved_RVs],
81+
model=self.model,
8282
)
8383
else:
8484
init_rnd = self.start
@@ -102,7 +102,7 @@ def setup_kernel(self):
102102
"""
103103
shared = make_shared_replacements(self.variables, self.model)
104104

105-
if self.kernel.lower() == "abc":
105+
if self.kernel == "abc":
106106
factors = [var.logpt for var in self.model.free_RVs]
107107
factors += [tt.sum(factor) for factor in self.model.potentials]
108108
self.prior_logp_func = logp_forw([tt.sum(factors)], self.variables, shared)
@@ -122,7 +122,7 @@ def setup_kernel(self):
122122
self.draws,
123123
self.save_sim_data,
124124
)
125-
elif self.kernel.lower() == "metropolis":
125+
elif self.kernel == "metropolis":
126126
self.prior_logp_func = logp_forw([self.model.varlogpt], self.variables, shared)
127127
self.likelihood_logp_func = logp_forw([self.model.datalogpt], self.variables, shared)
128128

@@ -136,7 +136,7 @@ def initialize_logp(self):
136136
self.prior_logp = np.array(priors).squeeze()
137137
self.likelihood_logp = np.array(likelihoods).squeeze()
138138

139-
if self.save_sim_data:
139+
if self.kernel == "abc" and self.save_sim_data:
140140
self.sim_data = self.likelihood_logp_func.get_data()
141141

142142
def update_weights_beta(self):
@@ -180,8 +180,6 @@ def resample(self):
180180
self.prior_logp = self.prior_logp[resampling_indexes]
181181
self.likelihood_logp = self.likelihood_logp[resampling_indexes]
182182
self.posterior_logp = self.prior_logp + self.likelihood_logp * self.beta
183-
self.acc_per_chain = self.acc_per_chain[resampling_indexes]
184-
self.scalings = self.scalings[resampling_indexes]
185183
if self.save_sim_data:
186184
self.sim_data = self.sim_data[resampling_indexes]
187185

@@ -198,47 +196,48 @@ def update_proposal(self):
198196

199197
def tune(self):
200198
"""
201-
Tune scaling and n_steps based on the acceptance rate.
199+
Tune n_steps based on the acceptance rate.
202200
"""
203-
ave_scaling = np.exp(np.log(self.scalings.mean()) + (self.acc_per_chain.mean() - 0.234))
204-
self.scalings = 0.5 * (
205-
ave_scaling + np.exp(np.log(self.scalings) + (self.acc_per_chain - 0.234))
206-
)
207-
208201
if self.tune_steps:
209202
acc_rate = max(1.0 / self.proposed, self.acc_rate)
210203
self.n_steps = min(
211-
self.max_steps, max(2, int(np.log(1 - self.p_acc_rate) / np.log(1 - acc_rate))),
204+
self.max_steps,
205+
max(2, int(np.log(1 - self.p_acc_rate) / np.log(1 - acc_rate))),
212206
)
213207

214208
self.proposed = self.draws * self.n_steps
215209

216210
def mutate(self):
217211
ac_ = np.empty((self.n_steps, self.draws))
218212

219-
proposals = (
220-
np.random.multivariate_normal(
221-
np.zeros(self.dimension), self.cov, size=(self.n_steps, self.draws)
222-
)
223-
* self.scalings[:, None]
224-
)
225213
log_R = np.log(np.random.rand(self.n_steps, self.draws))
226214

215+
# The proposal distribution is a MVNormal, with mean and covariance computed from the previous tempered posterior
216+
dist = multivariate_normal(self.posterior.mean(axis=0), self.cov)
217+
227218
for n_step in range(self.n_steps):
228-
proposal = floatX(self.posterior + proposals[n_step])
219+
# The proposal is independent from the current point.
220+
# We have to take that into account to compute the Metropolis-Hastings acceptance
221+
proposal = floatX(dist.rvs(size=self.draws))
222+
proposal = proposal.reshape(len(proposal), -1)
223+
# To do that we compute the logp of moving to a new point
224+
forward = dist.logpdf(proposal)
225+
# And to going back from that new point
226+
backward = multivariate_normal(proposal.mean(axis=0), self.cov).logpdf(self.posterior)
229227
ll = np.array([self.likelihood_logp_func(prop) for prop in proposal])
230228
pl = np.array([self.prior_logp_func(prop) for prop in proposal])
231229
proposal_logp = pl + ll * self.beta
232-
accepted = log_R[n_step] < (proposal_logp - self.posterior_logp)
230+
accepted = log_R[n_step] < (
231+
(proposal_logp + backward) - (self.posterior_logp + forward)
232+
)
233233
ac_[n_step] = accepted
234234
self.posterior[accepted] = proposal[accepted]
235235
self.posterior_logp[accepted] = proposal_logp[accepted]
236236
self.prior_logp[accepted] = pl[accepted]
237237
self.likelihood_logp[accepted] = ll[accepted]
238-
if self.save_sim_data:
238+
if self.kernel == "abc" and self.save_sim_data:
239239
self.sim_data[accepted] = self.likelihood_logp_func.get_data()[accepted]
240240

241-
self.acc_per_chain = np.mean(ac_, axis=0)
242241
self.acc_rate = np.mean(ac_)
243242

244243
def posterior_to_trace(self):

0 commit comments

Comments
 (0)