Skip to content

Commit 4e2c099

Browse files
authored
UserWarning if doing predictive sampling with models containing Potentials (#4419)
* Raise warning when sampling with Potentials * Add warning to fast_sample_ppc and unittests * Add release note * Avoid sampling in unittests
1 parent 2a3d9a3 commit 4e2c099

File tree

4 files changed

+68
-0
lines changed

4 files changed

+68
-0
lines changed

Diff for: RELEASE-NOTES.md

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ It also brings some dreadfully awaited fixes, so be sure to go through the chang
3434
- Fixed `MatrixNormal` random method to work with parameters as random variables. (see [#4368](https://github.com/pymc-devs/pymc3/pull/4368))
3535
- Update the `logcdf` method of several continuous distributions to return -inf for invalid parameters and values, and raise an informative error when multiple values cannot be evaluated in a single call. (see [4393](https://github.com/pymc-devs/pymc3/pull/4393))
3636
- Improve numerical stability in `logp` and `logcdf` methods of `ExGaussian` (see [#4407](https://github.com/pymc-devs/pymc3/pull/4407))
37+
- Issue UserWarning when doing prior or posterior predictive sampling with models containing Potential factors (see [#4419](https://github.com/pymc-devs/pymc3/pull/4419))
3738

3839
## PyMC3 3.10.0 (7 December 2020)
3940

Diff for: pymc3/distributions/posterior_predictive.py

+8
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,14 @@ def fast_sample_posterior_predictive(
222222

223223
model = modelcontext(model)
224224
assert model is not None
225+
226+
if model.potentials:
227+
warnings.warn(
228+
"The effect of Potentials on other parameters is ignored during posterior predictive sampling. "
229+
"This is likely to lead to invalid or biased predictive samples.",
230+
UserWarning,
231+
)
232+
225233
with model:
226234

227235
if keep_size and samples is not None:

Diff for: pymc3/sampling.py

+23
Original file line numberDiff line numberDiff line change
@@ -1692,6 +1692,13 @@ def sample_posterior_predictive(
16921692

16931693
model = modelcontext(model)
16941694

1695+
if model.potentials:
1696+
warnings.warn(
1697+
"The effect of Potentials on other parameters is ignored during posterior predictive sampling. "
1698+
"This is likely to lead to invalid or biased predictive samples.",
1699+
UserWarning,
1700+
)
1701+
16951702
if var_names is not None:
16961703
vars_ = [model[x] for x in var_names]
16971704
else:
@@ -1791,6 +1798,15 @@ def sample_posterior_predictive_w(
17911798
if models is None:
17921799
models = [modelcontext(models)] * len(traces)
17931800

1801+
for model in models:
1802+
if model.potentials:
1803+
warnings.warn(
1804+
"The effect of Potentials on other parameters is ignored during posterior predictive sampling. "
1805+
"This is likely to lead to invalid or biased predictive samples.",
1806+
UserWarning,
1807+
)
1808+
break
1809+
17941810
if weights is None:
17951811
weights = [1] * len(traces)
17961812

@@ -1903,6 +1919,13 @@ def sample_prior_predictive(
19031919
"""
19041920
model = modelcontext(model)
19051921

1922+
if model.potentials:
1923+
warnings.warn(
1924+
"The effect of Potentials on other parameters is ignored during prior predictive sampling. "
1925+
"This is likely to lead to invalid or biased predictive samples.",
1926+
UserWarning,
1927+
)
1928+
19061929
if var_names is None:
19071930
prior_pred_vars = model.observed_RVs
19081931
prior_vars = (

Diff for: pymc3/tests/test_sampling.py

+36
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,21 @@ def test_variable_type(self):
722722
assert ppc["a"].dtype.kind == "f"
723723
assert ppc["b"].dtype.kind == "i"
724724

725+
def test_potentials_warning(self):
726+
warning_msg = "The effect of Potentials on other parameters is ignored during"
727+
with pm.Model() as m:
728+
a = pm.Normal("a", 0, 1)
729+
p = pm.Potential("p", a + 1)
730+
obs = pm.Normal("obs", a, 1, observed=5)
731+
732+
trace = az.from_dict({"a": np.random.rand(10)})
733+
with m:
734+
with pytest.warns(UserWarning, match=warning_msg):
735+
pm.sample_posterior_predictive(trace, samples=5)
736+
737+
with pytest.warns(UserWarning, match=warning_msg):
738+
pm.fast_sample_posterior_predictive(trace, samples=5)
739+
725740

726741
class TestSamplePPCW(SeededTest):
727742
def test_sample_posterior_predictive_w(self):
@@ -773,6 +788,17 @@ def test_sample_posterior_predictive_w(self):
773788
):
774789
pm.sample_posterior_predictive_w([trace_0, trace_2], 100, [model_0, model_2])
775790

791+
def test_potentials_warning(self):
792+
warning_msg = "The effect of Potentials on other parameters is ignored during"
793+
with pm.Model() as m:
794+
a = pm.Normal("a", 0, 1)
795+
p = pm.Potential("p", a + 1)
796+
obs = pm.Normal("obs", a, 1, observed=5)
797+
798+
trace = az.from_dict({"a": np.random.rand(10)})
799+
with pytest.warns(UserWarning, match=warning_msg):
800+
pm.sample_posterior_predictive_w(samples=5, traces=[trace, trace], models=[m, m])
801+
776802

777803
@pytest.mark.parametrize(
778804
"method",
@@ -1012,6 +1038,16 @@ def test_bounded_dist(self):
10121038
prior_trace = pm.sample_prior_predictive(5)
10131039
assert prior_trace["x"].shape == (5, 3, 1)
10141040

1041+
def test_potentials_warning(self):
1042+
warning_msg = "The effect of Potentials on other parameters is ignored during"
1043+
with pm.Model() as m:
1044+
a = pm.Normal("a", 0, 1)
1045+
p = pm.Potential("p", a + 1)
1046+
1047+
with m:
1048+
with pytest.warns(UserWarning, match=warning_msg):
1049+
pm.sample_prior_predictive(samples=5)
1050+
10151051

10161052
class TestSamplePosteriorPredictive:
10171053
def test_point_list_arg_bug_fspp(self, point_list_arg_bug_fixture):

0 commit comments

Comments
 (0)