Skip to content

Commit 261583d

Browse files
committed
add more info to report
1 parent 43d7b6e commit 261583d

File tree

2 files changed

+12
-26
lines changed

2 files changed

+12
-26
lines changed

pymc3/smc/sample_smc.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -196,12 +196,15 @@ def sample_smc(
196196
for i in range(chains):
197197
results.append((sample_smc_int(*params, random_seed[i], i, _log)))
198198

199-
traces, log_marginal_likelihoods = zip(*results)
199+
traces, log_marginal_likelihoods, betas, accept_ratios, nsteps = zip(*results)
200200
trace = MultiTrace(traces)
201201
trace.report._n_draws = draws
202202
trace.report._n_tune = 0
203203
trace.report._t_sampling = time.time() - t1
204204
trace.report.log_marginal_likelihood = np.array(log_marginal_likelihoods)
205+
trace.report.betas = betas
206+
trace.report.accept_ratios = accept_ratios
207+
trace.report.nsteps = nsteps
205208

206209
return trace
207210

@@ -239,6 +242,9 @@ def sample_smc_int(
239242
chain=chain,
240243
)
241244
stage = 0
245+
betas = []
246+
accept_ratios = []
247+
nsteps = []
242248
smc.initialize_population()
243249
smc.setup_kernel()
244250
smc.initialize_logp()
@@ -252,5 +258,8 @@ def sample_smc_int(
252258
smc.mutate()
253259
smc.tune()
254260
stage += 1
261+
betas.append(smc.beta)
262+
accept_ratios.append(smc.acc_rate)
263+
nsteps.append(smc.n_steps)
255264

256-
return smc.posterior_to_trace()
265+
return smc.posterior_to_trace(), smc.log_marginal_likelihood, betas, accept_ratios, nsteps

pymc3/smc/smc.py

+1-24
Original file line numberDiff line numberDiff line change
@@ -229,29 +229,6 @@ def mutate(self):
229229
self.acc_per_chain = np.mean(ac_, axis=0)
230230
self.acc_rate = np.mean(ac_)
231231

232-
def posterior_to_trace_bk(self):
233-
"""
234-
Save results into a PyMC3 trace
235-
"""
236-
lenght_pos = len(self.posterior)
237-
varnames = [v.name for v in self.variables]
238-
straces = []
239-
with self.model:
240-
chain_lenght = int(lenght_pos / 10)
241-
for chain in range(10):
242-
strace = NDArray(self.model)
243-
strace.setup(chain_lenght, chain)
244-
for i in range(chain_lenght):
245-
value = []
246-
size = 0
247-
for var in varnames:
248-
shape, new_size = self.var_info[var]
249-
value.append(self.posterior[i][size : size + new_size].reshape(shape))
250-
size += new_size
251-
strace.record({k: v for k, v in zip(varnames, value)})
252-
straces.append(strace)
253-
return MultiTrace(straces)
254-
255232
def posterior_to_trace(self):
256233
"""
257234
Save results into a PyMC3 trace
@@ -270,7 +247,7 @@ def posterior_to_trace(self):
270247
value.append(self.posterior[i][size : size + new_size].reshape(shape))
271248
size += new_size
272249
strace.record(point={k: v for k, v in zip(varnames, value)})
273-
return strace, self.log_marginal_likelihood
250+
return strace
274251

275252

276253
def logp_forw(out_vars, vars, shared):

0 commit comments

Comments
 (0)