From f446af0919a60ec0ad871c63e52617d5500eb9c8 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Thu, 20 Jan 2022 19:28:18 +0100 Subject: [PATCH 1/2] Add explanation to test_sample_does_not_set_seed --- pymc/tests/test_sampling.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pymc/tests/test_sampling.py b/pymc/tests/test_sampling.py index 634ce10986..ace13e81c9 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/test_sampling.py @@ -63,11 +63,12 @@ def setup_method(self): self.model, self.start, self.step, _ = simple_init() def test_sample_does_not_set_seed(self): + # This tests that when random_seed is None, the global seed is not affected random_numbers = [] for _ in range(2): np.random.seed(1) with self.model: - pm.sample(1, tune=0, chains=1) + pm.sample(1, tune=0, chains=1, random_seed=None) random_numbers.append(np.random.random()) assert random_numbers[0] == random_numbers[1] From 0ebf04e1f37f7fb9d2a160adb91472db1b4c4e96 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Thu, 20 Jan 2022 19:28:45 +0100 Subject: [PATCH 2/2] Pass seed to `find_map` and set global seeds in `_iter_sample` and `_prepare_iter_sample` This reverts some changes in 47b61def5f24830e36714886356557a0e3918195 which wrongly disabled global seeding in some sampling contexts that still depended on it. --- pymc/sampling.py | 10 ++++++-- pymc/tests/test_sampling.py | 48 +++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 2 deletions(-) diff --git a/pymc/sampling.py b/pymc/sampling.py index cb439ad000..d7982e923e 100644 --- a/pymc/sampling.py +++ b/pymc/sampling.py @@ -969,6 +969,9 @@ def _iter_sample( if draws < 1: raise ValueError("Argument `draws` must be greater than 0.") + if random_seed is not None: + np.random.seed(random_seed) + try: step = CompoundStep(step) except TypeError: @@ -1229,6 +1232,9 @@ def _prepare_iter_population( if draws < 1: raise ValueError("Argument `draws` should be above 0.") + if random_seed is not None: + np.random.seed(random_seed) + # The initialization of traces, samplers and points must happen in the right order: # 1. population of points is created # 2. steppers are initialized and linked to the points object @@ -2511,7 +2517,7 @@ def init_nuts( cov = approx.std.eval() ** 2 potential = quadpotential.QuadPotentialDiag(cov) elif init == "advi_map": - start = pm.find_MAP(include_transformed=True) + start = pm.find_MAP(include_transformed=True, seed=seeds[0]) approx = pm.MeanField(model=model, start=start) pm.fit( random_seed=seeds[0], @@ -2526,7 +2532,7 @@ def init_nuts( cov = approx.std.eval() ** 2 potential = quadpotential.QuadPotentialDiag(cov) elif init == "map": - start = pm.find_MAP(include_transformed=True) + start = pm.find_MAP(include_transformed=True, seed=seeds[0]) cov = pm.find_hessian(point=start) initial_points = [start] * chains potential = quadpotential.QuadPotentialFull(cov) diff --git a/pymc/tests/test_sampling.py b/pymc/tests/test_sampling.py index ace13e81c9..ec7bac8395 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/test_sampling.py @@ -62,6 +62,54 @@ def setup_method(self): super().setup_method() self.model, self.start, self.step, _ = simple_init() + @pytest.mark.parametrize("init", ("jitter+adapt_diag", "advi", "map")) + @pytest.mark.parametrize("cores", (1, 2)) + @pytest.mark.parametrize( + "chains, seeds", + [ + (1, None), + (1, 1), + (1, [1]), + (2, None), + (2, 1), + (2, [1, 2]), + ], + ) + def test_random_seed(self, chains, seeds, cores, init): + with pm.Model(rng_seeder=3): + x = pm.Normal("x", 0, 10, initval="prior") + tr1 = pm.sample( + chains=chains, + random_seed=seeds, + cores=cores, + init=init, + tune=0, + draws=10, + return_inferencedata=False, + compute_convergence_checks=False, + ) + tr2 = pm.sample( + chains=chains, + random_seed=seeds, + cores=cores, + init=init, + tune=0, + draws=10, + return_inferencedata=False, + compute_convergence_checks=False, + ) + + allequal = np.all(tr1["x"] == tr2["x"]) + if seeds is None: + assert not allequal + # TODO: ADVI init methods are not correctly seeded, as they rely on the state of + # the model RandomState/Generators which is updated in place when the function + # is compiled and evaluated. This elif branch must be removed once this is fixed + elif init == "advi": + assert not allequal + else: + assert allequal + def test_sample_does_not_set_seed(self): # This tests that when random_seed is None, the global seed is not affected random_numbers = []