Skip to content

Commit fa43eba

Browse files
natazielGoose
and
Goose
authored
Use jaxified logp for initial point evaluation when sampling via Jax (#7610)
* use jaxified logp for initial point evaluation when sampling via Jax * correcting initial point type hinting * refactor init_jitter inputs --------- Co-authored-by: Goose <[email protected]>
1 parent 7a995a0 commit fa43eba

File tree

3 files changed

+147
-82
lines changed

3 files changed

+147
-82
lines changed

pymc/initial_point.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from pymc.logprob.transforms import Transform
2828
from pymc.pytensorf import (
29+
SeedSequenceSeed,
2930
compile,
3031
find_rng_nodes,
3132
replace_rng_nodes,
@@ -67,7 +68,7 @@ def make_initial_point_fns_per_chain(
6768
overrides: StartDict | Sequence[StartDict | None] | None,
6869
jitter_rvs: set[TensorVariable] | None = None,
6970
chains: int,
70-
) -> list[Callable]:
71+
) -> list[Callable[[SeedSequenceSeed], PointType]]:
7172
"""Create an initial point function for each chain, as defined by initvals.
7273
7374
If a single initval dictionary is passed, the function is replicated for each
@@ -82,6 +83,11 @@ def make_initial_point_fns_per_chain(
8283
Random variable tensors for which U(-1, 1) jitter shall be applied.
8384
(To the transformed space if applicable.)
8485
86+
Returns
87+
-------
88+
ipfns : list[Callable[[SeedSequenceSeed], dict[str, np.ndarray]]]
89+
list of functions that return initial points for each chain.
90+
8591
Raises
8692
------
8793
ValueError
@@ -124,7 +130,7 @@ def make_initial_point_fn(
124130
jitter_rvs: set[TensorVariable] | None = None,
125131
default_strategy: str = "support_point",
126132
return_transformed: bool = True,
127-
) -> Callable:
133+
) -> Callable[[SeedSequenceSeed], PointType]:
128134
"""Create seeded function that computes initial values for all free model variables.
129135
130136
Parameters
@@ -138,6 +144,10 @@ def make_initial_point_fn(
138144
Initial value (strategies) to use instead of what's specified in `Model.initial_values`.
139145
return_transformed : bool
140146
If `True` the returned variables will correspond to transformed initial values.
147+
148+
Returns
149+
-------
150+
initial_point_fn : Callable[[SeedSequenceSeed], dict[str, np.ndarray]]
141151
"""
142152
sdict_overrides = convert_str_to_rv_dict(model, overrides or {})
143153
initval_strats = {

pymc/sampling/jax.py

+120-69
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from collections.abc import Callable, Sequence
1919
from datetime import datetime
2020
from functools import partial
21+
from types import ModuleType
2122
from typing import Any, Literal
2223

2324
import arviz as az
@@ -28,6 +29,7 @@
2829

2930
from arviz.data.base import make_attrs
3031
from jax.lax import scan
32+
from numpy.typing import ArrayLike
3133
from pytensor.compile import SharedVariable, Supervisor, mode
3234
from pytensor.graph.basic import graph_inputs
3335
from pytensor.graph.fg import FunctionGraph
@@ -120,7 +122,7 @@ def _replace_shared_variables(graph: list[TensorVariable]) -> list[TensorVariabl
120122
def get_jaxified_graph(
121123
inputs: list[TensorVariable] | None = None,
122124
outputs: list[TensorVariable] | None = None,
123-
) -> list[TensorVariable]:
125+
) -> Callable[[list[TensorVariable]], list[TensorVariable]]:
124126
"""Compile a PyTensor graph into an optimized JAX function."""
125127
graph = _replace_shared_variables(outputs) if outputs is not None else None
126128

@@ -143,13 +145,13 @@ def get_jaxified_graph(
143145
return jax_funcify(fgraph)
144146

145147

146-
def get_jaxified_logp(model: Model, negative_logp=True) -> Callable:
148+
def get_jaxified_logp(model: Model, negative_logp: bool = True) -> Callable[[ArrayLike], jax.Array]:
147149
model_logp = model.logp()
148150
if not negative_logp:
149151
model_logp = -model_logp
150152
logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])
151153

152-
def logp_fn_wrap(x):
154+
def logp_fn_wrap(x: ArrayLike) -> jax.Array:
153155
return logp_fn(*x)[0]
154156

155157
return logp_fn_wrap
@@ -210,23 +212,43 @@ def _get_batched_jittered_initial_points(
210212
chains: int,
211213
initvals: StartDict | Sequence[StartDict | None] | None,
212214
random_seed: RandomSeed,
215+
logp_fn: Callable[[ArrayLike], jax.Array] | None = None,
213216
jitter: bool = True,
214217
jitter_max_retries: int = 10,
215218
) -> np.ndarray | list[np.ndarray]:
216-
"""Get jittered initial point in format expected by NumPyro MCMC kernel.
219+
"""Get jittered initial point in format expected by Jax MCMC kernel.
220+
221+
Parameters
222+
----------
223+
logp_fn : Callable[Sequence[np.ndarray]], np.ndarray]
224+
Jaxified logp function
217225
218226
Returns
219227
-------
220228
out: list of ndarrays
221229
list with one item per variable and number of chains as batch dimension.
222230
Each item has shape `(chains, *var.shape)`
223231
"""
232+
if logp_fn is None:
233+
eval_logp_initial_point = None
234+
235+
else:
236+
237+
def eval_logp_initial_point(point: dict[str, np.ndarray]) -> jax.Array:
238+
"""Wrap logp_fn to conform to _init_jitter logic.
239+
240+
Wraps jaxified logp function to accept a dict of
241+
{model_variable: np.array} key:value pairs.
242+
"""
243+
return logp_fn(point.values())
244+
224245
initial_points = _init_jitter(
225246
model,
226247
initvals,
227248
seeds=_get_seeds_per_chain(random_seed, chains),
228249
jitter=jitter,
229250
jitter_max_retries=jitter_max_retries,
251+
logp_fn=eval_logp_initial_point,
230252
)
231253
initial_points_values = [list(initial_point.values()) for initial_point in initial_points]
232254
if chains == 1:
@@ -235,7 +257,7 @@ def _get_batched_jittered_initial_points(
235257

236258

237259
def _blackjax_inference_loop(
238-
seed, init_position, logprob_fn, draws, tune, target_accept, **adaptation_kwargs
260+
seed, init_position, logp_fn, draws, tune, target_accept, **adaptation_kwargs
239261
):
240262
import blackjax
241263

@@ -251,13 +273,13 @@ def _blackjax_inference_loop(
251273

252274
adapt = blackjax.window_adaptation(
253275
algorithm=algorithm,
254-
logdensity_fn=logprob_fn,
276+
logdensity_fn=logp_fn,
255277
target_acceptance_rate=target_accept,
256278
adaptation_info_fn=get_filter_adapt_info_fn(),
257279
**adaptation_kwargs,
258280
)
259281
(last_state, tuned_params), _ = adapt.run(seed, init_position, num_steps=tune)
260-
kernel = algorithm(logprob_fn, **tuned_params).step
282+
kernel = algorithm(logp_fn, **tuned_params).step
261283

262284
def _one_step(state, xs):
263285
_, rng_key = xs
@@ -288,67 +310,51 @@ def _sample_blackjax_nuts(
288310
tune: int,
289311
draws: int,
290312
chains: int,
291-
chain_method: str | None,
313+
chain_method: Literal["parallel", "vectorized"],
292314
progressbar: bool,
293315
random_seed: int,
294-
initial_points,
295-
nuts_kwargs,
296-
) -> az.InferenceData:
316+
initial_points: np.ndarray | list[np.ndarray],
317+
nuts_kwargs: dict[str, Any],
318+
logp_fn: Callable[[ArrayLike], jax.Array] | None = None,
319+
) -> tuple[Any, dict[str, Any], ModuleType]:
297320
"""
298321
Draw samples from the posterior using the NUTS method from the ``blackjax`` library.
299322
300323
Parameters
301324
----------
302-
draws : int, default 1000
303-
The number of samples to draw. The number of tuned samples are discarded by
304-
default.
305-
tune : int, default 1000
325+
model : Model
326+
Model to sample from. The model needs to have free random variables.
327+
target_accept : float in [0, 1].
328+
The step size is tuned such that we approximate this acceptance rate. Higher
329+
values like 0.9 or 0.95 often work better for problematic posteriors.
330+
tune : int
306331
Number of iterations to tune. Samplers adjust the step sizes, scalings or
307332
similar during tuning. Tuning samples will be drawn in addition to the number
308333
specified in the ``draws`` argument.
309-
chains : int, default 4
334+
draws : int
335+
The number of samples to draw. The number of tuned samples are discarded by default.
336+
chains : int
310337
The number of chains to sample.
311-
target_accept : float in [0, 1].
312-
The step size is tuned such that we approximate this acceptance rate. Higher
313-
values like 0.9 or 0.95 often work better for problematic posteriors.
314-
random_seed : int, RandomState or Generator, optional
338+
chain_method : "parallel" or "vectorized"
339+
Specify how samples should be drawn.
340+
progressbar : bool
341+
Whether to show progressbar or not during sampling.
342+
random_seed : int, RandomState or Generator
315343
Random seed used by the sampling steps.
316-
initvals: StartDict or Sequence[Optional[StartDict]], optional
317-
Initial values for random variables provided as a dictionary (or sequence of
318-
dictionaries) mapping the random variable (by name or reference) to desired
319-
starting values.
320-
jitter: bool, default True
321-
If True, add jitter to initial points.
322-
model : Model, optional
323-
Model to sample from. The model needs to have free random variables. When inside
324-
a ``with`` model context, it defaults to that model, otherwise the model must be
325-
passed explicitly.
326-
var_names : sequence of str, optional
327-
Names of variables for which to compute the posterior samples. Defaults to all
328-
variables in the posterior.
329-
keep_untransformed : bool, default False
330-
Include untransformed variables in the posterior samples. Defaults to False.
331-
chain_method : str, default "parallel"
332-
Specify how samples should be drawn. The choices include "parallel", and
333-
"vectorized".
334-
postprocessing_backend: Optional[Literal["cpu", "gpu"]], default None,
335-
Specify how postprocessing should be computed. gpu or cpu
336-
postprocessing_vectorize: Literal["vmap", "scan"], default "scan"
337-
How to vectorize the postprocessing: vmap or sequential scan
338-
idata_kwargs : dict, optional
339-
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as
340-
value for the ``log_likelihood`` key to indicate that the pointwise log
341-
likelihood should not be included in the returned object. Values for
342-
``observed_data``, ``constant_data``, ``coords``, and ``dims`` are inferred from
343-
the ``model`` argument if not provided in ``idata_kwargs``. If ``coords`` and
344-
``dims`` are provided, they are used to update the inferred dictionaries.
344+
initial_points : np.ndarray or list[np.ndarray]
345+
Initial point(s) for sampler to begin sampling from.
346+
nuts_kwargs : dict
347+
Keyword arguments for the blackjax nuts sampler
348+
logp_fn : Callable[[ArrayLike], jax.Array], optional, default None
349+
jaxified logp function. If not passed in it will be created anew.
345350
346351
Returns
347352
-------
348-
InferenceData
349-
ArviZ ``InferenceData`` object that contains the posterior samples, together
350-
with their respective sample stats and pointwise log likeihood values (unless
351-
skipped with ``idata_kwargs``).
353+
raw_mcmc_samples
354+
Datastructure containing raw mcmc samples
355+
sample_stats : dict[str, Any]
356+
Dictionary containing sample stats
357+
blackjax : ModuleType["blackjax"]
352358
"""
353359
import blackjax
354360

@@ -365,15 +371,16 @@ def _sample_blackjax_nuts(
365371
if chains == 1:
366372
initial_points = [np.stack(init_state) for init_state in zip(initial_points)]
367373

368-
logprob_fn = get_jaxified_logp(model)
374+
if logp_fn is None:
375+
logp_fn = get_jaxified_logp(model)
369376

370377
seed = jax.random.PRNGKey(random_seed)
371378
keys = jax.random.split(seed, chains)
372379

373380
nuts_kwargs["progress_bar"] = progressbar
374381
get_posterior_samples = partial(
375382
_blackjax_inference_loop,
376-
logprob_fn=logprob_fn,
383+
logp_fn=logp_fn,
377384
tune=tune,
378385
draws=draws,
379386
target_accept=target_accept,
@@ -385,7 +392,7 @@ def _sample_blackjax_nuts(
385392

386393

387394
# Adopted from arviz numpyro extractor
388-
def _numpyro_stats_to_dict(posterior):
395+
def _numpyro_stats_to_dict(posterior) -> dict[str, Any]:
389396
"""Extract sample_stats from NumPyro posterior."""
390397
rename_key = {
391398
"potential_energy": "lp",
@@ -411,17 +418,58 @@ def _sample_numpyro_nuts(
411418
tune: int,
412419
draws: int,
413420
chains: int,
414-
chain_method: str | None,
421+
chain_method: Literal["parallel", "vectorized"],
415422
progressbar: bool,
416423
random_seed: int,
417-
initial_points,
424+
initial_points: np.ndarray | list[np.ndarray],
418425
nuts_kwargs: dict[str, Any],
419-
):
426+
logp_fn: Callable[[ArrayLike], jax.Array] | None = None,
427+
) -> tuple[Any, dict[str, Any], ModuleType]:
428+
"""
429+
Draw samples from the posterior using the NUTS method from the ``numpyro`` library.
430+
431+
Parameters
432+
----------
433+
model : Model
434+
Model to sample from. The model needs to have free random variables.
435+
target_accept : float in [0, 1].
436+
The step size is tuned such that we approximate this acceptance rate. Higher
437+
values like 0.9 or 0.95 often work better for problematic posteriors.
438+
tune : int
439+
Number of iterations to tune. Samplers adjust the step sizes, scalings or
440+
similar during tuning. Tuning samples will be drawn in addition to the number
441+
specified in the ``draws`` argument.
442+
draws : int
443+
The number of samples to draw. The number of tuned samples are discarded by default.
444+
chains : int
445+
The number of chains to sample.
446+
chain_method : "parallel" or "vectorized"
447+
Specify how samples should be drawn.
448+
progressbar : bool
449+
Whether to show progressbar or not during sampling.
450+
random_seed : int, RandomState or Generator
451+
Random seed used by the sampling steps.
452+
initial_points : np.ndarray or list[np.ndarray]
453+
Initial point(s) for sampler to begin sampling from.
454+
nuts_kwargs : dict
455+
Keyword arguments for the underlying numpyro nuts sampler
456+
logp_fn : Callable[[ArrayLike], jax.Array], optional, default None
457+
jaxified logp function. If not passed in it will be created anew.
458+
459+
Returns
460+
-------
461+
raw_mcmc_samples
462+
Datastructure containing raw mcmc samples
463+
sample_stats : dict[str, Any]
464+
Dictionary containing sample stats
465+
numpyro : ModuleType["numpyro"]
466+
"""
420467
import numpyro
421468

422469
from numpyro.infer import MCMC, NUTS
423470

424-
logp_fn = get_jaxified_logp(model, negative_logp=False)
471+
if logp_fn is None:
472+
logp_fn = get_jaxified_logp(model, negative_logp=False)
425473

426474
nuts_kwargs.setdefault("adapt_step_size", True)
427475
nuts_kwargs.setdefault("adapt_mass_matrix", True)
@@ -479,7 +527,7 @@ def sample_jax_nuts(
479527
nuts_kwargs: dict | None = None,
480528
progressbar: bool = True,
481529
keep_untransformed: bool = False,
482-
chain_method: str = "parallel",
530+
chain_method: Literal["parallel", "vectorized"] = "parallel",
483531
postprocessing_backend: Literal["cpu", "gpu"] | None = None,
484532
postprocessing_vectorize: Literal["vmap", "scan"] | None = None,
485533
postprocessing_chunks=None,
@@ -525,7 +573,7 @@ def sample_jax_nuts(
525573
If True, display a progressbar while sampling
526574
keep_untransformed : bool, default False
527575
Include untransformed variables in the posterior samples.
528-
chain_method : str, default "parallel"
576+
chain_method : Literal["parallel", "vectorized"], default "parallel"
529577
Specify how samples should be drawn. The choices include "parallel", and
530578
"vectorized".
531579
postprocessing_backend : Optional[Literal["cpu", "gpu"]], default None,
@@ -589,6 +637,15 @@ def sample_jax_nuts(
589637
get_default_varnames(filtered_var_names, include_transformed=keep_untransformed)
590638
)
591639

640+
if nuts_sampler == "numpyro":
641+
sampler_fn = _sample_numpyro_nuts
642+
logp_fn = get_jaxified_logp(model, negative_logp=False)
643+
elif nuts_sampler == "blackjax":
644+
sampler_fn = _sample_blackjax_nuts
645+
logp_fn = get_jaxified_logp(model)
646+
else:
647+
raise ValueError(f"{nuts_sampler=} not recognized")
648+
592649
(random_seed,) = _get_seeds_per_chain(random_seed, 1)
593650

594651
initial_points = _get_batched_jittered_initial_points(
@@ -597,15 +654,9 @@ def sample_jax_nuts(
597654
initvals=initvals,
598655
random_seed=random_seed,
599656
jitter=jitter,
657+
logp_fn=logp_fn,
600658
)
601659

602-
if nuts_sampler == "numpyro":
603-
sampler_fn = _sample_numpyro_nuts
604-
elif nuts_sampler == "blackjax":
605-
sampler_fn = _sample_blackjax_nuts
606-
else:
607-
raise ValueError(f"{nuts_sampler=} not recognized")
608-
609660
tic1 = datetime.now()
610661
raw_mcmc_samples, sample_stats, library = sampler_fn(
611662
model=model,

0 commit comments

Comments
 (0)