Skip to content

Commit 45cb4eb

Browse files
ricardoV94brandonwillard
authored andcommitted
Add auto_deterministics list to Model
Ensures that when missing variables are present in the model, the automatic deterministic (x_observed + x_missing) only appears in predictive sampling and not normal sampling. Fixes `x` missing from prior_predictive when missing values were present (only `x_missing` was present)
1 parent c098195 commit 45cb4eb

File tree

5 files changed

+41
-8
lines changed

5 files changed

+41
-8
lines changed

Diff for: pymc3/model.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,7 @@ def __init__(self, name="", model=None, aesara_config=None, coords=None, check_b
621621
self.rvs_to_values = treedict(parent=self.parent.rvs_to_values)
622622
self.free_RVs = treelist(parent=self.parent.free_RVs)
623623
self.observed_RVs = treelist(parent=self.parent.observed_RVs)
624+
self.auto_deterministics = treelist(parent=self.parent.auto_deterministics)
624625
self.deterministics = treelist(parent=self.parent.deterministics)
625626
self.potentials = treelist(parent=self.parent.potentials)
626627
else:
@@ -629,6 +630,7 @@ def __init__(self, name="", model=None, aesara_config=None, coords=None, check_b
629630
self.rvs_to_values = treedict()
630631
self.free_RVs = treelist()
631632
self.observed_RVs = treelist()
633+
self.auto_deterministics = treelist()
632634
self.deterministics = treelist()
633635
self.potentials = treelist()
634636

@@ -1076,7 +1078,7 @@ def make_obs_var(
10761078
rv_var = at.zeros(data.shape)
10771079
rv_var = at.set_subtensor(rv_var[mask.nonzero()], missing_rv_var)
10781080
rv_var = at.set_subtensor(rv_var[antimask_idx], observed_rv_var)
1079-
rv_var = Deterministic(name, rv_var, self, dims)
1081+
rv_var = Deterministic(name, rv_var, self, dims, auto=True)
10801082

10811083
elif sps.issparse(data):
10821084
data = sparse.basic.as_sparse(data, name=name)
@@ -1594,21 +1596,28 @@ def __call__(self, *args, **kwargs):
15941596
compilef = fastfn
15951597

15961598

1597-
def Deterministic(name, var, model=None, dims=None):
1599+
def Deterministic(name, var, model=None, dims=None, auto=False):
15981600
"""Create a named deterministic variable
15991601
16001602
Parameters
16011603
----------
16021604
name: str
16031605
var: Aesara variables
1606+
auto: bool
1607+
Add automatically created deterministics (e.g., when imputing missing values)
1608+
to a separate model.auto_deterministics list for filtering during sampling.
1609+
16041610
16051611
Returns
16061612
-------
16071613
var: var, with name attribute
16081614
"""
16091615
model = modelcontext(model)
16101616
var = var.copy(model.name_for(name))
1611-
model.deterministics.append(var)
1617+
if auto:
1618+
model.auto_deterministics.append(var)
1619+
else:
1620+
model.deterministics.append(var)
16121621
model.add_random_variable(var, dims)
16131622

16141623
return var

Diff for: pymc3/sampling.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1679,7 +1679,7 @@ def sample_posterior_predictive(
16791679
if var_names is not None:
16801680
vars_ = [model[x] for x in var_names]
16811681
else:
1682-
vars_ = model.observed_RVs
1682+
vars_ = model.observed_RVs + model.auto_deterministics
16831683

16841684
if random_seed is not None:
16851685
# np.random.seed(random_seed)
@@ -1955,7 +1955,7 @@ def sample_prior_predictive(
19551955
)
19561956

19571957
if var_names is None:
1958-
prior_pred_vars = model.observed_RVs
1958+
prior_pred_vars = model.observed_RVs + model.auto_deterministics
19591959
prior_vars = (
19601960
get_default_varnames(model.unobserved_RVs, include_transformed=True) + model.potentials
19611961
)

Diff for: pymc3/tests/test_idata_conversion.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def test_missing_data_model(self):
309309
assert "y_missing" in model.named_vars
310310

311311
test_dict = {
312-
"posterior": ["x", "y", "y_missing"],
312+
"posterior": ["x", "y_missing"],
313313
"observed_data": ["y_observed"],
314314
"log_likelihood": ["y_observed"],
315315
}

Diff for: pymc3/tests/test_missing.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from pymc3.distributions.transforms import Interval
2424
from pymc3.exceptions import ImputationWarning
2525
from pymc3.model import Model
26-
from pymc3.sampling import sample, sample_prior_predictive
26+
from pymc3.sampling import sample, sample_posterior_predictive, sample_prior_predictive
2727

2828

2929
@pytest.mark.parametrize(
@@ -125,6 +125,16 @@ def test_interval_missing_observations():
125125

126126
assert np.all(0 < trace["theta1_missing"].mean(0))
127127
assert np.all(0 < trace["theta2_missing"].mean(0))
128+
assert "theta1" not in trace.varnames
129+
assert "theta2" not in trace.varnames
130+
131+
# Make sure that the observed values are newly generated samples and that
132+
# the observed and deterministic matche
133+
pp_trace = sample_posterior_predictive(trace)
134+
assert np.all(np.var(pp_trace["theta1"], 0) > 0.0)
135+
assert np.all(np.var(pp_trace["theta2"], 0) > 0.0)
136+
assert np.mean(pp_trace["theta1"][:, ~obs1.mask] - pp_trace["theta1_observed"]) == 0.0
137+
assert np.mean(pp_trace["theta2"][:, ~obs2.mask] - pp_trace["theta2_observed"]) == 0.0
128138

129139

130140
def test_double_counting():
@@ -139,3 +149,17 @@ def test_double_counting():
139149

140150
logp_val = m2.logp({"x_missing_log__": np.array([0])})
141151
assert logp_val == -4.0
152+
153+
154+
def test_missing_logp():
155+
with Model() as m:
156+
theta1 = Normal("theta1", 0, 5, observed=[0, 1, 2, 3, 4])
157+
theta2 = Normal("theta2", mu=theta1, observed=[0, 1, 2, 3, 4])
158+
m_logp = m.logp()
159+
160+
with Model() as m_missing:
161+
theta1 = Normal("theta1", 0, 5, observed=np.array([0, 1, np.nan, 3, np.nan]))
162+
theta2 = Normal("theta2", mu=theta1, observed=np.array([np.nan, np.nan, 2, np.nan, 4]))
163+
m_missing_logp = m_missing.logp({"theta1_missing": [2, 4], "theta2_missing": [0, 1, 3]})
164+
165+
assert m_logp == m_missing_logp

Diff for: pymc3/tests/test_starting.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def test_find_MAP_issue_4488():
105105
y = Deterministic("y", x + 1)
106106
map_estimate = find_MAP()
107107

108-
assert not set.difference({"x", "x_missing", "x_missing_log__", "y"}, set(map_estimate.keys()))
108+
assert not set.difference({"x_missing", "x_missing_log__", "y"}, set(map_estimate.keys()))
109109
assert np.isclose(map_estimate["x_missing"], 0.2)
110110
np.testing.assert_array_equal(map_estimate["y"], [2.0, map_estimate["x_missing"][0] + 1])
111111

0 commit comments

Comments
 (0)