15
15
from cmdstanpy .cmdstan_args import CmdStanArgs , VariationalArgs
16
16
from cmdstanpy .model import CmdStanModel
17
17
from cmdstanpy .stanfit import CmdStanVB , RunSet , from_csv
18
+ from cmdstanpy .utils .cmdstan import cmdstan_version_before
18
19
19
20
HERE = os .path .dirname (os .path .abspath (__file__ ))
20
21
DATAFILES_PATH = os .path .join (HERE , 'data' )
@@ -150,12 +151,12 @@ def test_variational_good() -> None:
150
151
'mu[2]' ,
151
152
)
152
153
# fixed seed, id=1 by default will give known output values
153
- assert variational .eta == 100
154
+ assert variational .eta == 10
154
155
np .testing .assert_almost_equal (
155
- variational .variational_params_dict ['mu[1]' ], 311.545 , decimal = 2
156
+ variational .variational_params_dict ['mu[1]' ], 302.142 , decimal = 2
156
157
)
157
158
np .testing .assert_almost_equal (
158
- variational .variational_params_dict ['mu[2]' ], 532.801 , decimal = 2
159
+ variational .variational_params_dict ['mu[2]' ], 361.005 , decimal = 2
159
160
)
160
161
np .testing .assert_almost_equal (
161
162
variational .variational_params_np [0 ],
@@ -177,7 +178,11 @@ def test_variational_eta_small() -> None:
177
178
DATAFILES_PATH , 'variational' , 'eta_should_be_small.stan'
178
179
)
179
180
model = CmdStanModel (stan_file = stan )
180
- variational = model .variational (algorithm = 'meanfield' , seed = 12345 )
181
+ if cmdstan_version_before (2 , 35 ):
182
+ seed = 12345
183
+ else :
184
+ seed = 1234
185
+ variational = model .variational (algorithm = 'meanfield' , seed = seed )
181
186
assert variational .column_names == (
182
187
'lp__' ,
183
188
'log_p__' ,
0 commit comments