Skip to content

Commit 44cf8a7

Browse files
authored
Merge pull request #5182 from zaxtax/add_numpyro_deterministic
Adding Deterministic for sampling_numpyro
2 parents 4b7aaad + 5ead708 commit 44cf8a7

File tree

2 files changed

+37
-30
lines changed

2 files changed

+37
-30
lines changed

Diff for: pymc/sampling_jax.py

+20-30
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
from aesara.link.jax.dispatch import jax_funcify
2727

2828
from pymc import Model, modelcontext
29-
from pymc.aesaraf import compile_rv_inplace
29+
from pymc.aesaraf import compile_rv_inplace, inputvars
30+
from pymc.util import get_default_varnames
3031

3132
warnings.warn("This module is experimental.")
3233

@@ -101,13 +102,19 @@ def sample_numpyro_nuts(
101102
target_accept=0.8,
102103
random_seed=10,
103104
model=None,
105+
var_names=None,
104106
progress_bar=True,
105107
keep_untransformed=False,
106108
):
107109
from numpyro.infer import MCMC, NUTS
108110

109111
model = modelcontext(model)
110112

113+
if var_names is None:
114+
var_names = model.unobserved_value_vars
115+
116+
vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed))
117+
111118
tic1 = pd.Timestamp.now()
112119
print("Compiling...", file=sys.stdout)
113120

@@ -143,45 +150,28 @@ def sample_numpyro_nuts(
143150
seed = jax.random.PRNGKey(random_seed)
144151
map_seed = jax.random.split(seed, chains)
145152

146-
pmap_numpyro.run(map_seed, init_params=init_state_batched, extra_fields=("num_steps",))
153+
if chains == 1:
154+
pmap_numpyro.run(seed, init_params=init_state, extra_fields=("num_steps",))
155+
else:
156+
pmap_numpyro.run(map_seed, init_params=init_state_batched, extra_fields=("num_steps",))
157+
147158
raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)
148159

149160
tic3 = pd.Timestamp.now()
150161
print("Sampling time = ", tic3 - tic2, file=sys.stdout)
151162

152163
print("Transforming variables...", file=sys.stdout)
153-
mcmc_samples = []
154-
for i, (value_var, raw_samples) in enumerate(zip(model.value_vars, raw_mcmc_samples)):
155-
raw_samples = at.constant(np.asarray(raw_samples))
156-
157-
rv = model.values_to_rvs[value_var]
158-
transform = getattr(value_var.tag, "transform", None)
159-
160-
if transform is not None:
161-
# TODO: This will fail when the transformation depends on another variable
162-
# such as in interval transform with RVs as edges
163-
trans_samples = transform.backward(raw_samples, *rv.owner.inputs)
164-
trans_samples.name = rv.name
165-
mcmc_samples.append(trans_samples)
166-
167-
if keep_untransformed:
168-
raw_samples.name = value_var.name
169-
mcmc_samples.append(raw_samples)
170-
else:
171-
raw_samples.name = rv.name
172-
mcmc_samples.append(raw_samples)
173-
174-
mcmc_varnames = [var.name for var in mcmc_samples]
175-
mcmc_samples = compile_rv_inplace(
176-
[],
177-
mcmc_samples,
178-
mode="JAX",
179-
)()
164+
mcmc_samples = {}
165+
for v in vars_to_sample:
166+
fgraph = FunctionGraph(model.value_vars, [v], clone=False)
167+
jax_fn = jax_funcify(fgraph)
168+
result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0]
169+
mcmc_samples[v.name] = result
180170

181171
tic4 = pd.Timestamp.now()
182172
print("Transformation time = ", tic4 - tic3, file=sys.stdout)
183173

184-
posterior = {k: v for k, v in zip(mcmc_varnames, mcmc_samples)}
174+
posterior = mcmc_samples
185175
az_trace = az.from_dict(posterior=posterior)
186176

187177
return az_trace

Diff for: pymc/tests/test_sampling_jax.py

+17
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,23 @@ def test_transform_samples():
4444
assert 1.5 < trace.posterior["sigma"].mean() < 2.5
4545

4646

47+
def test_deterministic_samples():
48+
aesara.config.on_opt_error = "raise"
49+
np.random.seed(13244)
50+
51+
obs = np.random.normal(10, 2, size=100)
52+
obs_at = aesara.shared(obs, borrow=True, name="obs")
53+
with pm.Model() as model:
54+
a = pm.Uniform("a", -20, 20)
55+
b = pm.Deterministic("b", a / 2.0)
56+
c = pm.Normal("c", a, sigma=1.0, observed=obs_at)
57+
58+
trace = sample_numpyro_nuts(chains=2, random_seed=1322, keep_untransformed=True)
59+
60+
assert 8 < trace.posterior["a"].mean() < 11
61+
assert np.allclose(trace.posterior["b"].values, trace.posterior["a"].values / 2)
62+
63+
4764
def test_replace_shared_variables():
4865
x = aesara.shared(5, name="shared_x")
4966

0 commit comments

Comments
 (0)