Skip to content

Commit 4457ced

Browse files
committed
paramatrize cholesky_ldotlt test
1 parent d85d90d commit 4457ced

File tree

2 files changed

+62
-79
lines changed

2 files changed

+62
-79
lines changed

pytensor/sandbox/linalg/ops.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,11 @@ def psd_solve_with_chol(fgraph, node):
112112
@register_canonicalize
113113
@register_stabilize
114114
@node_rewriter([Cholesky])
115-
def chol_of_dot_chol(fgraph, node):
115+
def cholesky_ldotlt(fgraph, node):
116116
"""
117+
rewrite cholesky(dot(L, L.T), lower=True) = L, where L is lower triangular,
118+
or cholesky(dot(U.T, U), upper=True) = U where U is upper triangular.
119+
117120
This utilizes a boolean `lower_triangular` or `upper_triangular` tag on matrices.
118121
"""
119122
if not isinstance(node.op, Cholesky):
@@ -133,6 +136,7 @@ def chol_of_dot_chol(fgraph, node):
133136
and r.owner.op.new_order == (1, 0)
134137
and r.owner.inputs[0] == l
135138
):
139+
print("found right form")
136140
if node.op.lower:
137141
return [l]
138142
return [r]

tests/sandbox/linalg/test_linalg.py

Lines changed: 57 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import numpy as np
22
import numpy.linalg
3+
import pytest
34
import scipy.linalg
45

56
import pytensor
67
from pytensor import function
78
from pytensor import tensor as at
8-
from pytensor.compile import DeepCopyOp, ViewOp
9+
from pytensor.compile import get_default_mode
910
from pytensor.configdefaults import config
1011
from pytensor.sandbox.linalg.ops import inv_as_solve, spectral_radius_bound
1112
from pytensor.tensor.elemwise import DimShuffle
@@ -156,94 +157,72 @@ def test_matrix_inverse_solve():
156157
assert isinstance(out.owner.op, Solve)
157158

158159

159-
def test_cholesky_dot_lower():
160-
cholesky_lower = Cholesky(lower=True)
161-
cholesky_upper = Cholesky(lower=False)
160+
@pytest.mark.parametrize("tag", ("lower", "upper", None))
161+
@pytest.mark.parametrize("cholesky_form", ("lower", "upper"))
162+
@pytest.mark.parametrize("product", ("lower", "upper", None))
163+
def test_cholesky_ldotlt(tag, cholesky_form, product):
164+
cholesky = Cholesky(lower=(cholesky_form == "lower"))
162165

163-
L = matrix("L")
164-
L.tag.lower_triangular = True
166+
transform_removes_chol = tag is not None and product == tag
167+
transform_transposes = transform_removes_chol and cholesky_form != tag
165168

166-
C = cholesky_lower(L.dot(L.T))
167-
f = pytensor.function([L], C)
169+
A = matrix("L")
170+
if tag:
171+
setattr(A.tag, tag + "_triangular", True)
168172

169-
if config.mode != "FAST_COMPILE":
170-
assert (f.maker.fgraph.outputs[0] == f.maker.fgraph.inputs[0]) or (
171-
(o := f.maker.fgraph.outputs[0].owner)
172-
and isinstance(o.op, (DeepCopyOp, ViewOp))
173-
and o.inputs[0] == f.maker.fgraph.inputs[0]
174-
)
173+
if product == "lower":
174+
M = A.dot(A.T)
175+
elif product == "upper":
176+
M = A.T.dot(A)
177+
else:
178+
M = A
175179

176-
# Test some concrete value through f:
177-
Lv = np.array([[2, 0], [1, 4]])
178-
assert np.all(
179-
np.isclose(
180-
scipy.linalg.cholesky(np.dot(Lv, Lv.T), lower=True),
181-
f(Lv),
182-
)
183-
)
180+
C = cholesky(M)
181+
f = pytensor.function([A], C, mode=get_default_mode().including("cholesky_ldotlt"))
184182

185-
# Test upper decomposition factors down to a transpose
186-
C = cholesky_upper(L.dot(L.T))
187-
f = pytensor.function([L], C)
188-
if config.mode != "FAST_COMPILE":
189-
assert (
190-
(o1 := f.maker.fgraph.outputs[0].owner)
191-
and isinstance(o1.op, (DeepCopyOp, ViewOp))
192-
and (o2 := o1.inputs[0].owner)
193-
and isinstance(o2.op, DimShuffle)
194-
and o2.op.new_order == (1, 0)
195-
and o2.inputs[0] == f.maker.fgraph.inputs[0]
196-
)
183+
print(f.maker.fgraph.apply_nodes)
197184

198-
assert np.all(
199-
np.isclose(
200-
scipy.linalg.cholesky(np.dot(Lv, Lv.T), lower=False),
201-
f(Lv),
202-
)
185+
no_cholesky_in_graph = not any(
186+
isinstance(node.op, Cholesky) for node in f.maker.fgraph.apply_nodes
203187
)
204188

189+
assert no_cholesky_in_graph == transform_removes_chol
205190

206-
def test_cholesky_dot_upper():
207-
cholesky_lower = Cholesky(lower=True)
208-
cholesky_upper = Cholesky(lower=False)
209-
210-
U = matrix("U")
211-
U.tag.upper_triangular = True
212-
213-
C = cholesky_upper(U.T.dot(U))
214-
f = pytensor.function([U], C)
215-
if config.mode != "FAST_COMPILE":
216-
assert (f.maker.fgraph.outputs[0] == f.maker.fgraph.inputs[0]) or (
217-
(o := f.maker.fgraph.outputs[0].owner)
218-
and isinstance(o.op, (DeepCopyOp, ViewOp))
219-
and o.inputs[0] == f.maker.fgraph.inputs[0]
191+
if transform_transposes:
192+
assert any(
193+
isinstance(node.op, DimShuffle) and node.op.new_order == (1, 0)
194+
for node in f.maker.fgraph.apply_nodes
220195
)
221196

222-
# Test some concrete value through f:
223-
Uv = np.array([[2, 1], [0, 4]])
224-
assert np.all(
225-
np.isclose(
226-
scipy.linalg.cholesky(np.dot(Uv.T, Uv), lower=False),
227-
f(Uv),
197+
# Test some concrete value through f
198+
# there must be lower triangular (f assumes they are)
199+
Avs = [
200+
np.eye(1),
201+
np.eye(10),
202+
np.array([[2, 0], [1, 4]]),
203+
]
204+
if not tag:
205+
# these must be positive def
206+
Avs.extend(
207+
[
208+
np.ones((4, 4)) + np.eye(4),
209+
]
228210
)
229-
)
230211

231-
# Test lower decomposition factors down to a transpose
232-
C = cholesky_lower(U.T.dot(U))
233-
f = pytensor.function([U], C)
234-
if config.mode != "FAST_COMPILE":
235-
assert (
236-
(o1 := f.maker.fgraph.outputs[0].owner)
237-
and isinstance(o1.op, (DeepCopyOp, ViewOp))
238-
and (o2 := o1.inputs[0].owner)
239-
and isinstance(o2.op, DimShuffle)
240-
and o2.op.new_order == (1, 0)
241-
and o2.inputs[0] == f.maker.fgraph.inputs[0]
242-
)
243-
244-
assert np.all(
245-
np.isclose(
246-
scipy.linalg.cholesky(np.dot(Uv.T, Uv), lower=True),
247-
f(Uv),
212+
for Av in Avs:
213+
if tag == "upper":
214+
Av = Av.T
215+
216+
if product == "lower":
217+
Mv = Av.dot(Av.T)
218+
elif product == "upper":
219+
Mv = Av.T.dot(Av)
220+
else:
221+
Mv = Av
222+
223+
assert np.all(
224+
np.isclose(
225+
scipy.linalg.cholesky(Mv, lower=(cholesky_form == "lower")),
226+
f(Av),
227+
)
248228
)
249-
)

0 commit comments

Comments
 (0)