Skip to content

Commit 9b59771

Browse files
Add check for variables in step samplers (#6524)
* added value check for step samplers * changing error message to be more informative * added test checking if variable not being in model.value_vars will trigger Value error * implemented changes from PR review * Fix rebase --------- Co-authored-by: Michael Osthege <[email protected]>
1 parent b9a3dcc commit 9b59771

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

pymc/sampling/mcmc.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,11 @@ def assign_step_methods(model, step=None, methods=None, step_kwargs=None):
168168
except TypeError:
169169
steps.append(step)
170170
for step in steps:
171+
for var in step.vars:
172+
if var not in model.value_vars:
173+
raise ValueError(
174+
f"{var} assigned to {step} sampler is not a value variable in the model. You can use `util.get_value_vars_from_user_vars` to parse user provided variables."
175+
)
171176
assigned_vars = assigned_vars.union(set(step.vars))
172177

173178
# Use competence classmethods to select step methods for remaining

tests/sampling/test_mcmc.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
NUTS,
4141
BinaryGibbsMetropolis,
4242
CategoricalGibbsMetropolis,
43+
CompoundStep,
4344
HamiltonianMC,
4445
Metropolis,
4546
Slice,
@@ -817,6 +818,23 @@ def test_modify_step_methods(self):
817818
steps = assign_step_methods(model, [])
818819
assert isinstance(steps, NUTS)
819820

821+
def test_step_vars_in_model(self):
822+
"""Test if error is raised if step variable is not found in model.value_vars"""
823+
with pm.Model() as model:
824+
c1 = pm.HalfNormal("c1")
825+
c2 = pm.HalfNormal("c2")
826+
827+
with pytensor.config.change_flags(mode=fast_unstable_sampling_mode):
828+
step1 = NUTS([c1])
829+
step2 = NUTS([c2])
830+
step2.vars = [c2]
831+
step = CompoundStep([step1, step2])
832+
with pytest.raises(
833+
ValueError,
834+
match=r".* assigned to .* sampler is not a value variable in the model. You can use `util.get_value_vars_from_user_vars` to parse user provided variables.",
835+
):
836+
assign_step_methods(model, step)
837+
820838

821839
class TestType:
822840
samplers = (Metropolis, Slice, HamiltonianMC, NUTS)

0 commit comments

Comments
 (0)