Skip to content

Commit aa7cb73

Browse files
ColCarrolltwiecki
authored andcommitted
Refactor Slice.astep (#1371)
1 parent c147e75 commit aa7cb73

File tree

1 file changed

+29
-47
lines changed

1 file changed

+29
-47
lines changed

pymc3/step_methods/slicer.py

Lines changed: 29 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
# Modified from original implementation by Dominik Wabersich (2013)
22

3+
import numpy as np
4+
import numpy.random as nr
5+
36
from .arraystep import ArrayStep, Competence
47
from ..model import modelcontext
58
from ..theanof import inputvars
69
from ..vartypes import continuous_types
7-
from numpy import floor, abs, atleast_1d, empty, isfinite, sum, resize
8-
from numpy.random import standard_exponential, random, uniform
910

1011
__all__ = ['Slice']
1112

@@ -28,66 +29,47 @@ class Slice(ArrayStep):
2829
"""
2930
default_blocked = False
3031

31-
def __init__(self, vars=None, w=1, tune=True, model=None, **kwargs):
32-
33-
model = modelcontext(model)
32+
def __init__(self, vars=None, w=1., tune=True, model=None, **kwargs):
33+
self.model = modelcontext(model)
34+
self.w = w
35+
self.tune = tune
36+
self.w_sum = 0
37+
self.n_tunes = 0
3438

3539
if vars is None:
36-
vars = model.cont_vars
40+
vars = self.model.cont_vars
3741
vars = inputvars(vars)
3842

39-
self.w = w
40-
self.tune = tune
41-
self.w_tune = []
42-
self.model = model
43-
44-
super(Slice, self).__init__(vars, [model.fastlogp], **kwargs)
43+
super(Slice, self).__init__(vars, [self.model.fastlogp], **kwargs)
4544

4645
def astep(self, q0, logp):
47-
48-
q = q0.copy()
49-
self.w = resize(self.w, len(q))
50-
51-
y = logp(q0) - standard_exponential()
46+
self.w = np.resize(self.w, len(q0))
47+
y = logp(q0) - nr.standard_exponential()
5248

5349
# Stepping out procedure
54-
ql = q0.copy()
55-
ql -= uniform(0, self.w)
56-
qr = q0.copy()
57-
qr = ql + self.w
58-
59-
yl = logp(ql)
60-
yr = logp(qr)
61-
62-
while((y < yl).all()):
63-
ql -= self.w
64-
yl = logp(ql)
50+
q_left = q0 - nr.uniform(0, self.w)
51+
q_right = q_left + self.w
6552

66-
while((y < yr).all()):
67-
qr += self.w
68-
yr = logp(qr)
53+
while (y < logp(q_left)).all():
54+
q_left -= self.w
6955

70-
q_next = q0.copy()
71-
while True:
56+
while (y < logp(q_right)).all():
57+
q_right += self.w
7258

59+
q = nr.uniform(q_left, q_right, size=q_left.size) # new variable to avoid copies
60+
while logp(q) <= y:
7361
# Sample uniformly from slice
74-
qi = uniform(ql, qr, size=ql.size)
75-
76-
yi = logp(qi)
77-
78-
if yi > y:
79-
q = qi
80-
break
81-
elif (qi > q).all():
82-
qr = qi
83-
elif (qi < q).all():
84-
ql = qi
62+
if (q > q0).all():
63+
q_right = q
64+
elif (q < q0).all():
65+
q_left = q
66+
q = nr.uniform(q_left, q_right, size=q_left.size)
8567

8668
if self.tune:
8769
# Tune sampler parameters
88-
self.w_tune.append(abs(q0 - q))
89-
self.w = 2 * sum(self.w_tune, 0) / len(self.w_tune)
90-
70+
self.w_sum += np.abs(q0 - q)
71+
self.n_tunes += 1.
72+
self.w = 2. * self.w_sum / self.n_tunes
9173
return q
9274

9375
@staticmethod

0 commit comments

Comments
 (0)