Skip to content

Commit c72acd0

Browse files
committed
Update seed-dependent vb tests
1 parent 48a3e90 commit c72acd0

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

test/test_variational.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from cmdstanpy.cmdstan_args import CmdStanArgs, VariationalArgs
1616
from cmdstanpy.model import CmdStanModel
1717
from cmdstanpy.stanfit import CmdStanVB, RunSet, from_csv
18+
from cmdstanpy.utils.cmdstan import cmdstan_version_before
1819

1920
HERE = os.path.dirname(os.path.abspath(__file__))
2021
DATAFILES_PATH = os.path.join(HERE, 'data')
@@ -150,12 +151,12 @@ def test_variational_good() -> None:
150151
'mu[2]',
151152
)
152153
# fixed seed, id=1 by default will give known output values
153-
assert variational.eta == 100
154+
assert variational.eta == 10
154155
np.testing.assert_almost_equal(
155-
variational.variational_params_dict['mu[1]'], 311.545, decimal=2
156+
variational.variational_params_dict['mu[1]'], 302.142, decimal=2
156157
)
157158
np.testing.assert_almost_equal(
158-
variational.variational_params_dict['mu[2]'], 532.801, decimal=2
159+
variational.variational_params_dict['mu[2]'], 361.005, decimal=2
159160
)
160161
np.testing.assert_almost_equal(
161162
variational.variational_params_np[0],
@@ -177,7 +178,11 @@ def test_variational_eta_small() -> None:
177178
DATAFILES_PATH, 'variational', 'eta_should_be_small.stan'
178179
)
179180
model = CmdStanModel(stan_file=stan)
180-
variational = model.variational(algorithm='meanfield', seed=12345)
181+
if cmdstan_version_before(2, 35):
182+
seed = 12345
183+
else:
184+
seed = 1234
185+
variational = model.variational(algorithm='meanfield', seed=seed)
181186
assert variational.column_names == (
182187
'lp__',
183188
'log_p__',

0 commit comments

Comments
 (0)