Skip to content

Commit af8a747

Browse files
committed
Allow warmup iterations when not adapting
And, prevent adapting with no warmup
1 parent 46f8608 commit af8a747

File tree

3 files changed

+23
-4
lines changed

3 files changed

+23
-4
lines changed

cmdstanpy/cmdstan_args.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,9 @@ def validate(self, chains: Optional[int]) -> None:
139139
'Value for iter_warmup must be a non-negative integer,'
140140
' found {}.'.format(self.iter_warmup)
141141
)
142-
if self.iter_warmup > 0 and not self.adapt_engaged:
142+
if self.iter_warmup == 0 and self.adapt_engaged:
143143
raise ValueError(
144-
'Argument "adapt_engaged" is False, '
145-
'cannot specify warmup iterations.'
144+
'Must specify iter_warmup > 0 when adapt_engaged=True.'
146145
)
147146
if self.iter_sampling is not None:
148147
if self.iter_sampling < 0 or not isinstance(

test/test_cmdstan_args.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def test_bad() -> None:
111111
with pytest.raises(ValueError):
112112
args.validate(chains=2)
113113

114-
args = SamplerArgs(iter_warmup=10, adapt_engaged=False)
114+
args = SamplerArgs(iter_warmup=0, adapt_engaged=True)
115115
with pytest.raises(ValueError):
116116
args.validate(chains=2)
117117

test/test_sample.py

+20
Original file line numberDiff line numberDiff line change
@@ -1435,6 +1435,26 @@ def test_dont_save_warmup(caplog: pytest.LogCaptureFixture) -> None:
14351435
)
14361436

14371437

1438+
def test_warmup_no_adapt() -> None:
1439+
# we may want to have a "burn-in" period, even without adaptation
1440+
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
1441+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
1442+
1443+
bern_model = CmdStanModel(stan_file=stan)
1444+
bern_fit = bern_model.sample(
1445+
data=jdata,
1446+
chains=2,
1447+
seed=12345,
1448+
iter_warmup=200,
1449+
iter_sampling=100,
1450+
adapt_engaged=False,
1451+
)
1452+
1453+
assert bern_fit.column_names == tuple(BERNOULLI_COLS)
1454+
assert bern_fit.num_draws_sampling == 100
1455+
assert bern_fit.draws().shape == (100, 2, len(BERNOULLI_COLS))
1456+
1457+
14381458
def test_sampler_diags() -> None:
14391459
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
14401460
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')

0 commit comments

Comments
 (0)