24
24
from pymc3 .theanof import floatX
25
25
from pymc3 .vartypes import continuous_types
26
26
27
- __all__ = [' NUTS' ]
27
+ __all__ = [" NUTS" ]
28
28
29
29
30
30
def logbern (log_p ):
@@ -33,8 +33,25 @@ def logbern(log_p):
33
33
return np .log (nr .uniform ()) < log_p
34
34
35
35
36
+ def log1mexp_numpy (x ):
37
+ """Return log(1 - exp(-x)).
38
+ This function is numerically more stable than the naive approach.
39
+ For details, see
40
+ https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
41
+ """
42
+ return np .where (
43
+ x < 0.683 ,
44
+ np .log (- np .expm1 (- x )),
45
+ np .log1p (- np .exp (- x )))
46
+
47
+
48
+ def logdiffexp (a , b ):
49
+ """log(exp(a) - exp(b))"""
50
+ return a + log1mexp_numpy (a - b )
51
+
52
+
36
53
class NUTS (BaseHMC ):
37
- R """A sampler for continuous variables based on Hamiltonian mechanics.
54
+ r """A sampler for continuous variables based on Hamiltonian mechanics.
38
55
39
56
NUTS automatically tunes the step size and the number of steps per
40
57
sample. A detailed description can be found at [1], "Algorithm 6:
@@ -84,27 +101,28 @@ class NUTS(BaseHMC):
84
101
Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo.
85
102
"""
86
103
87
- name = ' nuts'
104
+ name = " nuts"
88
105
89
106
default_blocked = True
90
107
generates_stats = True
91
- stats_dtypes = [{
92
- 'depth' : np .int64 ,
93
- 'step_size' : np .float64 ,
94
- 'tune' : np .bool ,
95
- 'mean_tree_accept' : np .float64 ,
96
- 'step_size_bar' : np .float64 ,
97
- 'tree_size' : np .float64 ,
98
- 'diverging' : np .bool ,
99
- 'energy_error' : np .float64 ,
100
- 'energy' : np .float64 ,
101
- 'max_energy_error' : np .float64 ,
102
- 'model_logp' : np .float64 ,
103
- }]
104
-
105
- def __init__ (self , vars = None , max_treedepth = 10 , early_max_treedepth = 8 ,
106
- ** kwargs ):
107
- R"""Set up the No-U-Turn sampler.
108
+ stats_dtypes = [
109
+ {
110
+ "depth" : np .int64 ,
111
+ "step_size" : np .float64 ,
112
+ "tune" : np .bool ,
113
+ "mean_tree_accept" : np .float64 ,
114
+ "step_size_bar" : np .float64 ,
115
+ "tree_size" : np .float64 ,
116
+ "diverging" : np .bool ,
117
+ "energy_error" : np .float64 ,
118
+ "energy" : np .float64 ,
119
+ "max_energy_error" : np .float64 ,
120
+ "model_logp" : np .float64 ,
121
+ }
122
+ ]
123
+
124
+ def __init__ (self , vars = None , max_treedepth = 10 , early_max_treedepth = 8 , ** kwargs ):
125
+ r"""Set up the No-U-Turn sampler.
108
126
109
127
Parameters
110
128
----------
@@ -184,7 +202,7 @@ def _hamiltonian_step(self, start, p0, step_size):
184
202
self ._reached_max_treedepth += 1
185
203
186
204
stats = tree .stats ()
187
- accept_stat = stats [' mean_tree_accept' ]
205
+ accept_stat = stats [" mean_tree_accept" ]
188
206
return HMCStepData (tree .proposal , accept_stat , divergence_info , stats )
189
207
190
208
@staticmethod
@@ -200,10 +218,11 @@ def warnings(self):
200
218
n_treedepth = self ._reached_max_treedepth
201
219
202
220
if n_samples > 0 and n_treedepth / float (n_samples ) > 0.05 :
203
- msg = ('The chain reached the maximum tree depth. Increase '
204
- 'max_treedepth, increase target_accept or reparameterize.' )
205
- warn = SamplerWarning (WarningType .TREEDEPTH , msg , 'warn' ,
206
- None , None , None )
221
+ msg = (
222
+ "The chain reached the maximum tree depth. Increase "
223
+ "max_treedepth, increase target_accept or reparameterize."
224
+ )
225
+ warn = SamplerWarning (WarningType .TREEDEPTH , msg , "warn" , None , None , None )
207
226
warnings .append (warn )
208
227
return warnings
209
228
@@ -213,8 +232,8 @@ def warnings(self):
213
232
214
233
# A subtree of the binary tree built by nuts.
215
234
Subtree = namedtuple (
216
- "Subtree" ,
217
- "left, right, p_sum, proposal, log_size, log_accept_sum, n_proposals" )
235
+ "Subtree" , "left, right, p_sum, proposal, log_size, log_accept_sum, n_proposals"
236
+ )
218
237
219
238
220
239
class _Tree :
@@ -242,11 +261,12 @@ def __init__(self, ndim, integrator, start, step_size, Emax):
242
261
243
262
self .left = self .right = start
244
263
self .proposal = Proposal (
245
- start .q , start .q_grad , start .energy , 1.0 , start .model_logp )
264
+ start .q , start .q_grad , start .energy , 1.0 , start .model_logp
265
+ )
246
266
self .depth = 0
247
267
self .log_size = 0
248
268
self .log_accept_sum = - np .inf
249
- self .mean_tree_accept = 0.
269
+ self .mean_tree_accept = 0.0
250
270
self .n_proposals = 0
251
271
self .p_sum = start .p .copy ()
252
272
self .max_energy_change = 0
@@ -265,15 +285,17 @@ def extend(self, direction):
265
285
"""
266
286
if direction > 0 :
267
287
tree , diverging , turning = self ._build_subtree (
268
- self .right , self .depth , floatX (np .asarray (self .step_size )))
288
+ self .right , self .depth , floatX (np .asarray (self .step_size ))
289
+ )
269
290
leftmost_begin , leftmost_end = self .left , self .right
270
291
rightmost_begin , rightmost_end = tree .left , tree .right
271
292
leftmost_p_sum = self .p_sum
272
293
rightmost_p_sum = tree .p_sum
273
294
self .right = tree .right
274
295
else :
275
296
tree , diverging , turning = self ._build_subtree (
276
- self .left , self .depth , floatX (np .asarray (- self .step_size )))
297
+ self .left , self .depth , floatX (np .asarray (- self .step_size ))
298
+ )
277
299
leftmost_begin , leftmost_end = tree .right , tree .left
278
300
rightmost_begin , rightmost_end = self .left , self .right
279
301
leftmost_p_sum = tree .p_sum
@@ -291,8 +313,7 @@ def extend(self, direction):
291
313
self .proposal = tree .proposal
292
314
293
315
self .log_size = np .logaddexp (self .log_size , tree .log_size )
294
- self .log_accept_sum = np .logaddexp (self .log_accept_sum ,
295
- tree .log_accept_sum )
316
+ self .log_accept_sum = np .logaddexp (self .log_accept_sum , tree .log_accept_sum )
296
317
self .p_sum [:] += tree .p_sum
297
318
298
319
# Additional turning check only when tree depth > 0 to avoid redundant work
@@ -301,10 +322,14 @@ def extend(self, direction):
301
322
p_sum = self .p_sum
302
323
turning = (p_sum .dot (left .v ) <= 0 ) or (p_sum .dot (right .v ) <= 0 )
303
324
p_sum1 = leftmost_p_sum + rightmost_begin .p
304
- turning1 = (p_sum1 .dot (leftmost_begin .v ) <= 0 ) or (p_sum1 .dot (rightmost_begin .v ) <= 0 )
325
+ turning1 = (p_sum1 .dot (leftmost_begin .v ) <= 0 ) or (
326
+ p_sum1 .dot (rightmost_begin .v ) <= 0
327
+ )
305
328
p_sum2 = leftmost_end .p + rightmost_p_sum
306
- turning2 = (p_sum2 .dot (leftmost_end .v ) <= 0 ) or (p_sum2 .dot (rightmost_end .v ) <= 0 )
307
- turning = (turning | turning1 | turning2 )
329
+ turning2 = (p_sum2 .dot (leftmost_end .v ) <= 0 ) or (
330
+ p_sum2 .dot (rightmost_end .v ) <= 0
331
+ )
332
+ turning = turning | turning1 | turning2
308
333
309
334
return diverging , turning
310
335
@@ -324,21 +349,23 @@ def _single_step(self, left, epsilon):
324
349
if np .abs (energy_change ) > np .abs (self .max_energy_change ):
325
350
self .max_energy_change = energy_change
326
351
if np .abs (energy_change ) < self .Emax :
327
- # Acceptance statistic
328
- # e^{H(q_0, p_0) - H(q_n, p_n)} max(1, e^{H(q_0, p_0) - H(q_n, p_n)})
329
- # Saturated Metropolis accept probability with Boltzmann weight
330
- # if h - H0 < 0
331
- log_p_accept = - energy_change + min (0. , - energy_change )
352
+ # Acceptance statistic
353
+ # e^{H(q_0, p_0) - H(q_n, p_n)} max(1, e^{H(q_0, p_0) - H(q_n, p_n)})
354
+ # Saturated Metropolis accept probability with Boltzmann weight
355
+ # if h - H0 < 0
356
+ log_p_accept = - energy_change + min (0.0 , - energy_change )
332
357
log_size = - energy_change
333
358
proposal = Proposal (
334
- right .q , right .q_grad , right .energy , log_p_accept ,
335
- right .model_logp )
336
- tree = Subtree (right , right , right .p ,
337
- proposal , log_size , log_p_accept , 1 )
359
+ right .q , right .q_grad , right .energy , log_p_accept , right .model_logp
360
+ )
361
+ tree = Subtree (
362
+ right , right , right .p , proposal , log_size , log_p_accept , 1
363
+ )
338
364
return tree , None , False
339
365
else :
340
- error_msg = ("Energy change in leapfrog step is too large: %s."
341
- % energy_change )
366
+ error_msg = (
367
+ "Energy change in leapfrog step is too large: %s." % energy_change
368
+ )
342
369
error = None
343
370
tree = Subtree (None , None , None , None , - np .inf , - np .inf , 1 )
344
371
divergance_info = DivergenceInfo (error_msg , error , left )
@@ -348,13 +375,11 @@ def _build_subtree(self, left, depth, epsilon):
348
375
if depth == 0 :
349
376
return self ._single_step (left , epsilon )
350
377
351
- tree1 , diverging , turning = self ._build_subtree (
352
- left , depth - 1 , epsilon )
378
+ tree1 , diverging , turning = self ._build_subtree (left , depth - 1 , epsilon )
353
379
if diverging or turning :
354
380
return tree1 , diverging , turning
355
381
356
- tree2 , diverging , turning = self ._build_subtree (
357
- tree1 .right , depth - 1 , epsilon )
382
+ tree2 , diverging , turning = self ._build_subtree (tree1 .right , depth - 1 , epsilon )
358
383
359
384
left , right = tree1 .left , tree2 .right
360
385
@@ -364,14 +389,17 @@ def _build_subtree(self, left, depth, epsilon):
364
389
# Additional U turn check only when depth > 1 to avoid redundant work.
365
390
if depth - 1 > 0 :
366
391
p_sum1 = tree1 .p_sum + tree2 .left .p
367
- turning1 = (p_sum1 .dot (tree1 .left .v ) <= 0 ) or (p_sum1 .dot (tree2 .left .v ) <= 0 )
392
+ turning1 = (p_sum1 .dot (tree1 .left .v ) <= 0 ) or (
393
+ p_sum1 .dot (tree2 .left .v ) <= 0
394
+ )
368
395
p_sum2 = tree1 .right .p + tree2 .p_sum
369
- turning2 = (p_sum2 .dot (tree1 .right .v ) <= 0 ) or (p_sum2 .dot (tree2 .right .v ) <= 0 )
370
- turning = (turning | turning1 | turning2 )
396
+ turning2 = (p_sum2 .dot (tree1 .right .v ) <= 0 ) or (
397
+ p_sum2 .dot (tree2 .right .v ) <= 0
398
+ )
399
+ turning = turning | turning1 | turning2
371
400
372
401
log_size = np .logaddexp (tree1 .log_size , tree2 .log_size )
373
- log_accept_sum = np .logaddexp (tree1 .log_accept_sum ,
374
- tree2 .log_accept_sum )
402
+ log_accept_sum = np .logaddexp (tree1 .log_accept_sum , tree2 .log_accept_sum )
375
403
if logbern (tree2 .log_size - log_size ):
376
404
proposal = tree2 .proposal
377
405
else :
@@ -384,23 +412,24 @@ def _build_subtree(self, left, depth, epsilon):
384
412
385
413
n_proposals = tree1 .n_proposals + tree2 .n_proposals
386
414
387
- tree = Subtree (left , right , p_sum , proposal ,
388
- log_size , log_accept_sum , n_proposals )
415
+ tree = Subtree (
416
+ left , right , p_sum , proposal , log_size , log_accept_sum , n_proposals
417
+ )
389
418
return tree , diverging , turning
390
419
391
420
def stats (self ):
392
421
# Update accept stat if any subtrees were accepted
393
422
if self .log_size > 0 :
394
- # Remove contribution from initial state which is always a perfect
395
- # accept
396
- sum_weight = np . expm1 (self .log_size )
397
- self .mean_tree_accept = np .exp (self .log_accept_sum ) / sum_weight
423
+ # Remove contribution from initial state which is always a perfect
424
+ # accept
425
+ log_sum_weight = logdiffexp_numpy (self .log_size , 0. )
426
+ self .mean_tree_accept = np .exp (self .log_accept_sum - log_sum_weight )
398
427
return {
399
- ' depth' : self .depth ,
400
- ' mean_tree_accept' : self .mean_tree_accept ,
401
- ' energy_error' : self .proposal .energy - self .start .energy ,
402
- ' energy' : self .proposal .energy ,
403
- ' tree_size' : self .n_proposals ,
404
- ' max_energy_error' : self .max_energy_change ,
405
- ' model_logp' : self .proposal .logp ,
428
+ " depth" : self .depth ,
429
+ " mean_tree_accept" : self .mean_tree_accept ,
430
+ " energy_error" : self .proposal .energy - self .start .energy ,
431
+ " energy" : self .proposal .energy ,
432
+ " tree_size" : self .n_proposals ,
433
+ " max_energy_error" : self .max_energy_change ,
434
+ " model_logp" : self .proposal .logp ,
406
435
}
0 commit comments