1
+ from typing import Dict , List , Optional , TYPE_CHECKING , cast
2
+ if TYPE_CHECKING :
3
+ from typing import Any
4
+ from typing import Iterable as TIterable
1
5
from collections import defaultdict , Iterable
2
6
from copy import copy
3
7
import pickle
6
10
7
11
import numpy as np
8
12
import theano .gradient as tg
13
+ from theano .tensor import Tensor
9
14
10
15
from .backends .base import BaseTrace , MultiTrace
11
16
from .backends .ndarray import NDArray
12
17
from .distributions .distribution import draw_values
13
- from .model import modelcontext , Point , all_continuous
18
+ from .model import modelcontext , Point , all_continuous , Model
14
19
from .step_methods import (NUTS , HamiltonianMC , Metropolis , BinaryMetropolis ,
15
20
BinaryGibbsMetropolis , CategoricalGibbsMetropolis ,
16
21
Slice , CompoundStep , arraystep , smc )
@@ -529,7 +534,6 @@ def _sample_population(draws, chain, chains, start, random_seed, step, tune,
529
534
def _sample (chain , progressbar , random_seed , start , draws = None , step = None ,
530
535
trace = None , tune = None , model = None , ** kwargs ):
531
536
skip_first = kwargs .get ('skip_first' , 0 )
532
- refresh_every = kwargs .get ('refresh_every' , 100 )
533
537
534
538
sampling = _iter_sample (draws , step , start , trace , chain ,
535
539
tune , model , random_seed )
@@ -1027,8 +1031,14 @@ def stop_tuning(step):
1027
1031
return step
1028
1032
1029
1033
1030
- def sample_posterior_predictive (trace , samples = None , model = None , vars = None , size = None ,
1031
- random_seed = None , progressbar = True ):
1034
+ def sample_posterior_predictive (trace ,
1035
+ samples : Optional [int ]= None ,
1036
+ model : Optional [Model ]= None ,
1037
+ vars : Optional [TIterable [Tensor ]]= None ,
1038
+ var_names : Optional [List [str ]]= None ,
1039
+ size : Optional [int ]= None ,
1040
+ random_seed = None ,
1041
+ progressbar : bool = True ) -> Dict [str , np .ndarray ]:
1032
1042
"""Generate posterior predictive samples from a model given a trace.
1033
1043
1034
1044
Parameters
@@ -1042,7 +1052,10 @@ def sample_posterior_predictive(trace, samples=None, model=None, vars=None, size
1042
1052
Model used to generate `trace`
1043
1053
vars : iterable
1044
1054
Variables for which to compute the posterior predictive samples.
1045
- Defaults to `model.observed_RVs`.
1055
+ Defaults to `model.observed_RVs`. Deprecated: please use `var_names` instead.
1056
+ var_names : Iterable[str]
1057
+ Alternative way to specify vars to sample, to make this function orthogonal with
1058
+ others.
1046
1059
size : int
1047
1060
The number of random draws from the distribution specified by the parameters in each
1048
1061
sample of the trace.
@@ -1056,7 +1069,7 @@ def sample_posterior_predictive(trace, samples=None, model=None, vars=None, size
1056
1069
Returns
1057
1070
-------
1058
1071
samples : dict
1059
- Dictionary with the variables as keys. The values corresponding to the
1072
+ Dictionary with the variable names as keys, and values numpy arrays containing
1060
1073
posterior predictive samples.
1061
1074
"""
1062
1075
len_trace = len (trace )
@@ -1070,6 +1083,14 @@ def sample_posterior_predictive(trace, samples=None, model=None, vars=None, size
1070
1083
1071
1084
model = modelcontext (model )
1072
1085
1086
+ if var_names is not None :
1087
+ if vars is not None :
1088
+ raise ValueError ("Should not specify both vars and var_names arguments." )
1089
+ else :
1090
+ vars = [model [x ] for x in var_names ]
1091
+ elif vars is not None : # var_names is None, and vars is not.
1092
+ warnings .warn ("vars argument is deprecated in favor of var_names." ,
1093
+ DeprecationWarning )
1073
1094
if vars is None :
1074
1095
vars = model .observed_RVs
1075
1096
@@ -1081,7 +1102,7 @@ def sample_posterior_predictive(trace, samples=None, model=None, vars=None, size
1081
1102
if progressbar :
1082
1103
indices = tqdm (indices , total = samples )
1083
1104
1084
- ppc_trace = defaultdict (list )
1105
+ ppc_trace = defaultdict (list ) # type: Dict[str, List[Any]]
1085
1106
try :
1086
1107
for idx in indices :
1087
1108
if nchain > 1 :
@@ -1250,18 +1271,28 @@ def sample_ppc_w(*args, **kwargs):
1250
1271
return sample_posterior_predictive_w (* args , ** kwargs )
1251
1272
1252
1273
1253
- def sample_prior_predictive (samples = 500 , model = None , vars = None , random_seed = None ):
1274
+ def sample_prior_predictive (samples = 500 ,
1275
+ model : Optional [Model ]= None ,
1276
+ vars : Optional [TIterable [str ]] = None ,
1277
+ var_names : Optional [TIterable [str ]] = None ,
1278
+ random_seed = None ) -> Dict [str , np .ndarray ]:
1254
1279
"""Generate samples from the prior predictive distribution.
1255
1280
1256
1281
Parameters
1257
1282
----------
1258
1283
samples : int
1259
1284
Number of samples from the prior predictive to generate. Defaults to 500.
1260
1285
model : Model (optional if in `with` context)
1261
- vars : iterable
1286
+ vars : Iterable[str]
1287
+ A list of names of variables for which to compute the posterior predictive
1288
+ samples.
1289
+ Defaults to `model.named_vars`.
1290
+ DEPRECATED - Use `var_names` instead.
1291
+ var_names : Iterable[str]
1262
1292
A list of names of variables for which to compute the posterior predictive
1263
1293
samples.
1264
1294
Defaults to `model.named_vars`.
1295
+
1265
1296
random_seed : int
1266
1297
Seed for the random number generator.
1267
1298
@@ -1273,8 +1304,16 @@ def sample_prior_predictive(samples=500, model=None, vars=None, random_seed=None
1273
1304
"""
1274
1305
model = modelcontext (model )
1275
1306
1276
- if vars is None :
1307
+ if vars is None and var_names is None :
1277
1308
vars = set (model .named_vars .keys ())
1309
+ elif vars is None :
1310
+ vars = var_names
1311
+ elif vars is not None :
1312
+ warnings .warn ("vars argument is deprecated in favor of var_names." ,
1313
+ DeprecationWarning )
1314
+ else :
1315
+ raise ValueError ("Cannot supply both vars and var_names arguments." )
1316
+ vars = cast (TIterable [str ], vars ) # tell mypy that vars cannot be None here.
1278
1317
1279
1318
if random_seed is not None :
1280
1319
np .random .seed (random_seed )
@@ -1283,8 +1322,10 @@ def sample_prior_predictive(samples=500, model=None, vars=None, random_seed=None
1283
1322
values = draw_values ([model [name ] for name in names ], size = samples )
1284
1323
1285
1324
data = {k : v for k , v in zip (names , values )}
1325
+ if data is None :
1326
+ raise AssertionError ("No variables sampled: attempting to sample %s" % names )
1286
1327
1287
- prior = {}
1328
+ prior = {} # type: Dict[str, np.ndarray]
1288
1329
for var_name in vars :
1289
1330
if var_name in data :
1290
1331
prior [var_name ] = data [var_name ]
0 commit comments