|
26 | 26 | from aesara.link.jax.dispatch import jax_funcify
|
27 | 27 |
|
28 | 28 | from pymc import Model, modelcontext
|
29 |
| -from pymc.aesaraf import compile_rv_inplace |
| 29 | +from pymc.aesaraf import compile_rv_inplace, inputvars |
| 30 | +from pymc.util import get_default_varnames |
30 | 31 |
|
31 | 32 | warnings.warn("This module is experimental.")
|
32 | 33 |
|
@@ -101,13 +102,19 @@ def sample_numpyro_nuts(
|
101 | 102 | target_accept=0.8,
|
102 | 103 | random_seed=10,
|
103 | 104 | model=None,
|
| 105 | + var_names=None, |
104 | 106 | progress_bar=True,
|
105 | 107 | keep_untransformed=False,
|
106 | 108 | ):
|
107 | 109 | from numpyro.infer import MCMC, NUTS
|
108 | 110 |
|
109 | 111 | model = modelcontext(model)
|
110 | 112 |
|
| 113 | + if var_names is None: |
| 114 | + var_names = model.unobserved_value_vars |
| 115 | + |
| 116 | + vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed)) |
| 117 | + |
111 | 118 | tic1 = pd.Timestamp.now()
|
112 | 119 | print("Compiling...", file=sys.stdout)
|
113 | 120 |
|
@@ -143,45 +150,28 @@ def sample_numpyro_nuts(
|
143 | 150 | seed = jax.random.PRNGKey(random_seed)
|
144 | 151 | map_seed = jax.random.split(seed, chains)
|
145 | 152 |
|
146 |
| - pmap_numpyro.run(map_seed, init_params=init_state_batched, extra_fields=("num_steps",)) |
| 153 | + if chains == 1: |
| 154 | + pmap_numpyro.run(seed, init_params=init_state, extra_fields=("num_steps",)) |
| 155 | + else: |
| 156 | + pmap_numpyro.run(map_seed, init_params=init_state_batched, extra_fields=("num_steps",)) |
| 157 | + |
147 | 158 | raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)
|
148 | 159 |
|
149 | 160 | tic3 = pd.Timestamp.now()
|
150 | 161 | print("Sampling time = ", tic3 - tic2, file=sys.stdout)
|
151 | 162 |
|
152 | 163 | print("Transforming variables...", file=sys.stdout)
|
153 |
| - mcmc_samples = [] |
154 |
| - for i, (value_var, raw_samples) in enumerate(zip(model.value_vars, raw_mcmc_samples)): |
155 |
| - raw_samples = at.constant(np.asarray(raw_samples)) |
156 |
| - |
157 |
| - rv = model.values_to_rvs[value_var] |
158 |
| - transform = getattr(value_var.tag, "transform", None) |
159 |
| - |
160 |
| - if transform is not None: |
161 |
| - # TODO: This will fail when the transformation depends on another variable |
162 |
| - # such as in interval transform with RVs as edges |
163 |
| - trans_samples = transform.backward(raw_samples, *rv.owner.inputs) |
164 |
| - trans_samples.name = rv.name |
165 |
| - mcmc_samples.append(trans_samples) |
166 |
| - |
167 |
| - if keep_untransformed: |
168 |
| - raw_samples.name = value_var.name |
169 |
| - mcmc_samples.append(raw_samples) |
170 |
| - else: |
171 |
| - raw_samples.name = rv.name |
172 |
| - mcmc_samples.append(raw_samples) |
173 |
| - |
174 |
| - mcmc_varnames = [var.name for var in mcmc_samples] |
175 |
| - mcmc_samples = compile_rv_inplace( |
176 |
| - [], |
177 |
| - mcmc_samples, |
178 |
| - mode="JAX", |
179 |
| - )() |
| 164 | + mcmc_samples = {} |
| 165 | + for v in vars_to_sample: |
| 166 | + fgraph = FunctionGraph(model.value_vars, [v], clone=False) |
| 167 | + jax_fn = jax_funcify(fgraph) |
| 168 | + result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0] |
| 169 | + mcmc_samples[v.name] = result |
180 | 170 |
|
181 | 171 | tic4 = pd.Timestamp.now()
|
182 | 172 | print("Transformation time = ", tic4 - tic3, file=sys.stdout)
|
183 | 173 |
|
184 |
| - posterior = {k: v for k, v in zip(mcmc_varnames, mcmc_samples)} |
| 174 | + posterior = mcmc_samples |
185 | 175 | az_trace = az.from_dict(posterior=posterior)
|
186 | 176 |
|
187 | 177 | return az_trace
|
0 commit comments