Skip to content

Drop nuts init method from pm.sample #3863

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 31, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,6 @@ def sample(
* advi: Run ADVI to estimate posterior mean and diagonal mass matrix.
* advi_map: Initialize ADVI with MAP and use MAP as starting point.
* map: Use the MAP as starting point. This is discouraged.
* nuts: Run NUTS and estimate posterior mean and mass matrix from the trace.
* adapt_full: Adapt a dense mass matrix using the sample covariances
step: function or iterable of functions
A step function or collection of functions. If there are variables without step methods,
Expand Down Expand Up @@ -1865,14 +1864,12 @@ def init_nuts(
* advi: Run ADVI to estimate posterior mean and diagonal mass matrix.
* advi_map: Initialize ADVI with MAP and use MAP as starting point.
* map: Use the MAP as starting point. This is discouraged.
* nuts: Run NUTS and estimate posterior mean and mass matrix from
the trace.
* adapt_full: Adapt a dense mass matrix using the sample covariances
chains: int
Number of jobs to start.
n_init: int
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think n_init is only used for advi now, right? COuld make that clear in the doc string.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Number of iterations of initializer
If 'ADVI', number of iterations, if 'nuts', number of draws.
If 'ADVI', number of iterations.
model: Model (optional if in ``with`` context)
progressbar: bool
Whether or not to display a progressbar for advi sampling.
Expand Down Expand Up @@ -2001,13 +1998,6 @@ def init_nuts(
cov = pm.find_hessian(point=start)
start = [start] * chains
potential = quadpotential.QuadPotentialFull(cov)
elif init == "nuts":
init_trace = pm.sample(
draws=n_init, step=pm.NUTS(), tune=n_init // 2, random_seed=random_seed
)
cov = np.atleast_1d(pm.trace_cov(init_trace))
start = list(np.random.choice(init_trace, chains))
potential = quadpotential.QuadPotentialFull(cov)
elif init == "adapt_full":
start = [model.test_point] * chains
mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0)
Expand Down
3 changes: 1 addition & 2 deletions pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_sample(self):

def test_sample_init(self):
with self.model:
for init in ("advi", "advi_map", "map", "nuts"):
for init in ("advi", "advi_map", "map"):
pm.sample(
init=init,
tune=0,
Expand Down Expand Up @@ -675,7 +675,6 @@ def test_sample_posterior_predictive_w(self):
"advi+adapt_diag_grad",
"map",
"advi_map",
"nuts",
],
)
def test_exec_nuts_init(method):
Expand Down