Skip to content

Commit 0f5da80

Browse files
committed
More stable fix for JAX Multinomial
1 parent d9b1085 commit 0f5da80

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

Diff for: pytensor/link/jax/dispatch/random.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,9 @@ def _binomial_sample_fn(carry, p_rng):
412412
remaining_n, remaining_p = carry
413413
p, rng = p_rng
414414
samples = jnp.where(
415-
p == 0, 0, jax.random.binomial(rng, remaining_n, p / remaining_p)
415+
remaining_n == 0,
416+
0,
417+
jax.random.binomial(rng, remaining_n, p / remaining_p),
416418
)
417419
remaining_n -= samples
418420
remaining_p -= p

0 commit comments

Comments
 (0)