Skip to content

Fix seeding issues in sequential sampling #5377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions pymc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,9 @@ def _iter_sample(
if draws < 1:
raise ValueError("Argument `draws` must be greater than 0.")

if random_seed is not None:
np.random.seed(random_seed)

try:
step = CompoundStep(step)
except TypeError:
Expand Down Expand Up @@ -1229,6 +1232,9 @@ def _prepare_iter_population(
if draws < 1:
raise ValueError("Argument `draws` should be above 0.")

if random_seed is not None:
np.random.seed(random_seed)

# The initialization of traces, samplers and points must happen in the right order:
# 1. population of points is created
# 2. steppers are initialized and linked to the points object
Expand Down Expand Up @@ -2511,7 +2517,7 @@ def init_nuts(
cov = approx.std.eval() ** 2
potential = quadpotential.QuadPotentialDiag(cov)
elif init == "advi_map":
start = pm.find_MAP(include_transformed=True)
start = pm.find_MAP(include_transformed=True, seed=seeds[0])
approx = pm.MeanField(model=model, start=start)
pm.fit(
random_seed=seeds[0],
Expand All @@ -2526,7 +2532,7 @@ def init_nuts(
cov = approx.std.eval() ** 2
potential = quadpotential.QuadPotentialDiag(cov)
elif init == "map":
start = pm.find_MAP(include_transformed=True)
start = pm.find_MAP(include_transformed=True, seed=seeds[0])
cov = pm.find_hessian(point=start)
initial_points = [start] * chains
potential = quadpotential.QuadPotentialFull(cov)
Expand Down
51 changes: 50 additions & 1 deletion pymc/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,61 @@ def setup_method(self):
super().setup_method()
self.model, self.start, self.step, _ = simple_init()

@pytest.mark.parametrize("init", ("jitter+adapt_diag", "advi", "map"))
@pytest.mark.parametrize("cores", (1, 2))
@pytest.mark.parametrize(
"chains, seeds",
[
(1, None),
(1, 1),
(1, [1]),
(2, None),
(2, 1),
(2, [1, 2]),
],
)
def test_random_seed(self, chains, seeds, cores, init):
with pm.Model(rng_seeder=3):
x = pm.Normal("x", 0, 10, initval="prior")
tr1 = pm.sample(
chains=chains,
random_seed=seeds,
cores=cores,
init=init,
tune=0,
draws=10,
return_inferencedata=False,
compute_convergence_checks=False,
)
tr2 = pm.sample(
chains=chains,
random_seed=seeds,
cores=cores,
init=init,
tune=0,
draws=10,
return_inferencedata=False,
compute_convergence_checks=False,
)

allequal = np.all(tr1["x"] == tr2["x"])
if seeds is None:
assert not allequal
# TODO: ADVI init methods are not correctly seeded, as they rely on the state of
# the model RandomState/Generators which is updated in place when the function
# is compiled and evaluated. This elif branch must be removed once this is fixed
elif init == "advi":
assert not allequal
else:
assert allequal

def test_sample_does_not_set_seed(self):
# This tests that when random_seed is None, the global seed is not affected
random_numbers = []
for _ in range(2):
np.random.seed(1)
with self.model:
pm.sample(1, tune=0, chains=1)
pm.sample(1, tune=0, chains=1, random_seed=None)
random_numbers.append(np.random.random())
assert random_numbers[0] == random_numbers[1]

Expand Down