Skip to content

Commit 1d5c020

Browse files
Expand test coverage for Solve and SolveTriangular
1 parent 4084f82 commit 1d5c020

File tree

3 files changed

+198
-99
lines changed

3 files changed

+198
-99
lines changed

tests/link/jax/test_slinalg.py

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66

77
import pytensor.tensor as pt
8+
import tests.unittest_tools as utt
89
from pytensor.configdefaults import config
910
from pytensor.tensor import nlinalg as pt_nlinalg
1011
from pytensor.tensor import slinalg as pt_slinalg
@@ -103,28 +104,72 @@ def test_jax_basic():
103104
)
104105

105106

106-
@pytest.mark.parametrize("check_finite", [False, True])
107+
@pytest.mark.parametrize(
108+
"b_shape",
109+
[(5, 1), (5, 5), (5,)],
110+
ids=["b_col_vec", "b_matrix", "b_vec"],
111+
)
112+
@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str)
113+
@pytest.mark.parametrize("lower", [False, True])
114+
@pytest.mark.parametrize("transposed", [False, True])
115+
def test_jax_solve(b_shape: tuple[int], assume_a, lower, transposed):
116+
rng = np.random.default_rng(utt.fetch_seed())
117+
118+
A = pt.tensor("A", shape=(5, 5))
119+
b = pt.tensor("B", shape=b_shape)
120+
121+
def A_func(x):
122+
if assume_a == "sym":
123+
return (x + x.T) / 2
124+
if assume_a == "pos":
125+
return x @ x.T
126+
return x
127+
128+
out = pt_slinalg.solve(
129+
A_func(A), b, assume_a=assume_a, lower=lower, transposed=transposed
130+
)
131+
132+
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
133+
b_val = rng.normal(size=b_shape).astype(config.floatX)
134+
135+
compare_jax_and_py(
136+
[A, b],
137+
[out],
138+
[A_val, b_val],
139+
)
140+
141+
142+
@pytest.mark.parametrize(
143+
"b_shape", [(5, 1), (5, 5), (5,)], ids=["b_col_vec", "b_matrix", "b_vec"]
144+
)
107145
@pytest.mark.parametrize("lower", [False, True])
108146
@pytest.mark.parametrize("trans", [0, 1, 2])
109-
def test_jax_SolveTriangular(trans, lower, check_finite):
110-
x = matrix("x")
111-
b = vector("b")
147+
@pytest.mark.parametrize("unit_diagonal", [False, True])
148+
def test_jax_SolveTriangular(b_shape: tuple[int], lower, trans, unit_diagonal):
149+
rng = np.random.default_rng(utt.fetch_seed())
150+
151+
A = pt.tensor("A", shape=(5, 5))
152+
b = pt.tensor("B", shape=b_shape)
153+
154+
def A_func(x):
155+
x = x @ x.T
156+
x = pt.linalg.cholesky(x, lower=lower)
157+
if unit_diagonal:
158+
x = pt.fill_diagonal(x, 1.0)
159+
160+
return x
161+
162+
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
163+
b_val = rng.normal(size=b_shape).astype(config.floatX)
112164

113165
out = pt_slinalg.solve_triangular(
114-
x,
166+
A_func(A),
115167
b,
116168
trans=trans,
117169
lower=lower,
118-
check_finite=check_finite,
119-
)
120-
compare_jax_and_py(
121-
[x, b],
122-
[out],
123-
[
124-
np.eye(10).astype(config.floatX),
125-
np.arange(10).astype(config.floatX),
126-
],
170+
unit_diagonal=unit_diagonal,
127171
)
172+
compare_jax_and_py([A, b], [out], [A_val, b_val])
128173

129174

130175
def test_jax_block_diag():

tests/link/numba/test_slinalg.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import numpy as np
66
import pytest
77
from numpy.testing import assert_allclose
8-
from scipy import linalg as scipy_linalg
98

109
import pytensor
1110
import pytensor.tensor as pt
@@ -26,9 +25,9 @@ def transpose_func(x, trans):
2625
if trans == 0:
2726
return x
2827
if trans == 1:
29-
return x.conj().T
30-
if trans == 2:
3128
return x.T
29+
if trans == 2:
30+
return x.conj().T
3231

3332

3433
@pytest.mark.parametrize(
@@ -59,18 +58,18 @@ def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_compl
5958

6059
def A_func(x):
6160
x = x @ x.conj().T
62-
x_tri = scipy_linalg.cholesky(x, lower=lower).astype(dtype)
61+
x_tri = pt.linalg.cholesky(x, lower=lower).astype(dtype)
6362

6463
if unit_diag:
65-
x_tri[np.diag_indices_from(x_tri)] = 1.0
64+
x_tri = pt.fill_diagonal(x_tri, 1.0)
6665

67-
return x_tri.astype(dtype)
66+
return x_tri
6867

6968
solve_op = partial(
7069
pt.linalg.solve_triangular, lower=lower, trans=trans, unit_diagonal=unit_diag
7170
)
7271

73-
X = solve_op(A, b)
72+
X = solve_op(A_func(A), b)
7473
f = pytensor.function([A, b], X, mode="NUMBA")
7574

7675
A_val = np.random.normal(size=(5, 5))
@@ -80,20 +79,20 @@ def A_func(x):
8079
A_val = A_val + np.random.normal(size=(5, 5)) * 1j
8180
b_val = b_val + np.random.normal(size=b_shape) * 1j
8281

83-
X_np = f(A_func(A_val), b_val)
84-
85-
test_input = transpose_func(A_func(A_val), trans)
86-
87-
ATOL = 1e-8 if floatX.endswith("64") else 1e-4
88-
RTOL = 1e-8 if floatX.endswith("64") else 1e-4
89-
90-
np.testing.assert_allclose(test_input @ X_np, b_val, atol=ATOL, rtol=RTOL)
82+
X_np = f(A_val.copy(), b_val.copy())
83+
A_val_transformed = transpose_func(A_func(A_val), trans).eval()
84+
np.testing.assert_allclose(
85+
A_val_transformed @ X_np,
86+
b_val,
87+
atol=1e-8 if floatX.endswith("64") else 1e-4,
88+
rtol=1e-8 if floatX.endswith("64") else 1e-4,
89+
)
9190

9291
compiled_fgraph = f.maker.fgraph
9392
compare_numba_and_py(
9493
compiled_fgraph.inputs,
9594
compiled_fgraph.outputs,
96-
[A_func(A_val), b_val],
95+
[A_val, b_val],
9796
)
9897

9998

0 commit comments

Comments
 (0)