Skip to content

BUG: tag_solve_triangular doesn't use a triangular solver #382

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
jessegrabowski opened this issue Jul 14, 2023 · 9 comments · Fixed by #383
Closed

BUG: tag_solve_triangular doesn't use a triangular solver #382

jessegrabowski opened this issue Jul 14, 2023 · 9 comments · Fixed by #383
Labels
bug Something isn't working

Comments

@jessegrabowski
Copy link
Member

Describe the issue:

After #303 I was thinking it would be nice to use the new lower_triangular and upper_triangular tags to re-write inverse and solves involving triangular matrices to use solve_triangular. I saw this optimization already exists as tag_solve_triangular. For example, this graph:

import pytensor
import pytensor.tensor as pt
import numpy as np

A = pt.dmatrix('A')
b = pt.dmatrix('b')
L = pt.linalg.cholesky(A)
X = pt.linalg.solve(L, b)
f = pytensor.function([A, b], [X])
pytensor.dprint(X)

Solve{assume_a='gen', lower=False, check_finite=True} [id A]
 ├─ Cholesky{lower=True, destructive=False, on_error='raise'} [id B]
 │  └─ A [id C]
 └─ b [id D]

Gets rewritten to:

pytensor.dprint(f)
Solve{assume_a='sym', lower=True, check_finite=True} [id A] 1
 ├─ Cholesky{lower=True, destructive=False, on_error='raise'} [id B] 0
 │  └─ A [id C]
 └─ b [id D]

But as I point out in #291, Solve(assume_a='sym', lower=True)(A, b) is not the same as solve_triangular(A, b, lower=True). Indeed, a lot of speed is being left on the table:

Z = np.random.normal(size=(5000, 5000))
P = Z @ Z.T
P_chol = np.linalg.cholesky(P)
eye = np.eye(5000)
%timeit f(P, eye)
>>> 3.69 s ± 48.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

from pytensor.tensor.slinalg import solve_triangular
X2 = solve_triangular(L, b, lower=True)
f2 = pytensor.function([A, b], [X2])
%timeit f2(P, eye)
>>> 1.36 s ± 15.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

And it seems like something is going wrong with the Solve() approach, because the following test fails:

from numpy.testing import assert_allclose
x1 = f(P, eye)[0]
x2 = f2(P, eye)[0]
assert_allclose(x1 @ P_chol, eye, atol=1e-8) # fails
assert_allclose(x2 @ P_chol, eye, atol=1e-8) # passes

Am I missing something with all this?

Reproducable code example:

See above

Error message:

No response

PyTensor version information:

Pytensor version: 2.12.3

Context for the issue:

No response

@jessegrabowski jessegrabowski added the bug Something isn't working label Jul 14, 2023
@ricardoV94
Copy link
Member

@jessegrabowski Are you saying the rewrite should just use the SolveTriangular instead of the generic Solve

@ricardoV94
Copy link
Member

@dehorsley is this something you would be interested in tweaking?

@jessegrabowski
Copy link
Member Author

Yeah that rewrite should use SolveTriangular, but it can also be made more general by checking for a lower_x tag instead of looking for a Cholesky parent. I believe that rewrite would miss a matrix sampled from LKJ, for example?

But I'm also curious why Solve(assume_a='sym') is taken to be equivalent to SolveTriangular everywhere in the code base. I assume someone knew something I don't? Also why does that assert fail, it should just be inv(P_chol) @ P_chol = Eye, but it isn't.

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 14, 2023

Edit: Nevermind I was looking at the wrong output


Anyway, it's fine to tweak any useful rewrite that you think of that can make use of the new tags. The tag solution is not very elegant but it's better than nothing for now.

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 14, 2023

The fault comes from this sequence of commits:

  • 86282bd refactored Solve so that any assume_a != "gen" now defaulted to scipy.linalg.solve_triangular, whereas before anything that was not in (lower_triangular, upper_triangular) defaulted to the generic scipy.linalg.solve. In this context the rewrite was still correct, although all other scipy tags sym, her, pos, were potentially faulty (unless the triangular solve also works for them).
  • 79961a6 Then fixed this inconsistency, and now Solve behaved as scipy.linalg.solve and the new SolveTriangular behaved as scipy.linalg.solve_triangular. However the rewrites were not changed to use the new Op. Apparently they were not being numerically tested either.

At this point the deprecation warning mentioned in #291 should also have been updated to direct people to use the new solve_triangular, instead of solve. That warning was removed during the cleanup in 6e3758f

@ricardoV94
Copy link
Member

I got confused with the tag in the name of this rewrite. It actually doesn't make any use of tags, unlike those in #303

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 15, 2023

@dehorsley is this something you would be interested in tweaking?

@dehorsley we realized there was a bug (not in your rewrites) and went ahead and fixed it.

However I think @jessegrabowski suggested we also expand the rewrite (renamed after #383) to anything that has a "triangular" tag? Did I get that correctly?

Also should we make the outputs of "Cholesky" have such a tag from the get go?

If Jesse confirms, we would appreciate your help with those tasks :)

@jessegrabowski
Copy link
Member Author

jessegrabowski commented Jul 15, 2023

Yeah my idea was to cast a wider net. Right now, the rewrite checks for "triangularity" by looking for a Cholesky parent node. So this graph is rewritten:

import pytensor.tensor as pt
import pytensor

A = pt.dmatrix("A")
b = pt.dmatrix('b')
L = pt.linalg.cholesky(A)
x = pt.linalg.solve(L, b)
f = pytensor.function([A, b], x)

But, importantly for PyMC users, this graph is not rewritten:

import pymc as pm
L2, *_ = pm.LKJCholeskyCov.dist(n=3, sd_dist=pm.Exponential.dist(1), eta=1)
x2 = pt.linalg.solve(L2, b)
f2 = pytensor.function([b], x2)

Because the cholesky factorized matrix returned by LKJCholeskyCov isn't actually created via the Cholesky Op. My thought was this case could be caught by looking at the new tags instead?

I also wanted to verify that if we have a tagged triangular matrix, pt.linalg.inv(L) gets re-written to pt.linalg.triangular_solve(L, pt.eye(L.shape[0])) via sequential application of rewrites. I think it should, but I also didn't check.

Also should we make the outputs of "Cholesky" have such a tag from the get go?

Yes, this was my thinking

@ricardoV94
Copy link
Member

Thanks @jessegrabowski. I opened #385 to keep track of it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants