diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index 678b2dc486..37f9362ed1 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -409,12 +409,14 @@ def sample_fn(rng_key, size, dtype, n, p): sampling_rng = jax.random.split(rng_key, binom_p.shape[0]) def _binomial_sample_fn(carry, p_rng): - s, rho = carry + remaining_n, remaining_p = carry p, rng = p_rng - samples = jax.random.binomial(rng, s, p / rho) - s = s - samples - rho = rho - p - return ((s, rho), samples) + samples = jnp.where( + p == 0, 0, jax.random.binomial(rng, remaining_n, p / remaining_p) + ) + remaining_n -= samples + remaining_p -= p + return ((remaining_n, remaining_p), samples) (remain, _), samples = jax.lax.scan( _binomial_sample_fn, diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index 183b629f79..04be3c881e 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -733,6 +733,18 @@ def test_multinomial(): samples.std(axis=0), np.sqrt(n[0, :, None] * p * (1 - p)), rtol=0.1 ) + # Test with p=0 + g = pt.random.multinomial(n=5, p=pt.eye(4)) + g_fn = compile_random_function([], g, mode="JAX") + samples = g_fn() + np.testing.assert_array_equal(samples, np.eye(4) * 5) + + # Test with n=0 + g = pt.random.multinomial(n=0, p=np.ones(4) / 4) + g_fn = compile_random_function([], g, mode="JAX") + samples = g_fn() + np.testing.assert_array_equal(samples, np.zeros(4)) + @pytest.mark.skipif(not numpyro_available, reason="VonMises dispatch requires numpyro") def test_vonmises_mu_outside_circle():