Skip to content

Commit 1af8596

Browse files
committed
Fix tests for new cmdstan
1 parent cbea79f commit 1af8596

File tree

3 files changed

+10
-7
lines changed

3 files changed

+10
-7
lines changed

cmdstanpy/stanfit/pathfinder.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,10 @@ def is_resampled(self) -> bool:
214214
"""
215215
return ( # type: ignore
216216
self._metadata.cmdstan_config.get("num_paths", 4) > 1
217-
and self._metadata.cmdstan_config.get('psis_resample', 1) == 1
218-
and self._metadata.cmdstan_config.get('calculate_lp', 1) == 1
217+
and self._metadata.cmdstan_config.get('psis_resample', 1)
218+
in (1, 'true')
219+
and self._metadata.cmdstan_config.get('calculate_lp', 1)
220+
in (1, 'true')
219221
)
220222

221223
def save_csvfiles(self, dir: Optional[str] = None) -> None:

cmdstanpy/utils/stancsv.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def check_sampler_csv(
5252
)
5353
)
5454
if save_warmup:
55-
if not ('save_warmup' in meta and meta['save_warmup'] == 1):
55+
if not ('save_warmup' in meta and meta['save_warmup'] in (1, 'true')):
5656
raise ValueError(
5757
'bad Stan CSV file {}, '
5858
'config error, expected save_warmup = 1'.format(path)

test/test_optimize.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -262,12 +262,13 @@ def test_variables_3d() -> None:
262262
)
263263
vars_iters = multidim_mle_iters.stan_variables(inc_iterations=True)
264264
assert len(vars_iters) == len(multidim_mle_iters.metadata.stan_vars)
265+
assert 'frac_60' in vars_iters
266+
n_iter = vars_iters['frac_60'].shape[0]
267+
assert n_iter > 1
265268
assert 'y_rep' in vars_iters
266-
assert vars_iters['y_rep'].shape == (8, 5, 4, 3)
269+
assert vars_iters['y_rep'].shape == (n_iter, 5, 4, 3)
267270
assert 'beta' in vars_iters
268-
assert vars_iters['beta'].shape == (8, 2)
269-
assert 'frac_60' in vars_iters
270-
assert vars_iters['frac_60'].shape == (8,)
271+
assert vars_iters['beta'].shape == (n_iter, 2)
271272

272273

273274
def test_optimize_good() -> None:

0 commit comments

Comments
 (0)