diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 7361d3c6eb..cba911aed9 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -168,6 +168,11 @@ def assign_step_methods(model, step=None, methods=None, step_kwargs=None): except TypeError: steps.append(step) for step in steps: + for var in step.vars: + if var not in model.value_vars: + raise ValueError( + f"{var} assigned to {step} sampler is not a value variable in the model. You can use `util.get_value_vars_from_user_vars` to parse user provided variables." + ) assigned_vars = assigned_vars.union(set(step.vars)) # Use competence classmethods to select step methods for remaining diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index 44cb4fc038..950694e646 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -40,6 +40,7 @@ NUTS, BinaryGibbsMetropolis, CategoricalGibbsMetropolis, + CompoundStep, HamiltonianMC, Metropolis, Slice, @@ -817,6 +818,23 @@ def test_modify_step_methods(self): steps = assign_step_methods(model, []) assert isinstance(steps, NUTS) + def test_step_vars_in_model(self): + """Test if error is raised if step variable is not found in model.value_vars""" + with pm.Model() as model: + c1 = pm.HalfNormal("c1") + c2 = pm.HalfNormal("c2") + + with pytensor.config.change_flags(mode=fast_unstable_sampling_mode): + step1 = NUTS([c1]) + step2 = NUTS([c2]) + step2.vars = [c2] + step = CompoundStep([step1, step2]) + with pytest.raises( + ValueError, + match=r".* assigned to .* sampler is not a value variable in the model. You can use `util.get_value_vars_from_user_vars` to parse user provided variables.", + ): + assign_step_methods(model, step) + class TestType: samplers = (Metropolis, Slice, HamiltonianMC, NUTS)