Skip to content

Commit 48b3862

Browse files
committed
Revise sample_prior/posterior_predictive() parameters.
Add the option to take varnames (`var_names`), rather than var objects as parameters. Extend the set of parameters to accept varnames as an alternative to vars, preserving backwards compatibility. Also revise the docstring, to clarify the return type and add type comments.
1 parent f9916ea commit 48b3862

File tree

4 files changed

+96
-14
lines changed

4 files changed

+96
-14
lines changed

Diff for: RELEASE-NOTES.md

+7
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,15 @@
5151

5252
- `nuts_kwargs` and `step_kwargs` have been deprecated in favor of using the standard `kwargs` to pass optional step method arguments.
5353
- `SGFS` and `CSG` have been removed (Fix for [#3353](https://github.com/pymc-devs/pymc3/issues/3353)). They have been moved to [pymc3-experimental](https://github.com/pymc-devs/pymc3-experimental).
54+
<<<<<<< master
5455
- References to `live_plot` and corresponding notebooks have been removed.
5556
- Function `approx_hessian` was removed, due to `numdifftools` becoming incompatible with current `scipy`. The function was already optional, only available to a user who installed `numdifftools` separately, and not hit on any common codepaths. [#3485](https://github.com/pymc-devs/pymc3/pull/3485).
57+
- Deprecated `vars` parameter of `sample_posterior_predictive` in favor of `varnames`.
58+
=======
59+
- References to `live_plot` and corresponding notebooks have been removed.
60+
- Deprecated `vars` parameters of `sample_posterior_predictive` and `sample_prior_predictive` in favor of `var_names`. At least for the latter, this is more accurate, since the `vars` parameter actually took names.
61+
>>>>>>> Update release notes.
62+
5663

5764
## PyMC3 3.6 (Dec 21 2018)
5865

Diff for: pymc3/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import itertools
44
import threading
55
import warnings
6-
from typing import Optional, Dict, Any
6+
from typing import Optional
77

88
import numpy as np
99
from pandas import Series

Diff for: pymc3/sampling.py

+52-10
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
from typing import Dict, List, Optional, TYPE_CHECKING, cast
2+
if TYPE_CHECKING:
3+
from typing import Any
4+
from typing import Iterable as TIterable
15
from collections import defaultdict, Iterable
26
from copy import copy
37
import pickle
@@ -6,11 +10,12 @@
610

711
import numpy as np
812
import theano.gradient as tg
13+
from theano.tensor import Tensor
914

1015
from .backends.base import BaseTrace, MultiTrace
1116
from .backends.ndarray import NDArray
1217
from .distributions.distribution import draw_values
13-
from .model import modelcontext, Point, all_continuous
18+
from .model import modelcontext, Point, all_continuous, Model
1419
from .step_methods import (NUTS, HamiltonianMC, Metropolis, BinaryMetropolis,
1520
BinaryGibbsMetropolis, CategoricalGibbsMetropolis,
1621
Slice, CompoundStep, arraystep, smc)
@@ -1026,8 +1031,14 @@ def stop_tuning(step):
10261031
return step
10271032

10281033

1029-
def sample_posterior_predictive(trace, samples=None, model=None, vars=None, size=None,
1030-
random_seed=None, progressbar=True):
1034+
def sample_posterior_predictive(trace,
1035+
samples: Optional[int]=None,
1036+
model: Optional[Model]=None,
1037+
vars: Optional[TIterable[Tensor]]=None,
1038+
var_names: Optional[List[str]]=None,
1039+
size: Optional[int]=None,
1040+
random_seed=None,
1041+
progressbar: bool=True) -> Dict[str, np.ndarray]:
10311042
"""Generate posterior predictive samples from a model given a trace.
10321043
10331044
Parameters
@@ -1041,7 +1052,10 @@ def sample_posterior_predictive(trace, samples=None, model=None, vars=None, size
10411052
Model used to generate `trace`
10421053
vars : iterable
10431054
Variables for which to compute the posterior predictive samples.
1044-
Defaults to `model.observed_RVs`.
1055+
Defaults to `model.observed_RVs`. Deprecated: please use `var_names` instead.
1056+
var_names : Iterable[str]
1057+
Alternative way to specify vars to sample, to make this function orthogonal with
1058+
others.
10451059
size : int
10461060
The number of random draws from the distribution specified by the parameters in each
10471061
sample of the trace.
@@ -1055,7 +1069,7 @@ def sample_posterior_predictive(trace, samples=None, model=None, vars=None, size
10551069
Returns
10561070
-------
10571071
samples : dict
1058-
Dictionary with the variables as keys. The values corresponding to the
1072+
Dictionary with the variable names as keys, and values numpy arrays containing
10591073
posterior predictive samples.
10601074
"""
10611075
len_trace = len(trace)
@@ -1069,6 +1083,14 @@ def sample_posterior_predictive(trace, samples=None, model=None, vars=None, size
10691083

10701084
model = modelcontext(model)
10711085

1086+
if var_names is not None:
1087+
if vars is not None:
1088+
raise ValueError("Should not specify both vars and var_names arguments.")
1089+
else:
1090+
vars = [model[x] for x in var_names]
1091+
elif vars is not None: # var_names is None, and vars is not.
1092+
warnings.warn("vars argument is deprecated in favor of var_names.",
1093+
DeprecationWarning)
10721094
if vars is None:
10731095
vars = model.observed_RVs
10741096

@@ -1080,7 +1102,7 @@ def sample_posterior_predictive(trace, samples=None, model=None, vars=None, size
10801102
if progressbar:
10811103
indices = tqdm(indices, total=samples)
10821104

1083-
ppc_trace = defaultdict(list)
1105+
ppc_trace = defaultdict(list) # type: Dict[str, List[Any]]
10841106
try:
10851107
for idx in indices:
10861108
if nchain > 1:
@@ -1249,18 +1271,28 @@ def sample_ppc_w(*args, **kwargs):
12491271
return sample_posterior_predictive_w(*args, **kwargs)
12501272

12511273

1252-
def sample_prior_predictive(samples=500, model=None, vars=None, random_seed=None):
1274+
def sample_prior_predictive(samples=500,
1275+
model: Optional[Model]=None,
1276+
vars: Optional[TIterable[str]] = None,
1277+
var_names: Optional[TIterable[str]] = None,
1278+
random_seed=None) -> Dict[str, np.ndarray]:
12531279
"""Generate samples from the prior predictive distribution.
12541280
12551281
Parameters
12561282
----------
12571283
samples : int
12581284
Number of samples from the prior predictive to generate. Defaults to 500.
12591285
model : Model (optional if in `with` context)
1260-
vars : iterable
1286+
vars : Iterable[str]
1287+
A list of names of variables for which to compute the posterior predictive
1288+
samples.
1289+
Defaults to `model.named_vars`.
1290+
DEPRECATED - Use `var_names` instead.
1291+
var_names : Iterable[str]
12611292
A list of names of variables for which to compute the posterior predictive
12621293
samples.
12631294
Defaults to `model.named_vars`.
1295+
12641296
random_seed : int
12651297
Seed for the random number generator.
12661298
@@ -1272,8 +1304,16 @@ def sample_prior_predictive(samples=500, model=None, vars=None, random_seed=None
12721304
"""
12731305
model = modelcontext(model)
12741306

1275-
if vars is None:
1307+
if vars is None and var_names is None:
12761308
vars = set(model.named_vars.keys())
1309+
elif vars is None:
1310+
vars = var_names
1311+
elif vars is not None:
1312+
warnings.warn("vars argument is deprecated in favor of var_names.",
1313+
DeprecationWarning)
1314+
else:
1315+
raise ValueError("Cannot supply both vars and var_names arguments.")
1316+
vars = cast(TIterable[str], vars) # tell mypy that vars cannot be None here.
12771317

12781318
if random_seed is not None:
12791319
np.random.seed(random_seed)
@@ -1282,8 +1322,10 @@ def sample_prior_predictive(samples=500, model=None, vars=None, random_seed=None
12821322
values = draw_values([model[name] for name in names], size=samples)
12831323

12841324
data = {k: v for k, v in zip(names, values)}
1325+
if data is None:
1326+
raise AssertionError("No variables sampled: attempting to sample %s"%names)
12851327

1286-
prior = {}
1328+
prior = {} # type: Dict[str, np.ndarray]
12871329
for var_name in vars:
12881330
if var_name in data:
12891331
prior[var_name] = data[var_name]

Diff for: pymc3/tests/test_sampling.py

+36-3
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,36 @@ def test_model_shared_variable(self):
359359
assert np.allclose(post_pred["p"], expected_p)
360360

361361
def test_deterministic_of_observed(self):
362+
meas_in_1 = pm.theanof.floatX(2 + 4 * np.random.randn(100))
363+
meas_in_2 = pm.theanof.floatX(5 + 4 * np.random.randn(100))
364+
with pm.Model() as model:
365+
mu_in_1 = pm.Normal("mu_in_1", 0, 1)
366+
sigma_in_1 = pm.HalfNormal("sd_in_1", 1)
367+
mu_in_2 = pm.Normal("mu_in_2", 0, 1)
368+
sigma_in_2 = pm.HalfNormal("sd__in_2", 1)
369+
370+
in_1 = pm.Normal("in_1", mu_in_1, sigma_in_1, observed=meas_in_1)
371+
in_2 = pm.Normal("in_2", mu_in_2, sigma_in_2, observed=meas_in_2)
372+
out_diff = in_1 + in_2
373+
pm.Deterministic("out", out_diff)
374+
375+
trace = pm.sample(100)
376+
ppc_trace = pm.trace_to_dataframe(
377+
trace, varnames=[n for n in trace.varnames if n != "out"]
378+
).to_dict("records")
379+
with pytest.warns(DeprecationWarning):
380+
ppc = pm.sample_posterior_predictive(
381+
model=model,
382+
trace=ppc_trace,
383+
samples=len(ppc_trace),
384+
vars=(model.deterministics + model.basic_RVs)
385+
)
386+
387+
rtol = 1e-5 if theano.config.floatX == "float64" else 1e-3
388+
assert np.allclose(ppc["in_1"] + ppc["in_2"], ppc["out"], rtol=rtol)
389+
390+
391+
def test_deterministic_of_observed_modified_interface(self):
362392
meas_in_1 = pm.theanof.floatX(2 + 4 * np.random.randn(100))
363393
meas_in_2 = pm.theanof.floatX(5 + 4 * np.random.randn(100))
364394
with pm.Model() as model:
@@ -380,7 +410,7 @@ def test_deterministic_of_observed(self):
380410
model=model,
381411
trace=ppc_trace,
382412
samples=len(ppc_trace),
383-
vars=(model.deterministics + model.basic_RVs),
413+
var_names=[x.name for x in (model.deterministics + model.basic_RVs)],
384414
)
385415

386416
rtol = 1e-5 if theano.config.floatX == "float64" else 1e-3
@@ -466,10 +496,13 @@ def test_respects_shape(self):
466496
with pm.Model():
467497
mu = pm.Gamma("mu", 3, 1, shape=1)
468498
goals = pm.Poisson("goals", mu, shape=shape)
469-
trace = pm.sample_prior_predictive(10)
499+
with pytest.warns(DeprecationWarning):
500+
trace1 = pm.sample_prior_predictive(10, vars=['mu', 'goals'])
501+
trace2 = pm.sample_prior_predictive(10, var_names=['mu', 'goals'])
470502
if shape == 2: # want to test shape as an int
471503
shape = (2,)
472-
assert trace["goals"].shape == (10,) + shape
504+
assert trace1["goals"].shape == (10,) + shape
505+
assert trace2["goals"].shape == (10,) + shape
473506

474507
def test_multivariate(self):
475508
with pm.Model():

0 commit comments

Comments
 (0)