Skip to content

Supporting Population Samplers (implemented DE-MCMC) #2735

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 60 commits into from
Dec 5, 2017
Merged
Show file tree
Hide file tree
Changes from 50 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
b95af1f
catching other error types that may occur when gradients are not avai…
michaelosthege Oct 18, 2017
168c3f5
Merge branch 'master' of https://github.com/pymc-devs/pymc3
michaelosthege Nov 9, 2017
671b387
Merge branch 'master' of https://github.com/pymc-devs/pymc3
michaelosthege Nov 14, 2017
48ef144
Merge branch 'master' of https://github.com/pymc-devs/pymc3
michaelosthege Nov 16, 2017
c0ce6ec
Merge branch 'master' of https://github.com/pymc-devs/pymc3
michaelosthege Nov 21, 2017
fd16ba5
added a benchmark example for correlated dimensions
michaelosthege Nov 21, 2017
8ec0b80
marking Metropolis as COMPATIBLE for all types, added line breaks (co…
michaelosthege Nov 21, 2017
1d04f53
specifying single job to force the sample_many function
michaelosthege Nov 21, 2017
2bd6a5a
modified sampling procedure to interate chains in parallel (instead o…
michaelosthege Nov 21, 2017
df37cbb
updated description
michaelosthege Nov 21, 2017
ba7d083
indexing samplers by chain number instead of chain id
michaelosthege Nov 22, 2017
a4894bf
print transposes result table
michaelosthege Nov 22, 2017
cd14d6a
created PopulationArrayStepShared base class that allows the individu…
michaelosthege Nov 23, 2017
dbc42bc
modified sampling loop to account for the PopulationArrayStepShared s…
michaelosthege Nov 23, 2017
0ca3804
added the DEMetropolis sampler
michaelosthege Nov 23, 2017
2206b45
raisig an error when the population is too small
michaelosthege Nov 23, 2017
6e63f49
verbose debug logging
michaelosthege Nov 23, 2017
9395de3
removed debug print
michaelosthege Nov 23, 2017
6105dec
forcing CompoundStep type
michaelosthege Nov 23, 2017
0494bc0
formatting
michaelosthege Nov 23, 2017
e25ce19
setting DEMetropolis as a blocked step method
michaelosthege Nov 23, 2017
2841442
measuring the runtime, example with both 2D z-variable and two 1D x,y…
michaelosthege Nov 23, 2017
fd6a1ef
changed the initialization order such that variable transforms are ap…
michaelosthege Nov 24, 2017
cf4ee2f
fixed a bug caused by start=None
michaelosthege Nov 24, 2017
6733ae1
fixes a bug in computing lambda
michaelosthege Nov 27, 2017
368f4aa
using a Uniform proposal with low initial scale
michaelosthege Nov 27, 2017
abcb12f
renamed local variable
michaelosthege Nov 27, 2017
acf1311
logging the crossover and scaling
michaelosthege Nov 28, 2017
3b6a2d9
fixed a bug that caused step methods to not be copied
michaelosthege Nov 28, 2017
27e8263
smarter multiprocessing
michaelosthege Nov 28, 2017
d83c5f4
automatic multiprocessing decision, reporting relative sampling rates
michaelosthege Nov 28, 2017
fc4e1d0
print format
michaelosthege Nov 28, 2017
569731b
inheriting PopulationArraySharedStep from ArrayStepShared, using a bi…
michaelosthege Nov 28, 2017
8de3977
printing the number of effective samples per variable
michaelosthege Nov 28, 2017
67e07a4
docstrings and comments
michaelosthege Nov 28, 2017
5769b9d
falling back to sequential sampling if no population samplers are used
michaelosthege Nov 28, 2017
5f6c29c
removed debugging stats logging
michaelosthege Nov 28, 2017
b53b510
fixed nested if else
michaelosthege Nov 28, 2017
4af2aa5
updated print statement
michaelosthege Nov 28, 2017
e5f7ff2
fixed a bug related to bijection updating
michaelosthege Nov 29, 2017
874e6b2
docstring and comments
michaelosthege Nov 29, 2017
eba4a9d
refactoring for better clarity and less diff
michaelosthege Nov 29, 2017
f71a37c
code style
michaelosthege Nov 29, 2017
63ae017
removed unused import
michaelosthege Nov 29, 2017
e20295a
fixed a bug where Slice was preferred on multidimensional variables
michaelosthege Nov 29, 2017
acc7538
printing the stepper hierarchy, fixed a variable name, handling non-C…
michaelosthege Nov 29, 2017
c3c233c
fixed a bug where DEMetropolis assigned itself to discrete vars, fixe…
michaelosthege Nov 29, 2017
01adad8
improved code style, including Slice in comparison
michaelosthege Nov 29, 2017
44b1115
including DEMetropolis in existing tests, added test case for Populat…
michaelosthege Nov 30, 2017
9a07a43
fixes python 2.7 compatibility
michaelosthege Nov 30, 2017
e2cfbbb
Using multiprocessing to parallelize iteration of chain populations (…
michaelosthege Nov 30, 2017
30f437d
added references
michaelosthege Dec 1, 2017
4998005
added a warning that DEMetropolis is experimental
michaelosthege Dec 1, 2017
107a618
forgotten space
michaelosthege Dec 1, 2017
b94906a
modified the PopulationStepper to automatically use parallelization o…
michaelosthege Dec 1, 2017
58fa336
avoiding a reimport, disabled chain parallelization by default
michaelosthege Dec 1, 2017
3bcf0fd
increased nchains
michaelosthege Dec 1, 2017
c919742
resolving conflicts
michaelosthege Dec 4, 2017
24b4d63
resolving conflicts
michaelosthege Dec 4, 2017
7db836a
included DEMetropolis in new features
michaelosthege Dec 5, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions pymc3/examples/samplers_mvnormal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""
Comparing different samplers on a correlated bivariate normal distribution.

This example will sample a bivariate normal with Metropolis, NUTS and DEMetropolis
at two correlations (0, 0.9) and print out the effective sample sizes, runtime and
normalized effective sampling rates.
"""


import numpy as np
import time
import pandas as pd
import pymc3 as pm
import theano.tensor as tt

# with this flag one can switch between defining the bivariate normal as
# either a 2D MvNormal (USE_XY = False) split up the two dimensions into
# two variables 'x' and 'y'. The latter is recommended because it highlights
# different behaviour with respect to blocking.
USE_XY = True

def run(steppers, p):
steppers = set(steppers)
traces = {}
effn = {}
runtimes = {}

with pm.Model() as model:
if USE_XY:
x = pm.Flat('x')
y = pm.Flat('y')
mu = np.array([0.,0.])
cov = np.array([[1.,p],[p,1.]])
z = pm.MvNormal.dist(mu=mu, cov=cov, shape=(2,)).logp(tt.stack([x,y]))
pot = pm.Potential('logp_xy', z)
start = {'x': 0, 'y': 0}
else:
mu = np.array([0.,0.])
cov = np.array([[1.,p],[p,1.]])
z = pm.MvNormal('z', mu=mu, cov=cov, shape=(2,))
start={'z': [0, 0]}

for step_cls in steppers:
name = step_cls.__name__
t_start = time.time()
mt = pm.sample(
draws=10000,
chains=6,
step=step_cls(),
start=start
)
runtimes[name] = time.time() - t_start
print('{} samples across {} chains'.format(len(mt) * mt.nchains, mt.nchains))
traces[name] = mt
en = pm.diagnostics.effective_n(mt)
print('effective: {}\r\n'.format(en))
if USE_XY:
effn[name] = np.mean(en['x']) / len(mt) / mt.nchains
else:
effn[name] = np.mean(en['z']) / len(mt) / mt.nchains
return traces, effn, runtimes


if __name__ == '__main__':
methods = [
pm.Metropolis,
pm.Slice,
pm.NUTS,
pm.DEMetropolis
]
names = [c.__name__ for c in methods]

df_base = pd.DataFrame(columns=['p'] + names)
df_base['p'] = [.0,.9]
df_base = df_base.set_index('p')

df_effectiven = df_base.copy()
df_runtime = df_base.copy()
df_performance = df_base.copy()

for p in df_effectiven.index:
trace, rate, runtime = run(methods, p)
for name in names:
df_effectiven.set_value(p, name, rate[name])
df_runtime.set_value(p, name, runtime[name])
df_performance.set_value(p, name, rate[name] / runtime[name])

print('\r\nEffective sample size [0...1]')
print(df_effectiven.T.to_string(float_format='{:.3f}'.format))

print('\r\nRuntime [s]')
print(df_runtime.T.to_string(float_format='{:.1f}'.format))

if 'NUTS' in names:
print('\r\nNormalized effective sampling rate [0...1]')
df_performance = df_performance.T / df_performance.loc[0]['NUTS']
else:
print('\r\nNormalized effective sampling rate [1/s]')
df_performance = df_performance.T
print(df_performance.to_string(float_format='{:.3f}'.format))
150 changes: 145 additions & 5 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import defaultdict, Iterable
from copy import copy
import pickle

from joblib import Parallel, delayed
Expand All @@ -10,6 +11,7 @@
from .backends.base import BaseTrace, MultiTrace
from .backends.ndarray import NDArray
from .model import modelcontext, Point
from .step_methods import arraystep
from .step_methods import (NUTS, HamiltonianMC, SGFS, Metropolis, BinaryMetropolis,
BinaryGibbsMetropolis, CategoricalGibbsMetropolis,
Slice, CompoundStep)
Expand Down Expand Up @@ -143,6 +145,19 @@ def assign_step_methods(model, step=None, methods=STEP_METHODS,
return instantiate_steppers(model, steps, selected_steps, step_kwargs)


def print_step_hierarchy(s, level=0):
if isinstance(s, (list, tuple)):
pm._log.info('>' * level + 'list')
for i in s:
print_step_hierarchy(i, level+1)
elif isinstance(s, CompoundStep):
pm._log.info('>' * level + 'CompoundStep')
for i in s.methods:
print_step_hierarchy(i, level+1)
else:
pm._log.info('>' * level + '{}: {}'.format(s.__class__.__name__, s.vars))


def _cpu_count():
"""Try to guess the number of CPUs in the system.

Expand Down Expand Up @@ -357,8 +372,10 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None,
else:
step = assign_step_methods(model, step, step_kwargs=step_kwargs)

if isinstance(step, list):
step = CompoundStep(step)
if start is None:
start = [None] * chains
start = {}
if isinstance(start, dict):
start = [start] * chains

Expand All @@ -380,23 +397,36 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None,

sample_args.update(kwargs)

parallel = njobs > 1 and chains > 1
has_population_samplers = np.any([
isinstance(m, arraystep.PopulationArrayStepShared)
for m in (step.methods if isinstance(step, CompoundStep) else [step])
])
parallel = njobs > 1 and chains > 1 and not has_population_samplers
if parallel:
pm._log.info('Multiprocess sampling ({} chains in {} jobs)'.format(chains, njobs))
print_step_hierarchy(step)
try:
trace = _mp_sample(**sample_args)
except pickle.PickleError:
pm._log.warn("Could not pickle model, sampling sequentially.")
pm._log.warn("Could not pickle model, sampling singlethreaded.")
pm._log.debug('Pickling error:', exec_info=True)
parallel = False
except AttributeError as e:
if str(e).startswith("AttributeError: Can't pickle"):
pm._log.warn("Could not pickle model, sampling sequentially.")
pm._log.warn("Could not pickle model, sampling singlethreaded.")
pm._log.debug('Pickling error:', exec_info=True)
parallel = False
else:
raise
if not parallel:
trace = _sample_many(**sample_args)
if has_population_samplers:
pm._log.info('Population sampling ({} chains in 1 job)'.format(chains))
print_step_hierarchy(step)
trace = _sample_population(**sample_args)
else:
pm._log.info('Sequential sampling ({} chains in 1 job)'.format(chains))
print_step_hierarchy(step)
trace = _sample_many(**sample_args)

discard = tune if discard_tuned_samples else 0
return trace[discard:]
Expand Down Expand Up @@ -448,6 +478,23 @@ def _sample_many(draws, chain, chains, start, random_seed, **kwargs):
return MultiTrace(traces)


def _sample_population(draws, chain, chains, start, random_seed, step, tune,
model, progressbar=None, **kwargs):
# create the generator that iterates all chains in parallel
chains = [chain + c for c in range(chains)]
sampling = _iter_chains(draws, chains, step, start, tune=tune,
model=model, random_seed=random_seed)

if progressbar:
sampling = tqdm(sampling, total=draws)

latest_traces = None
for it,traces in enumerate(sampling):
latest_traces = traces
# TODO: add support for liveplot during population-sampling
return MultiTrace(latest_traces)


def _sample(chain, progressbar, random_seed, start, draws=None, step=None,
trace=None, tune=None, model=None, live_plot=False,
live_plot_kwargs=None, **kwargs):
Expand Down Expand Up @@ -580,6 +627,99 @@ def _iter_sample(draws, step, start=None, trace=None, chain=0, tune=None,
step.report._finalize(strace)


def _iter_chains(draws, chains, step, start, tune=None,
model=None, random_seed=None):
# chains contains the chain numbers, but for indexing we need indices...
nchains = len(chains)
model = modelcontext(model)
draws = int(draws)
if random_seed is not None:
np.random.seed(random_seed)
if draws < 1:
raise ValueError('Argument `draws` should be above 0.')

# The initialization of traces, samplers and points must happen in the right order:
# 1. traces are initialized and update_start_vals configures variable transforms
# 2. population of points is created
# 3. steppers are initialized and linked to the points object
# 4. traces are configured to track the sampler stats


# 1. prepare a BaseTrace for each chain
traces = [_choose_backend(None, chain, model=model) for chain in chains]
for c,strace in enumerate(traces):
# initialize the trace size and variable transforms
if len(strace) > 0:
update_start_vals(start[c], strace.point(-1), model)
else:
update_start_vals(start[c], model.test_point, model)

# 2. create a population (points) that tracks each chain
# it is updated as the chains are advanced
points = [Point(start[c], model=model) for c in range(nchains)]
updates = [None] * nchains

# 3. Set up the steppers
steppers = [None] * nchains
for c in range(nchains):
# need indepenent samplers for each chain
# it is important to copy the actual steppers (but not the delta_logp)
if isinstance(step, CompoundStep):
chainstep = CompoundStep([copy(m) for m in step.methods])
else:
chainstep = copy(step)
# link population samplers to the shared population state
for sm in (chainstep.methods if isinstance(step, CompoundStep) else [chainstep]):
if isinstance(sm, arraystep.PopulationArrayStepShared):
sm.link_population(points, c)
steppers[c] = chainstep

# 4. configure tracking of sampler stats
for c in range(nchains):
if steppers[c].generates_stats and traces[c].supports_sampler_stats:
traces[c].setup(draws, c, steppers[c].stats_dtypes)
else:
traces[c].setup(draws, c)

try:
# iterate draws of all chains
for i in range(draws):
# step each of the chains
for c in range(nchains):
if i == tune:
steppers[c] = stop_tuning(steppers[c])
updates[c] = steppers[c].step(points[c])

# apply the update to the points and record to the traces
for c,strace in enumerate(traces):
if steppers[c].generates_stats:
points[c], states = updates[c]
if strace.supports_sampler_stats:
strace.record(points[c], states)
else:
strace.record(points[c])
else:
points[c] = updates[c]
strace.record(points[c])
# yield the state of all chains in parallel
yield traces
except KeyboardInterrupt:
for c,strace in enumerate(traces):
strace.close()
if hasattr(step, 'report'):
step.report._finalize(strace)
raise
except BaseException:
for c,strace in enumerate(traces):
strace.close()
raise
else:
for c,strace in enumerate(traces):
strace.close()
if hasattr(step, 'report'):
step.report._finalize(strace)


def _choose_backend(trace, chain, shortcuts=None, **kwds):
if isinstance(trace, BaseTrace):
return trace
Expand Down
1 change: 1 addition & 0 deletions pymc3/step_methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .hmc import HamiltonianMC, NUTS

from .metropolis import Metropolis
from .metropolis import DEMetropolis
from .metropolis import BinaryMetropolis
from .metropolis import BinaryGibbsMetropolis
from .metropolis import CategoricalGibbsMetropolis
Expand Down
49 changes: 44 additions & 5 deletions pymc3/step_methods/arraystep.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,19 +150,58 @@ def __init__(self, vars, shared, blocked=True):
self.ordering = ArrayOrdering(vars)
self.shared = {str(var): shared for var, shared in shared.items()}
self.blocked = blocked
self.bij = None

def step(self, point):
for var, share in self.shared.items():
share.set_value(point[var])

bij = DictToArrayBijection(self.ordering, point)
self.bij = DictToArrayBijection(self.ordering, point)

if self.generates_stats:
apoint, stats = self.astep(bij.map(point))
return bij.rmap(apoint), stats
apoint, stats = self.astep(self.bij.map(point))
return self.bij.rmap(apoint), stats
else:
apoint = self.astep(bij.map(point))
return bij.rmap(apoint)
apoint = self.astep(self.bij.map(point))
return self.bij.rmap(apoint)


class PopulationArrayStepShared(ArrayStepShared):
"""Version of ArrayStepShared that allows samplers to access the states
of other chains in the population.

Works by linking a list of Points that is updated as the chains are iterated.
"""

def __init__(self, vars, shared, blocked=True):
"""
Parameters
----------
vars : list of sampling variables
shared : dict of theano variable -> shared variable
blocked : Boolean (default True)
"""
self.population = None
self.this_chain = None
self.other_chains = None
return super(PopulationArrayStepShared, self).__init__(vars, shared, blocked)

def link_population(self, population, chain_index):
"""Links the sampler to the population.

Parameters
----------
population : list of Points. (The elements of this list must be
replaced with current chain states in every iteration.)
chain_index : int of the index of this sampler in the population
"""
self.population = population
self.this_chain = chain_index
self.other_chains = [c for c in range(len(population)) if c != chain_index]
if not len(self.other_chains) > 1:
raise ValueError('Population is just {} + {}. This is too small. You should ' \
'increase the number of chains.'.format(self.this_chain, self.other_chains))
return


class GradientSharedStep(BlockedStep):
Expand Down
Loading