|
3 | 3 | import re
|
4 | 4 | import warnings
|
5 | 5 |
|
| 6 | +from collections import defaultdict |
| 7 | + |
6 | 8 | xla_flags = os.getenv("XLA_FLAGS", "").lstrip("--")
|
7 | 9 | xla_flags = re.sub(r"xla_force_host_platform_device_count=.+\s", "", xla_flags).split()
|
8 | 10 | os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(100)])
|
@@ -175,8 +177,50 @@ def _sample(current_state, seed):
|
175 | 177 | # print("Sampling time = ", tic4 - tic3)
|
176 | 178 |
|
177 | 179 | posterior = {k: v for k, v in zip(rv_names, mcmc_samples)}
|
| 180 | + tic3 = pd.Timestamp.now() |
| 181 | + posterior = _transform_samples(posterior, model, keep_untransformed=False) |
| 182 | + tic4 = pd.Timestamp.now() |
178 | 183 |
|
179 | 184 | az_trace = az.from_dict(posterior=posterior)
|
180 |
| - tic3 = pd.Timestamp.now() |
181 | 185 | print("Compilation + sampling time = ", tic3 - tic2)
|
| 186 | + print("Transformation time = ", tic4 - tic3) |
| 187 | + |
182 | 188 | return az_trace # , leapfrogs_taken, tic3 - tic2
|
| 189 | + |
| 190 | + |
| 191 | +def _transform_samples(samples, model, keep_untransformed=False): |
| 192 | + |
| 193 | + # Find out which RVs we need to compute: |
| 194 | + free_rv_names = {x.name for x in model.free_RVs} |
| 195 | + unobserved_names = {x.name for x in model.unobserved_RVs} |
| 196 | + |
| 197 | + names_to_compute = unobserved_names - free_rv_names |
| 198 | + ops_to_compute = [x for x in model.unobserved_RVs if x.name in names_to_compute] |
| 199 | + |
| 200 | + # Create function graph for these: |
| 201 | + fgraph = theano.graph.fg.FunctionGraph(model.free_RVs, ops_to_compute) |
| 202 | + |
| 203 | + # Jaxify, which returns a list of functions, one for each op |
| 204 | + jax_fns = jax_funcify(fgraph) |
| 205 | + |
| 206 | + # Put together the inputs |
| 207 | + inputs = [samples[x.name] for x in model.free_RVs] |
| 208 | + |
| 209 | + for cur_op, cur_jax_fn in zip(ops_to_compute, jax_fns): |
| 210 | + |
| 211 | + # We need a function taking a single argument to run vmap, while the |
| 212 | + # jax_fn takes a list, so: |
| 213 | + to_run = lambda x: cur_jax_fn(*x) |
| 214 | + |
| 215 | + result = jax.vmap(jax.vmap(to_run))(inputs) |
| 216 | + |
| 217 | + # Add to sample dict |
| 218 | + samples[cur_op.name] = result |
| 219 | + |
| 220 | + # Discard unwanted transformed variables, if desired: |
| 221 | + vars_to_keep = set( |
| 222 | + pm.util.get_default_varnames(list(samples.keys()), include_transformed=keep_untransformed) |
| 223 | + ) |
| 224 | + samples = {x: y for x, y in samples.items() if x in vars_to_keep} |
| 225 | + |
| 226 | + return samples |
0 commit comments