Skip to content

Commit a90457a

Browse files
authored
Bring back SMC and allow prior_predictive_sampling to return transformed values (#4769)
* Enable prior_predictive to return transformed values * Add test which closes #4490 * Fix SMC regression and re-enable `test_smc.py` * Minor changes to the `pytest.yml` comments * Add workaround for floatX == 'float32' and discrete variables
1 parent 7e35cdd commit a90457a

File tree

6 files changed

+114
-8
lines changed

6 files changed

+114
-8
lines changed

Diff for: .github/workflows/pytest.yml

+3-2
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@ jobs:
1313
floatx: [float32, float64]
1414
test-subset:
1515
# Tests are split into multiple jobs to accelerate the CI.
16+
# Different jobs should be organized to take approximately the same
17+
# time to complete (and not be prohibitely slow)
1618
#
1719
# How this works:
1820
# 1st block: Only passes --ignore parameters to pytest.
1921
# → pytest will run all test_*.py files that are NOT ignored.
20-
# Other blocks: Only pass paths to test files.
22+
# Subsequent blocks: Only pass paths to test files.
2123
# → pytest will run only these files
2224
#
2325
# Any test that was not ignored runs in the first job.
@@ -30,7 +32,6 @@ jobs:
3032
--ignore=pymc3/tests/test_modelcontext.py
3133
--ignore=pymc3/tests/test_parallel_sampling.py
3234
--ignore=pymc3/tests/test_profile.py
33-
--ignore=pymc3/tests/test_smc.py
3435
--ignore=pymc3/tests/test_step.py
3536
--ignore=pymc3/tests/test_tuning.py
3637
--ignore=pymc3/tests/test_types.py

Diff for: RELEASE-NOTES.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
- The GLM submodule has been removed, please use [Bambi](https://bambinos.github.io/bambi/) instead.
88
- The `Distribution` keyword argument `testval` has been deprecated in favor of `initval`.
99
- `pm.sample` now returns results as `InferenceData` instead of `MultiTrace` by default (see [#4744](https://github.com/pymc-devs/pymc3/pull/4744)).
10-
- ...
10+
- `pm.sample_prior_predictive` no longer returns transformed variable values by default. Pass them by name in `var_names` if you want to obtain these draws (see [4769](https://github.com/pymc-devs/pymc3/pull/4769)).
11+
...
1112

1213
### New Features
1314
- The `CAR` distribution has been added to allow for use of conditional autoregressions which often are used in spatial and network models.

Diff for: pymc3/sampling.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -1943,7 +1943,8 @@ def sample_prior_predictive(
19431943
model : Model (optional if in ``with`` context)
19441944
var_names : Iterable[str]
19451945
A list of names of variables for which to compute the posterior predictive
1946-
samples. Defaults to both observed and unobserved RVs.
1946+
samples. Defaults to both observed and unobserved RVs. Transformed values
1947+
are not included unless explicitly defined in var_names.
19471948
random_seed : int
19481949
Seed for the random number generator.
19491950
mode:
@@ -1983,8 +1984,26 @@ def sample_prior_predictive(
19831984
)
19841985

19851986
names = get_default_varnames(vars_, include_transformed=False)
1986-
19871987
vars_to_sample = [model[name] for name in names]
1988+
1989+
# Any variables from var_names that are missing must be transformed variables.
1990+
# Misspelled variables would have raised a KeyError above.
1991+
missing_names = vars_.difference(names)
1992+
for name in missing_names:
1993+
transformed_value_var = model[name]
1994+
rv_var = model.values_to_rvs[transformed_value_var]
1995+
transform = transformed_value_var.tag.transform
1996+
transformed_rv_var = transform.forward(rv_var, rv_var)
1997+
1998+
names.append(name)
1999+
vars_to_sample.append(transformed_rv_var)
2000+
2001+
# If the user asked for the transformed variable in var_names, but not the
2002+
# original RV, we add it manually here
2003+
if rv_var.name not in names:
2004+
names.append(rv_var.name)
2005+
vars_to_sample.append(rv_var)
2006+
19882007
inputs = [i for i in inputvars(vars_to_sample) if not isinstance(i, SharedVariable)]
19892008

19902009
sampler_fn = compile_rv_inplace(

Diff for: pymc3/smc/smc.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import warnings
16+
1517
from collections import OrderedDict
1618

1719
import aesara.tensor as at
1820
import numpy as np
1921

22+
from aesara import config
2023
from aesara import function as aesara_function
2124
from scipy.special import logsumexp
2225
from scipy.stats import multivariate_normal
@@ -87,7 +90,7 @@ def initialize_population(self):
8790
if self.start is None:
8891
init_rnd = sample_prior_predictive(
8992
self.draws,
90-
var_names=[v.name for v in self.model.unobserved_RVs],
93+
var_names=[v.name for v in self.model.unobserved_value_vars],
9194
model=self.model,
9295
)
9396
else:
@@ -290,9 +293,21 @@ def logp_forw(point, out_vars, vars, shared):
290293
shared: List
291294
containing :class:`aesara.tensor.Tensor` for depended shared data
292295
"""
296+
293297
out_list, inarray0 = join_nonshared_inputs(point, out_vars, vars, shared)
294-
f = aesara_function([inarray0], out_list[0])
295-
f.trust_input = True
298+
# TODO: Figure out how to safely accept float32 (floatX) input when there are
299+
# discrete variables of int64 dtype in `vars`.
300+
# See https://github.com/pymc-devs/pymc3/pull/4769#issuecomment-861494080
301+
if config.floatX == "float32" and any(var.dtype == "int64" for var in vars):
302+
warnings.warn(
303+
"SMC sampling may run slower due to the presence of discrete variables "
304+
"together with aesara.config.floatX == `float32`",
305+
UserWarning,
306+
)
307+
f = aesara_function([inarray0], out_list[0], allow_input_downcast=True)
308+
else:
309+
f = aesara_function([inarray0], out_list[0])
310+
f.trust_input = False
296311
return f
297312

298313

Diff for: pymc3/tests/test_sampling.py

+60
Original file line numberDiff line numberDiff line change
@@ -1076,6 +1076,66 @@ def test_potentials_warning(self):
10761076
with pytest.warns(UserWarning, match=warning_msg):
10771077
pm.sample_prior_predictive(samples=5)
10781078

1079+
def test_transformed_vars(self):
1080+
# Test that prior predictive returns transformation of RVs when these are
1081+
# passed explicitly in `var_names`
1082+
1083+
def ub_interval_forward(x, ub):
1084+
# Interval transform assuming lower bound is zero
1085+
return np.log(x - 0) - np.log(ub - x)
1086+
1087+
with pm.Model(rng_seeder=123) as model:
1088+
ub = pm.HalfNormal("ub", 10)
1089+
x = pm.Uniform("x", 0, ub)
1090+
1091+
prior = pm.sample_prior_predictive(
1092+
var_names=["ub", "ub_log__", "x", "x_interval__"],
1093+
samples=10,
1094+
)
1095+
1096+
# Check values are correct
1097+
assert np.allclose(prior["ub_log__"], np.log(prior["ub"]))
1098+
assert np.allclose(
1099+
prior["x_interval__"],
1100+
ub_interval_forward(prior["x"], prior["ub"]),
1101+
)
1102+
1103+
# Check that it works when the original RVs are not mentioned in var_names
1104+
with pm.Model(rng_seeder=123) as model_transformed_only:
1105+
ub = pm.HalfNormal("ub", 10)
1106+
x = pm.Uniform("x", 0, ub)
1107+
1108+
prior_transformed_only = pm.sample_prior_predictive(
1109+
var_names=["ub_log__", "x_interval__"],
1110+
samples=10,
1111+
)
1112+
assert "ub" not in prior_transformed_only and "x" not in prior_transformed_only
1113+
assert np.allclose(prior["ub_log__"], prior_transformed_only["ub_log__"])
1114+
assert np.allclose(prior["x_interval__"], prior_transformed_only["x_interval__"])
1115+
1116+
def test_issue_4490(self):
1117+
# Test that samples do not depend on var_name order or, more fundamentally,
1118+
# that they do not depend on the set order used inside `sample_prior_predictive`
1119+
seed = 4490
1120+
with pm.Model(rng_seeder=seed) as m1:
1121+
a = pm.Normal("a")
1122+
b = pm.Normal("b")
1123+
c = pm.Normal("c")
1124+
d = pm.Normal("d")
1125+
prior1 = pm.sample_prior_predictive(samples=1, var_names=["a", "b", "c", "d"])
1126+
1127+
with pm.Model(rng_seeder=seed) as m2:
1128+
a = pm.Normal("a")
1129+
b = pm.Normal("b")
1130+
c = pm.Normal("c")
1131+
d = pm.Normal("d")
1132+
prior2 = pm.sample_prior_predictive(samples=1, var_names=["b", "a", "d", "c"])
1133+
1134+
assert prior1["a"] == prior2["a"]
1135+
assert prior1["b"] == prior2["b"]
1136+
assert prior1["c"] == prior2["c"]
1137+
assert prior1["d"] == prior2["d"]
1138+
10791139

10801140
class TestSamplePosteriorPredictive:
10811141
def test_point_list_arg_bug_spp(self, point_list_arg_bug_fixture):

Diff for: pymc3/tests/test_smc.py

+10
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import aesara
1516
import aesara.tensor as at
1617
import numpy as np
1718
import pytest
@@ -97,7 +98,16 @@ def test_start(self):
9798
}
9899
trace = pm.sample_smc(500, start=start)
99100

101+
def test_slowdown_warning(self):
102+
with aesara.config.change_flags(floatX="float32"):
103+
with pytest.warns(UserWarning, match="SMC sampling may run slower due to"):
104+
with pm.Model() as model:
105+
a = pm.Poisson("a", 5)
106+
y = pm.Normal("y", a, 5, observed=[1, 2, 3, 4])
107+
trace = pm.sample_smc()
100108

109+
110+
@pytest.mark.xfail(reason="SMC-ABC not refactored yet")
101111
class TestSMCABC(SeededTest):
102112
def setup_class(self):
103113
super().setup_class()

0 commit comments

Comments
 (0)