Skip to content

Commit ae6c1bb

Browse files
committed
Fix numerical issue in NUTS acceptance ratio computation
Should reduce the frequency of getting "Mass matrix contains zeros on the diagonal" during warmup and fix a bunch of issue in https://github.com/pymc-devs/pymc3/issues/3959
1 parent 8d241cd commit ae6c1bb

File tree

1 file changed

+98
-69
lines changed

1 file changed

+98
-69
lines changed

pymc3/step_methods/hmc/nuts.py

+98-69
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from pymc3.theanof import floatX
2525
from pymc3.vartypes import continuous_types
2626

27-
__all__ = ['NUTS']
27+
__all__ = ["NUTS"]
2828

2929

3030
def logbern(log_p):
@@ -33,8 +33,25 @@ def logbern(log_p):
3333
return np.log(nr.uniform()) < log_p
3434

3535

36+
def log1mexp_numpy(x):
37+
"""Return log(1 - exp(-x)).
38+
This function is numerically more stable than the naive approach.
39+
For details, see
40+
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
41+
"""
42+
return np.where(
43+
x < 0.683,
44+
np.log(-np.expm1(-x)),
45+
np.log1p(-np.exp(-x)))
46+
47+
48+
def logdiffexp(a, b):
49+
"""log(exp(a) - exp(b))"""
50+
return a + log1mexp_numpy(a - b)
51+
52+
3653
class NUTS(BaseHMC):
37-
R"""A sampler for continuous variables based on Hamiltonian mechanics.
54+
r"""A sampler for continuous variables based on Hamiltonian mechanics.
3855
3956
NUTS automatically tunes the step size and the number of steps per
4057
sample. A detailed description can be found at [1], "Algorithm 6:
@@ -84,27 +101,28 @@ class NUTS(BaseHMC):
84101
Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo.
85102
"""
86103

87-
name = 'nuts'
104+
name = "nuts"
88105

89106
default_blocked = True
90107
generates_stats = True
91-
stats_dtypes = [{
92-
'depth': np.int64,
93-
'step_size': np.float64,
94-
'tune': np.bool,
95-
'mean_tree_accept': np.float64,
96-
'step_size_bar': np.float64,
97-
'tree_size': np.float64,
98-
'diverging': np.bool,
99-
'energy_error': np.float64,
100-
'energy': np.float64,
101-
'max_energy_error': np.float64,
102-
'model_logp': np.float64,
103-
}]
104-
105-
def __init__(self, vars=None, max_treedepth=10, early_max_treedepth=8,
106-
**kwargs):
107-
R"""Set up the No-U-Turn sampler.
108+
stats_dtypes = [
109+
{
110+
"depth": np.int64,
111+
"step_size": np.float64,
112+
"tune": np.bool,
113+
"mean_tree_accept": np.float64,
114+
"step_size_bar": np.float64,
115+
"tree_size": np.float64,
116+
"diverging": np.bool,
117+
"energy_error": np.float64,
118+
"energy": np.float64,
119+
"max_energy_error": np.float64,
120+
"model_logp": np.float64,
121+
}
122+
]
123+
124+
def __init__(self, vars=None, max_treedepth=10, early_max_treedepth=8, **kwargs):
125+
r"""Set up the No-U-Turn sampler.
108126
109127
Parameters
110128
----------
@@ -184,7 +202,7 @@ def _hamiltonian_step(self, start, p0, step_size):
184202
self._reached_max_treedepth += 1
185203

186204
stats = tree.stats()
187-
accept_stat = stats['mean_tree_accept']
205+
accept_stat = stats["mean_tree_accept"]
188206
return HMCStepData(tree.proposal, accept_stat, divergence_info, stats)
189207

190208
@staticmethod
@@ -200,10 +218,11 @@ def warnings(self):
200218
n_treedepth = self._reached_max_treedepth
201219

202220
if n_samples > 0 and n_treedepth / float(n_samples) > 0.05:
203-
msg = ('The chain reached the maximum tree depth. Increase '
204-
'max_treedepth, increase target_accept or reparameterize.')
205-
warn = SamplerWarning(WarningType.TREEDEPTH, msg, 'warn',
206-
None, None, None)
221+
msg = (
222+
"The chain reached the maximum tree depth. Increase "
223+
"max_treedepth, increase target_accept or reparameterize."
224+
)
225+
warn = SamplerWarning(WarningType.TREEDEPTH, msg, "warn", None, None, None)
207226
warnings.append(warn)
208227
return warnings
209228

@@ -213,8 +232,8 @@ def warnings(self):
213232

214233
# A subtree of the binary tree built by nuts.
215234
Subtree = namedtuple(
216-
"Subtree",
217-
"left, right, p_sum, proposal, log_size, log_accept_sum, n_proposals")
235+
"Subtree", "left, right, p_sum, proposal, log_size, log_accept_sum, n_proposals"
236+
)
218237

219238

220239
class _Tree:
@@ -242,11 +261,12 @@ def __init__(self, ndim, integrator, start, step_size, Emax):
242261

243262
self.left = self.right = start
244263
self.proposal = Proposal(
245-
start.q, start.q_grad, start.energy, 1.0, start.model_logp)
264+
start.q, start.q_grad, start.energy, 1.0, start.model_logp
265+
)
246266
self.depth = 0
247267
self.log_size = 0
248268
self.log_accept_sum = -np.inf
249-
self.mean_tree_accept = 0.
269+
self.mean_tree_accept = 0.0
250270
self.n_proposals = 0
251271
self.p_sum = start.p.copy()
252272
self.max_energy_change = 0
@@ -265,15 +285,17 @@ def extend(self, direction):
265285
"""
266286
if direction > 0:
267287
tree, diverging, turning = self._build_subtree(
268-
self.right, self.depth, floatX(np.asarray(self.step_size)))
288+
self.right, self.depth, floatX(np.asarray(self.step_size))
289+
)
269290
leftmost_begin, leftmost_end = self.left, self.right
270291
rightmost_begin, rightmost_end = tree.left, tree.right
271292
leftmost_p_sum = self.p_sum
272293
rightmost_p_sum = tree.p_sum
273294
self.right = tree.right
274295
else:
275296
tree, diverging, turning = self._build_subtree(
276-
self.left, self.depth, floatX(np.asarray(-self.step_size)))
297+
self.left, self.depth, floatX(np.asarray(-self.step_size))
298+
)
277299
leftmost_begin, leftmost_end = tree.right, tree.left
278300
rightmost_begin, rightmost_end = self.left, self.right
279301
leftmost_p_sum = tree.p_sum
@@ -291,8 +313,7 @@ def extend(self, direction):
291313
self.proposal = tree.proposal
292314

293315
self.log_size = np.logaddexp(self.log_size, tree.log_size)
294-
self.log_accept_sum = np.logaddexp(self.log_accept_sum,
295-
tree.log_accept_sum)
316+
self.log_accept_sum = np.logaddexp(self.log_accept_sum, tree.log_accept_sum)
296317
self.p_sum[:] += tree.p_sum
297318

298319
# Additional turning check only when tree depth > 0 to avoid redundant work
@@ -301,10 +322,14 @@ def extend(self, direction):
301322
p_sum = self.p_sum
302323
turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0)
303324
p_sum1 = leftmost_p_sum + rightmost_begin.p
304-
turning1 = (p_sum1.dot(leftmost_begin.v) <= 0) or (p_sum1.dot(rightmost_begin.v) <= 0)
325+
turning1 = (p_sum1.dot(leftmost_begin.v) <= 0) or (
326+
p_sum1.dot(rightmost_begin.v) <= 0
327+
)
305328
p_sum2 = leftmost_end.p + rightmost_p_sum
306-
turning2 = (p_sum2.dot(leftmost_end.v) <= 0) or (p_sum2.dot(rightmost_end.v) <= 0)
307-
turning = (turning | turning1 | turning2)
329+
turning2 = (p_sum2.dot(leftmost_end.v) <= 0) or (
330+
p_sum2.dot(rightmost_end.v) <= 0
331+
)
332+
turning = turning | turning1 | turning2
308333

309334
return diverging, turning
310335

@@ -324,21 +349,23 @@ def _single_step(self, left, epsilon):
324349
if np.abs(energy_change) > np.abs(self.max_energy_change):
325350
self.max_energy_change = energy_change
326351
if np.abs(energy_change) < self.Emax:
327-
# Acceptance statistic
328-
# e^{H(q_0, p_0) - H(q_n, p_n)} max(1, e^{H(q_0, p_0) - H(q_n, p_n)})
329-
# Saturated Metropolis accept probability with Boltzmann weight
330-
# if h - H0 < 0
331-
log_p_accept = -energy_change + min(0., -energy_change)
352+
# Acceptance statistic
353+
# e^{H(q_0, p_0) - H(q_n, p_n)} max(1, e^{H(q_0, p_0) - H(q_n, p_n)})
354+
# Saturated Metropolis accept probability with Boltzmann weight
355+
# if h - H0 < 0
356+
log_p_accept = -energy_change + min(0.0, -energy_change)
332357
log_size = -energy_change
333358
proposal = Proposal(
334-
right.q, right.q_grad, right.energy, log_p_accept,
335-
right.model_logp)
336-
tree = Subtree(right, right, right.p,
337-
proposal, log_size, log_p_accept, 1)
359+
right.q, right.q_grad, right.energy, log_p_accept, right.model_logp
360+
)
361+
tree = Subtree(
362+
right, right, right.p, proposal, log_size, log_p_accept, 1
363+
)
338364
return tree, None, False
339365
else:
340-
error_msg = ("Energy change in leapfrog step is too large: %s."
341-
% energy_change)
366+
error_msg = (
367+
"Energy change in leapfrog step is too large: %s." % energy_change
368+
)
342369
error = None
343370
tree = Subtree(None, None, None, None, -np.inf, -np.inf, 1)
344371
divergance_info = DivergenceInfo(error_msg, error, left)
@@ -348,13 +375,11 @@ def _build_subtree(self, left, depth, epsilon):
348375
if depth == 0:
349376
return self._single_step(left, epsilon)
350377

351-
tree1, diverging, turning = self._build_subtree(
352-
left, depth - 1, epsilon)
378+
tree1, diverging, turning = self._build_subtree(left, depth - 1, epsilon)
353379
if diverging or turning:
354380
return tree1, diverging, turning
355381

356-
tree2, diverging, turning = self._build_subtree(
357-
tree1.right, depth - 1, epsilon)
382+
tree2, diverging, turning = self._build_subtree(tree1.right, depth - 1, epsilon)
358383

359384
left, right = tree1.left, tree2.right
360385

@@ -364,14 +389,17 @@ def _build_subtree(self, left, depth, epsilon):
364389
# Additional U turn check only when depth > 1 to avoid redundant work.
365390
if depth - 1 > 0:
366391
p_sum1 = tree1.p_sum + tree2.left.p
367-
turning1 = (p_sum1.dot(tree1.left.v) <= 0) or (p_sum1.dot(tree2.left.v) <= 0)
392+
turning1 = (p_sum1.dot(tree1.left.v) <= 0) or (
393+
p_sum1.dot(tree2.left.v) <= 0
394+
)
368395
p_sum2 = tree1.right.p + tree2.p_sum
369-
turning2 = (p_sum2.dot(tree1.right.v) <= 0) or (p_sum2.dot(tree2.right.v) <= 0)
370-
turning = (turning | turning1 | turning2)
396+
turning2 = (p_sum2.dot(tree1.right.v) <= 0) or (
397+
p_sum2.dot(tree2.right.v) <= 0
398+
)
399+
turning = turning | turning1 | turning2
371400

372401
log_size = np.logaddexp(tree1.log_size, tree2.log_size)
373-
log_accept_sum = np.logaddexp(tree1.log_accept_sum,
374-
tree2.log_accept_sum)
402+
log_accept_sum = np.logaddexp(tree1.log_accept_sum, tree2.log_accept_sum)
375403
if logbern(tree2.log_size - log_size):
376404
proposal = tree2.proposal
377405
else:
@@ -384,23 +412,24 @@ def _build_subtree(self, left, depth, epsilon):
384412

385413
n_proposals = tree1.n_proposals + tree2.n_proposals
386414

387-
tree = Subtree(left, right, p_sum, proposal,
388-
log_size, log_accept_sum, n_proposals)
415+
tree = Subtree(
416+
left, right, p_sum, proposal, log_size, log_accept_sum, n_proposals
417+
)
389418
return tree, diverging, turning
390419

391420
def stats(self):
392421
# Update accept stat if any subtrees were accepted
393422
if self.log_size > 0:
394-
# Remove contribution from initial state which is always a perfect
395-
# accept
396-
sum_weight = np.expm1(self.log_size)
397-
self.mean_tree_accept = np.exp(self.log_accept_sum) / sum_weight
423+
# Remove contribution from initial state which is always a perfect
424+
# accept
425+
log_sum_weight = logdiffexp_numpy(self.log_size, 0.)
426+
self.mean_tree_accept = np.exp(self.log_accept_sum - log_sum_weight)
398427
return {
399-
'depth': self.depth,
400-
'mean_tree_accept': self.mean_tree_accept,
401-
'energy_error': self.proposal.energy - self.start.energy,
402-
'energy': self.proposal.energy,
403-
'tree_size': self.n_proposals,
404-
'max_energy_error': self.max_energy_change,
405-
'model_logp': self.proposal.logp,
428+
"depth": self.depth,
429+
"mean_tree_accept": self.mean_tree_accept,
430+
"energy_error": self.proposal.energy - self.start.energy,
431+
"energy": self.proposal.energy,
432+
"tree_size": self.n_proposals,
433+
"max_energy_error": self.max_energy_change,
434+
"model_logp": self.proposal.logp,
406435
}

0 commit comments

Comments
 (0)