Skip to content

Commit 6f0013f

Browse files
ozankabaktwiecki
authored andcommitted
Implemented CategoricalGibbsMetropolis, optimized BinaryGibbsMetropolis. (#1439)
* Implemented CategoricalGibbsMetropolis, optimized BinaryGibbsMetropolis. * Added unit test for CategoricalGibbsMetropolis. Also Added extra comments requested by PyMC3 devs, and fixed a minor bug in the proportional proposal utilized by CategoricalGibbsMetropolis. * Fix a minor bug in the newly added CategoricalGibbsMetropolis test.
1 parent 2f29af8 commit 6f0013f

File tree

6 files changed

+197
-32
lines changed

6 files changed

+197
-32
lines changed

pymc3/sampling.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from .backends.ndarray import NDArray
1010
from .model import modelcontext, Point
1111
from .step_methods import (NUTS, HamiltonianMC, Metropolis, BinaryMetropolis,
12-
BinaryGibbsMetropolis, Slice, ElemwiseCategorical, CompoundStep)
12+
BinaryGibbsMetropolis, CategoricalGibbsMetropolis,
13+
Slice, CompoundStep)
1314
from tqdm import tqdm
1415

1516
import sys
@@ -20,7 +21,7 @@
2021

2122
def assign_step_methods(model, step=None, methods=(NUTS, HamiltonianMC, Metropolis,
2223
BinaryMetropolis, BinaryGibbsMetropolis,
23-
Slice, ElemwiseCategorical)):
24+
Slice, CategoricalGibbsMetropolis)):
2425
'''
2526
Assign model variables to appropriate step methods. Passing a specified
2627
model will auto-assign its constituent stochastic variables to step methods

pymc3/step_methods/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .metropolis import Metropolis
66
from .metropolis import BinaryMetropolis
77
from .metropolis import BinaryGibbsMetropolis
8+
from .metropolis import CategoricalGibbsMetropolis
89
from .metropolis import NormalProposal
910
from .metropolis import CauchyProposal
1011
from .metropolis import LaplaceProposal

pymc3/step_methods/gibbs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from ..distributions.discrete import Categorical
88
from numpy import array, max, exp, cumsum, nested_iters, empty, searchsorted, ones, arange
99
from numpy.random import uniform
10+
from warnings import warn
1011

1112
from theano.gof.graph import inputs
1213
from theano.tensor import add
@@ -25,6 +26,8 @@ class ElemwiseCategorical(ArrayStep):
2526
# variables)
2627

2728
def __init__(self, vars, values=None, model=None):
29+
warn('ElemwiseCategorical is deprecated, switch to CategoricalGibbsMetropolis.',
30+
DeprecationWarning, stacklevel = 2)
2831
model = modelcontext(model)
2932
self.var = vars[0]
3033
self.sh = ones(self.var.dshape, self.var.dtype)
@@ -45,10 +48,7 @@ def competence(var):
4548
distribution = getattr(
4649
var.distribution, 'parent_dist', var.distribution)
4750
if isinstance(var.distribution, Categorical):
48-
if var.distribution.k > 2:
49-
return Competence.IDEAL
50-
else:
51-
return Competence.COMPATIBLE
51+
return Competence.COMPATIBLE
5252
return Competence.INCOMPATIBLE
5353

5454

pymc3/step_methods/metropolis.py

Lines changed: 153 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
import numpy.random as nr
33
import theano
44

5+
from ..distributions import draw_values
56
from .arraystep import ArrayStepShared, ArrayStep, metrop_select, Competence
67
import pymc3 as pm
78

89

9-
__all__ = ['Metropolis', 'BinaryMetropolis', 'BinaryGibbsMetropolis', 'NormalProposal',
10-
'CauchyProposal', 'LaplaceProposal', 'PoissonProposal', 'MultivariateNormalProposal']
10+
__all__ = ['Metropolis', 'BinaryMetropolis', 'BinaryGibbsMetropolis',
11+
'CategoricalGibbsMetropolis', 'NormalProposal', 'CauchyProposal',
12+
'LaplaceProposal', 'PoissonProposal', 'MultivariateNormalProposal']
1113

1214
# Available proposal distributions for Metropolis
1315

@@ -239,41 +241,51 @@ def competence(var):
239241

240242

241243
class BinaryGibbsMetropolis(ArrayStep):
242-
"""Metropolis-Hastings optimized for binary variables"""
244+
"""A Metropolis-within-Gibbs step method optimized for binary variables"""
243245

244246
def __init__(self, vars, order='random', model=None):
245247

246248
model = pm.modelcontext(model)
247249

248250
self.dim = sum(v.dsize for v in vars)
249-
self.order = order
251+
252+
if order == 'random':
253+
self.shuffle_dims = True
254+
self.order = list(range(self.dim))
255+
else:
256+
if sorted(order) != list(range(self.dim)):
257+
raise ValueError('Argument \'order\' has to be a permutation')
258+
self.shuffle_dims = False
259+
self.order = order
250260

251261
if not all([v.dtype in pm.discrete_types for v in vars]):
252262
raise ValueError(
253-
'All variables must be Bernoulli for BinaryGibbsMetropolis')
263+
'All variables must be binary for BinaryGibbsMetropolis')
254264

255265
super(BinaryGibbsMetropolis, self).__init__(vars, [model.fastlogp])
256266

257267
def astep(self, q0, logp):
258-
order = list(range(self.dim))
259-
if self.order == 'random':
268+
order = self.order
269+
if self.shuffle_dims:
260270
nr.shuffle(order)
261271

262-
q_prop = np.copy(q0)
263-
q_cur = np.copy(q0)
272+
q = np.copy(q0)
273+
logp_curr = logp(q)
264274

265275
for idx in order:
266-
q_prop[idx] = True - q_prop[idx]
267-
q_cur = metrop_select(logp(q_prop) - logp(q_cur), q_prop, q_cur)
268-
q_prop = np.copy(q_cur)
276+
curr_val, q[idx] = q[idx], True - q[idx]
277+
logp_prop = logp(q)
278+
q[idx] = metrop_select(logp_prop - logp_curr, q[idx], curr_val)
279+
if q[idx] != curr_val:
280+
logp_curr = logp_prop
269281

270-
return q_cur
282+
return q
271283

272284
@staticmethod
273285
def competence(var):
274286
'''
275-
BinaryMetropolis is only suitable for binary (bool)
276-
and Categorical variables with k=1.
287+
BinaryMetropolis is only suitable for Bernoulli
288+
and Categorical variables with k=2.
277289
'''
278290
distribution = getattr(
279291
var.distribution, 'parent_dist', var.distribution)
@@ -283,6 +295,132 @@ def competence(var):
283295
return Competence.IDEAL
284296
return Competence.INCOMPATIBLE
285297

298+
class CategoricalGibbsMetropolis(ArrayStep):
299+
"""A Metropolis-within-Gibbs step method optimized for categorical variables.
300+
This step method works for Bernoulli variables as well, but it is not
301+
optimized for them, like BinaryGibbsMetropolis is. Step method supports
302+
two types of proposals: A uniform proposal and a proportional proposal,
303+
which was introduced by Liu in his 1996 technical report
304+
"Metropolized Gibbs Sampler: An Improvement".
305+
"""
306+
307+
def __init__(self, vars, proposal='uniform', order='random', model=None):
308+
309+
model = pm.modelcontext(model)
310+
vars = pm.inputvars(vars)
311+
312+
dimcats = []
313+
# The above variable is a list of pairs (aggregate dimension, number
314+
# of categories). For example, if vars = [x, y] with x being a 2-D
315+
# variable with M categories and y being a 3-D variable with N
316+
# categories, we will have dimcats = [(0, M), (1, M), (2, N), (3, N), (4, N)].
317+
for v in vars:
318+
distr = getattr(v.distribution, 'parent_dist', v.distribution)
319+
if isinstance(distr, pm.Categorical):
320+
k = draw_values([distr.k])[0]
321+
elif isinstance(distr, pm.Bernoulli) or (v.dtype in pm.bool_types):
322+
k = 2
323+
else:
324+
raise ValueError('All variables must be categorical or binary' +
325+
'for CategoricalGibbsMetropolis')
326+
start = len(dimcats)
327+
dimcats += [(dim, k) for dim in range(start, start + v.dsize)]
328+
329+
if order == 'random':
330+
self.shuffle_dims = True
331+
self.dimcats = dimcats
332+
else:
333+
if sorted(order) != list(range(len(dimcats))):
334+
raise ValueError('Argument \'order\' has to be a permutation')
335+
self.shuffle_dims = False
336+
self.dimcats = [dimcats[j] for j in order]
337+
338+
if proposal == 'uniform':
339+
self.astep = self.astep_unif
340+
elif proposal == 'proportional':
341+
# Use the optimized "Metropolized Gibbs Sampler" described in Liu96.
342+
self.astep = self.astep_prop
343+
else:
344+
raise ValueError('Argument \'proposal\' should either be ' +
345+
'\'uniform\' or \'proportional\'')
346+
347+
super(CategoricalGibbsMetropolis, self).__init__(vars, [model.fastlogp])
348+
349+
def astep_unif(self, q0, logp):
350+
dimcats = self.dimcats
351+
if self.shuffle_dims:
352+
nr.shuffle(dimcats)
353+
354+
q = np.copy(q0)
355+
logp_curr = logp(q)
356+
357+
for dim, k in dimcats:
358+
curr_val, q[dim] = q[dim], sample_except(k, q[dim])
359+
logp_prop = logp(q)
360+
q[dim] = metrop_select(logp_prop - logp_curr, q[dim], curr_val)
361+
if q[dim] != curr_val:
362+
logp_curr = logp_prop
363+
364+
return q
365+
366+
def astep_prop(self, q0, logp):
367+
dimcats = self.dimcats
368+
if self.shuffle_dims:
369+
nr.shuffle(dimcats)
370+
371+
q = np.copy(q0)
372+
logp_curr = logp(q)
373+
374+
for dim, k in dimcats:
375+
logp_curr = self.metropolis_proportional(q, logp, logp_curr, dim, k)
376+
377+
return q
378+
379+
def metropolis_proportional(self, q, logp, logp_curr, dim, k):
380+
given_cat = int(q[dim])
381+
log_probs = np.zeros(k)
382+
log_probs[given_cat] = logp_curr
383+
candidates = list(range(k))
384+
for candidate_cat in candidates:
385+
if candidate_cat != given_cat:
386+
q[dim] = candidate_cat
387+
log_probs[candidate_cat] = logp(q)
388+
probs = softmax(log_probs)
389+
prob_curr, probs[given_cat] = probs[given_cat], 0.0
390+
probs /= (1.0 - prob_curr)
391+
proposed_cat = nr.choice(candidates, p = probs)
392+
accept_ratio = (1.0 - prob_curr) / (1.0 - probs[proposed_cat])
393+
if not np.isfinite(accept_ratio) or nr.uniform() >= accept_ratio:
394+
q[dim] = given_cat
395+
return logp_curr
396+
q[dim] = proposed_cat
397+
return log_probs[proposed_cat]
398+
399+
@staticmethod
400+
def competence(var):
401+
'''
402+
CategoricalGibbsMetropolis is only suitable for Bernoulli and
403+
Categorical variables.
404+
'''
405+
distribution = getattr(
406+
var.distribution, 'parent_dist', var.distribution)
407+
if isinstance(distribution, pm.Categorical):
408+
if distribution.k > 2:
409+
return Competence.IDEAL
410+
return Competence.COMPATIBLE
411+
elif isinstance(distribution, pm.Bernoulli) or (var.dtype in pm.bool_types):
412+
return Competence.COMPATIBLE
413+
return Competence.INCOMPATIBLE
414+
415+
def sample_except(limit, excluded):
416+
candidate = nr.choice(limit - 1)
417+
if candidate >= excluded:
418+
candidate += 1
419+
return candidate
420+
421+
def softmax(x):
422+
e_x = np.exp(x - np.max(x))
423+
return e_x / np.sum(e_x, axis = 0)
286424

287425
def delta_logp(logp, vars, shared):
288426
[logp0], inarray0 = pm.join_nonshared_inputs([logp], vars, shared)

pymc3/tests/models.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from pymc3 import Model, Normal, Metropolis
1+
from pymc3 import Model, Normal, Categorical, Metropolis
22
import numpy as np
33
import pymc3 as pm
44
from itertools import product
@@ -14,6 +14,17 @@ def simple_model():
1414
return model.test_point, model, (mu, tau ** -1)
1515

1616

17+
def simple_categorical():
18+
p = np.array([0.1, 0.2, 0.3, 0.4])
19+
v = np.array([0.0, 1.0, 2.0, 3.0])
20+
with Model() as model:
21+
Categorical('x', p, shape = 3, testval = [1, 2, 3])
22+
23+
mu = np.dot(p, v)
24+
var = np.dot(p, (v - mu) ** 2)
25+
return model.test_point, model, (mu, var)
26+
27+
1728
def multidimensional_model():
1829
mu = -2.1
1930
tau = 1.3

pymc3/tests/test_step.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import unittest
22

33
from .checks import close_to
4-
from .models import mv_simple, mv_simple_discrete, simple_2model
4+
from .models import simple_categorical, mv_simple, mv_simple_discrete, simple_2model
55
from pymc3.sampling import assign_step_methods, sample
66
from pymc3.model import Model
7-
from pymc3.step_methods import (NUTS, BinaryGibbsMetropolis, Metropolis, Constant, Slice,
8-
CompoundStep, MultivariateNormalProposal, HamiltonianMC)
7+
from pymc3.step_methods import (NUTS, BinaryGibbsMetropolis, CategoricalGibbsMetropolis,
8+
Metropolis, Constant, Slice, CompoundStep,
9+
MultivariateNormalProposal, HamiltonianMC)
910
from pymc3.distributions import Binomial, Normal, Bernoulli, Categorical
1011
from numpy.testing import assert_almost_equal
1112
import numpy as np
@@ -52,6 +53,21 @@ def test_step_discrete(self):
5253
trace = sample(20000, step=step, start=start, model=model, random_seed=1)
5354
self.check_stat(check, trace)
5455

56+
def test_step_categorical(self):
57+
start, model, (mu, C) = simple_categorical()
58+
unc = C ** .5
59+
check = (('x', np.mean, mu, unc / 10.),
60+
('x', np.std, unc, unc / 10.))
61+
with model:
62+
steps = (
63+
CategoricalGibbsMetropolis(model.x, proposal = 'uniform'),
64+
CategoricalGibbsMetropolis(model.x, proposal = 'proportional'),
65+
)
66+
for step in steps:
67+
trace = sample(8000, step=step, start=start, model=model, random_seed=1)
68+
self.check_stat(check, trace)
69+
70+
5571
def test_constant_step(self):
5672
with Model():
5773
x = Normal('x', 0, 1)
@@ -95,17 +111,15 @@ def test_normal(self):
95111
self.assertIsInstance(steps, NUTS)
96112

97113
def test_categorical(self):
98-
"""Test categorical distribution is assigned binary gibbs metropolis method"""
114+
"""Test categorical distribution is assigned categorical gibbs metropolis method"""
99115
with Model() as model:
100116
Categorical('x', np.array([0.25, 0.75]))
101117
steps = assign_step_methods(model, [])
102118
self.assertIsInstance(steps, BinaryGibbsMetropolis)
103-
104-
# with Model() as model:
105-
# x = Categorical('x', np.array([0.25, 0.70, 0.05]))
106-
# steps = assign_step_methods(model, [])
107-
#
108-
# assert isinstance(steps, ElemwiseCategoricalStep)
119+
with Model() as model:
120+
Categorical('y', np.array([0.25, 0.70, 0.05]))
121+
steps = assign_step_methods(model, [])
122+
self.assertIsInstance(steps, CategoricalGibbsMetropolis)
109123

110124
def test_binomial(self):
111125
"""Test binomial distribution is assigned metropolis method."""

0 commit comments

Comments
 (0)