Skip to content

Commit b988ba9

Browse files
madanhJunpeng Lao
authored and
Junpeng Lao
committed
Make slice sampler sample from 1D conditionals as it should (#2446)
* Make Slice sampler sample from 1D conditionals In the previous implementation it would sample jointly from non-scalar variables, and hang for when the size is high (due to low probability to get a joint sample within the slice in high-D). * slicer.py Fix broken indentation due to copypaste * Apply autopep8 * Delete a superfluous commented line * Update the master sample for Slice in test_step.py
1 parent 15b8595 commit b988ba9

File tree

2 files changed

+67
-48
lines changed

2 files changed

+67
-48
lines changed

Diff for: pymc3/step_methods/slicer.py

+33-28
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ def __init__(self, vars=None, w=1., tune=True, model=None, **kwargs):
3434
self.model = modelcontext(model)
3535
self.w = w
3636
self.tune = tune
37-
self.w_sum = 0
38-
self.n_tunes = 0
37+
self.n_tunes = 0.
3938

4039
if vars is None:
4140
vars = self.model.cont_vars
@@ -44,33 +43,39 @@ def __init__(self, vars=None, w=1., tune=True, model=None, **kwargs):
4443
super(Slice, self).__init__(vars, [self.model.fastlogp], **kwargs)
4544

4645
def astep(self, q0, logp):
47-
self.w = np.resize(self.w, len(q0))
48-
y = logp(q0) - nr.standard_exponential()
49-
50-
# Stepping out procedure
51-
q_left = q0 - nr.uniform(0, self.w)
52-
q_right = q_left + self.w
53-
54-
while (y < logp(q_left)).all():
55-
q_left -= self.w
56-
57-
while (y < logp(q_right)).all():
58-
q_right += self.w
59-
60-
q = nr.uniform(q_left, q_right, size=q_left.size) # new variable to avoid copies
61-
while logp(q) <= y:
62-
# Sample uniformly from slice
63-
if (q > q0).all():
64-
q_right = q
65-
elif (q < q0).all():
66-
q_left = q
67-
q = nr.uniform(q_left, q_right, size=q_left.size)
68-
46+
self.w = np.resize(self.w, len(q0)) # this is a repmat
47+
q = np.copy(q0) # TODO: find out if we need this
48+
ql = np.copy(q0) # l for left boundary
49+
qr = np.copy(q0) # r for right boudary
50+
for i in range(len(q0)):
51+
# uniformly sample from 0 to p(q), but in log space
52+
y = logp(q) - nr.standard_exponential()
53+
ql[i] = q[i] - nr.uniform(0, self.w[i])
54+
qr[i] = q[i] + self.w[i]
55+
# Stepping out procedure
56+
while(y <= logp(ql)): # changed lt to leq for locally uniform posteriors
57+
ql[i] -= self.w[i]
58+
while(y <= logp(qr)):
59+
qr[i] += self.w[i]
60+
61+
q[i] = nr.uniform(ql[i], qr[i])
62+
while logp(q) < y: # Changed leq to lt, to accomodate for locally flat posteriors
63+
# Sample uniformly from slice
64+
if q[i] > q0[i]:
65+
qr[i] = q[i]
66+
elif q[i] < q0[i]:
67+
ql[i] = q[i]
68+
q[i] = nr.uniform(ql[i], qr[i])
69+
70+
if self.tune: # I was under impression from MacKays lectures that slice width can be tuned without
71+
# breaking markovianness. Can we do it regardless of self.tune?(@madanh)
72+
self.w[i] = self.w[i] * (self.n_tunes / (self.n_tunes + 1)) +\
73+
(qr[i] - ql[i]) / (self.n_tunes + 1) # same as before
74+
# unobvious and important: return qr and ql to the same point
75+
qr[i] = q[i]
76+
ql[i] = q[i]
6977
if self.tune:
70-
# Tune sampler parameters
71-
self.w_sum += np.abs(q0 - q)
72-
self.n_tunes += 1.
73-
self.w = 2. * self.w_sum / self.n_tunes
78+
self.n_tunes += 1
7479
return q
7580

7681
@staticmethod

Diff for: pymc3/tests/test_step.py

+34-20
Original file line numberDiff line numberDiff line change
@@ -27,26 +27,40 @@
2727
class TestStepMethods(object): # yield test doesn't work subclassing object
2828
master_samples = {
2929
Slice: np.array([
30-
-8.13087389e-01, -3.08921856e-01, -6.79377098e-01, 6.50812585e-01, -7.63577596e-01,
31-
-8.13199793e-01, -1.63823548e+00, -7.03863676e-02, 2.05107771e+00, 1.68598170e+00,
32-
6.92463695e-01, -7.75120766e-01, -1.62296463e+00, 3.59722423e-01, -2.31421712e-01,
33-
-7.80686956e-02, -6.05860731e-01, -1.13000202e-01, 1.55675942e-01, -6.78527612e-01,
34-
6.31052333e-01, 6.09012517e-01, -1.56621643e+00, 5.04330883e-01, 3.14824082e-03,
35-
-1.31287073e+00, 4.10706927e-01, 8.93815792e-01, 8.19317020e-01, 3.71900919e-01,
36-
-2.62067312e+00, -3.47616592e+00, 1.50335041e+00, -1.05993351e+00, 2.41571723e-01,
37-
-1.06258156e+00, 5.87999429e-01, -1.78480091e-01, -3.60278680e-01, 1.90615274e-01,
38-
-1.24399204e-01, 4.03845589e-01, -1.47797573e-01, 7.90445804e-01, -1.21043819e+00,
39-
-1.33964776e+00, 1.36366329e+00, -7.50175388e-01, 9.25241839e-01, -4.17493767e-01,
40-
1.85311339e+00, -2.49715343e+00, -3.18571692e-01, -1.49099668e+00, -2.62079621e-01,
41-
-5.82376852e-01, -2.53033395e+00, 2.07580503e+00, -9.82615856e-01, 6.00517782e-01,
42-
-9.83941620e-01, -1.59014118e+00, -1.83931394e-03, -4.71163466e-01, 1.90073737e+00,
43-
-2.08929125e-01, -6.98388847e-01, 1.64502092e+00, -1.19525944e+00, 1.44424109e+00,
44-
1.52974876e+00, -5.70140077e-01, 5.08633322e-01, -1.70862492e-02, -1.69887948e-01,
45-
5.19760297e-01, -4.15149647e-01, 8.63685174e-02, -3.66805233e-01, -9.24988952e-01,
46-
2.33307122e+00, -2.60391496e-01, -5.86271814e-01, -5.01297170e-01, -1.53866195e+00,
47-
5.71285373e-01, -1.30571830e+00, 8.59587795e-01, 6.72170694e-01, 9.12433943e-01,
48-
7.04959179e-01, 8.37863464e-01, -5.24200836e-01, 1.28261340e+00, 9.08774240e-01,
49-
8.80566763e-01, 7.82911967e-01, 8.01843432e-01, 7.09251098e-01, 5.73803618e-01]),
30+
-5.95252353e-01, -1.81894861e-01, -4.98211488e-01,
31+
-1.02262800e-01, -4.26726030e-01, 1.75446860e+00,
32+
-1.30022548e+00, 8.35658004e-01, 8.95879638e-01,
33+
-8.85214481e-01, -6.63530918e-01, -8.39303080e-01,
34+
9.42792225e-01, 9.03554344e-01, 8.45254684e-01,
35+
-1.43299803e+00, 9.04897201e-01, -1.74303131e-01,
36+
-6.38611581e-01, 1.50013968e+00, 1.06864438e+00,
37+
-4.80484421e-01, -7.52199709e-01, 1.95067495e+00,
38+
-3.67960104e+00, 2.49291588e+00, -2.11039152e+00,
39+
1.61674758e-01, -1.59564182e-01, 2.19089873e-01,
40+
1.88643940e+00, 4.04098154e-01, -4.59352326e-01,
41+
-9.06370675e-01, 5.42817654e-01, 6.99040611e-03,
42+
1.66396391e-01, -4.74549281e-01, 8.19064437e-02,
43+
1.69689952e+00, -1.62667304e+00, 1.61295808e+00,
44+
1.30099144e+00, -5.46722750e-01, -7.87745494e-01,
45+
7.91027521e-01, -2.35706976e-02, 1.68824376e+00,
46+
7.10566880e-01, -7.23551374e-01, 8.85613069e-01,
47+
-1.27300146e+00, 1.80274430e+00, 9.34266276e-01,
48+
2.40427061e+00, -1.85132552e-01, 4.47234196e-01,
49+
-9.81894859e-01, -2.83399706e-01, 1.84717533e+00,
50+
-1.58593284e+00, 3.18027270e-02, 1.40566006e+00,
51+
-9.45758714e-01, 1.18813188e-01, -1.19938604e+00,
52+
-8.26038466e-01, 5.03469984e-01, -4.72742758e-01,
53+
2.27820946e-01, -1.02608915e-03, -6.02507158e-01,
54+
7.72739682e-01, 7.16064505e-01, -1.63693490e+00,
55+
-3.97161966e-01, 1.17147944e+00, -2.87796982e+00,
56+
-1.59533297e+00, 6.73096114e-01, -3.34397247e-01,
57+
1.22357427e-01, -4.57299104e-02, 1.32005771e+00,
58+
-1.29910645e+00, 8.16168850e-01, -1.47357594e+00,
59+
1.34688446e+00, 1.06377551e+00, 4.34296696e-02,
60+
8.23143354e-01, 8.40906324e-01, 1.88596864e+00,
61+
5.77120694e-01, 2.71732927e-01, -1.36217979e+00,
62+
2.41488213e+00, 4.68298379e-01, 4.86342250e-01,
63+
-8.43949966e-01]),
5064
HamiltonianMC: np.array([
5165
-0.74925631, -0.2566773 , -2.12480977, 1.64328926, -1.39315913,
5266
2.04200003, 0.00706711, 0.34240498, 0.44276674, -0.21368043,

0 commit comments

Comments
 (0)