Skip to content

Commit 5535721

Browse files
committed
Transform samples from sample_numpyro_nuts
* Add `pymc3.sampling_jax._transform_samples` function which transforms draws * Modify `pymc3.sampling_jax.sample_numpyro_nuts` function to use this function to return transformed samples * Add release note
1 parent 823906a commit 5535721

File tree

2 files changed

+46
-1
lines changed

2 files changed

+46
-1
lines changed

Diff for: RELEASE-NOTES.md

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
### Breaking Changes
66

77
### New Features
8+
- `pymc3.sampling_jax.sample_numpyro_nuts` now returns samples from transformed random variables, rather than from the unconstrained representation (see [#4427](https://github.com/pymc-devs/pymc3/pull/4427)).
89

910
### Maintenance
1011
- `math.log1mexp_numpy` no longer raises RuntimeWarning when given very small inputs. These were commonly observed during NUTS sampling (see [#4428](https://github.com/pymc-devs/pymc3/pull/4428)).

Diff for: pymc3/sampling_jax.py

+45-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import re
44
import warnings
55

6+
from collections import defaultdict
7+
68
xla_flags = os.getenv("XLA_FLAGS", "").lstrip("--")
79
xla_flags = re.sub(r"xla_force_host_platform_device_count=.+\s", "", xla_flags).split()
810
os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(100)])
@@ -175,8 +177,50 @@ def _sample(current_state, seed):
175177
# print("Sampling time = ", tic4 - tic3)
176178

177179
posterior = {k: v for k, v in zip(rv_names, mcmc_samples)}
180+
tic3 = pd.Timestamp.now()
181+
posterior = _transform_samples(posterior, model, keep_untransformed=False)
182+
tic4 = pd.Timestamp.now()
178183

179184
az_trace = az.from_dict(posterior=posterior)
180-
tic3 = pd.Timestamp.now()
181185
print("Compilation + sampling time = ", tic3 - tic2)
186+
print("Transformation time = ", tic4 - tic3)
187+
182188
return az_trace # , leapfrogs_taken, tic3 - tic2
189+
190+
191+
def _transform_samples(samples, model, keep_untransformed=False):
192+
193+
# Find out which RVs we need to compute:
194+
free_rv_names = {x.name for x in model.free_RVs}
195+
unobserved_names = {x.name for x in model.unobserved_RVs}
196+
197+
names_to_compute = unobserved_names - free_rv_names
198+
ops_to_compute = [x for x in model.unobserved_RVs if x.name in names_to_compute]
199+
200+
# Create function graph for these:
201+
fgraph = theano.graph.fg.FunctionGraph(model.free_RVs, ops_to_compute)
202+
203+
# Jaxify, which returns a list of functions, one for each op
204+
jax_fns = jax_funcify(fgraph)
205+
206+
# Put together the inputs
207+
inputs = [samples[x.name] for x in model.free_RVs]
208+
209+
for cur_op, cur_jax_fn in zip(ops_to_compute, jax_fns):
210+
211+
# We need a function taking a single argument to run vmap, while the
212+
# jax_fn takes a list, so:
213+
to_run = lambda x: cur_jax_fn(*x)
214+
215+
result = jax.vmap(jax.vmap(to_run))(inputs)
216+
217+
# Add to sample dict
218+
samples[cur_op.name] = result
219+
220+
# Discard unwanted transformed variables, if desired:
221+
vars_to_keep = set(
222+
pm.util.get_default_varnames(list(samples.keys()), include_transformed=keep_untransformed)
223+
)
224+
samples = {x: y for x, y in samples.items() if x in vars_to_keep}
225+
226+
return samples

0 commit comments

Comments
 (0)