-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Transform result of JAX sampling #4415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Hi Martin, thanks - that'd be definitely be useful. Is that something you'd be interested in working on? We're still looking for help on the JAX backend. |
This comment has been minimized.
This comment has been minimized.
Hi all, I believe the pull request addresses this issue. Sorry @kc611 , I just gave it a shot, but if you like, if you see anything that could be improved please let me know! |
This comment has been minimized.
This comment has been minimized.
Fixed by #4427 |
Is the fix on #4427 exclusive to the Numpyro sampler? I used the JAX-TFP sampler today and I got neither my transformed variables nor my deterministic variables. If that is the case, we probably should reopen this issue. (My install is pymc3 3.11.2, jax 0.2.11, tensorflow 2.5.0, tensorflow-probability 0.13.0. running on Google Colab with GPU) |
Any reason you don't use the numpyro sampler? I thought we might want to
remove the TFP one.
…On Wed, Jul 14, 2021, 07:02 Pedro Sebe ***@***.***> wrote:
Is the fix on #4427 <#4427>
exclusive to the Numpyro sampler? I used the JAX-TFP sampler today and I
got neither my transformed variables nor my deterministic variables. If
that is the case, we probably should reopen this issue.
(My install is pymc3 3.11.2, jax 0.2.11, tensorflow 2.5.0,
tensorflow-probability 0.13.0. running on Google Colab with GPU)
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#4415 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAFETGBGWDKBMP4TSAMB2VLTXUK5ZANCNFSM4V44RMZQ>
.
|
I was running on a Google Colab instance, and using Numpyro caused the problem reported at #4645. I was not aware that TFP sampler is going to be removed, will it happen on the next release? |
This is fixed in a recent |
Hi all,
I'm assuming this is known but since no issue seems to have been raised about it (let me know if I just missed it, sorry!), I wanted to raise one.
When sampling using JAX, it appears that the result is given on the unconstrained representation of the variables. For example, in the Radon example, the parameter
sigma_a
yields samples calledsigma_a_log__
, which are in fact samples fromlog(sigma_a)
. This is easy enough to change as a user in this case (just transform withnp.exp
e.g.), but it gets a bit more fiddly when sampling e.g. covariance matrices, where two transformations have to be composed (exp for the diagonal elements followed by L L^T to get the covariance). I'm assuming thatpymc3
's original NUTS sampler required these transformations also so hopefully the same logic can be used here, too.Best,
Martin
The text was updated successfully, but these errors were encountered: