Skip to content

Transform result of JAX sampling #4415

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
martiningram opened this issue Jan 11, 2021 · 10 comments
Closed

Transform result of JAX sampling #4415

martiningram opened this issue Jan 11, 2021 · 10 comments

Comments

@martiningram
Copy link
Contributor

martiningram commented Jan 11, 2021

Hi all,

I'm assuming this is known but since no issue seems to have been raised about it (let me know if I just missed it, sorry!), I wanted to raise one.

When sampling using JAX, it appears that the result is given on the unconstrained representation of the variables. For example, in the Radon example, the parameter sigma_a yields samples called sigma_a_log__, which are in fact samples from log(sigma_a). This is easy enough to change as a user in this case (just transform with np.exp e.g.), but it gets a bit more fiddly when sampling e.g. covariance matrices, where two transformations have to be composed (exp for the diagonal elements followed by L L^T to get the covariance). I'm assuming that pymc3's original NUTS sampler required these transformations also so hopefully the same logic can be used here, too.

Best,
Martin

@twiecki
Copy link
Member

twiecki commented Jan 20, 2021

Hi Martin, thanks - that'd be definitely be useful. Is that something you'd be interested in working on? We're still looking for help on the JAX backend.

@kc611

This comment has been minimized.

@martiningram
Copy link
Contributor Author

I'm happy to give it a go too @twiecki ! @kc611 not sure if we could work together on this somehow? I'm planning to take a first look today to see if I can work out the original logic, can keep you posted.

@martiningram
Copy link
Contributor Author

Hi all, I believe the pull request addresses this issue. Sorry @kc611 , I just gave it a shot, but if you like, if you see anything that could be improved please let me know!

@kc611

This comment has been minimized.

@martiningram
Copy link
Contributor Author

Fixed by #4427

@PedroSebe
Copy link

Is the fix on #4427 exclusive to the Numpyro sampler? I used the JAX-TFP sampler today and I got neither my transformed variables nor my deterministic variables. If that is the case, we probably should reopen this issue.

(My install is pymc3 3.11.2, jax 0.2.11, tensorflow 2.5.0, tensorflow-probability 0.13.0. running on Google Colab with GPU)

@twiecki
Copy link
Member

twiecki commented Jul 14, 2021 via email

@PedroSebe
Copy link

I was running on a Google Colab instance, and using Numpyro caused the problem reported at #4645. I was not aware that TFP sampler is going to be removed, will it happen on the next release?

@twiecki
Copy link
Member

twiecki commented Jul 19, 2021

This is fixed in a recent aesara version. Can you try installing aesara and pymc3 main?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants