23
23
24
24
from collections import defaultdict
25
25
from copy import copy
26
- from typing import Any , Dict , Iterable , List , Optional , Union , cast
26
+ from typing import Any , Dict , Iterable , List , Optional , Set , Union , cast
27
27
28
28
import arviz
29
29
import numpy as np
56
56
Metropolis ,
57
57
Slice ,
58
58
)
59
- from pymc3 .step_methods .arraystep import PopulationArrayStepShared
59
+ from pymc3 .step_methods .arraystep import BlockedStep , PopulationArrayStepShared
60
60
from pymc3 .step_methods .hmc import quadpotential
61
61
from pymc3 .util import (
62
62
chains_and_samples ,
91
91
CategoricalGibbsMetropolis ,
92
92
PGBART ,
93
93
)
94
- Step = Union [
95
- NUTS ,
96
- HamiltonianMC ,
97
- Metropolis ,
98
- BinaryMetropolis ,
99
- BinaryGibbsMetropolis ,
100
- Slice ,
101
- CategoricalGibbsMetropolis ,
102
- PGBART ,
103
- CompoundStep ,
104
- ]
105
-
94
+ Step = Union [BlockedStep , CompoundStep ]
106
95
107
96
ArrayLike = Union [np .ndarray , List [float ]]
108
97
PointType = Dict [str , np .ndarray ]
@@ -1898,7 +1887,6 @@ def sample_posterior_predictive_w(
1898
1887
def sample_prior_predictive (
1899
1888
samples = 500 ,
1900
1889
model : Optional [Model ] = None ,
1901
- vars : Optional [Iterable [str ]] = None ,
1902
1890
var_names : Optional [Iterable [str ]] = None ,
1903
1891
random_seed = None ,
1904
1892
) -> Dict [str , np .ndarray ]:
@@ -1909,9 +1897,6 @@ def sample_prior_predictive(
1909
1897
samples : int
1910
1898
Number of samples from the prior predictive to generate. Defaults to 500.
1911
1899
model : Model (optional if in ``with`` context)
1912
- vars : Iterable[str]
1913
- A list of names of variables for which to compute the posterior predictive
1914
- samples. *DEPRECATED* - Use ``var_names`` argument instead.
1915
1900
var_names : Iterable[str]
1916
1901
A list of names of variables for which to compute the posterior predictive
1917
1902
samples. Defaults to both observed and unobserved RVs.
@@ -1926,20 +1911,14 @@ def sample_prior_predictive(
1926
1911
"""
1927
1912
model = modelcontext (model )
1928
1913
1929
- if vars is None and var_names is None :
1914
+ if var_names is None :
1930
1915
prior_pred_vars = model .observed_RVs
1931
1916
prior_vars = (
1932
1917
get_default_varnames (model .unobserved_RVs , include_transformed = True ) + model .potentials
1933
1918
)
1934
- vars_ : Iterable [str ] = [var .name for var in prior_vars + prior_pred_vars ]
1935
- elif vars is None :
1936
- assert var_names is not None # help mypy
1937
- vars_ = var_names
1938
- elif var_names is None :
1939
- warnings .warn ("vars argument is deprecated in favor of var_names." , DeprecationWarning )
1940
- vars_ = vars
1919
+ vars_ : Set [str ] = {var .name for var in prior_vars + prior_pred_vars }
1941
1920
else :
1942
- raise ValueError ( "Cannot supply both vars and var_names arguments." )
1921
+ vars_ = set ( var_names )
1943
1922
1944
1923
if random_seed is not None :
1945
1924
np .random .seed (random_seed )
0 commit comments