-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Adding log_likelihood, observed_data, and sample_stats to numpyro sampler #5189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
549f714
ff3a0cd
6987bac
af108d4
f295288
8a28fad
59ebfdb
f5aeaf6
bf2ad0d
a3c6cc2
0098876
5f5cd87
6e4fcab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,7 +26,9 @@ | |
from aesara.link.jax.dispatch import jax_funcify | ||
|
||
from pymc import Model, modelcontext | ||
from pymc.aesaraf import compile_rv_inplace, inputvars | ||
from pymc.aesaraf import compile_rv_inplace | ||
from pymc.backends.arviz import find_observations | ||
from pymc.distributions import logpt | ||
from pymc.util import get_default_varnames | ||
|
||
warnings.warn("This module is experimental.") | ||
|
@@ -95,6 +97,39 @@ def logp_fn_wrap(x): | |
return logp_fn_wrap | ||
|
||
|
||
# Adopted from arviz numpyro extractor | ||
def _sample_stats_to_xarray(posterior): | ||
"""Extract sample_stats from NumPyro posterior.""" | ||
rename_key = { | ||
"potential_energy": "lp", | ||
"adapt_state.step_size": "step_size", | ||
"num_steps": "n_steps", | ||
"accept_prob": "acceptance_rate", | ||
} | ||
data = {} | ||
for stat, value in posterior.get_extra_fields(group_by_chain=True).items(): | ||
if isinstance(value, (dict, tuple)): | ||
continue | ||
name = rename_key.get(stat, stat) | ||
value = value.copy() | ||
data[name] = value | ||
if stat == "num_steps": | ||
data["tree_depth"] = np.log2(value).astype(int) + 1 | ||
return data | ||
|
||
|
||
def _get_log_likelihood(model, samples): | ||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"Compute log-likelihood for all observations" | ||
data = {} | ||
for v in model.observed_RVs: | ||
logp_v = replace_shared_variables([logpt(v)]) | ||
fgraph = FunctionGraph(model.value_vars, logp_v, clone=False) | ||
jax_fn = jax_funcify(fgraph) | ||
result = jax.vmap(jax.vmap(jax_fn))(*samples)[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Out of curiosity, would we expect any benefits to jit_compiling this outer There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be possible to use a similar approach with Aesara directly? Here we only loop over observed variables in order to get the pointwise log likelihood. We had some discussion about this in #4489 but ended up keeping the 3 nested loops over variables, chains and draws. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should be possible, but requires a Aesara Scan, and at least for small models this was not faster than python looping when I checked it. Here is a Notebook that documents some things I tried: https://gist.github.com/ricardoV94/6089a8c46a0e19665f01c79ea04e1cb2 It might be faster if using shared variables... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No idea. I think the easiest thing to do is just benchmark it. I don't even call When I run the model in the unit test with the change
I don't really get a speed-up until there are millions of samples. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We should definitely call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Then it's probably not worth it. I was under the impression it would be possible to vectorize/broadcast the operation from the conversations in #4489 and in slack. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It must be possible, since the vmap above works just fine. I just have no idea how they do it xD, or how/if you could do it in Aesara. I also wonder whether the vmap works for more complicated models with multivariate distributions and the like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alright. I'm going to make a separate PR for some of this other stuff. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cool, feel free to tag me if you want me to review, I am not watching PRs. I can already say I won't be able to help with the vectorized log_likelihood thing, I tried and I lost much more time with that than what would have been healthy. I should be able to help with coords and dims though |
||
data[v.name] = result | ||
return data | ||
|
||
|
||
def sample_numpyro_nuts( | ||
draws=1000, | ||
tune=1000, | ||
|
@@ -151,9 +186,23 @@ def sample_numpyro_nuts( | |
map_seed = jax.random.split(seed, chains) | ||
|
||
if chains == 1: | ||
pmap_numpyro.run(seed, init_params=init_state, extra_fields=("num_steps",)) | ||
init_params = init_state | ||
map_seed = seed | ||
else: | ||
pmap_numpyro.run(map_seed, init_params=init_state_batched, extra_fields=("num_steps",)) | ||
init_params = init_state_batched | ||
|
||
pmap_numpyro.run( | ||
map_seed, | ||
init_params=init_params, | ||
extra_fields=( | ||
"num_steps", | ||
"potential_energy", | ||
"energy", | ||
"adapt_state.step_size", | ||
"accept_prob", | ||
"diverging", | ||
), | ||
) | ||
|
||
raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True) | ||
|
||
|
@@ -172,6 +221,11 @@ def sample_numpyro_nuts( | |
print("Transformation time = ", tic4 - tic3, file=sys.stdout) | ||
|
||
posterior = mcmc_samples | ||
az_trace = az.from_dict(posterior=posterior) | ||
az_posterior = az.from_dict(posterior=posterior) | ||
|
||
az_obs = az.from_dict(observed_data=find_observations(model)) | ||
az_stats = az.from_dict(sample_stats=_sample_stats_to_xarray(pmap_numpyro)) | ||
az_ll = az.from_dict(log_likelihood=_get_log_likelihood(model, raw_mcmc_samples)) | ||
az_trace = az.concat(az_posterior, az_ll, az_obs, az_stats) | ||
|
||
return az_trace |
Uh oh!
There was an error while loading. Please reload this page.