Skip to content

Make Metropolis cope better with multiple dimensions #5823

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 85 additions & 36 deletions pymc/step_methods/metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,12 @@ def __init__(
vars = model.value_vars
else:
vars = [model.rvs_to_values.get(var, var) for var in vars]

vars = pm.inputvars(vars)

initial_values_shape = [initial_values[v.name].shape for v in vars]
if S is None:
S = np.ones(sum(initial_values[v.name].size for v in vars))
S = np.ones(int(sum(np.prod(ivs) for ivs in initial_values_shape)))

if proposal_dist is not None:
self.proposal_dist = proposal_dist(S)
Expand All @@ -186,7 +188,6 @@ def __init__(
self.tune = tune
self.tune_interval = tune_interval
self.steps_until_tune = tune_interval
self.accepted = 0

# Determine type of variables
self.discrete = np.concatenate(
Expand All @@ -195,11 +196,33 @@ def __init__(
self.any_discrete = self.discrete.any()
self.all_discrete = self.discrete.all()

# remember initial settings before tuning so they can be reset
self._untuned_settings = dict(
scaling=self.scaling, steps_until_tune=tune_interval, accepted=self.accepted
# Metropolis will try to handle one batched dimension at a time This, however,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Metropolis will try to handle one batched dimension at a time This, however,
# Metropolis will try to handle one batched dimension at a time. This, however,

# is not safe for discrete multivariate distributions (looking at you Multinomial),
# due to high dependency among the support dimensions. For continuous multivariate
# distributions we assume they are being transformed in a way that makes each
# dimension semi-independent.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# dimension semi-independent.
# dimension semi-independent.
# Consequently we'd like to tune the scaling based on
# acceptance rate in each dimension independently.

is_scalar = len(initial_values_shape) == 1 and initial_values_shape[0] == ()
self.elemwise_update = not (
is_scalar
or (
self.any_discrete
and max(getattr(model.values_to_rvs[var].owner.op, "ndim_supp", 1) for var in vars)
> 0
)
)
if self.elemwise_update:
dims = int(sum(np.prod(ivs) for ivs in initial_values_shape))
Comment on lines +213 to +214
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Should this be an internal attribute?
  2. Could you come up with a more informative name?

Right now (reading the diff top to bottom) I'm confused because this smells like CompoundStep, but it's only about elementwise tuning, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comment below. CompoundStep has nothing to do with this.

else:
dims = 1
self.enum_dims = np.arange(dims, dtype=int)
self.accept_rate_iter = np.zeros(dims, dtype=float)
self.accepted_iter = np.zeros(dims, dtype=bool)
self.accepted_sum = np.zeros(dims, dtype=int)

# remember initial settings before tuning so they can be reset
self._untuned_settings = dict(scaling=self.scaling, steps_until_tune=tune_interval)

# TODO: This is not being used when compiling the logp function!
self.mode = mode

shared = pm.make_shared_replacements(initial_values, vars, model)
Expand All @@ -210,6 +233,7 @@ def reset_tuning(self):
"""Resets the tuned sampler parameters to their initial values."""
for attr, initial_value in self._untuned_settings.items():
setattr(self, attr, initial_value)
self.accepted_sum[:] = 0
return

def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
Expand All @@ -219,10 +243,10 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:

if not self.steps_until_tune and self.tune:
# Tune scaling parameter
self.scaling = tune(self.scaling, self.accepted / float(self.tune_interval))
self.scaling = tune(self.scaling, self.accepted_sum / float(self.tune_interval))
# Reset counter
self.steps_until_tune = self.tune_interval
self.accepted = 0
self.accepted_sum[:] = 0

delta = self.proposal_dist() * self.scaling

Expand All @@ -237,23 +261,36 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
else:
q = floatX(q0 + delta)

accept = self.delta_logp(q, q0)
q_new, accepted = metrop_select(accept, q, q0)

self.accepted += accepted
if self.elemwise_update:
q_temp = q0.copy()
# Shuffle order of updates (probably we don't need to do this in every step)
np.random.shuffle(self.enum_dims)
for i in self.enum_dims:
q_temp[i] = q[i]
accept_rate_i = self.delta_logp(q_temp, q0)
q_temp_, accepted_i = metrop_select(accept_rate_i, q_temp, q0)
q_temp[i] = q_temp_[i]
self.accept_rate_iter[i] = accept_rate_i
self.accepted_iter[i] = accepted_i
self.accepted_sum[i] += accepted_i
q = q_temp
Comment on lines +268 to +276
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this doing what we usually do with CompoundStep?

If not, maybe explain what the if/else blocks do in a code comment

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CompoundStep assigns one variable per step, here we are sampling one dimension within a variable (or within multiple variables) semi-independently

else:
accept_rate = self.delta_logp(q, q0)
q, accepted = metrop_select(accept_rate, q, q0)
self.accept_rate_iter = accept_rate
self.accepted_iter = accepted
self.accepted_sum += accepted

self.steps_until_tune -= 1

stats = {
"tune": self.tune,
"scaling": self.scaling,
"accept": np.exp(accept),
"accepted": accepted,
"scaling": np.mean(self.scaling),
"accept": np.mean(np.exp(self.accept_rate_iter)),
"accepted": np.mean(self.accepted_iter),
}

q_new = RaveledVars(q_new, point_map_info)

return q_new, [stats]
return RaveledVars(q, point_map_info), [stats]

@staticmethod
def competence(var, has_grad):
Expand All @@ -275,26 +312,38 @@ def tune(scale, acc_rate):
>0.95 x 10

"""
if acc_rate < 0.001:
return scale * np.where(
acc_rate < 0.001,
# reduce by 90 percent
return scale * 0.1
elif acc_rate < 0.05:
# reduce by 50 percent
return scale * 0.5
elif acc_rate < 0.2:
# reduce by ten percent
return scale * 0.9
elif acc_rate > 0.95:
# increase by factor of ten
return scale * 10.0
elif acc_rate > 0.75:
# increase by double
return scale * 2.0
elif acc_rate > 0.5:
# increase by ten percent
return scale * 1.1

return scale
0.1,
np.where(
acc_rate < 0.05,
# reduce by 50 percent
0.5,
np.where(
acc_rate < 0.2,
# reduce by ten percent
0.9,
np.where(
acc_rate > 0.95,
# increase by factor of ten
10.0,
np.where(
acc_rate > 0.75,
# increase by double
2.0,
np.where(
acc_rate > 0.5,
# increase by ten percent
1.1,
# Do not change
1.0,
),
),
),
),
),
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugh, can't we do this in a for iteration with a break?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can probably index, that should be faster for large arrays because all branches of np.where are evaluated by default



class BinaryMetropolis(ArrayStep):
Expand Down
8 changes: 4 additions & 4 deletions pymc/step_methods/mlda.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,22 +787,22 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
if isinstance(self.step_method_below, MLDA):
self.base_tuning_stats = self.step_method_below.base_tuning_stats
elif isinstance(self.step_method_below, MetropolisMLDA):
self.base_tuning_stats.append({"base_scaling": self.step_method_below.scaling})
self.base_tuning_stats.append({"base_scaling": np.mean(self.step_method_below.scaling)})
elif isinstance(self.step_method_below, DEMetropolisZMLDA):
self.base_tuning_stats.append(
{
"base_scaling": self.step_method_below.scaling,
"base_scaling": np.mean(self.step_method_below.scaling),
"base_lambda": self.step_method_below.lamb,
}
)
elif isinstance(self.step_method_below, CompoundStep):
# Below method is CompoundStep
for method in self.step_method_below.methods:
if isinstance(method, MetropolisMLDA):
self.base_tuning_stats.append({"base_scaling": method.scaling})
self.base_tuning_stats.append({"base_scaling": np.mean(method.scaling)})
elif isinstance(method, DEMetropolisZMLDA):
self.base_tuning_stats.append(
{"base_scaling": method.scaling, "base_lambda": method.lamb}
{"base_scaling": np.mean(method.scaling), "base_lambda": method.lamb}
)

return q_new, [stats] + self.base_tuning_stats
Expand Down
99 changes: 64 additions & 35 deletions pymc/tests/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
Beta,
Binomial,
Categorical,
Dirichlet,
HalfNormal,
Multinomial,
MvNormal,
Normal,
)
Expand Down Expand Up @@ -174,33 +176,6 @@ def test_step_categorical(self, proposal):
self.check_stat(check, idata, step.__class__.__name__)


class TestMetropolisProposal:
def test_proposal_choice(self):
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
_, model, _ = mv_simple()
with model:
initial_point = model.initial_point()
initial_point_size = sum(initial_point[n.name].size for n in model.value_vars)

s = np.ones(initial_point_size)
sampler = Metropolis(S=s)
assert isinstance(sampler.proposal_dist, NormalProposal)
s = np.diag(s)
sampler = Metropolis(S=s)
assert isinstance(sampler.proposal_dist, MultivariateNormalProposal)
s[0, 0] = -s[0, 0]
with pytest.raises(np.linalg.LinAlgError):
sampler = Metropolis(S=s)

def test_mv_proposal(self):
np.random.seed(42)
cov = np.random.randn(5, 5)
cov = cov.dot(cov.T)
prop = MultivariateNormalProposal(cov)
samples = np.array([prop() for _ in range(10000)])
npt.assert_allclose(np.cov(samples.T), cov, rtol=0.2)


class TestCompoundStep:
samplers = (Metropolis, Slice, HamiltonianMC, NUTS, DEMetropolis)

Expand Down Expand Up @@ -383,6 +358,31 @@ def test_parallelized_chains_are_random(self):


class TestMetropolis:
def test_proposal_choice(self):
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
_, model, _ = mv_simple()
with model:
initial_point = model.initial_point()
initial_point_size = sum(initial_point[n.name].size for n in model.value_vars)

s = np.ones(initial_point_size)
sampler = Metropolis(S=s)
assert isinstance(sampler.proposal_dist, NormalProposal)
s = np.diag(s)
sampler = Metropolis(S=s)
assert isinstance(sampler.proposal_dist, MultivariateNormalProposal)
s[0, 0] = -s[0, 0]
with pytest.raises(np.linalg.LinAlgError):
sampler = Metropolis(S=s)

def test_mv_proposal(self):
np.random.seed(42)
cov = np.random.randn(5, 5)
cov = cov.dot(cov.T)
prop = MultivariateNormalProposal(cov)
samples = np.array([prop() for _ in range(10000)])
npt.assert_allclose(np.cov(samples.T), cov, rtol=0.2)

def test_tuning_reset(self):
"""Re-use of the step method instance with cores=1 must not leak tuning information between chains."""
with Model() as pmodel:
Expand All @@ -403,6 +403,40 @@ def test_tuning_reset(self):
assert tuned != 0.1
np.testing.assert_array_equal(idata.sample_stats["scaling"].sel(chain=c).values, tuned)

@pytest.mark.parametrize(
"batched_dist",
(
Binomial.dist(n=5, p=0.9), # scalar case
Binomial.dist(n=np.arange(40) + 1, p=np.linspace(0.1, 0.9, 40), shape=(40,)),
Binomial.dist(
n=(np.arange(20) + 1)[::-1],
p=np.linspace(0.1, 0.9, 20),
shape=(
2,
20,
),
),
Dirichlet.dist(a=np.ones(3) * (np.arange(40) + 1)[:, None], shape=(40, 3)),
Dirichlet.dist(a=np.ones(3) * (np.arange(20) + 1)[:, None], shape=(2, 20, 3)),
),
)
def test_elemwise_update(self, batched_dist):
Copy link
Member Author

@ricardoV94 ricardoV94 May 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I run these new tests 10 times each, with the current new algorithm and the one proposed by @aloctavodia that we use in SMC. Things look sensible in both cases:

10 runs each with different seeds

Tuning method based on switch statement:

1. (0D Bin) max_rhat(mean=1.005, std=0.005), min_ess=(mean=327.4, std=29.9)
2. (1D Bin) max_rhat(mean=1.031, std=0.007), min_ess=(mean=115.5, std=20.9)
3. (2D Bin) max_rhat(mean=1.021, std=0.005), min_ess=(mean=180.0, std=22.5)
4. (2D Dir) max_rhat(mean=1.041, std=0.010), min_ess=(mean=91.2, std=24.7)
5. (3D Dir) max_rhat(mean=1.037, std=0.008), min_ess=(mean=110.4, std=24.4)

Tuning method based on distance from 0.234:

1. (0D Bin) max_rhat(mean=1.009, std=0.008), min_ess=(mean=231.3, std=45.6)  # Worse
2. (1D Bin) max_rhat(mean=1.026, std=0.008), min_ess=(mean=206.9, std=21.5)  # Better
3. (2D Bin) max_rhat(mean=1.026, std=0.005), min_ess=(mean=188.7, std=22.3)
4. (2D Dir) max_rhat(mean=1.043, std=0.012), min_ess=(mean=98.9, std=27.8)
5. (3D Dir) max_rhat(mean=1.039, std=0.009), min_ess=(mean=102.6, std=21.5)

Copy link
Member Author

@ricardoV94 ricardoV94 May 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the MLDA complains a lot about the SMC-like tuning... specially this test fails even if I increase the number of draws: https://github.com/ricardoV94/pymc/blob/3a84db6b91035a7c7f68633bdbb18c5f11efd46f/pymc/tests/test_step.py#L1184

Maybe I am just misunderstanding what that test is supposed to check

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the standard deviations I'm not that impressed thb.
(Considering that the code is now more complicated..)

Or do you have another application that motivated this?

Copy link
Member Author

@ricardoV94 ricardoV94 May 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@michaelosthege this is not the change from now vs before the PR, it's the change between this PR and an alternative procedure that @aloctavodia suggested during this PR.

Before this PR there were zero accepted transitions in all the new tests (except the first scalar case)

I updated the headings to make it obvious

with Model() as m:
m.register_rv(batched_dist, name="batched_dist")
step = pm.Metropolis([batched_dist])
assert step.elemwise_update == (batched_dist.ndim > 0)
trace = pm.sample(draws=1000, chains=2, step=step, random_seed=428)

assert az.rhat(trace).max()["batched_dist"].values < 1.1
assert az.ess(trace).min()["batched_dist"].values > 50

def test_multinomial_no_elemwise_update(self):
with Model() as m:
batched_dist = Multinomial("batched_dist", n=5, p=np.ones(4) / 4, shape=(10, 4))
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
step = pm.Metropolis([batched_dist])
assert not step.elemwise_update


class TestDEMetropolisZ:
def test_tuning_lambda_sequential(self):
Expand Down Expand Up @@ -1217,8 +1251,6 @@ def perform(self, node, inputs, outputs):
mout = []
coarse_models = []

rng = np.random.RandomState(seed)

with Model() as coarse_model_0:
if aesara.config.floatX == "float32":
Q = Data("Q", np.float32(0.0))
Expand All @@ -1236,8 +1268,6 @@ def perform(self, node, inputs, outputs):

coarse_models.append(coarse_model_0)

rng = np.random.RandomState(seed)

with Model() as coarse_model_1:
if aesara.config.floatX == "float32":
Q = Data("Q", np.float32(0.0))
Expand All @@ -1255,8 +1285,6 @@ def perform(self, node, inputs, outputs):

coarse_models.append(coarse_model_1)

rng = np.random.RandomState(seed)

with Model() as model:
if aesara.config.floatX == "float32":
Q = Data("Q", np.float32(0.0))
Expand Down Expand Up @@ -1314,8 +1342,9 @@ def perform(self, node, inputs, outputs):
(nchains, ndraws * nsub)
)
Q_2_1 = np.concatenate(trace.get_sampler_stats("Q_2_1")).reshape((nchains, ndraws))
assert Q_1_0.mean(axis=1) == 0.0
assert Q_2_1.mean(axis=1) == 0.0
# This used to be a scrict zero equality!
assert np.isclose(Q_1_0.mean(axis=1), 0.0, atol=1e-4)
assert np.isclose(Q_2_1.mean(axis=1), 0.0, atol=1e-4)
Comment on lines +1345 to +1347
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know who is maintaining the MLDA step method, but I had to change this strict equality... Is that okay?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nobody is actively maintaining it. Also, didn't we want to move it to pymc-experimental?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is still maintained as it was ported by the original team to V4



class TestRVsAssignmentSteps:
Expand Down