Skip to content

Commit e895a5c

Browse files
committed
Update tests
1 parent e33e517 commit e895a5c

File tree

2 files changed

+46
-11
lines changed

2 files changed

+46
-11
lines changed

pymc/sampling/forward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -820,7 +820,7 @@ def sample_posterior_predictive(
820820
else:
821821
vars_ = model.observed_RVs + observed_dependent_deterministics(model)
822822
if observed_data is not None:
823-
vars_ += [model[x] for x in observed_data if x in model]
823+
vars_ += [model[x] for x in observed_data if x in model and x not in vars_]
824824

825825
vars_to_sample = list(get_default_varnames(vars_, include_transformed=False))
826826

tests/sampling/test_forward.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -481,16 +481,6 @@ def test_normal_scalar(self):
481481
chains=nchains,
482482
)
483483

484-
# test that trace is used in ppc
485-
with pm.Model() as model_ppc:
486-
mu = pm.Normal("mu", 0.0, 1.0)
487-
a = pm.Normal("a", mu=mu, sigma=1)
488-
489-
ppc = pm.sample_posterior_predictive(
490-
trace=trace, model=model_ppc, return_inferencedata=False
491-
)
492-
assert "a" in ppc
493-
494484
with model:
495485
# test list input
496486
ppc0 = pm.sample_posterior_predictive(
@@ -550,6 +540,51 @@ def test_normal_scalar_idata(self):
550540
ppc = pm.sample_posterior_predictive(idata, return_inferencedata=False)
551541
assert ppc["a"].shape == (nchains, ndraws)
552542

543+
def test_external_trace(self):
544+
nchains = 2
545+
ndraws = 500
546+
with pm.Model() as model:
547+
mu = pm.Normal("mu", 0.0, 1.0)
548+
a = pm.Normal("a", mu=mu, sigma=1, observed=0.0)
549+
trace = pm.sample(
550+
draws=ndraws,
551+
chains=nchains,
552+
)
553+
554+
# test that trace is used in ppc
555+
with pm.Model() as model_ppc:
556+
mu = pm.Normal("mu", 0.0, 1.0)
557+
a = pm.Normal("a", mu=mu, sigma=1)
558+
559+
ppc = pm.sample_posterior_predictive(
560+
trace=trace, model=model_ppc, return_inferencedata=False
561+
)
562+
assert list(ppc.keys()) == ["a"]
563+
564+
@pytest.mark.xfail(reason="Auto-imputation of variables not supported in this setting")
565+
def test_external_trace_det(self):
566+
nchains = 2
567+
ndraws = 500
568+
with pm.Model() as model:
569+
mu = pm.Normal("mu", 0.0, 1.0)
570+
a = pm.Normal("a", mu=mu, sigma=1, observed=0.0)
571+
b = pm.Deterministic("b", a + 1)
572+
trace = pm.sample(
573+
draws=ndraws,
574+
chains=nchains,
575+
)
576+
577+
# test that trace is used in ppc
578+
with pm.Model() as model_ppc:
579+
mu = pm.Normal("mu", 0.0, 1.0)
580+
a = pm.Normal("a", mu=mu, sigma=1)
581+
b = pm.Deterministic("b", a + 1)
582+
583+
ppc = pm.sample_posterior_predictive(
584+
trace=trace, model=model_ppc, return_inferencedata=False
585+
)
586+
assert list(ppc.keys()) == ["a", "b"]
587+
553588
def test_normal_vector(self):
554589
with pm.Model() as model:
555590
mu = pm.Normal("mu", 0.0, 1.0)

0 commit comments

Comments
 (0)