Include observed_data
and sample_stats
in inferencedata returned from sample_numpyro_nuts
#5121
Labels
Dear PyMC developers,
I notice the trace output difference between ordinary sampling and jax_sampling.
I notice that with return_inferencedata=True in ordinary sampling, I will get the following in trace:
posterior
log_likelihood
sample_stats
observed_data
But in jax_sampling, I only get posterior, which I cannot proceed with plot_ppc as I do not have observed_data.
May I know how to resolve this issue? Or is there any setup I can do to resolve this issue?
thanks
Below is my reproducible code:
import warnings
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc3 as pm
import pymc3.sampling_jax
import numpyro
import theano
numpyro.util.set_platform('cpu')
print(f"Running on PyMC3 v{pm.version}")
import numpy as np
import pandas as pd
import datetime as dt
#from pandas_datareader import data
import matplotlib.pyplot as plt
%matplotlib inline
returns = pd.read_csv(pm.get_data("SP500.csv"), index_col="Date")
returns["change"] = np.log(returns["Close"]).diff()
returns = returns.dropna()
n=len(returns)
returns.head()
Wiggins
def WigginsDGP(volatility_mu, volatility_theta, volatility_sigma):
def sde(x, theta, mu, sigma):
return theta * (mu - x), sigma
%%time
import pymc3.distributions.timeseries as ts
with pm.Model() as wiggins_model:
Versions and main components
The text was updated successfully, but these errors were encountered: