-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Make Metropolis cope better with multiple dimensions #5823
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -168,10 +168,12 @@ def __init__( | |||||||||
vars = model.value_vars | ||||||||||
else: | ||||||||||
vars = [model.rvs_to_values.get(var, var) for var in vars] | ||||||||||
|
||||||||||
vars = pm.inputvars(vars) | ||||||||||
|
||||||||||
initial_values_shape = [initial_values[v.name].shape for v in vars] | ||||||||||
if S is None: | ||||||||||
S = np.ones(sum(initial_values[v.name].size for v in vars)) | ||||||||||
S = np.ones(int(sum(np.prod(ivs) for ivs in initial_values_shape))) | ||||||||||
|
||||||||||
if proposal_dist is not None: | ||||||||||
self.proposal_dist = proposal_dist(S) | ||||||||||
|
@@ -186,7 +188,6 @@ def __init__( | |||||||||
self.tune = tune | ||||||||||
self.tune_interval = tune_interval | ||||||||||
self.steps_until_tune = tune_interval | ||||||||||
self.accepted = 0 | ||||||||||
|
||||||||||
# Determine type of variables | ||||||||||
self.discrete = np.concatenate( | ||||||||||
|
@@ -195,11 +196,33 @@ def __init__( | |||||||||
self.any_discrete = self.discrete.any() | ||||||||||
self.all_discrete = self.discrete.all() | ||||||||||
|
||||||||||
# remember initial settings before tuning so they can be reset | ||||||||||
self._untuned_settings = dict( | ||||||||||
scaling=self.scaling, steps_until_tune=tune_interval, accepted=self.accepted | ||||||||||
# Metropolis will try to handle one batched dimension at a time This, however, | ||||||||||
# is not safe for discrete multivariate distributions (looking at you Multinomial), | ||||||||||
# due to high dependency among the support dimensions. For continuous multivariate | ||||||||||
# distributions we assume they are being transformed in a way that makes each | ||||||||||
# dimension semi-independent. | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
is_scalar = len(initial_values_shape) == 1 and initial_values_shape[0] == () | ||||||||||
self.elemwise_update = not ( | ||||||||||
is_scalar | ||||||||||
or ( | ||||||||||
self.any_discrete | ||||||||||
and max(getattr(model.values_to_rvs[var].owner.op, "ndim_supp", 1) for var in vars) | ||||||||||
> 0 | ||||||||||
) | ||||||||||
) | ||||||||||
if self.elemwise_update: | ||||||||||
dims = int(sum(np.prod(ivs) for ivs in initial_values_shape)) | ||||||||||
Comment on lines
+213
to
+214
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Right now (reading the diff top to bottom) I'm confused because this smells like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See my comment below. |
||||||||||
else: | ||||||||||
dims = 1 | ||||||||||
self.enum_dims = np.arange(dims, dtype=int) | ||||||||||
self.accept_rate_iter = np.zeros(dims, dtype=float) | ||||||||||
self.accepted_iter = np.zeros(dims, dtype=bool) | ||||||||||
self.accepted_sum = np.zeros(dims, dtype=int) | ||||||||||
|
||||||||||
# remember initial settings before tuning so they can be reset | ||||||||||
self._untuned_settings = dict(scaling=self.scaling, steps_until_tune=tune_interval) | ||||||||||
|
||||||||||
# TODO: This is not being used when compiling the logp function! | ||||||||||
self.mode = mode | ||||||||||
|
||||||||||
shared = pm.make_shared_replacements(initial_values, vars, model) | ||||||||||
|
@@ -210,6 +233,7 @@ def reset_tuning(self): | |||||||||
"""Resets the tuned sampler parameters to their initial values.""" | ||||||||||
for attr, initial_value in self._untuned_settings.items(): | ||||||||||
setattr(self, attr, initial_value) | ||||||||||
self.accepted_sum[:] = 0 | ||||||||||
return | ||||||||||
|
||||||||||
def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: | ||||||||||
|
@@ -219,10 +243,10 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: | |||||||||
|
||||||||||
if not self.steps_until_tune and self.tune: | ||||||||||
# Tune scaling parameter | ||||||||||
self.scaling = tune(self.scaling, self.accepted / float(self.tune_interval)) | ||||||||||
self.scaling = tune(self.scaling, self.accepted_sum / float(self.tune_interval)) | ||||||||||
# Reset counter | ||||||||||
self.steps_until_tune = self.tune_interval | ||||||||||
self.accepted = 0 | ||||||||||
self.accepted_sum[:] = 0 | ||||||||||
|
||||||||||
delta = self.proposal_dist() * self.scaling | ||||||||||
|
||||||||||
|
@@ -237,23 +261,36 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: | |||||||||
else: | ||||||||||
q = floatX(q0 + delta) | ||||||||||
|
||||||||||
accept = self.delta_logp(q, q0) | ||||||||||
q_new, accepted = metrop_select(accept, q, q0) | ||||||||||
|
||||||||||
self.accepted += accepted | ||||||||||
if self.elemwise_update: | ||||||||||
q_temp = q0.copy() | ||||||||||
# Shuffle order of updates (probably we don't need to do this in every step) | ||||||||||
np.random.shuffle(self.enum_dims) | ||||||||||
for i in self.enum_dims: | ||||||||||
q_temp[i] = q[i] | ||||||||||
accept_rate_i = self.delta_logp(q_temp, q0) | ||||||||||
q_temp_, accepted_i = metrop_select(accept_rate_i, q_temp, q0) | ||||||||||
q_temp[i] = q_temp_[i] | ||||||||||
self.accept_rate_iter[i] = accept_rate_i | ||||||||||
self.accepted_iter[i] = accepted_i | ||||||||||
self.accepted_sum[i] += accepted_i | ||||||||||
q = q_temp | ||||||||||
Comment on lines
+268
to
+276
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this doing what we usually do with If not, maybe explain what the if/else blocks do in a code comment There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CompoundStep assigns one variable per step, here we are sampling one dimension within a variable (or within multiple variables) semi-independently |
||||||||||
else: | ||||||||||
accept_rate = self.delta_logp(q, q0) | ||||||||||
q, accepted = metrop_select(accept_rate, q, q0) | ||||||||||
self.accept_rate_iter = accept_rate | ||||||||||
self.accepted_iter = accepted | ||||||||||
self.accepted_sum += accepted | ||||||||||
|
||||||||||
self.steps_until_tune -= 1 | ||||||||||
|
||||||||||
stats = { | ||||||||||
"tune": self.tune, | ||||||||||
"scaling": self.scaling, | ||||||||||
"accept": np.exp(accept), | ||||||||||
"accepted": accepted, | ||||||||||
"scaling": np.mean(self.scaling), | ||||||||||
"accept": np.mean(np.exp(self.accept_rate_iter)), | ||||||||||
"accepted": np.mean(self.accepted_iter), | ||||||||||
} | ||||||||||
|
||||||||||
q_new = RaveledVars(q_new, point_map_info) | ||||||||||
|
||||||||||
return q_new, [stats] | ||||||||||
return RaveledVars(q, point_map_info), [stats] | ||||||||||
|
||||||||||
@staticmethod | ||||||||||
def competence(var, has_grad): | ||||||||||
|
@@ -275,26 +312,38 @@ def tune(scale, acc_rate): | |||||||||
>0.95 x 10 | ||||||||||
|
||||||||||
""" | ||||||||||
if acc_rate < 0.001: | ||||||||||
return scale * np.where( | ||||||||||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
acc_rate < 0.001, | ||||||||||
# reduce by 90 percent | ||||||||||
return scale * 0.1 | ||||||||||
elif acc_rate < 0.05: | ||||||||||
# reduce by 50 percent | ||||||||||
return scale * 0.5 | ||||||||||
elif acc_rate < 0.2: | ||||||||||
# reduce by ten percent | ||||||||||
return scale * 0.9 | ||||||||||
elif acc_rate > 0.95: | ||||||||||
# increase by factor of ten | ||||||||||
return scale * 10.0 | ||||||||||
elif acc_rate > 0.75: | ||||||||||
# increase by double | ||||||||||
return scale * 2.0 | ||||||||||
elif acc_rate > 0.5: | ||||||||||
# increase by ten percent | ||||||||||
return scale * 1.1 | ||||||||||
|
||||||||||
return scale | ||||||||||
0.1, | ||||||||||
np.where( | ||||||||||
acc_rate < 0.05, | ||||||||||
# reduce by 50 percent | ||||||||||
0.5, | ||||||||||
np.where( | ||||||||||
acc_rate < 0.2, | ||||||||||
# reduce by ten percent | ||||||||||
0.9, | ||||||||||
np.where( | ||||||||||
acc_rate > 0.95, | ||||||||||
# increase by factor of ten | ||||||||||
10.0, | ||||||||||
np.where( | ||||||||||
acc_rate > 0.75, | ||||||||||
# increase by double | ||||||||||
2.0, | ||||||||||
np.where( | ||||||||||
acc_rate > 0.5, | ||||||||||
# increase by ten percent | ||||||||||
1.1, | ||||||||||
# Do not change | ||||||||||
1.0, | ||||||||||
), | ||||||||||
), | ||||||||||
), | ||||||||||
), | ||||||||||
), | ||||||||||
) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ugh, can't we do this in a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can probably index, that should be faster for large arrays because all branches of |
||||||||||
|
||||||||||
|
||||||||||
class BinaryMetropolis(ArrayStep): | ||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,7 +35,9 @@ | |
Beta, | ||
Binomial, | ||
Categorical, | ||
Dirichlet, | ||
HalfNormal, | ||
Multinomial, | ||
MvNormal, | ||
Normal, | ||
) | ||
|
@@ -174,33 +176,6 @@ def test_step_categorical(self, proposal): | |
self.check_stat(check, idata, step.__class__.__name__) | ||
|
||
|
||
class TestMetropolisProposal: | ||
def test_proposal_choice(self): | ||
with aesara.config.change_flags(mode=fast_unstable_sampling_mode): | ||
_, model, _ = mv_simple() | ||
with model: | ||
initial_point = model.initial_point() | ||
initial_point_size = sum(initial_point[n.name].size for n in model.value_vars) | ||
|
||
s = np.ones(initial_point_size) | ||
sampler = Metropolis(S=s) | ||
assert isinstance(sampler.proposal_dist, NormalProposal) | ||
s = np.diag(s) | ||
sampler = Metropolis(S=s) | ||
assert isinstance(sampler.proposal_dist, MultivariateNormalProposal) | ||
s[0, 0] = -s[0, 0] | ||
with pytest.raises(np.linalg.LinAlgError): | ||
sampler = Metropolis(S=s) | ||
|
||
def test_mv_proposal(self): | ||
np.random.seed(42) | ||
cov = np.random.randn(5, 5) | ||
cov = cov.dot(cov.T) | ||
prop = MultivariateNormalProposal(cov) | ||
samples = np.array([prop() for _ in range(10000)]) | ||
npt.assert_allclose(np.cov(samples.T), cov, rtol=0.2) | ||
|
||
|
||
class TestCompoundStep: | ||
samplers = (Metropolis, Slice, HamiltonianMC, NUTS, DEMetropolis) | ||
|
||
|
@@ -383,6 +358,31 @@ def test_parallelized_chains_are_random(self): | |
|
||
|
||
class TestMetropolis: | ||
def test_proposal_choice(self): | ||
with aesara.config.change_flags(mode=fast_unstable_sampling_mode): | ||
_, model, _ = mv_simple() | ||
with model: | ||
initial_point = model.initial_point() | ||
initial_point_size = sum(initial_point[n.name].size for n in model.value_vars) | ||
|
||
s = np.ones(initial_point_size) | ||
sampler = Metropolis(S=s) | ||
assert isinstance(sampler.proposal_dist, NormalProposal) | ||
s = np.diag(s) | ||
sampler = Metropolis(S=s) | ||
assert isinstance(sampler.proposal_dist, MultivariateNormalProposal) | ||
s[0, 0] = -s[0, 0] | ||
with pytest.raises(np.linalg.LinAlgError): | ||
sampler = Metropolis(S=s) | ||
|
||
def test_mv_proposal(self): | ||
np.random.seed(42) | ||
cov = np.random.randn(5, 5) | ||
cov = cov.dot(cov.T) | ||
prop = MultivariateNormalProposal(cov) | ||
samples = np.array([prop() for _ in range(10000)]) | ||
npt.assert_allclose(np.cov(samples.T), cov, rtol=0.2) | ||
|
||
def test_tuning_reset(self): | ||
"""Re-use of the step method instance with cores=1 must not leak tuning information between chains.""" | ||
with Model() as pmodel: | ||
|
@@ -403,6 +403,40 @@ def test_tuning_reset(self): | |
assert tuned != 0.1 | ||
np.testing.assert_array_equal(idata.sample_stats["scaling"].sel(chain=c).values, tuned) | ||
|
||
@pytest.mark.parametrize( | ||
"batched_dist", | ||
( | ||
Binomial.dist(n=5, p=0.9), # scalar case | ||
Binomial.dist(n=np.arange(40) + 1, p=np.linspace(0.1, 0.9, 40), shape=(40,)), | ||
Binomial.dist( | ||
n=(np.arange(20) + 1)[::-1], | ||
p=np.linspace(0.1, 0.9, 20), | ||
shape=( | ||
2, | ||
20, | ||
), | ||
), | ||
Dirichlet.dist(a=np.ones(3) * (np.arange(40) + 1)[:, None], shape=(40, 3)), | ||
Dirichlet.dist(a=np.ones(3) * (np.arange(20) + 1)[:, None], shape=(2, 20, 3)), | ||
), | ||
) | ||
def test_elemwise_update(self, batched_dist): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I run these new tests 10 times each, with the current new algorithm and the one proposed by @aloctavodia that we use in SMC. Things look sensible in both cases: 10 runs each with different seedsTuning method based on switch statement:
Tuning method based on distance from 0.234:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But the MLDA complains a lot about the SMC-like tuning... specially this test fails even if I increase the number of draws: https://github.com/ricardoV94/pymc/blob/3a84db6b91035a7c7f68633bdbb18c5f11efd46f/pymc/tests/test_step.py#L1184 Maybe I am just misunderstanding what that test is supposed to check There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looking at the standard deviations I'm not that impressed thb. Or do you have another application that motivated this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @michaelosthege this is not the change from now vs before the PR, it's the change between this PR and an alternative procedure that @aloctavodia suggested during this PR. Before this PR there were zero accepted transitions in all the new tests (except the first scalar case) I updated the headings to make it obvious |
||
with Model() as m: | ||
m.register_rv(batched_dist, name="batched_dist") | ||
step = pm.Metropolis([batched_dist]) | ||
assert step.elemwise_update == (batched_dist.ndim > 0) | ||
trace = pm.sample(draws=1000, chains=2, step=step, random_seed=428) | ||
|
||
assert az.rhat(trace).max()["batched_dist"].values < 1.1 | ||
assert az.ess(trace).min()["batched_dist"].values > 50 | ||
|
||
def test_multinomial_no_elemwise_update(self): | ||
with Model() as m: | ||
batched_dist = Multinomial("batched_dist", n=5, p=np.ones(4) / 4, shape=(10, 4)) | ||
with aesara.config.change_flags(mode=fast_unstable_sampling_mode): | ||
step = pm.Metropolis([batched_dist]) | ||
assert not step.elemwise_update | ||
|
||
|
||
class TestDEMetropolisZ: | ||
def test_tuning_lambda_sequential(self): | ||
|
@@ -1217,8 +1251,6 @@ def perform(self, node, inputs, outputs): | |
mout = [] | ||
coarse_models = [] | ||
|
||
rng = np.random.RandomState(seed) | ||
|
||
with Model() as coarse_model_0: | ||
if aesara.config.floatX == "float32": | ||
Q = Data("Q", np.float32(0.0)) | ||
|
@@ -1236,8 +1268,6 @@ def perform(self, node, inputs, outputs): | |
|
||
coarse_models.append(coarse_model_0) | ||
|
||
rng = np.random.RandomState(seed) | ||
|
||
with Model() as coarse_model_1: | ||
if aesara.config.floatX == "float32": | ||
Q = Data("Q", np.float32(0.0)) | ||
|
@@ -1255,8 +1285,6 @@ def perform(self, node, inputs, outputs): | |
|
||
coarse_models.append(coarse_model_1) | ||
|
||
rng = np.random.RandomState(seed) | ||
|
||
with Model() as model: | ||
if aesara.config.floatX == "float32": | ||
Q = Data("Q", np.float32(0.0)) | ||
|
@@ -1314,8 +1342,9 @@ def perform(self, node, inputs, outputs): | |
(nchains, ndraws * nsub) | ||
) | ||
Q_2_1 = np.concatenate(trace.get_sampler_stats("Q_2_1")).reshape((nchains, ndraws)) | ||
assert Q_1_0.mean(axis=1) == 0.0 | ||
assert Q_2_1.mean(axis=1) == 0.0 | ||
# This used to be a scrict zero equality! | ||
assert np.isclose(Q_1_0.mean(axis=1), 0.0, atol=1e-4) | ||
assert np.isclose(Q_2_1.mean(axis=1), 0.0, atol=1e-4) | ||
Comment on lines
+1345
to
+1347
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know who is maintaining the MLDA step method, but I had to change this strict equality... Is that okay? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nobody is actively maintaining it. Also, didn't we want to move it to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is still maintained as it was ported by the original team to V4 |
||
|
||
|
||
class TestRVsAssignmentSteps: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.