18
18
from collections .abc import Callable , Sequence
19
19
from datetime import datetime
20
20
from functools import partial
21
+ from types import ModuleType
21
22
from typing import Any , Literal
22
23
23
24
import arviz as az
28
29
29
30
from arviz .data .base import make_attrs
30
31
from jax .lax import scan
32
+ from numpy .typing import ArrayLike
31
33
from pytensor .compile import SharedVariable , Supervisor , mode
32
34
from pytensor .graph .basic import graph_inputs
33
35
from pytensor .graph .fg import FunctionGraph
@@ -120,7 +122,7 @@ def _replace_shared_variables(graph: list[TensorVariable]) -> list[TensorVariabl
120
122
def get_jaxified_graph (
121
123
inputs : list [TensorVariable ] | None = None ,
122
124
outputs : list [TensorVariable ] | None = None ,
123
- ) -> list [TensorVariable ]:
125
+ ) -> Callable [[ list [TensorVariable ]], list [ TensorVariable ] ]:
124
126
"""Compile a PyTensor graph into an optimized JAX function."""
125
127
graph = _replace_shared_variables (outputs ) if outputs is not None else None
126
128
@@ -143,13 +145,13 @@ def get_jaxified_graph(
143
145
return jax_funcify (fgraph )
144
146
145
147
146
- def get_jaxified_logp (model : Model , negative_logp = True ) -> Callable :
148
+ def get_jaxified_logp (model : Model , negative_logp : bool = True ) -> Callable [[ ArrayLike ], jax . Array ] :
147
149
model_logp = model .logp ()
148
150
if not negative_logp :
149
151
model_logp = - model_logp
150
152
logp_fn = get_jaxified_graph (inputs = model .value_vars , outputs = [model_logp ])
151
153
152
- def logp_fn_wrap (x ) :
154
+ def logp_fn_wrap (x : ArrayLike ) -> jax . Array :
153
155
return logp_fn (* x )[0 ]
154
156
155
157
return logp_fn_wrap
@@ -210,23 +212,43 @@ def _get_batched_jittered_initial_points(
210
212
chains : int ,
211
213
initvals : StartDict | Sequence [StartDict | None ] | None ,
212
214
random_seed : RandomSeed ,
215
+ logp_fn : Callable [[ArrayLike ], jax .Array ] | None = None ,
213
216
jitter : bool = True ,
214
217
jitter_max_retries : int = 10 ,
215
218
) -> np .ndarray | list [np .ndarray ]:
216
- """Get jittered initial point in format expected by NumPyro MCMC kernel.
219
+ """Get jittered initial point in format expected by Jax MCMC kernel.
220
+
221
+ Parameters
222
+ ----------
223
+ logp_fn : Callable[Sequence[np.ndarray]], np.ndarray]
224
+ Jaxified logp function
217
225
218
226
Returns
219
227
-------
220
228
out: list of ndarrays
221
229
list with one item per variable and number of chains as batch dimension.
222
230
Each item has shape `(chains, *var.shape)`
223
231
"""
232
+ if logp_fn is None :
233
+ eval_logp_initial_point = None
234
+
235
+ else :
236
+
237
+ def eval_logp_initial_point (point : dict [str , np .ndarray ]) -> jax .Array :
238
+ """Wrap logp_fn to conform to _init_jitter logic.
239
+
240
+ Wraps jaxified logp function to accept a dict of
241
+ {model_variable: np.array} key:value pairs.
242
+ """
243
+ return logp_fn (point .values ())
244
+
224
245
initial_points = _init_jitter (
225
246
model ,
226
247
initvals ,
227
248
seeds = _get_seeds_per_chain (random_seed , chains ),
228
249
jitter = jitter ,
229
250
jitter_max_retries = jitter_max_retries ,
251
+ logp_fn = eval_logp_initial_point ,
230
252
)
231
253
initial_points_values = [list (initial_point .values ()) for initial_point in initial_points ]
232
254
if chains == 1 :
@@ -235,7 +257,7 @@ def _get_batched_jittered_initial_points(
235
257
236
258
237
259
def _blackjax_inference_loop (
238
- seed , init_position , logprob_fn , draws , tune , target_accept , ** adaptation_kwargs
260
+ seed , init_position , logp_fn , draws , tune , target_accept , ** adaptation_kwargs
239
261
):
240
262
import blackjax
241
263
@@ -251,13 +273,13 @@ def _blackjax_inference_loop(
251
273
252
274
adapt = blackjax .window_adaptation (
253
275
algorithm = algorithm ,
254
- logdensity_fn = logprob_fn ,
276
+ logdensity_fn = logp_fn ,
255
277
target_acceptance_rate = target_accept ,
256
278
adaptation_info_fn = get_filter_adapt_info_fn (),
257
279
** adaptation_kwargs ,
258
280
)
259
281
(last_state , tuned_params ), _ = adapt .run (seed , init_position , num_steps = tune )
260
- kernel = algorithm (logprob_fn , ** tuned_params ).step
282
+ kernel = algorithm (logp_fn , ** tuned_params ).step
261
283
262
284
def _one_step (state , xs ):
263
285
_ , rng_key = xs
@@ -288,67 +310,51 @@ def _sample_blackjax_nuts(
288
310
tune : int ,
289
311
draws : int ,
290
312
chains : int ,
291
- chain_method : str | None ,
313
+ chain_method : Literal [ "parallel" , "vectorized" ] ,
292
314
progressbar : bool ,
293
315
random_seed : int ,
294
- initial_points ,
295
- nuts_kwargs ,
296
- ) -> az .InferenceData :
316
+ initial_points : np .ndarray | list [np .ndarray ],
317
+ nuts_kwargs : dict [str , Any ],
318
+ logp_fn : Callable [[ArrayLike ], jax .Array ] | None = None ,
319
+ ) -> tuple [Any , dict [str , Any ], ModuleType ]:
297
320
"""
298
321
Draw samples from the posterior using the NUTS method from the ``blackjax`` library.
299
322
300
323
Parameters
301
324
----------
302
- draws : int, default 1000
303
- The number of samples to draw. The number of tuned samples are discarded by
304
- default.
305
- tune : int, default 1000
325
+ model : Model
326
+ Model to sample from. The model needs to have free random variables.
327
+ target_accept : float in [0, 1].
328
+ The step size is tuned such that we approximate this acceptance rate. Higher
329
+ values like 0.9 or 0.95 often work better for problematic posteriors.
330
+ tune : int
306
331
Number of iterations to tune. Samplers adjust the step sizes, scalings or
307
332
similar during tuning. Tuning samples will be drawn in addition to the number
308
333
specified in the ``draws`` argument.
309
- chains : int, default 4
334
+ draws : int
335
+ The number of samples to draw. The number of tuned samples are discarded by default.
336
+ chains : int
310
337
The number of chains to sample.
311
- target_accept : float in [0, 1].
312
- The step size is tuned such that we approximate this acceptance rate. Higher
313
- values like 0.9 or 0.95 often work better for problematic posteriors.
314
- random_seed : int, RandomState or Generator, optional
338
+ chain_method : "parallel" or "vectorized"
339
+ Specify how samples should be drawn.
340
+ progressbar : bool
341
+ Whether to show progressbar or not during sampling.
342
+ random_seed : int, RandomState or Generator
315
343
Random seed used by the sampling steps.
316
- initvals: StartDict or Sequence[Optional[StartDict]], optional
317
- Initial values for random variables provided as a dictionary (or sequence of
318
- dictionaries) mapping the random variable (by name or reference) to desired
319
- starting values.
320
- jitter: bool, default True
321
- If True, add jitter to initial points.
322
- model : Model, optional
323
- Model to sample from. The model needs to have free random variables. When inside
324
- a ``with`` model context, it defaults to that model, otherwise the model must be
325
- passed explicitly.
326
- var_names : sequence of str, optional
327
- Names of variables for which to compute the posterior samples. Defaults to all
328
- variables in the posterior.
329
- keep_untransformed : bool, default False
330
- Include untransformed variables in the posterior samples. Defaults to False.
331
- chain_method : str, default "parallel"
332
- Specify how samples should be drawn. The choices include "parallel", and
333
- "vectorized".
334
- postprocessing_backend: Optional[Literal["cpu", "gpu"]], default None,
335
- Specify how postprocessing should be computed. gpu or cpu
336
- postprocessing_vectorize: Literal["vmap", "scan"], default "scan"
337
- How to vectorize the postprocessing: vmap or sequential scan
338
- idata_kwargs : dict, optional
339
- Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as
340
- value for the ``log_likelihood`` key to indicate that the pointwise log
341
- likelihood should not be included in the returned object. Values for
342
- ``observed_data``, ``constant_data``, ``coords``, and ``dims`` are inferred from
343
- the ``model`` argument if not provided in ``idata_kwargs``. If ``coords`` and
344
- ``dims`` are provided, they are used to update the inferred dictionaries.
344
+ initial_points : np.ndarray or list[np.ndarray]
345
+ Initial point(s) for sampler to begin sampling from.
346
+ nuts_kwargs : dict
347
+ Keyword arguments for the blackjax nuts sampler
348
+ logp_fn : Callable[[ArrayLike], jax.Array], optional, default None
349
+ jaxified logp function. If not passed in it will be created anew.
345
350
346
351
Returns
347
352
-------
348
- InferenceData
349
- ArviZ ``InferenceData`` object that contains the posterior samples, together
350
- with their respective sample stats and pointwise log likeihood values (unless
351
- skipped with ``idata_kwargs``).
353
+ raw_mcmc_samples
354
+ Datastructure containing raw mcmc samples
355
+ sample_stats : dict[str, Any]
356
+ Dictionary containing sample stats
357
+ blackjax : ModuleType["blackjax"]
352
358
"""
353
359
import blackjax
354
360
@@ -365,15 +371,16 @@ def _sample_blackjax_nuts(
365
371
if chains == 1 :
366
372
initial_points = [np .stack (init_state ) for init_state in zip (initial_points )]
367
373
368
- logprob_fn = get_jaxified_logp (model )
374
+ if logp_fn is None :
375
+ logp_fn = get_jaxified_logp (model )
369
376
370
377
seed = jax .random .PRNGKey (random_seed )
371
378
keys = jax .random .split (seed , chains )
372
379
373
380
nuts_kwargs ["progress_bar" ] = progressbar
374
381
get_posterior_samples = partial (
375
382
_blackjax_inference_loop ,
376
- logprob_fn = logprob_fn ,
383
+ logp_fn = logp_fn ,
377
384
tune = tune ,
378
385
draws = draws ,
379
386
target_accept = target_accept ,
@@ -385,7 +392,7 @@ def _sample_blackjax_nuts(
385
392
386
393
387
394
# Adopted from arviz numpyro extractor
388
- def _numpyro_stats_to_dict (posterior ):
395
+ def _numpyro_stats_to_dict (posterior ) -> dict [ str , Any ] :
389
396
"""Extract sample_stats from NumPyro posterior."""
390
397
rename_key = {
391
398
"potential_energy" : "lp" ,
@@ -411,17 +418,58 @@ def _sample_numpyro_nuts(
411
418
tune : int ,
412
419
draws : int ,
413
420
chains : int ,
414
- chain_method : str | None ,
421
+ chain_method : Literal [ "parallel" , "vectorized" ] ,
415
422
progressbar : bool ,
416
423
random_seed : int ,
417
- initial_points ,
424
+ initial_points : np . ndarray | list [ np . ndarray ] ,
418
425
nuts_kwargs : dict [str , Any ],
419
- ):
426
+ logp_fn : Callable [[ArrayLike ], jax .Array ] | None = None ,
427
+ ) -> tuple [Any , dict [str , Any ], ModuleType ]:
428
+ """
429
+ Draw samples from the posterior using the NUTS method from the ``numpyro`` library.
430
+
431
+ Parameters
432
+ ----------
433
+ model : Model
434
+ Model to sample from. The model needs to have free random variables.
435
+ target_accept : float in [0, 1].
436
+ The step size is tuned such that we approximate this acceptance rate. Higher
437
+ values like 0.9 or 0.95 often work better for problematic posteriors.
438
+ tune : int
439
+ Number of iterations to tune. Samplers adjust the step sizes, scalings or
440
+ similar during tuning. Tuning samples will be drawn in addition to the number
441
+ specified in the ``draws`` argument.
442
+ draws : int
443
+ The number of samples to draw. The number of tuned samples are discarded by default.
444
+ chains : int
445
+ The number of chains to sample.
446
+ chain_method : "parallel" or "vectorized"
447
+ Specify how samples should be drawn.
448
+ progressbar : bool
449
+ Whether to show progressbar or not during sampling.
450
+ random_seed : int, RandomState or Generator
451
+ Random seed used by the sampling steps.
452
+ initial_points : np.ndarray or list[np.ndarray]
453
+ Initial point(s) for sampler to begin sampling from.
454
+ nuts_kwargs : dict
455
+ Keyword arguments for the underlying numpyro nuts sampler
456
+ logp_fn : Callable[[ArrayLike], jax.Array], optional, default None
457
+ jaxified logp function. If not passed in it will be created anew.
458
+
459
+ Returns
460
+ -------
461
+ raw_mcmc_samples
462
+ Datastructure containing raw mcmc samples
463
+ sample_stats : dict[str, Any]
464
+ Dictionary containing sample stats
465
+ numpyro : ModuleType["numpyro"]
466
+ """
420
467
import numpyro
421
468
422
469
from numpyro .infer import MCMC , NUTS
423
470
424
- logp_fn = get_jaxified_logp (model , negative_logp = False )
471
+ if logp_fn is None :
472
+ logp_fn = get_jaxified_logp (model , negative_logp = False )
425
473
426
474
nuts_kwargs .setdefault ("adapt_step_size" , True )
427
475
nuts_kwargs .setdefault ("adapt_mass_matrix" , True )
@@ -479,7 +527,7 @@ def sample_jax_nuts(
479
527
nuts_kwargs : dict | None = None ,
480
528
progressbar : bool = True ,
481
529
keep_untransformed : bool = False ,
482
- chain_method : str = "parallel" ,
530
+ chain_method : Literal [ "parallel" , "vectorized" ] = "parallel" ,
483
531
postprocessing_backend : Literal ["cpu" , "gpu" ] | None = None ,
484
532
postprocessing_vectorize : Literal ["vmap" , "scan" ] | None = None ,
485
533
postprocessing_chunks = None ,
@@ -525,7 +573,7 @@ def sample_jax_nuts(
525
573
If True, display a progressbar while sampling
526
574
keep_untransformed : bool, default False
527
575
Include untransformed variables in the posterior samples.
528
- chain_method : str , default "parallel"
576
+ chain_method : Literal["parallel", "vectorized"] , default "parallel"
529
577
Specify how samples should be drawn. The choices include "parallel", and
530
578
"vectorized".
531
579
postprocessing_backend : Optional[Literal["cpu", "gpu"]], default None,
@@ -589,6 +637,15 @@ def sample_jax_nuts(
589
637
get_default_varnames (filtered_var_names , include_transformed = keep_untransformed )
590
638
)
591
639
640
+ if nuts_sampler == "numpyro" :
641
+ sampler_fn = _sample_numpyro_nuts
642
+ logp_fn = get_jaxified_logp (model , negative_logp = False )
643
+ elif nuts_sampler == "blackjax" :
644
+ sampler_fn = _sample_blackjax_nuts
645
+ logp_fn = get_jaxified_logp (model )
646
+ else :
647
+ raise ValueError (f"{ nuts_sampler = } not recognized" )
648
+
592
649
(random_seed ,) = _get_seeds_per_chain (random_seed , 1 )
593
650
594
651
initial_points = _get_batched_jittered_initial_points (
@@ -597,15 +654,9 @@ def sample_jax_nuts(
597
654
initvals = initvals ,
598
655
random_seed = random_seed ,
599
656
jitter = jitter ,
657
+ logp_fn = logp_fn ,
600
658
)
601
659
602
- if nuts_sampler == "numpyro" :
603
- sampler_fn = _sample_numpyro_nuts
604
- elif nuts_sampler == "blackjax" :
605
- sampler_fn = _sample_blackjax_nuts
606
- else :
607
- raise ValueError (f"{ nuts_sampler = } not recognized" )
608
-
609
660
tic1 = datetime .now ()
610
661
raw_mcmc_samples , sample_stats , library = sampler_fn (
611
662
model = model ,
0 commit comments