Skip to content

Commit 8cdc9ee

Browse files
authored
Blackjax sampler fix for breaking change / enable progress bar under parallel chain_method (#7453)
* remove blackjax pmap warning * use gen_scan_fn * remove labels * retrigger checks * retrigger checks
1 parent f3cff73 commit 8cdc9ee

File tree

1 file changed

+2
-15
lines changed

1 file changed

+2
-15
lines changed

pymc/sampling/jax.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -278,15 +278,10 @@ def _one_step(state, xs):
278278
return state, (position, stats)
279279

280280
progress_bar = adaptation_kwargs.pop("progress_bar", False)
281-
if progress_bar:
282-
from blackjax.progress_bar import progress_bar_scan
283-
284-
one_step = jax.jit(progress_bar_scan(draws)(_one_step))
285-
else:
286-
one_step = jax.jit(_one_step)
287281

288282
keys = jax.random.split(seed, draws)
289-
_, (samples, stats) = jax.lax.scan(one_step, last_state, (jnp.arange(draws), keys))
283+
scan_fn = blackjax.progress_bar.gen_scan_fn(draws, progress_bar)
284+
_, (samples, stats) = scan_fn(_one_step, last_state, (jnp.arange(draws), keys))
290285

291286
return samples, stats
292287

@@ -365,14 +360,6 @@ def _sample_blackjax_nuts(
365360
# Adapted from numpyro
366361
if chain_method == "parallel":
367362
map_fn = jax.pmap
368-
if progressbar:
369-
import warnings
370-
371-
warnings.warn(
372-
"BlackJax currently only display progress bar correctly under "
373-
"`chain_method == 'vectorized'`. Setting `progressbar=False`."
374-
)
375-
progressbar = False
376363
elif chain_method == "vectorized":
377364
map_fn = jax.vmap
378365
else:

0 commit comments

Comments
 (0)