Skip to content

Commit 6cebccd

Browse files
committed
move linalg rewrites, delete sandbox
Moving caused a circular dependency with tensor.blas. It seems most linalg rewrites are in the stablize set, so should run before the blas specializers anyway, so these checks were removed. This also deleted the unused `spectral_radius_bound` and dummy `Minimal(Op)`.
1 parent 4457ced commit 6cebccd

File tree

9 files changed

+5
-131
lines changed

9 files changed

+5
-131
lines changed

pytensor/sandbox/__init__.py

Whitespace-only changes.

pytensor/sandbox/linalg/__init__.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

pytensor/sandbox/minimal.py

Lines changed: 0 additions & 46 deletions
This file was deleted.

pytensor/tensor/rewriting/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
# Register JAX specializations
66
import pytensor.tensor.rewriting.jax
7+
import pytensor.tensor.rewriting.linalg
78
import pytensor.tensor.rewriting.math
89
import pytensor.tensor.rewriting.shape
910
import pytensor.tensor.rewriting.special

pytensor/sandbox/linalg/ops.py renamed to pytensor/tensor/rewriting/linalg.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from pytensor.graph.rewriting.basic import node_rewriter
44
from pytensor.tensor import basic as at
5-
from pytensor.tensor.blas import Dot22
65
from pytensor.tensor.elemwise import DimShuffle
76
from pytensor.tensor.math import Dot, Prod, dot, log
87
from pytensor.tensor.math import pow as at_pow
@@ -32,12 +31,12 @@ def transinv_to_invtrans(fgraph, node):
3231

3332

3433
@register_stabilize
35-
@node_rewriter([Dot, Dot22])
34+
@node_rewriter([Dot])
3635
def inv_as_solve(fgraph, node):
3736
"""
3837
This utilizes a boolean `symmetric` tag on the matrices.
3938
"""
40-
if isinstance(node.op, (Dot, Dot22)):
39+
if isinstance(node.op, Dot):
4140
l, r = node.inputs
4241
if l.owner and isinstance(l.owner.op, MatrixInverse):
4342
return [solve(l.owner.inputs[0], r)]
@@ -123,7 +122,7 @@ def cholesky_ldotlt(fgraph, node):
123122
return
124123

125124
A = node.inputs[0]
126-
if not (A.owner and isinstance(A.owner.op, (Dot, Dot22))):
125+
if not (A.owner and isinstance(A.owner.op, Dot)):
127126
return
128127

129128
l, r = A.owner.inputs

tests/sandbox/__init__.py

Whitespace-only changes.

tests/sandbox/linalg/__init__.py

Whitespace-only changes.

tests/sandbox/test_minimal.py

Lines changed: 0 additions & 32 deletions
This file was deleted.

tests/sandbox/linalg/test_linalg.py renamed to tests/tensor/rewriting/test_linalg.py

Lines changed: 1 addition & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
from pytensor import tensor as at
99
from pytensor.compile import get_default_mode
1010
from pytensor.configdefaults import config
11-
from pytensor.sandbox.linalg.ops import inv_as_solve, spectral_radius_bound
1211
from pytensor.tensor.elemwise import DimShuffle
1312
from pytensor.tensor.math import _allclose
1413
from pytensor.tensor.nlinalg import MatrixInverse, matrix_inverse
14+
from pytensor.tensor.rewriting.linalg import inv_as_solve
1515
from pytensor.tensor.slinalg import Cholesky, Solve, solve
1616
from pytensor.tensor.type import dmatrix, matrix, vector
1717
from tests import unittest_tools as utt
@@ -68,53 +68,6 @@ def test_rop_lop():
6868
assert _allclose(v1, v2), f"LOP mismatch: {v1} {v2}"
6969

7070

71-
def test_spectral_radius_bound():
72-
tol = 10 ** (-6)
73-
rng = np.random.default_rng(utt.fetch_seed())
74-
x = matrix()
75-
radius_bound = spectral_radius_bound(x, 5)
76-
f = pytensor.function([x], radius_bound)
77-
78-
shp = (3, 4)
79-
m = rng.random(shp)
80-
m = np.cov(m).astype(config.floatX)
81-
radius_bound_pytensor = f(m)
82-
83-
# test the approximation
84-
mm = m
85-
for i in range(5):
86-
mm = np.dot(mm, mm)
87-
radius_bound_numpy = np.trace(mm) ** (2 ** (-5))
88-
assert abs(radius_bound_numpy - radius_bound_pytensor) < tol
89-
90-
# test the bound
91-
eigen_val = numpy.linalg.eig(m)
92-
assert (eigen_val[0].max() - radius_bound_pytensor) < tol
93-
94-
# test type errors
95-
xx = vector()
96-
ok = False
97-
try:
98-
spectral_radius_bound(xx, 5)
99-
except TypeError:
100-
ok = True
101-
assert ok
102-
ok = False
103-
try:
104-
spectral_radius_bound(x, 5.0)
105-
except TypeError:
106-
ok = True
107-
assert ok
108-
109-
# test value error
110-
ok = False
111-
try:
112-
spectral_radius_bound(x, -5)
113-
except ValueError:
114-
ok = True
115-
assert ok
116-
117-
11871
def test_transinv_to_invtrans():
11972
X = matrix("X")
12073
Y = matrix_inverse(X)

0 commit comments

Comments
 (0)