Skip to content

Fix nan for valid parameters in jax implementation of Multinomial #1328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 30, 2025

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Mar 28, 2025

Closes #1327


📚 Documentation preview 📚: https://pytensor--1328.org.readthedocs.build/en/1328/

Copy link

codecov bot commented Mar 28, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 82.01%. Comparing base (0b56ed9) to head (bd85204).
Report is 1 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #1328   +/-   ##
=======================================
  Coverage   82.01%   82.01%           
=======================================
  Files         203      203           
  Lines       48805    48805           
  Branches     8688     8688           
=======================================
  Hits        40026    40026           
  Misses       6627     6627           
  Partials     2152     2152           
Files with missing lines Coverage Δ
pytensor/link/jax/dispatch/random.py 94.01% <100.00%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@educhesne
Copy link
Contributor

educhesne commented Mar 28, 2025

Your fix solves the problem when p is valid (the sum of p is 1), but if not you may have rho==0 without p==0; otherwise it is possible to do:
samples = jax.random.binomial(rng, s, jnp.where(rho == 0, 0, p / rho))

@ricardoV94
Copy link
Member Author

Ah the problem is the division by zero, yes that's better. I guess I should also test n=0 multinomial

@ricardoV94 ricardoV94 force-pushed the fix_jax_multinomial branch from 34cf7c1 to bd85204 Compare March 29, 2025 06:51
@ricardoV94
Copy link
Member Author

Your fix solves the problem when p is valid (the sum of p is 1), but if not you may have rho==0 without p==0; otherwise it is possible to do: samples = jax.random.binomial(rng, s, jnp.where(rho == 0, 0, p / rho))

Actually I don't think I want to mask that, if p adds to more than 1, then it's fine to get nans. In numpy it would raise an error, but jax doesn't have runtime errors.

@ricardoV94 ricardoV94 changed the title Fix nan in jax implementation of Multinomial Fix nan for valid parameters in jax implementation of Multinomial Mar 30, 2025
@jessegrabowski jessegrabowski merged commit 3af923b into pymc-devs:main Mar 30, 2025
73 checks passed
@ricardoV94 ricardoV94 deleted the fix_jax_multinomial branch March 30, 2025 12:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working jax random variables
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Nans in JAX multinomial dispatch
3 participants