Skip to content

Commit 6cadc76

Browse files
committed
Consider hermitian in jax dispatch of pinv
1 parent 5a8f313 commit 6cadc76

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

pytensor/link/jax/dispatch/nlinalg.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def dot(x, y):
8989
@jax_funcify.register(MatrixPinv)
9090
def jax_funcify_Pinv(op, **kwargs):
9191
def pinv(x):
92-
return jnp.linalg.pinv(x)
92+
return jnp.linalg.pinv(x, hermitian=op.hermitian)
9393

9494
return pinv
9595

tests/link/jax/test_nlinalg.py

+31
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,34 @@ def test_pinv():
143143
fgraph = FunctionGraph([x], [x_inv])
144144
x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
145145
compare_jax_and_py(fgraph, [x_np])
146+
147+
148+
def test_pinv_hermitian():
149+
A = matrix("A", dtype="complex128")
150+
A_h_test = np.c_[[3, 3 + 2j], [3 - 2j, 2]]
151+
A_not_h_test = A_h_test + 0 + 1j
152+
153+
A_inv = at_nlinalg.pinv(A, hermitian=False)
154+
jax_fn = function([A], A_inv, mode="JAX")
155+
156+
assert np.allclose(jax_fn(A_h_test), np.linalg.pinv(A_h_test, hermitian=False))
157+
assert np.allclose(jax_fn(A_h_test), np.linalg.pinv(A_h_test, hermitian=True))
158+
assert np.allclose(
159+
jax_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=False)
160+
)
161+
assert not np.allclose(
162+
jax_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=True)
163+
)
164+
165+
A_inv = at_nlinalg.pinv(A, hermitian=True)
166+
jax_fn = function([A], A_inv, mode="JAX")
167+
168+
assert np.allclose(jax_fn(A_h_test), np.linalg.pinv(A_h_test, hermitian=False))
169+
assert np.allclose(jax_fn(A_h_test), np.linalg.pinv(A_h_test, hermitian=True))
170+
assert not np.allclose(
171+
jax_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=False)
172+
)
173+
# Numpy fails differently than JAX when hermitian assumption is violated
174+
assert not np.allclose(
175+
jax_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=True)
176+
)

0 commit comments

Comments
 (0)