diff --git a/pymc/sampling.py b/pymc/sampling.py index cb439ad000..d7982e923e 100644 --- a/pymc/sampling.py +++ b/pymc/sampling.py @@ -969,6 +969,9 @@ def _iter_sample( if draws < 1: raise ValueError("Argument `draws` must be greater than 0.") + if random_seed is not None: + np.random.seed(random_seed) + try: step = CompoundStep(step) except TypeError: @@ -1229,6 +1232,9 @@ def _prepare_iter_population( if draws < 1: raise ValueError("Argument `draws` should be above 0.") + if random_seed is not None: + np.random.seed(random_seed) + # The initialization of traces, samplers and points must happen in the right order: # 1. population of points is created # 2. steppers are initialized and linked to the points object @@ -2511,7 +2517,7 @@ def init_nuts( cov = approx.std.eval() ** 2 potential = quadpotential.QuadPotentialDiag(cov) elif init == "advi_map": - start = pm.find_MAP(include_transformed=True) + start = pm.find_MAP(include_transformed=True, seed=seeds[0]) approx = pm.MeanField(model=model, start=start) pm.fit( random_seed=seeds[0], @@ -2526,7 +2532,7 @@ def init_nuts( cov = approx.std.eval() ** 2 potential = quadpotential.QuadPotentialDiag(cov) elif init == "map": - start = pm.find_MAP(include_transformed=True) + start = pm.find_MAP(include_transformed=True, seed=seeds[0]) cov = pm.find_hessian(point=start) initial_points = [start] * chains potential = quadpotential.QuadPotentialFull(cov) diff --git a/pymc/tests/test_sampling.py b/pymc/tests/test_sampling.py index 634ce10986..ec7bac8395 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/test_sampling.py @@ -62,12 +62,61 @@ def setup_method(self): super().setup_method() self.model, self.start, self.step, _ = simple_init() + @pytest.mark.parametrize("init", ("jitter+adapt_diag", "advi", "map")) + @pytest.mark.parametrize("cores", (1, 2)) + @pytest.mark.parametrize( + "chains, seeds", + [ + (1, None), + (1, 1), + (1, [1]), + (2, None), + (2, 1), + (2, [1, 2]), + ], + ) + def test_random_seed(self, chains, seeds, cores, init): + with pm.Model(rng_seeder=3): + x = pm.Normal("x", 0, 10, initval="prior") + tr1 = pm.sample( + chains=chains, + random_seed=seeds, + cores=cores, + init=init, + tune=0, + draws=10, + return_inferencedata=False, + compute_convergence_checks=False, + ) + tr2 = pm.sample( + chains=chains, + random_seed=seeds, + cores=cores, + init=init, + tune=0, + draws=10, + return_inferencedata=False, + compute_convergence_checks=False, + ) + + allequal = np.all(tr1["x"] == tr2["x"]) + if seeds is None: + assert not allequal + # TODO: ADVI init methods are not correctly seeded, as they rely on the state of + # the model RandomState/Generators which is updated in place when the function + # is compiled and evaluated. This elif branch must be removed once this is fixed + elif init == "advi": + assert not allequal + else: + assert allequal + def test_sample_does_not_set_seed(self): + # This tests that when random_seed is None, the global seed is not affected random_numbers = [] for _ in range(2): np.random.seed(1) with self.model: - pm.sample(1, tune=0, chains=1) + pm.sample(1, tune=0, chains=1, random_seed=None) random_numbers.append(np.random.random()) assert random_numbers[0] == random_numbers[1]