|
5 | 5 | import pytest
|
6 | 6 |
|
7 | 7 | import pytensor.tensor as pt
|
| 8 | +import tests.unittest_tools as utt |
8 | 9 | from pytensor.configdefaults import config
|
9 | 10 | from pytensor.tensor import nlinalg as pt_nlinalg
|
10 | 11 | from pytensor.tensor import slinalg as pt_slinalg
|
@@ -103,28 +104,72 @@ def test_jax_basic():
|
103 | 104 | )
|
104 | 105 |
|
105 | 106 |
|
106 |
| -@pytest.mark.parametrize("check_finite", [False, True]) |
| 107 | +@pytest.mark.parametrize( |
| 108 | + "b_shape", |
| 109 | + [(5, 1), (5, 5), (5,)], |
| 110 | + ids=["b_col_vec", "b_matrix", "b_vec"], |
| 111 | +) |
| 112 | +@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str) |
| 113 | +@pytest.mark.parametrize("lower", [False, True]) |
| 114 | +@pytest.mark.parametrize("transposed", [False, True]) |
| 115 | +def test_jax_solve(b_shape: tuple[int], assume_a, lower, transposed): |
| 116 | + rng = np.random.default_rng(utt.fetch_seed()) |
| 117 | + |
| 118 | + A = pt.tensor("A", shape=(5, 5)) |
| 119 | + b = pt.tensor("B", shape=b_shape) |
| 120 | + |
| 121 | + def A_func(x): |
| 122 | + if assume_a == "sym": |
| 123 | + return (x + x.T) / 2 |
| 124 | + if assume_a == "pos": |
| 125 | + return x @ x.T |
| 126 | + return x |
| 127 | + |
| 128 | + out = pt_slinalg.solve( |
| 129 | + A_func(A), b, assume_a=assume_a, lower=lower, transposed=transposed |
| 130 | + ) |
| 131 | + |
| 132 | + A_val = rng.normal(size=(5, 5)).astype(config.floatX) |
| 133 | + b_val = rng.normal(size=b_shape).astype(config.floatX) |
| 134 | + |
| 135 | + compare_jax_and_py( |
| 136 | + [A, b], |
| 137 | + [out], |
| 138 | + [A_val, b_val], |
| 139 | + ) |
| 140 | + |
| 141 | + |
| 142 | +@pytest.mark.parametrize( |
| 143 | + "b_shape", [(5, 1), (5, 5), (5,)], ids=["b_col_vec", "b_matrix", "b_vec"] |
| 144 | +) |
107 | 145 | @pytest.mark.parametrize("lower", [False, True])
|
108 | 146 | @pytest.mark.parametrize("trans", [0, 1, 2])
|
109 |
| -def test_jax_SolveTriangular(trans, lower, check_finite): |
110 |
| - x = matrix("x") |
111 |
| - b = vector("b") |
| 147 | +@pytest.mark.parametrize("unit_diagonal", [False, True]) |
| 148 | +def test_jax_SolveTriangular(b_shape: tuple[int], lower, trans, unit_diagonal): |
| 149 | + rng = np.random.default_rng(utt.fetch_seed()) |
| 150 | + |
| 151 | + A = pt.tensor("A", shape=(5, 5)) |
| 152 | + b = pt.tensor("B", shape=b_shape) |
| 153 | + |
| 154 | + def A_func(x): |
| 155 | + x = x @ x.T |
| 156 | + x = pt.linalg.cholesky(x, lower=lower) |
| 157 | + if unit_diagonal: |
| 158 | + x = pt.fill_diagonal(x, 1.0) |
| 159 | + |
| 160 | + return x |
| 161 | + |
| 162 | + A_val = rng.normal(size=(5, 5)).astype(config.floatX) |
| 163 | + b_val = rng.normal(size=b_shape).astype(config.floatX) |
112 | 164 |
|
113 | 165 | out = pt_slinalg.solve_triangular(
|
114 |
| - x, |
| 166 | + A_func(A), |
115 | 167 | b,
|
116 | 168 | trans=trans,
|
117 | 169 | lower=lower,
|
118 |
| - check_finite=check_finite, |
119 |
| - ) |
120 |
| - compare_jax_and_py( |
121 |
| - [x, b], |
122 |
| - [out], |
123 |
| - [ |
124 |
| - np.eye(10).astype(config.floatX), |
125 |
| - np.arange(10).astype(config.floatX), |
126 |
| - ], |
| 170 | + unit_diagonal=unit_diagonal, |
127 | 171 | )
|
| 172 | + compare_jax_and_py([A, b], [out], [A_val, b_val]) |
128 | 173 |
|
129 | 174 |
|
130 | 175 | def test_jax_block_diag():
|
|
0 commit comments