Skip to content

Commit 78d15f4

Browse files
committed
Added a small test
1 parent c18f0ce commit 78d15f4

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

Diff for: pymc3/sampling_jax.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def sample_numpyro_nuts(
123123
random_seed=10,
124124
model=None,
125125
progress_bar=True,
126+
keep_untransformed=False,
126127
):
127128
from numpyro.infer import MCMC, NUTS
128129

@@ -178,7 +179,7 @@ def _sample(current_state, seed):
178179

179180
posterior = {k: v for k, v in zip(rv_names, mcmc_samples)}
180181
tic3 = pd.Timestamp.now()
181-
posterior = _transform_samples(posterior, model, keep_untransformed=False)
182+
posterior = _transform_samples(posterior, model, keep_untransformed=keep_untransformed)
182183
tic4 = pd.Timestamp.now()
183184

184185
az_trace = az.from_dict(posterior=posterior)

Diff for: pymc3/tests/test_sampling_jax.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import numpy as np
2+
3+
import pymc3 as pm
4+
5+
from pymc3.sampling_jax import sample_numpyro_nuts
6+
7+
8+
def test_transform_samples():
9+
10+
with pm.Model() as model:
11+
12+
sigma = pm.HalfNormal("sigma")
13+
b = pm.Normal("b", sigma=sigma)
14+
trace = sample_numpyro_nuts(keep_untransformed=True)
15+
16+
log_vals = trace.posterior["sigma_log__"].values
17+
trans_vals = trace.posterior["sigma"].values
18+
19+
assert np.allclose(np.exp(log_vals), trans_vals)

0 commit comments

Comments
 (0)