Skip to content

Fix failed test related to PR #443 by replacing pt.batched_dot with pt.vectorize(pt.dot,...) #453

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 12, 2025

Conversation

aphc14
Copy link
Contributor

@aphc14 aphc14 commented Apr 10, 2025

Fix failed test from PR #443 and replace pt.batched_dot with pt.einsum

  • Modified the calculation of mu in bfgs_sample_dense and bfgs_sample_sparse to use einsum instead of the deprecated batched_dot function.
  • Added calculations for H_inv in bfgs_sample_sparse to improve readability.
  • Renamed compile_pymc to compile in pathfinder functions to reflect API changes.
  • Updated warning filters to ignore UserWarnings related to einsum subscripts.

aphc14 added 3 commits March 30, 2025 02:19
* Moved the import statement for blackjax to ensure it is only imported when needed.
* Moved blackjax import statement prevents import errors for users on Windows.
* Updated the fit function to specify the return type as az.InferenceData.
…nsum

* Modified the calculation of  in  and  to use  instead of the deprecated  function.
* Added calculations for  in  to improve readability.
* Renamed  to  in pathfinder functions to reflect API changes.
* Updated warning filters to ignore UserWarnings related to einsum subscripts.
@aphc14
Copy link
Contributor Author

aphc14 commented Apr 10, 2025

hi @zaxtax, FYI the latest merged PR #443 failed recent tests. This PR should resolve those issues (pending pytest results).

@zaxtax
Copy link
Contributor

zaxtax commented Apr 10, 2025

It looks like it still has some errors. Once this passes CI we can merge

@aphc14
Copy link
Contributor Author

aphc14 commented Apr 10, 2025

Ahh damn I spoke too soon. The failures are mostly due to pytensor.graph.utils.InconsistencyError: Multiple destroyers of g, which I will look into sometime tomorrow.

I don't get this error when I run pytest locally (with the same package dependencies). If there's a better way to run tests locally that'll pick up on CI errors, please do share, as I'm keen to know if I'm missing anything!

@zaxtax
Copy link
Contributor

zaxtax commented Apr 10, 2025 via email

@zaxtax
Copy link
Contributor

zaxtax commented Apr 12, 2025

@aphc14 it might make sense while we figure out what caused this to have a more modest change where we filter the warning for batched_dot?

The regression appears to have been introduced between pytensor-2.30.2 pytensor-2.30.3. So we should try to make a minimal failing example and open an issue there too!

@ricardoV94
Copy link
Member

The warning is there for performance reasons it means you're using the same letter for something that has a different shape (but broadcasts fine) which is less efficient than using a different letter. Also users will still see the warning even if you ignore in the tests which is not a great experience.

@aphc14
Copy link
Contributor Author

aphc14 commented Apr 12, 2025

Thanks @zaxtax @ricardoV94 , that narrows it down quite a fair bit. Will attempt to resolve this within the next couple of hours 🤞

aphc14 added 2 commits April 13, 2025 01:44
…e(pt.dot,...)

* Fixed errors with deprecated einsum usage in bfgs_sample_dense and bfgs_sample_sparse functions by implementing pt.vectorize(pt.dot,...).
* Updated test_pathfinder to filter out deprecation warnings related to JAXopt.
@aphc14 aphc14 changed the title Fix failed test related to PR #443 by replacing pt.batched_dot with pt.einsum Fix failed test related to PR #443 by replacing pt.batched_dot with pt.vectorize(pt.dot,...) Apr 12, 2025
@aphc14
Copy link
Contributor Author

aphc14 commented Apr 12, 2025

@zaxtax @ricardoV94 done :)

replaced:

mu = x - pt.einsum("ijk,ik->ij", H_inv, g) # causes error: Multiple destroyers of g

with:

batched_dot = pt.vectorize(pt.dot, signature="(ijk),(ilk)->(ij)")
mu = x - batched_dot(H_inv, pt.matrix_transpose(g[..., None]))

pt.einsum is fussier for some reason, even though pt.vectorize(pt.dot,...) and the previous pt.batched_dot work

@zaxtax
Copy link
Contributor

zaxtax commented Apr 12, 2025 via email

@zaxtax zaxtax self-requested a review April 12, 2025 17:10
@zaxtax zaxtax merged commit abf45b7 into pymc-devs:main Apr 12, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants