Skip to content

Commit c098195

Browse files
ricardoV94brandonwillard
authored andcommitted
Add unobserved_value_vars property to Model
Fixes pm.Deterministics still using rvs during pm.Sample() Fixes find_MAP() failing when pm.Deterministics were present in model
1 parent 7adf05d commit c098195

File tree

7 files changed

+66
-22
lines changed

7 files changed

+66
-22
lines changed

Diff for: pymc3/backends/base.py

+1-11
Original file line numberDiff line numberDiff line change
@@ -61,17 +61,7 @@ def __init__(self, name, model=None, vars=None, test_point=None):
6161
model = modelcontext(model)
6262
self.model = model
6363
if vars is None:
64-
vars = []
65-
for v in model.unobserved_RVs:
66-
var = getattr(v.tag, "value_var", v)
67-
transform = getattr(var.tag, "transform", None)
68-
if transform:
69-
# We need to create and add an un-transformed version of
70-
# each transformed variable
71-
untrans_var = transform.backward(v, var)
72-
untrans_var.name = v.name
73-
vars.append(untrans_var)
74-
vars.append(var)
64+
vars = model.unobserved_value_vars
7565

7666
self.vars = vars
7767
self.varnames = [var.name for var in vars]

Diff for: pymc3/model.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,29 @@ def value_vars(self):
789789
"""
790790
return [self.rvs_to_values[v] for v in self.free_RVs]
791791

792+
@property
793+
def unobserved_value_vars(self):
794+
"""List of all random variables (including untransformed projections),
795+
as well as deterministics used as inputs and outputs of the the model's
796+
log-likelihood graph
797+
"""
798+
vars = []
799+
for rv in self.free_RVs:
800+
value_var = self.rvs_to_values[rv]
801+
transform = getattr(value_var.tag, "transform", None)
802+
if transform is not None:
803+
# We need to create and add an un-transformed version of
804+
# each transformed variable
805+
untrans_value_var = transform.backward(rv, value_var)
806+
untrans_value_var.name = rv.name
807+
vars.append(untrans_value_var)
808+
vars.append(value_var)
809+
810+
# Remove rvs from deterministics graph
811+
deterministics, _ = rvs_to_value_vars(self.deterministics, apply_transforms=True)
812+
813+
return vars + deterministics
814+
792815
@property
793816
def basic_RVs(self):
794817
"""List of random variables the model is defined in terms of
@@ -803,7 +826,7 @@ def basic_RVs(self):
803826

804827
@property
805828
def unobserved_RVs(self):
806-
"""List of all random variable, including deterministic ones.
829+
"""List of all random variables, including deterministic ones.
807830
808831
These are the actual random variable terms that make up the
809832
"sample-space" graph (i.e. you can sample these graphs by compiling them
@@ -1049,10 +1072,12 @@ def make_obs_var(
10491072
self.add_random_variable(observed_rv_var, dims)
10501073
self.observed_RVs.append(observed_rv_var)
10511074

1075+
# Create deterministic that combines observed and missing
10521076
rv_var = at.zeros(data.shape)
10531077
rv_var = at.set_subtensor(rv_var[mask.nonzero()], missing_rv_var)
10541078
rv_var = at.set_subtensor(rv_var[antimask_idx], observed_rv_var)
10551079
rv_var = Deterministic(name, rv_var, self, dims)
1080+
10561081
elif sps.issparse(data):
10571082
data = sparse.basic.as_sparse(data, name=name)
10581083
rv_var.tag.observations = data

Diff for: pymc3/tests/test_missing.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,10 @@ def test_missing_dual_observations():
8080
prior_trace = sample_prior_predictive()
8181
assert {"beta1", "beta2", "theta", "o1", "o2"} <= set(prior_trace.keys())
8282
# TODO: Assert something
83-
trace = sample(chains=1)
83+
trace = sample(chains=1, draws=50)
8484

8585

86-
def test_internal_missing_observations():
86+
def test_interval_missing_observations():
8787
with Model() as model:
8888
obs1 = ma.masked_values([1, 2, -1, 4, -1], value=-1)
8989
obs2 = ma.masked_values([-1, -1, 6, -1, 8], value=-1)
@@ -109,8 +109,8 @@ def test_internal_missing_observations():
109109
assert prior_trace["theta2"].shape[-1] == obs2.shape[0]
110110

111111
# Make sure that the observed values are newly generated samples
112-
assert np.var(prior_trace["theta1_observed"]) > 0.0
113-
assert np.var(prior_trace["theta2_observed"]) > 0.0
112+
assert np.all(np.var(prior_trace["theta1_observed"], 0) > 0.0)
113+
assert np.all(np.var(prior_trace["theta2_observed"], 0) > 0.0)
114114

115115
# Make sure the missing parts of the combined deterministic matches the
116116
# sampled missing and observed variable values
@@ -121,7 +121,7 @@ def test_internal_missing_observations():
121121

122122
assert {"theta1", "theta2"} <= set(prior_trace.keys())
123123

124-
trace = sample(chains=1)
124+
trace = sample(chains=1, draws=50, compute_convergence_checks=False)
125125

126126
assert np.all(0 < trace["theta1_missing"].mean(0))
127127
assert np.all(0 < trace["theta2_missing"].mean(0))

Diff for: pymc3/tests/test_sampling.py

+9
Original file line numberDiff line numberDiff line change
@@ -1098,3 +1098,12 @@ def test_sample_from_xarray_posterior(self, point_list_arg_bug_fixture):
10981098
idat = pm.to_inference_data(trace)
10991099
with pmodel:
11001100
pp = pm.sample_posterior_predictive(idat.posterior, var_names=["d"])
1101+
1102+
1103+
def test_sample_deterministic():
1104+
with pm.Model() as model:
1105+
x = pm.HalfNormal("x", 1)
1106+
y = pm.Deterministic("y", x + 100)
1107+
trace = pm.sample(chains=1, draws=50, compute_convergence_checks=False)
1108+
1109+
np.testing.assert_allclose(trace["y"], trace["x"] + 100)

Diff for: pymc3/tests/test_starting.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,17 @@
1616

1717
from pytest import raises
1818

19-
from pymc3 import Beta, Binomial, Model, Normal, Point, Uniform, find_MAP
19+
from pymc3 import (
20+
Beta,
21+
Binomial,
22+
Deterministic,
23+
Gamma,
24+
Model,
25+
Normal,
26+
Point,
27+
Uniform,
28+
find_MAP,
29+
)
2030
from pymc3.tests.checks import close_to
2131
from pymc3.tests.helpers import select_by_precision
2232
from pymc3.tests.models import non_normal, simple_arbitrary_det, simple_model
@@ -88,6 +98,18 @@ def test_find_MAP():
8898
close_to(map_est2["sigma"], 1, tol)
8999

90100

101+
def test_find_MAP_issue_4488():
102+
# Test for https://github.com/pymc-devs/pymc3/issues/4488
103+
with Model() as m:
104+
x = Gamma("x", alpha=3, beta=10, observed=np.array([1, np.nan]))
105+
y = Deterministic("y", x + 1)
106+
map_estimate = find_MAP()
107+
108+
assert not set.difference({"x", "x_missing", "x_missing_log__", "y"}, set(map_estimate.keys()))
109+
assert np.isclose(map_estimate["x_missing"], 0.2)
110+
np.testing.assert_array_equal(map_estimate["y"], [2.0, map_estimate["x_missing"][0] + 1])
111+
112+
91113
def test_allinmodel():
92114
model1 = Model()
93115
model2 = Model()

Diff for: pymc3/tuning/scaling.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def find_hessian(point, vars=None, model=None):
6060
"""
6161
model = modelcontext(model)
6262
H = model.fastd2logp(vars)
63-
return H(Point(point, model=model))
63+
return H(Point(point, filter_model_vars=True, model=model))
6464

6565

6666
def find_hessian_diag(point, vars=None, model=None):

Diff for: pymc3/tuning/starting.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,7 @@ def dlogp_func(x):
157157

158158
mx0 = RaveledVars(mx0, x0.point_map_info)
159159

160-
vars = get_default_varnames(
161-
[v.tag.value_var for v in model.unobserved_RVs], include_transformed
162-
)
160+
vars = get_default_varnames(model.unobserved_value_vars, include_transformed)
163161
mx = {
164162
var.name: value
165163
for var, value in zip(vars, model.fastfn(vars)(DictToArrayBijection.rmap(mx0)))

0 commit comments

Comments
 (0)