Skip to content

Commit a149f6c

Browse files
committed
Enable new assume_a in Solve
1 parent 6e06f81 commit a149f6c

File tree

4 files changed

+164
-35
lines changed

4 files changed

+164
-35
lines changed

Diff for: pytensor/link/jax/dispatch/slinalg.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
import jax
24

35
from pytensor.link.jax.dispatch.basic import jax_funcify
@@ -39,13 +41,29 @@ def cholesky(a, lower=lower):
3941

4042
@jax_funcify.register(Solve)
4143
def jax_funcify_Solve(op, **kwargs):
42-
if op.assume_a != "gen" and op.lower:
43-
lower = True
44+
assume_a = op.assume_a
45+
lower = op.lower
46+
47+
if assume_a == "tridiagonal":
48+
# jax.scipy.solve does not yet support tridiagonal matrices
49+
# But there's a jax.lax.linalg.tridiaonal_solve we can use instead.
50+
def solve(a, b):
51+
dl = jax.numpy.diagonal(a, offset=-1, axis1=-2, axis2=-1)
52+
d = jax.numpy.diagonal(a, offset=0, axis1=-2, axis2=-1)
53+
du = jax.numpy.diagonal(a, offset=1, axis1=-2, axis2=-1)
54+
return jax.lax.linalg.tridiagonal_solve(dl, d, du, b, lower=lower)
55+
4456
else:
45-
lower = False
57+
if assume_a not in ("gen", "sym", "her", "pos"):
58+
warnings.warn(
59+
f"JAX solve does not support assume_a={op.assume_a}. Defaulting to assume_a='gen'.\n"
60+
f"If appropriate, you may want to set assume_a to one of 'sym', 'pos', 'her' or 'tridiagonal' to improve performance.",
61+
UserWarning,
62+
)
63+
assume_a = "gen"
4664

47-
def solve(a, b, lower=lower):
48-
return jax.scipy.linalg.solve(a, b, lower=lower)
65+
def solve(a, b):
66+
return jax.scipy.linalg.solve(a, b, lower=lower, assume_a=assume_a)
4967

5068
return solve
5169

Diff for: pytensor/link/numba/dispatch/slinalg.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from collections.abc import Callable
23

34
import numba
@@ -1071,14 +1072,17 @@ def numba_funcify_Solve(op, node, **kwargs):
10711072
elif assume_a == "sym":
10721073
solve_fn = _solve_symmetric
10731074
elif assume_a == "her":
1074-
raise NotImplementedError(
1075-
'Use assume_a = "sym" for symmetric real matrices. If you need compelx support, '
1076-
"please open an issue on github."
1077-
)
1075+
# We already ruled out complex inputs
1076+
solve_fn = _solve_symmetric
10781077
elif assume_a == "pos":
10791078
solve_fn = _solve_psd
10801079
else:
1081-
raise NotImplementedError(f"Assumption {assume_a} not supported in Numba mode")
1080+
warnings.warn(
1081+
f"Numba assume_a={assume_a} not implemented. Falling back to general solve.\n"
1082+
f"If appropriate, you may want to set assume_a to one of 'sym', 'pos', or 'her' to improve performance.",
1083+
UserWarning,
1084+
)
1085+
solve_fn = _solve_gen
10821086

10831087
@numba_basic.numba_njit(inline="always")
10841088
def solve(a, b):

Diff for: pytensor/tensor/slinalg.py

+98-23
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pytensor.tensor import TensorLike, as_tensor_variable
1616
from pytensor.tensor import basic as ptb
1717
from pytensor.tensor import math as ptm
18+
from pytensor.tensor.basic import diagonal
1819
from pytensor.tensor.blockwise import Blockwise
1920
from pytensor.tensor.nlinalg import kron, matrix_dot
2021
from pytensor.tensor.shape import reshape
@@ -260,10 +261,10 @@ def make_node(self, A, b):
260261
raise ValueError(f"`b` must have {self.b_ndim} dims; got {b.type} instead.")
261262

262263
# Infer dtype by solving the most simple case with 1x1 matrices
263-
inp_arr = [np.eye(1).astype(A.dtype), np.eye(1).astype(b.dtype)]
264-
out_arr = [[None]]
265-
self.perform(None, inp_arr, out_arr)
266-
o_dtype = out_arr[0][0].dtype
264+
o_dtype = scipy_linalg.solve(
265+
np.ones((1, 1), dtype=A.dtype),
266+
np.ones((1,), dtype=b.dtype),
267+
).dtype
267268
x = tensor(dtype=o_dtype, shape=b.type.shape)
268269
return Apply(self, [A, b], [x])
269270

@@ -315,7 +316,7 @@ def _default_b_ndim(b, b_ndim):
315316

316317
b = as_tensor_variable(b)
317318
if b_ndim is None:
318-
return min(b.ndim, 2) # By default assume the core case is a matrix
319+
return min(b.ndim, 2) # By default, assume the core case is a matrix
319320

320321

321322
class CholeskySolve(SolveBase):
@@ -332,6 +333,19 @@ def __init__(self, **kwargs):
332333
kwargs.setdefault("lower", True)
333334
super().__init__(**kwargs)
334335

336+
def make_node(self, *inputs):
337+
# Allow base class to do input validation
338+
super_apply = super().make_node(*inputs)
339+
A, b = super_apply.inputs
340+
[super_out] = super_apply.outputs
341+
# The dtype of chol_solve does not match solve, which the base class checks
342+
dtype = scipy_linalg.cho_solve(
343+
(np.ones((1, 1), dtype=A.dtype), False),
344+
np.ones((1,), dtype=b.dtype),
345+
).dtype
346+
out = tensor(dtype=dtype, shape=super_out.type.shape)
347+
return Apply(self, [A, b], [out])
348+
335349
def perform(self, node, inputs, output_storage):
336350
C, b = inputs
337351
rval = scipy_linalg.cho_solve(
@@ -499,8 +513,33 @@ class Solve(SolveBase):
499513
)
500514

501515
def __init__(self, *, assume_a="gen", **kwargs):
502-
if assume_a not in ("gen", "sym", "her", "pos"):
503-
raise ValueError(f"{assume_a} is not a recognized matrix structure")
516+
# Triangular and diagonal are handled outside of Solve
517+
valid_options = ["gen", "sym", "her", "pos", "tridiagonal", "banded"]
518+
519+
assume_a = assume_a.lower()
520+
# We use the old names as the different dispatches are more likely to support them
521+
long_to_short = {
522+
"general": "gen",
523+
"symmetric": "sym",
524+
"hermitian": "her",
525+
"positive definite": "pos",
526+
}
527+
assume_a = long_to_short.get(assume_a, assume_a)
528+
529+
if assume_a not in valid_options:
530+
raise ValueError(
531+
f"Invalid assume_a: {assume_a}. It must be one of {valid_options} or {list(long_to_short.keys())}"
532+
)
533+
534+
if assume_a in ("tridiagonal", "banded"):
535+
from scipy import __version__ as sp_version
536+
537+
if tuple(map(int, sp_version.split(".")[:-1])) < (1, 15):
538+
warnings.warn(
539+
f"assume_a={assume_a} requires scipy>=1.5.0. Defaulting to assume_a='gen'.",
540+
UserWarning,
541+
)
542+
assume_a = "gen"
504543

505544
super().__init__(**kwargs)
506545
self.assume_a = assume_a
@@ -536,10 +575,12 @@ def solve(
536575
a,
537576
b,
538577
*,
539-
assume_a="gen",
540-
lower=False,
541-
transposed=False,
542-
check_finite=True,
578+
lower: bool = False,
579+
overwrite_a: bool = False,
580+
overwrite_b: bool = False,
581+
check_finite: bool = True,
582+
assume_a: str = "gen",
583+
transposed: bool = False,
543584
b_ndim: int | None = None,
544585
):
545586
"""Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
@@ -548,14 +589,19 @@ def solve(
548589
corresponding string to ``assume_a`` key chooses the dedicated solver.
549590
The available options are
550591
551-
=================== ========
552-
generic matrix 'gen'
553-
symmetric 'sym'
554-
hermitian 'her'
555-
positive definite 'pos'
556-
=================== ========
592+
=================== ================================
593+
diagonal 'diagonal'
594+
tridiagonal 'tridiagonal'
595+
banded 'banded'
596+
upper triangular 'upper triangular'
597+
lower triangular 'lower triangular'
598+
symmetric 'symmetric' (or 'sym')
599+
hermitian 'hermitian' (or 'her')
600+
positive definite 'positive definite' (or 'pos')
601+
general 'general' (or 'gen')
602+
=================== ================================
557603
558-
If omitted, ``'gen'`` is the default structure.
604+
If omitted, ``'general'`` is the default structure.
559605
560606
The datatype of the arrays define which solver is called regardless
561607
of the values. In other words, even when the complex array entries have
@@ -568,23 +614,52 @@ def solve(
568614
Square input data
569615
b : (..., N, NRHS) array_like
570616
Input data for the right hand side.
571-
lower : bool, optional
572-
If True, use only the data contained in the lower triangle of `a`. Default
573-
is to use upper triangle. (ignored for ``'gen'``)
574-
transposed: bool, optional
575-
If True, solves the system A^T x = b. Default is False.
617+
lower : bool, default False
618+
Ignored unless ``assume_a`` is one of ``'sym'``, ``'her'``, or ``'pos'``.
619+
If True, the calculation uses only the data in the lower triangle of `a`;
620+
entries above the diagonal are ignored. If False (default), the
621+
calculation uses only the data in the upper triangle of `a`; entries
622+
below the diagonal are ignored.
623+
overwrite_a : bool
624+
Unused by PyTensor. PyTensor will always perform the operation in-place if possible.
625+
overwrite_b : bool
626+
Unused by PyTensor. PyTensor will always perform the operation in-place if possible.
576627
check_finite : bool, optional
577628
Whether to check that the input matrices contain only finite numbers.
578629
Disabling may give a performance gain, but may result in problems
579630
(crashes, non-termination) if the inputs do contain infinities or NaNs.
580631
assume_a : str, optional
581632
Valid entries are explained above.
633+
transposed: bool, default False
634+
If True, solves the system A^T x = b. Default is False.
582635
b_ndim : int
583636
Whether the core case of b is a vector (1) or matrix (2).
584637
This will influence how batched dimensions are interpreted.
638+
By default, we assume b_ndim = b.ndim is 2 if b.ndim > 1, else 1.
585639
"""
640+
assume_a = assume_a.lower()
641+
642+
if assume_a in ("lower triangular", "upper triangular"):
643+
lower = "lower" in assume_a
644+
return solve_triangular(
645+
a,
646+
b,
647+
lower=lower,
648+
trans=transposed,
649+
check_finite=check_finite,
650+
b_ndim=b_ndim,
651+
)
652+
586653
b_ndim = _default_b_ndim(b, b_ndim)
587654

655+
if assume_a == "diagonal":
656+
a_diagonal = diagonal(a, axis1=-2, axis2=-1)
657+
b_transposed = b[None, :] if b_ndim == 1 else b.mT
658+
x = (b_transposed / pt.expand_dims(a_diagonal, -2)).mT
659+
if b_ndim == 1:
660+
x = x.squeeze(-1)
661+
return x
662+
588663
if transposed:
589664
a = a.mT
590665
lower = not lower

Diff for: tests/tensor/test_slinalg.py

+34-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from pytensor import function, grad
1111
from pytensor import tensor as pt
1212
from pytensor.configdefaults import config
13+
from pytensor.graph.basic import equal_computations
14+
from pytensor.tensor import TensorVariable
1315
from pytensor.tensor.slinalg import (
1416
Cholesky,
1517
CholeskySolve,
@@ -211,8 +213,8 @@ def test__repr__(self):
211213
)
212214

213215

214-
def test_solve_raises_on_invalid_A():
215-
with pytest.raises(ValueError, match="is not a recognized matrix structure"):
216+
def test_solve_raises_on_invalid_assume_a():
217+
with pytest.raises(ValueError, match="Invalid assume_a: test. It must be one of"):
216218
Solve(assume_a="test", b_ndim=2)
217219

218220

@@ -225,6 +227,10 @@ def test_solve_raises_on_invalid_A():
225227
("pos", False, False),
226228
("pos", True, False),
227229
("pos", True, True),
230+
("diagonal", False, False),
231+
("diagonal", False, True),
232+
("tridiagonal", False, False),
233+
("tridiagonal", False, True),
228234
]
229235
solve_test_ids = [
230236
f'{assume_a}_{"lower" if lower else "upper"}_{"A^T" if transposed else "A"}'
@@ -239,6 +245,16 @@ def A_func(x, assume_a):
239245
return x @ x.T
240246
elif assume_a == "sym":
241247
return (x + x.T) / 2
248+
elif assume_a == "diagonal":
249+
eye_fn = pt.eye if isinstance(x, TensorVariable) else np.eye
250+
return x * eye_fn(x.shape[1])
251+
elif assume_a == "tridiagonal":
252+
eye_fn = pt.eye if isinstance(x, TensorVariable) else np.eye
253+
return x * (
254+
eye_fn(x.shape[1], k=0)
255+
+ eye_fn(x.shape[1], k=-1)
256+
+ eye_fn(x.shape[1], k=1)
257+
)
242258
else:
243259
return x
244260

@@ -346,6 +362,22 @@ def test_solve_gradient(
346362
lambda A, b: solve_op(A_func(A), b), [A_val, b_val], 3, rng, eps=eps
347363
)
348364

365+
def test_solve_tringular_indirection(self):
366+
a = pt.matrix("a")
367+
b = pt.vector("b")
368+
369+
indirect = solve(a, b, assume_a="lower triangular")
370+
direct = solve_triangular(a, b, lower=True, trans=False)
371+
assert equal_computations([indirect], [direct])
372+
373+
indirect = solve(a, b, assume_a="upper triangular")
374+
direct = solve_triangular(a, b, lower=False, trans=False)
375+
assert equal_computations([indirect], [direct])
376+
377+
indirect = solve(a, b, assume_a="upper triangular", transposed=True)
378+
direct = solve_triangular(a, b, lower=False, trans=True)
379+
assert equal_computations([indirect], [direct])
380+
349381

350382
class TestSolveTriangular(utt.InferShapeTester):
351383
@staticmethod

0 commit comments

Comments
 (0)