Skip to content

Commit 68d5201

Browse files
authored
Update ppc test (#4246)
1 parent b990e49 commit 68d5201

File tree

1 file changed

+21
-22
lines changed

1 file changed

+21
-22
lines changed

Diff for: pymc3/tests/test_sampling.py

+21-22
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,7 @@ def test_model_shared_variable(self):
587587

588588
expected_p = np.array([logistic.eval({coeff: val}) for val in trace["x"][:samples]])
589589
assert post_pred["obs"].shape == (samples, 3)
590-
assert np.allclose(post_pred["p"], expected_p)
590+
npt.assert_allclose(post_pred["p"], expected_p)
591591

592592
# fast version
593593
samples = 100
@@ -598,11 +598,12 @@ def test_model_shared_variable(self):
598598

599599
expected_p = np.array([logistic.eval({coeff: val}) for val in trace["x"][:samples]])
600600
assert post_pred["obs"].shape == (samples, 3)
601-
assert np.allclose(post_pred["p"], expected_p)
601+
npt.assert_allclose(post_pred["p"], expected_p)
602602

603603
def test_deterministic_of_observed(self):
604-
meas_in_1 = pm.theanof.floatX(2 + 4 * np.random.randn(100))
605-
meas_in_2 = pm.theanof.floatX(5 + 4 * np.random.randn(100))
604+
meas_in_1 = pm.theanof.floatX(2 + 4 * np.random.randn(10))
605+
meas_in_2 = pm.theanof.floatX(5 + 4 * np.random.randn(10))
606+
nchains = 2
606607
with pm.Model() as model:
607608
mu_in_1 = pm.Normal("mu_in_1", 0, 1)
608609
sigma_in_1 = pm.HalfNormal("sd_in_1", 1)
@@ -614,40 +615,38 @@ def test_deterministic_of_observed(self):
614615
out_diff = in_1 + in_2
615616
pm.Deterministic("out", out_diff)
616617

617-
trace = pm.sample(100)
618-
ppc_trace = pm.trace_to_dataframe(
619-
trace, varnames=[n for n in trace.varnames if n != "out"]
620-
).to_dict("records")
618+
trace = pm.sample(100, chains=nchains)
619+
np.random.seed(0)
621620
with pytest.warns(DeprecationWarning):
622621
ppc = pm.sample_posterior_predictive(
623622
model=model,
624-
trace=ppc_trace,
625-
samples=len(ppc_trace),
623+
trace=trace,
624+
samples=len(trace) * nchains,
626625
vars=(model.deterministics + model.basic_RVs),
627626
)
628627

629-
rtol = 1e-5 if theano.config.floatX == "float64" else 1e-3
630-
assert np.allclose(ppc["in_1"] + ppc["in_2"], ppc["out"], rtol=rtol)
628+
rtol = 1e-5 if theano.config.floatX == "float64" else 1e-4
629+
npt.assert_allclose(ppc["in_1"] + ppc["in_2"], ppc["out"], rtol=rtol)
631630

631+
np.random.seed(0)
632632
ppc = pm.sample_posterior_predictive(
633633
model=model,
634-
trace=ppc_trace,
635-
samples=len(ppc_trace),
634+
trace=trace,
635+
samples=len(trace) * nchains,
636636
var_names=[var.name for var in (model.deterministics + model.basic_RVs)],
637637
)
638638

639-
rtol = 1e-5 if theano.config.floatX == "float64" else 1e-3
640-
assert np.allclose(ppc["in_1"] + ppc["in_2"], ppc["out"], rtol=rtol)
639+
npt.assert_allclose(ppc["in_1"] + ppc["in_2"], ppc["out"], rtol=rtol)
641640

641+
np.random.seed(0)
642642
ppc = pm.fast_sample_posterior_predictive(
643643
model=model,
644-
trace=ppc_trace,
645-
samples=len(ppc_trace),
644+
trace=trace,
645+
samples=len(trace) * nchains,
646646
var_names=[var.name for var in (model.deterministics + model.basic_RVs)],
647647
)
648648

649-
rtol = 1e-5 if theano.config.floatX == "float64" else 1e-3
650-
assert np.allclose(ppc["in_1"] + ppc["in_2"], ppc["out"], rtol=rtol)
649+
npt.assert_allclose(ppc["in_1"] + ppc["in_2"], ppc["out"], rtol=rtol)
651650

652651
def test_deterministic_of_observed_modified_interface(self):
653652
meas_in_1 = pm.theanof.floatX(2 + 4 * np.random.randn(100))
@@ -675,7 +674,7 @@ def test_deterministic_of_observed_modified_interface(self):
675674
)
676675

677676
rtol = 1e-5 if theano.config.floatX == "float64" else 1e-3
678-
assert np.allclose(ppc["in_1"] + ppc["in_2"], ppc["out"], rtol=rtol)
677+
npt.assert_allclose(ppc["in_1"] + ppc["in_2"], ppc["out"], rtol=rtol)
679678

680679
ppc = pm.fast_sample_posterior_predictive(
681680
model=model,
@@ -685,7 +684,7 @@ def test_deterministic_of_observed_modified_interface(self):
685684
)
686685

687686
rtol = 1e-5 if theano.config.floatX == "float64" else 1e-3
688-
assert np.allclose(ppc["in_1"] + ppc["in_2"], ppc["out"], rtol=rtol)
687+
npt.assert_allclose(ppc["in_1"] + ppc["in_2"], ppc["out"], rtol=rtol)
689688

690689
def test_variable_type(self):
691690
with pm.Model() as model:

0 commit comments

Comments
 (0)