Skip to content

Commit 3164e70

Browse files
committed
Fix: properly allow multiple inits in Pathfinder
1 parent 72e9c5c commit 3164e70

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

cmdstanpy/cmdstan_args.py

+1
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,7 @@ def validate(self) -> None:
930930
if not (
931931
isinstance(self.method_args, SamplerArgs)
932932
and self.method_args.num_chains > 1
933+
or isinstance(self.method_args, PathfinderArgs)
933934
):
934935
if not os.path.exists(self.inits):
935936
raise ValueError('no such file {}'.format(self.inits))

test/test_pathfinder.py

+22
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
Tests for the Pathfinder method.
33
"""
44

5+
import contextlib
6+
from io import StringIO
57
from pathlib import Path
68

79
import numpy as np
@@ -129,6 +131,26 @@ def test_pathfinder_init_sampling():
129131
assert fit.draws().shape == (1000, 4, 9)
130132

131133

134+
def test_inits_for_pathfinder():
135+
stan = DATAFILES_PATH / 'bernoulli.stan'
136+
bern_model = cmdstanpy.CmdStanModel(stan_file=stan)
137+
jdata = str(DATAFILES_PATH / 'bernoulli.data.json')
138+
bern_model.pathfinder(
139+
jdata, inits=[{"theta": 0.1}, {"theta": 0.9}], num_paths=2
140+
)
141+
142+
# second path is initialized too large!
143+
with contextlib.redirect_stdout(StringIO()) as captured:
144+
bern_model.pathfinder(
145+
jdata,
146+
inits=[{"theta": 0.1}, {"theta": 1.1}],
147+
num_paths=2,
148+
show_console=True,
149+
)
150+
151+
assert "Bounded variable is 1.1" in captured.getvalue()
152+
153+
132154
def test_pathfinder_no_psis():
133155
stan = DATAFILES_PATH / 'bernoulli.stan'
134156
bern_model = cmdstanpy.CmdStanModel(stan_file=stan)

0 commit comments

Comments
 (0)