From a592208a3a8637afde1be0db2638fd1db80f5257 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Sun, 12 Jan 2025 17:35:38 +0100 Subject: [PATCH 1/5] Check for observed variables in the trace as well as the model --- pymc/sampling/forward.py | 4 ++++ tests/sampling/test_forward.py | 10 ++++++++++ 2 files changed, 14 insertions(+) diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index c07683555a..c7f52aaedd 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -767,6 +767,7 @@ def sample_posterior_predictive( if "coords" not in idata_kwargs: idata_kwargs["coords"] = {} idata: InferenceData | None = None + observed_data = None stacked_dims = None if isinstance(trace, InferenceData): _constant_data = getattr(trace, "constant_data", None) @@ -774,6 +775,7 @@ def sample_posterior_predictive( trace_coords.update({str(k): v.data for k, v in _constant_data.coords.items()}) constant_data.update({str(k): v.data for k, v in _constant_data.items()}) idata = trace + observed_data = trace["observed_data"] trace = trace["posterior"] if isinstance(trace, xarray.Dataset): trace_coords.update({str(k): v.data for k, v in trace.coords.items()}) @@ -817,6 +819,8 @@ def sample_posterior_predictive( vars_ = [model[x] for x in var_names] else: vars_ = model.observed_RVs + observed_dependent_deterministics(model) + if observed_data is not None: + vars_ += [model[x] for x in observed_data if x in model] vars_to_sample = list(get_default_varnames(vars_, include_transformed=False)) diff --git a/tests/sampling/test_forward.py b/tests/sampling/test_forward.py index 404f74a961..bd5197a04e 100644 --- a/tests/sampling/test_forward.py +++ b/tests/sampling/test_forward.py @@ -481,6 +481,16 @@ def test_normal_scalar(self): chains=nchains, ) + # test that trace is used in ppc + with pm.Model() as model_ppc: + mu = pm.Normal("mu", 0.0, 1.0) + a = pm.Normal("a", mu=mu, sigma=1) + + ppc = pm.sample_posterior_predictive( + trace=trace, model=model_ppc, return_inferencedata=False + ) + assert "a" in ppc + with model: # test list input ppc0 = pm.sample_posterior_predictive( From e33e517940922e609ec39e48ad1ec3ebbad2a0ce Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Sun, 12 Jan 2025 22:59:13 +0100 Subject: [PATCH 2/5] Bugfix --- pymc/sampling/forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index c7f52aaedd..a88bfdd0c9 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -775,7 +775,7 @@ def sample_posterior_predictive( trace_coords.update({str(k): v.data for k, v in _constant_data.coords.items()}) constant_data.update({str(k): v.data for k, v in _constant_data.items()}) idata = trace - observed_data = trace["observed_data"] + observed_data = trace.get("observed_data", None) trace = trace["posterior"] if isinstance(trace, xarray.Dataset): trace_coords.update({str(k): v.data for k, v in trace.coords.items()}) From e895a5c4861fbd5dac3cb206c7cd70e42b711980 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Sun, 19 Jan 2025 17:46:48 +0100 Subject: [PATCH 3/5] Update tests --- pymc/sampling/forward.py | 2 +- tests/sampling/test_forward.py | 55 +++++++++++++++++++++++++++------- 2 files changed, 46 insertions(+), 11 deletions(-) diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index a88bfdd0c9..1712352d5b 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -820,7 +820,7 @@ def sample_posterior_predictive( else: vars_ = model.observed_RVs + observed_dependent_deterministics(model) if observed_data is not None: - vars_ += [model[x] for x in observed_data if x in model] + vars_ += [model[x] for x in observed_data if x in model and x not in vars_] vars_to_sample = list(get_default_varnames(vars_, include_transformed=False)) diff --git a/tests/sampling/test_forward.py b/tests/sampling/test_forward.py index bd5197a04e..660b66327d 100644 --- a/tests/sampling/test_forward.py +++ b/tests/sampling/test_forward.py @@ -481,16 +481,6 @@ def test_normal_scalar(self): chains=nchains, ) - # test that trace is used in ppc - with pm.Model() as model_ppc: - mu = pm.Normal("mu", 0.0, 1.0) - a = pm.Normal("a", mu=mu, sigma=1) - - ppc = pm.sample_posterior_predictive( - trace=trace, model=model_ppc, return_inferencedata=False - ) - assert "a" in ppc - with model: # test list input ppc0 = pm.sample_posterior_predictive( @@ -550,6 +540,51 @@ def test_normal_scalar_idata(self): ppc = pm.sample_posterior_predictive(idata, return_inferencedata=False) assert ppc["a"].shape == (nchains, ndraws) + def test_external_trace(self): + nchains = 2 + ndraws = 500 + with pm.Model() as model: + mu = pm.Normal("mu", 0.0, 1.0) + a = pm.Normal("a", mu=mu, sigma=1, observed=0.0) + trace = pm.sample( + draws=ndraws, + chains=nchains, + ) + + # test that trace is used in ppc + with pm.Model() as model_ppc: + mu = pm.Normal("mu", 0.0, 1.0) + a = pm.Normal("a", mu=mu, sigma=1) + + ppc = pm.sample_posterior_predictive( + trace=trace, model=model_ppc, return_inferencedata=False + ) + assert list(ppc.keys()) == ["a"] + + @pytest.mark.xfail(reason="Auto-imputation of variables not supported in this setting") + def test_external_trace_det(self): + nchains = 2 + ndraws = 500 + with pm.Model() as model: + mu = pm.Normal("mu", 0.0, 1.0) + a = pm.Normal("a", mu=mu, sigma=1, observed=0.0) + b = pm.Deterministic("b", a + 1) + trace = pm.sample( + draws=ndraws, + chains=nchains, + ) + + # test that trace is used in ppc + with pm.Model() as model_ppc: + mu = pm.Normal("mu", 0.0, 1.0) + a = pm.Normal("a", mu=mu, sigma=1) + b = pm.Deterministic("b", a + 1) + + ppc = pm.sample_posterior_predictive( + trace=trace, model=model_ppc, return_inferencedata=False + ) + assert list(ppc.keys()) == ["a", "b"] + def test_normal_vector(self): with pm.Model() as model: mu = pm.Normal("mu", 0.0, 1.0) From d813da104d55fd9e1b89ebeb4208d7e0f1963849 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Sun, 19 Jan 2025 19:36:52 +0100 Subject: [PATCH 4/5] Add logic to handle conditional nodes for observed variables --- pymc/sampling/forward.py | 8 ++++++-- tests/sampling/test_forward.py | 5 ++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index 1712352d5b..861f19b886 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -345,10 +345,13 @@ def draw( return [np.stack(v) for v in drawn_values] -def observed_dependent_deterministics(model: Model): +def observed_dependent_deterministics(model: Model, extra_observeds=None): """Find deterministics that depend directly on observed variables.""" + if extra_observeds is None: + extra_observeds = [] + deterministics = model.deterministics - observed_rvs = set(model.observed_RVs) + observed_rvs = set(model.observed_RVs + extra_observeds) blockers = model.basic_RVs return [ deterministic @@ -821,6 +824,7 @@ def sample_posterior_predictive( vars_ = model.observed_RVs + observed_dependent_deterministics(model) if observed_data is not None: vars_ += [model[x] for x in observed_data if x in model and x not in vars_] + vars_ += observed_dependent_deterministics(model, vars_) vars_to_sample = list(get_default_varnames(vars_, include_transformed=False)) diff --git a/tests/sampling/test_forward.py b/tests/sampling/test_forward.py index 660b66327d..4b1acdc64b 100644 --- a/tests/sampling/test_forward.py +++ b/tests/sampling/test_forward.py @@ -561,7 +561,6 @@ def test_external_trace(self): ) assert list(ppc.keys()) == ["a"] - @pytest.mark.xfail(reason="Auto-imputation of variables not supported in this setting") def test_external_trace_det(self): nchains = 2 ndraws = 500 @@ -578,12 +577,12 @@ def test_external_trace_det(self): with pm.Model() as model_ppc: mu = pm.Normal("mu", 0.0, 1.0) a = pm.Normal("a", mu=mu, sigma=1) - b = pm.Deterministic("b", a + 1) + c = pm.Deterministic("c", a + 1) ppc = pm.sample_posterior_predictive( trace=trace, model=model_ppc, return_inferencedata=False ) - assert list(ppc.keys()) == ["a", "b"] + assert list(ppc.keys()) == ["a", "c"] def test_normal_vector(self): with pm.Model() as model: From 2fcf39581e11865d9acdbaf0e7bf7c25d859c0d5 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Mon, 20 Jan 2025 12:26:09 +0100 Subject: [PATCH 5/5] Remove redundant test --- pymc/sampling/forward.py | 8 +++++--- tests/sampling/test_forward.py | 28 +--------------------------- 2 files changed, 6 insertions(+), 30 deletions(-) diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index 861f19b886..d997e00466 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -821,10 +821,12 @@ def sample_posterior_predictive( if var_names is not None: vars_ = [model[x] for x in var_names] else: - vars_ = model.observed_RVs + observed_dependent_deterministics(model) + observed_vars = model.observed_RVs if observed_data is not None: - vars_ += [model[x] for x in observed_data if x in model and x not in vars_] - vars_ += observed_dependent_deterministics(model, vars_) + observed_vars += [ + model[x] for x in observed_data if x in model and x not in observed_vars + ] + vars_ = observed_vars + observed_dependent_deterministics(model, observed_vars) vars_to_sample = list(get_default_varnames(vars_, include_transformed=False)) diff --git a/tests/sampling/test_forward.py b/tests/sampling/test_forward.py index 4b1acdc64b..9348296297 100644 --- a/tests/sampling/test_forward.py +++ b/tests/sampling/test_forward.py @@ -540,38 +540,12 @@ def test_normal_scalar_idata(self): ppc = pm.sample_posterior_predictive(idata, return_inferencedata=False) assert ppc["a"].shape == (nchains, ndraws) - def test_external_trace(self): - nchains = 2 - ndraws = 500 - with pm.Model() as model: - mu = pm.Normal("mu", 0.0, 1.0) - a = pm.Normal("a", mu=mu, sigma=1, observed=0.0) - trace = pm.sample( - draws=ndraws, - chains=nchains, - ) - - # test that trace is used in ppc - with pm.Model() as model_ppc: - mu = pm.Normal("mu", 0.0, 1.0) - a = pm.Normal("a", mu=mu, sigma=1) - - ppc = pm.sample_posterior_predictive( - trace=trace, model=model_ppc, return_inferencedata=False - ) - assert list(ppc.keys()) == ["a"] - def test_external_trace_det(self): - nchains = 2 - ndraws = 500 with pm.Model() as model: mu = pm.Normal("mu", 0.0, 1.0) a = pm.Normal("a", mu=mu, sigma=1, observed=0.0) b = pm.Deterministic("b", a + 1) - trace = pm.sample( - draws=ndraws, - chains=nchains, - ) + trace = pm.sample(tune=50, draws=50, chains=1, compute_convergence_checks=False) # test that trace is used in ppc with pm.Model() as model_ppc: