Skip to content

Commit 66f4ed6

Browse files
committed
🔥 remove deprecated vars from sample_prior_predictive
1 parent 60cf2cd commit 66f4ed6

File tree

3 files changed

+10
-31
lines changed

3 files changed

+10
-31
lines changed

Diff for: pymc3/sampling.py

+6-27
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from collections import defaultdict
2525
from copy import copy
26-
from typing import Any, Dict, Iterable, List, Optional, Union, cast
26+
from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast
2727

2828
import arviz
2929
import numpy as np
@@ -56,7 +56,7 @@
5656
Metropolis,
5757
Slice,
5858
)
59-
from pymc3.step_methods.arraystep import PopulationArrayStepShared
59+
from pymc3.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
6060
from pymc3.step_methods.hmc import quadpotential
6161
from pymc3.util import (
6262
chains_and_samples,
@@ -91,18 +91,7 @@
9191
CategoricalGibbsMetropolis,
9292
PGBART,
9393
)
94-
Step = Union[
95-
NUTS,
96-
HamiltonianMC,
97-
Metropolis,
98-
BinaryMetropolis,
99-
BinaryGibbsMetropolis,
100-
Slice,
101-
CategoricalGibbsMetropolis,
102-
PGBART,
103-
CompoundStep,
104-
]
105-
94+
Step = Union[BlockedStep, CompoundStep]
10695

10796
ArrayLike = Union[np.ndarray, List[float]]
10897
PointType = Dict[str, np.ndarray]
@@ -1898,7 +1887,6 @@ def sample_posterior_predictive_w(
18981887
def sample_prior_predictive(
18991888
samples=500,
19001889
model: Optional[Model] = None,
1901-
vars: Optional[Iterable[str]] = None,
19021890
var_names: Optional[Iterable[str]] = None,
19031891
random_seed=None,
19041892
) -> Dict[str, np.ndarray]:
@@ -1909,9 +1897,6 @@ def sample_prior_predictive(
19091897
samples : int
19101898
Number of samples from the prior predictive to generate. Defaults to 500.
19111899
model : Model (optional if in ``with`` context)
1912-
vars : Iterable[str]
1913-
A list of names of variables for which to compute the posterior predictive
1914-
samples. *DEPRECATED* - Use ``var_names`` argument instead.
19151900
var_names : Iterable[str]
19161901
A list of names of variables for which to compute the posterior predictive
19171902
samples. Defaults to both observed and unobserved RVs.
@@ -1926,20 +1911,14 @@ def sample_prior_predictive(
19261911
"""
19271912
model = modelcontext(model)
19281913

1929-
if vars is None and var_names is None:
1914+
if var_names is None:
19301915
prior_pred_vars = model.observed_RVs
19311916
prior_vars = (
19321917
get_default_varnames(model.unobserved_RVs, include_transformed=True) + model.potentials
19331918
)
1934-
vars_: Iterable[str] = [var.name for var in prior_vars + prior_pred_vars]
1935-
elif vars is None:
1936-
assert var_names is not None # help mypy
1937-
vars_ = var_names
1938-
elif var_names is None:
1939-
warnings.warn("vars argument is deprecated in favor of var_names.", DeprecationWarning)
1940-
vars_ = vars
1919+
vars_: Set[str] = {var.name for var in prior_vars + prior_pred_vars}
19411920
else:
1942-
raise ValueError("Cannot supply both vars and var_names arguments.")
1921+
vars_ = set(var_names)
19431922

19441923
if random_seed is not None:
19451924
np.random.seed(random_seed)

Diff for: pymc3/step_methods/arraystep.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from numpy.random import uniform
2121

2222
from pymc3.blocking import ArrayOrdering, DictToArrayBijection
23-
from pymc3.model import modelcontext
23+
from pymc3.model import PyMC3Variable, modelcontext
2424
from pymc3.step_methods.compound import CompoundStep
2525
from pymc3.theanof import inputvars
2626
from pymc3.util import get_var_name
@@ -48,6 +48,7 @@ class BlockedStep:
4848

4949
generates_stats = False
5050
stats_dtypes: List[Dict[str, np.dtype]] = []
51+
vars: List[PyMC3Variable] = []
5152

5253
def __new__(cls, *args, **kwargs):
5354
blocked = kwargs.get("blocked")

Diff for: pymc3/tests/test_sampling.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -903,9 +903,8 @@ def test_respects_shape(self):
903903
with pm.Model():
904904
mu = pm.Gamma("mu", 3, 1, shape=1)
905905
goals = pm.Poisson("goals", mu, shape=shape)
906-
with pytest.warns(DeprecationWarning):
907-
trace1 = pm.sample_prior_predictive(10, vars=["mu", "goals"])
908-
trace2 = pm.sample_prior_predictive(10, var_names=["mu", "goals"])
906+
trace1 = pm.sample_prior_predictive(10, var_names=["mu", "mu", "goals"])
907+
trace2 = pm.sample_prior_predictive(10, var_names=["mu", "goals"])
909908
if shape == 2: # want to test shape as an int
910909
shape = (2,)
911910
assert trace1["goals"].shape == (10,) + shape

0 commit comments

Comments
 (0)