|
1 | 1 | import numpy as np
|
2 | 2 | import numpy.linalg
|
| 3 | +import pytest |
3 | 4 | import scipy.linalg
|
4 | 5 |
|
5 | 6 | import pytensor
|
6 | 7 | from pytensor import function
|
7 | 8 | from pytensor import tensor as at
|
8 |
| -from pytensor.compile import DeepCopyOp, ViewOp |
| 9 | +from pytensor.compile import get_default_mode |
9 | 10 | from pytensor.configdefaults import config
|
10 | 11 | from pytensor.sandbox.linalg.ops import inv_as_solve, spectral_radius_bound
|
11 | 12 | from pytensor.tensor.elemwise import DimShuffle
|
@@ -156,94 +157,72 @@ def test_matrix_inverse_solve():
|
156 | 157 | assert isinstance(out.owner.op, Solve)
|
157 | 158 |
|
158 | 159 |
|
159 |
| -def test_cholesky_dot_lower(): |
160 |
| - cholesky_lower = Cholesky(lower=True) |
161 |
| - cholesky_upper = Cholesky(lower=False) |
| 160 | +@pytest.mark.parametrize("tag", ("lower", "upper", None)) |
| 161 | +@pytest.mark.parametrize("cholesky_form", ("lower", "upper")) |
| 162 | +@pytest.mark.parametrize("product", ("lower", "upper", None)) |
| 163 | +def test_cholesky_ldotlt(tag, cholesky_form, product): |
| 164 | + cholesky = Cholesky(lower=(cholesky_form == "lower")) |
162 | 165 |
|
163 |
| - L = matrix("L") |
164 |
| - L.tag.lower_triangular = True |
| 166 | + transform_removes_chol = tag is not None and product == tag |
| 167 | + transform_transposes = transform_removes_chol and cholesky_form != tag |
165 | 168 |
|
166 |
| - C = cholesky_lower(L.dot(L.T)) |
167 |
| - f = pytensor.function([L], C) |
| 169 | + A = matrix("L") |
| 170 | + if tag: |
| 171 | + setattr(A.tag, tag + "_triangular", True) |
168 | 172 |
|
169 |
| - if config.mode != "FAST_COMPILE": |
170 |
| - assert (f.maker.fgraph.outputs[0] == f.maker.fgraph.inputs[0]) or ( |
171 |
| - (o := f.maker.fgraph.outputs[0].owner) |
172 |
| - and isinstance(o.op, (DeepCopyOp, ViewOp)) |
173 |
| - and o.inputs[0] == f.maker.fgraph.inputs[0] |
174 |
| - ) |
| 173 | + if product == "lower": |
| 174 | + M = A.dot(A.T) |
| 175 | + elif product == "upper": |
| 176 | + M = A.T.dot(A) |
| 177 | + else: |
| 178 | + M = A |
175 | 179 |
|
176 |
| - # Test some concrete value through f: |
177 |
| - Lv = np.array([[2, 0], [1, 4]]) |
178 |
| - assert np.all( |
179 |
| - np.isclose( |
180 |
| - scipy.linalg.cholesky(np.dot(Lv, Lv.T), lower=True), |
181 |
| - f(Lv), |
182 |
| - ) |
183 |
| - ) |
| 180 | + C = cholesky(M) |
| 181 | + f = pytensor.function([A], C, mode=get_default_mode().including("cholesky_ldotlt")) |
184 | 182 |
|
185 |
| - # Test upper decomposition factors down to a transpose |
186 |
| - C = cholesky_upper(L.dot(L.T)) |
187 |
| - f = pytensor.function([L], C) |
188 |
| - if config.mode != "FAST_COMPILE": |
189 |
| - assert ( |
190 |
| - (o1 := f.maker.fgraph.outputs[0].owner) |
191 |
| - and isinstance(o1.op, (DeepCopyOp, ViewOp)) |
192 |
| - and (o2 := o1.inputs[0].owner) |
193 |
| - and isinstance(o2.op, DimShuffle) |
194 |
| - and o2.op.new_order == (1, 0) |
195 |
| - and o2.inputs[0] == f.maker.fgraph.inputs[0] |
196 |
| - ) |
| 183 | + print(f.maker.fgraph.apply_nodes) |
197 | 184 |
|
198 |
| - assert np.all( |
199 |
| - np.isclose( |
200 |
| - scipy.linalg.cholesky(np.dot(Lv, Lv.T), lower=False), |
201 |
| - f(Lv), |
202 |
| - ) |
| 185 | + no_cholesky_in_graph = not any( |
| 186 | + isinstance(node.op, Cholesky) for node in f.maker.fgraph.apply_nodes |
203 | 187 | )
|
204 | 188 |
|
| 189 | + assert no_cholesky_in_graph == transform_removes_chol |
205 | 190 |
|
206 |
| -def test_cholesky_dot_upper(): |
207 |
| - cholesky_lower = Cholesky(lower=True) |
208 |
| - cholesky_upper = Cholesky(lower=False) |
209 |
| - |
210 |
| - U = matrix("U") |
211 |
| - U.tag.upper_triangular = True |
212 |
| - |
213 |
| - C = cholesky_upper(U.T.dot(U)) |
214 |
| - f = pytensor.function([U], C) |
215 |
| - if config.mode != "FAST_COMPILE": |
216 |
| - assert (f.maker.fgraph.outputs[0] == f.maker.fgraph.inputs[0]) or ( |
217 |
| - (o := f.maker.fgraph.outputs[0].owner) |
218 |
| - and isinstance(o.op, (DeepCopyOp, ViewOp)) |
219 |
| - and o.inputs[0] == f.maker.fgraph.inputs[0] |
| 191 | + if transform_transposes: |
| 192 | + assert any( |
| 193 | + isinstance(node.op, DimShuffle) and node.op.new_order == (1, 0) |
| 194 | + for node in f.maker.fgraph.apply_nodes |
220 | 195 | )
|
221 | 196 |
|
222 |
| - # Test some concrete value through f: |
223 |
| - Uv = np.array([[2, 1], [0, 4]]) |
224 |
| - assert np.all( |
225 |
| - np.isclose( |
226 |
| - scipy.linalg.cholesky(np.dot(Uv.T, Uv), lower=False), |
227 |
| - f(Uv), |
| 197 | + # Test some concrete value through f |
| 198 | + # there must be lower triangular (f assumes they are) |
| 199 | + Avs = [ |
| 200 | + np.eye(1), |
| 201 | + np.eye(10), |
| 202 | + np.array([[2, 0], [1, 4]]), |
| 203 | + ] |
| 204 | + if not tag: |
| 205 | + # these must be positive def |
| 206 | + Avs.extend( |
| 207 | + [ |
| 208 | + np.ones((4, 4)) + np.eye(4), |
| 209 | + ] |
228 | 210 | )
|
229 |
| - ) |
230 | 211 |
|
231 |
| - # Test lower decomposition factors down to a transpose |
232 |
| - C = cholesky_lower(U.T.dot(U)) |
233 |
| - f = pytensor.function([U], C) |
234 |
| - if config.mode != "FAST_COMPILE": |
235 |
| - assert ( |
236 |
| - (o1 := f.maker.fgraph.outputs[0].owner) |
237 |
| - and isinstance(o1.op, (DeepCopyOp, ViewOp)) |
238 |
| - and (o2 := o1.inputs[0].owner) |
239 |
| - and isinstance(o2.op, DimShuffle) |
240 |
| - and o2.op.new_order == (1, 0) |
241 |
| - and o2.inputs[0] == f.maker.fgraph.inputs[0] |
242 |
| - ) |
243 |
| - |
244 |
| - assert np.all( |
245 |
| - np.isclose( |
246 |
| - scipy.linalg.cholesky(np.dot(Uv.T, Uv), lower=True), |
247 |
| - f(Uv), |
| 212 | + for Av in Avs: |
| 213 | + if tag == "upper": |
| 214 | + Av = Av.T |
| 215 | + |
| 216 | + if product == "lower": |
| 217 | + Mv = Av.dot(Av.T) |
| 218 | + elif product == "upper": |
| 219 | + Mv = Av.T.dot(Av) |
| 220 | + else: |
| 221 | + Mv = Av |
| 222 | + |
| 223 | + assert np.all( |
| 224 | + np.isclose( |
| 225 | + scipy.linalg.cholesky(Mv, lower=(cholesky_form == "lower")), |
| 226 | + f(Av), |
| 227 | + ) |
248 | 228 | )
|
249 |
| - ) |
0 commit comments