-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Check for observed variables in the trace #7641
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
a592208
e33e517
e895a5c
d813da1
2fcf395
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -345,10 +345,13 @@ def draw( | |
return [np.stack(v) for v in drawn_values] | ||
|
||
|
||
def observed_dependent_deterministics(model: Model): | ||
def observed_dependent_deterministics(model: Model, extra_observeds=None): | ||
"""Find deterministics that depend directly on observed variables.""" | ||
if extra_observeds is None: | ||
extra_observeds = [] | ||
|
||
deterministics = model.deterministics | ||
observed_rvs = set(model.observed_RVs) | ||
observed_rvs = set(model.observed_RVs + extra_observeds) | ||
blockers = model.basic_RVs | ||
return [ | ||
deterministic | ||
|
@@ -767,13 +770,15 @@ def sample_posterior_predictive( | |
if "coords" not in idata_kwargs: | ||
idata_kwargs["coords"] = {} | ||
idata: InferenceData | None = None | ||
observed_data = None | ||
stacked_dims = None | ||
if isinstance(trace, InferenceData): | ||
_constant_data = getattr(trace, "constant_data", None) | ||
if _constant_data is not None: | ||
trace_coords.update({str(k): v.data for k, v in _constant_data.coords.items()}) | ||
constant_data.update({str(k): v.data for k, v in _constant_data.items()}) | ||
idata = trace | ||
observed_data = trace.get("observed_data", None) | ||
trace = trace["posterior"] | ||
if isinstance(trace, xarray.Dataset): | ||
trace_coords.update({str(k): v.data for k, v in trace.coords.items()}) | ||
|
@@ -817,6 +822,9 @@ def sample_posterior_predictive( | |
vars_ = [model[x] for x in var_names] | ||
else: | ||
vars_ = model.observed_RVs + observed_dependent_deterministics(model) | ||
if observed_data is not None: | ||
vars_ += [model[x] for x in observed_data if x in model and x not in vars_] | ||
vars_ += observed_dependent_deterministics(model, vars_) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is going to duplicate deterministics of an observed variable, in case there's a mix of observed model variables and implied observed variables from the idata. the |
||
|
||
vars_to_sample = list(get_default_varnames(vars_, include_transformed=False)) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -540,6 +540,50 @@ def test_normal_scalar_idata(self): | |
ppc = pm.sample_posterior_predictive(idata, return_inferencedata=False) | ||
assert ppc["a"].shape == (nchains, ndraws) | ||
|
||
def test_external_trace(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove this test? The one with |
||
nchains = 2 | ||
ndraws = 500 | ||
with pm.Model() as model: | ||
mu = pm.Normal("mu", 0.0, 1.0) | ||
a = pm.Normal("a", mu=mu, sigma=1, observed=0.0) | ||
trace = pm.sample( | ||
draws=ndraws, | ||
chains=nchains, | ||
) | ||
|
||
# test that trace is used in ppc | ||
with pm.Model() as model_ppc: | ||
mu = pm.Normal("mu", 0.0, 1.0) | ||
a = pm.Normal("a", mu=mu, sigma=1) | ||
|
||
ppc = pm.sample_posterior_predictive( | ||
trace=trace, model=model_ppc, return_inferencedata=False | ||
) | ||
assert list(ppc.keys()) == ["a"] | ||
|
||
def test_external_trace_det(self): | ||
nchains = 2 | ||
ndraws = 500 | ||
with pm.Model() as model: | ||
mu = pm.Normal("mu", 0.0, 1.0) | ||
a = pm.Normal("a", mu=mu, sigma=1, observed=0.0) | ||
b = pm.Deterministic("b", a + 1) | ||
trace = pm.sample( | ||
draws=ndraws, | ||
chains=nchains, | ||
) | ||
zaxtax marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# test that trace is used in ppc | ||
with pm.Model() as model_ppc: | ||
mu = pm.Normal("mu", 0.0, 1.0) | ||
a = pm.Normal("a", mu=mu, sigma=1) | ||
c = pm.Deterministic("c", a + 1) | ||
|
||
ppc = pm.sample_posterior_predictive( | ||
trace=trace, model=model_ppc, return_inferencedata=False | ||
) | ||
assert list(ppc.keys()) == ["a", "c"] | ||
|
||
def test_normal_vector(self): | ||
with pm.Model() as model: | ||
mu = pm.Normal("mu", 0.0, 1.0) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, the the
observed_dependent_deterministics
above is not going to work if these variables are not observed in the model.That happens with auto-imputation models, which I assume the as_model wrapper won't handle correctly either because the models are different depending on whether you pass data or not.
Just something to keep in mind
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with this. You'll have to adapt
observed_dependent_deterministics
to also accept a list of extra variables that will depend on yourobserved_data