-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Reuse jaxified logp when sampling via jax #7681
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #7681 +/- ##
==========================================
- Coverage 92.70% 92.64% -0.06%
==========================================
Files 107 107
Lines 18391 18324 -67
==========================================
- Hits 17050 16977 -73
- Misses 1341 1347 +6
🚀 New features to boost your workflow:
|
@ricardoV94 not sure if you've seen this, but it's a super tiny change that we should have included with #7610 that I just missed |
Claude Code: |
Useless |
Thanks @nataziel |
reuse jaxified logp when sampling via jax
Description
#7610 added logic to handle passing a pre-jaxified logp function into the blackjax/numpyro samplers, but missed actually passing the jaxified logp that is computed in
sample_jax_nuts
Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7681.org.readthedocs.build/en/7681/