Skip to content

Commit bf628c9

Browse files
Allow transposed argument in linalg.solve (#1231)
* Add transposed argument to `solve` and `solve_triangular` * Expand test coverage for `Solve` and `SolveTriangular`
1 parent 757a10c commit bf628c9

File tree

6 files changed

+210
-111
lines changed

6 files changed

+210
-111
lines changed

pytensor/link/jax/dispatch/slinalg.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ def solve(a, b, lower=lower):
5353
@jax_funcify.register(SolveTriangular)
5454
def jax_funcify_SolveTriangular(op, **kwargs):
5555
lower = op.lower
56-
trans = op.trans
5756
unit_diagonal = op.unit_diagonal
5857
check_finite = op.check_finite
5958

@@ -62,7 +61,7 @@ def solve_triangular(A, b):
6261
A,
6362
b,
6463
lower=lower,
65-
trans=trans,
64+
trans=0, # this is handled by explicitly transposing A, so it will always be 0 when we get to here.
6665
unit_diagonal=unit_diagonal,
6766
check_finite=check_finite,
6867
)

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,6 @@ def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
180180

181181
@numba_funcify.register(SolveTriangular)
182182
def numba_funcify_SolveTriangular(op, node, **kwargs):
183-
trans = bool(op.trans)
184183
lower = op.lower
185184
unit_diagonal = op.unit_diagonal
186185
check_finite = op.check_finite
@@ -208,7 +207,7 @@ def solve_triangular(a, b):
208207
res = _solve_triangular(
209208
a,
210209
b,
211-
trans=trans,
210+
trans=0, # transposing is handled explicitly on the graph, so we never use this argument
212211
lower=lower,
213212
unit_diagonal=unit_diagonal,
214213
overwrite_b=overwrite_b,

pytensor/tensor/slinalg.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -296,13 +296,12 @@ def L_op(self, inputs, outputs, output_gradients):
296296
# We need to return (dC/d[inv(A)], dC/db)
297297
c_bar = output_gradients[0]
298298

299-
trans_solve_op = type(self)(
300-
**{
301-
k: (not getattr(self, k) if k == "lower" else getattr(self, k))
302-
for k in self.__props__
303-
}
304-
)
305-
b_bar = trans_solve_op(A.T, c_bar)
299+
props_dict = self._props_dict()
300+
props_dict["lower"] = not self.lower
301+
302+
solve_op = type(self)(**props_dict)
303+
304+
b_bar = solve_op(A.T, c_bar)
306305
# force outer product if vector second input
307306
A_bar = -ptm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T)
308307

@@ -385,19 +384,17 @@ class SolveTriangular(SolveBase):
385384
"""Solve a system of linear equations."""
386385

387386
__props__ = (
388-
"trans",
389387
"unit_diagonal",
390388
"lower",
391389
"check_finite",
392390
"b_ndim",
393391
"overwrite_b",
394392
)
395393

396-
def __init__(self, *, trans=0, unit_diagonal=False, **kwargs):
394+
def __init__(self, *, unit_diagonal=False, **kwargs):
397395
if kwargs.get("overwrite_a", False):
398396
raise ValueError("overwrite_a is not supported for SolverTriangulare")
399397
super().__init__(**kwargs)
400-
self.trans = trans
401398
self.unit_diagonal = unit_diagonal
402399

403400
def perform(self, node, inputs, outputs):
@@ -406,7 +403,7 @@ def perform(self, node, inputs, outputs):
406403
A,
407404
b,
408405
lower=self.lower,
409-
trans=self.trans,
406+
trans=0,
410407
unit_diagonal=self.unit_diagonal,
411408
check_finite=self.check_finite,
412409
overwrite_b=self.overwrite_b,
@@ -445,9 +442,9 @@ def solve_triangular(
445442
446443
Parameters
447444
----------
448-
a
445+
a: TensorVariable
449446
Square input data
450-
b
447+
b: TensorVariable
451448
Input data for the right hand side.
452449
lower : bool, optional
453450
Use only data contained in the lower triangle of `a`. Default is to use upper triangle.
@@ -468,10 +465,17 @@ def solve_triangular(
468465
This will influence how batched dimensions are interpreted.
469466
"""
470467
b_ndim = _default_b_ndim(b, b_ndim)
468+
469+
if trans in [1, "T", True]:
470+
a = a.mT
471+
lower = not lower
472+
if trans in [2, "C"]:
473+
a = a.conj().mT
474+
lower = not lower
475+
471476
ret = Blockwise(
472477
SolveTriangular(
473478
lower=lower,
474-
trans=trans,
475479
unit_diagonal=unit_diagonal,
476480
check_finite=check_finite,
477481
b_ndim=b_ndim,
@@ -534,6 +538,7 @@ def solve(
534538
*,
535539
assume_a="gen",
536540
lower=False,
541+
transposed=False,
537542
check_finite=True,
538543
b_ndim: int | None = None,
539544
):
@@ -564,8 +569,10 @@ def solve(
564569
b : (..., N, NRHS) array_like
565570
Input data for the right hand side.
566571
lower : bool, optional
567-
If True, only the data contained in the lower triangle of `a`. Default
572+
If True, use only the data contained in the lower triangle of `a`. Default
568573
is to use upper triangle. (ignored for ``'gen'``)
574+
transposed: bool, optional
575+
If True, solves the system A^T x = b. Default is False.
569576
check_finite : bool, optional
570577
Whether to check that the input matrices contain only finite numbers.
571578
Disabling may give a performance gain, but may result in problems
@@ -577,6 +584,11 @@ def solve(
577584
This will influence how batched dimensions are interpreted.
578585
"""
579586
b_ndim = _default_b_ndim(b, b_ndim)
587+
588+
if transposed:
589+
a = a.mT
590+
lower = not lower
591+
580592
return Blockwise(
581593
Solve(
582594
lower=lower,

tests/link/jax/test_slinalg.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66

77
import pytensor.tensor as pt
8+
import tests.unittest_tools as utt
89
from pytensor.configdefaults import config
910
from pytensor.tensor import nlinalg as pt_nlinalg
1011
from pytensor.tensor import slinalg as pt_slinalg
@@ -103,28 +104,41 @@ def test_jax_basic():
103104
)
104105

105106

106-
@pytest.mark.parametrize("check_finite", [False, True])
107-
@pytest.mark.parametrize("lower", [False, True])
108-
@pytest.mark.parametrize("trans", [0, 1, 2])
109-
def test_jax_SolveTriangular(trans, lower, check_finite):
110-
x = matrix("x")
111-
b = vector("b")
107+
def test_jax_solve():
108+
rng = np.random.default_rng(utt.fetch_seed())
109+
110+
A = pt.tensor("A", shape=(5, 5))
111+
b = pt.tensor("B", shape=(5, 5))
112+
113+
out = pt_slinalg.solve(A, b, lower=False, transposed=False)
114+
115+
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
116+
b_val = rng.normal(size=(5, 5)).astype(config.floatX)
112117

113-
out = pt_slinalg.solve_triangular(
114-
x,
115-
b,
116-
trans=trans,
117-
lower=lower,
118-
check_finite=check_finite,
119-
)
120118
compare_jax_and_py(
121-
[x, b],
119+
[A, b],
122120
[out],
123-
[
124-
np.eye(10).astype(config.floatX),
125-
np.arange(10).astype(config.floatX),
126-
],
121+
[A_val, b_val],
122+
)
123+
124+
125+
def test_jax_SolveTriangular():
126+
rng = np.random.default_rng(utt.fetch_seed())
127+
128+
A = pt.tensor("A", shape=(5, 5))
129+
b = pt.tensor("B", shape=(5, 5))
130+
131+
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
132+
b_val = rng.normal(size=(5, 5)).astype(config.floatX)
133+
134+
out = pt_slinalg.solve_triangular(
135+
A,
136+
b,
137+
trans=0,
138+
lower=True,
139+
unit_diagonal=False,
127140
)
141+
compare_jax_and_py([A, b], [out], [A_val, b_val])
128142

129143

130144
def test_jax_block_diag():

tests/link/numba/test_slinalg.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import numpy as np
66
import pytest
77
from numpy.testing import assert_allclose
8-
from scipy import linalg as scipy_linalg
98

109
import pytensor
1110
import pytensor.tensor as pt
@@ -26,9 +25,9 @@ def transpose_func(x, trans):
2625
if trans == 0:
2726
return x
2827
if trans == 1:
29-
return x.conj().T
30-
if trans == 2:
3128
return x.T
29+
if trans == 2:
30+
return x.conj().T
3231

3332

3433
@pytest.mark.parametrize(
@@ -59,18 +58,18 @@ def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_compl
5958

6059
def A_func(x):
6160
x = x @ x.conj().T
62-
x_tri = scipy_linalg.cholesky(x, lower=lower).astype(dtype)
61+
x_tri = pt.linalg.cholesky(x, lower=lower).astype(dtype)
6362

6463
if unit_diag:
65-
x_tri[np.diag_indices_from(x_tri)] = 1.0
64+
x_tri = pt.fill_diagonal(x_tri, 1.0)
6665

67-
return x_tri.astype(dtype)
66+
return x_tri
6867

6968
solve_op = partial(
7069
pt.linalg.solve_triangular, lower=lower, trans=trans, unit_diagonal=unit_diag
7170
)
7271

73-
X = solve_op(A, b)
72+
X = solve_op(A_func(A), b)
7473
f = pytensor.function([A, b], X, mode="NUMBA")
7574

7675
A_val = np.random.normal(size=(5, 5))
@@ -80,20 +79,20 @@ def A_func(x):
8079
A_val = A_val + np.random.normal(size=(5, 5)) * 1j
8180
b_val = b_val + np.random.normal(size=b_shape) * 1j
8281

83-
X_np = f(A_func(A_val), b_val)
84-
85-
test_input = transpose_func(A_func(A_val), trans)
86-
87-
ATOL = 1e-8 if floatX.endswith("64") else 1e-4
88-
RTOL = 1e-8 if floatX.endswith("64") else 1e-4
89-
90-
np.testing.assert_allclose(test_input @ X_np, b_val, atol=ATOL, rtol=RTOL)
82+
X_np = f(A_val.copy(), b_val.copy())
83+
A_val_transformed = transpose_func(A_func(A_val), trans).eval()
84+
np.testing.assert_allclose(
85+
A_val_transformed @ X_np,
86+
b_val,
87+
atol=1e-8 if floatX.endswith("64") else 1e-4,
88+
rtol=1e-8 if floatX.endswith("64") else 1e-4,
89+
)
9190

9291
compiled_fgraph = f.maker.fgraph
9392
compare_numba_and_py(
9493
compiled_fgraph.inputs,
9594
compiled_fgraph.outputs,
96-
[A_func(A_val), b_val],
95+
[A_val, b_val],
9796
)
9897

9998

@@ -145,7 +144,6 @@ def test_solve_triangular_overwrite_b_correct(overwrite_b):
145144
b_test_nb = b_test_py.copy(order="F")
146145

147146
op = SolveTriangular(
148-
trans=0,
149147
unit_diagonal=False,
150148
lower=False,
151149
check_finite=True,

0 commit comments

Comments
 (0)