From 4b0c09a926a530a5038a5eff5f45c4469084b4e7 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sun, 14 May 2023 11:14:22 +0200 Subject: [PATCH 1/2] Fix failing float32 jax pinv test --- tests/link/jax/test_nlinalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/link/jax/test_nlinalg.py b/tests/link/jax/test_nlinalg.py index 88b29f10be..ff37433e22 100644 --- a/tests/link/jax/test_nlinalg.py +++ b/tests/link/jax/test_nlinalg.py @@ -141,5 +141,5 @@ def test_pinv(): x_inv = at_nlinalg.pinv(x) fgraph = FunctionGraph([x], [x_inv]) - x_np = np.array([[1.0, 2.0], [3.0, 4.0]]) + x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) compare_jax_and_py(fgraph, [x_np]) From bde3e5df315ea9c488cad5856968f275300ceef6 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sun, 14 May 2023 11:14:53 +0200 Subject: [PATCH 2/2] Consider hermitian in jax dispatch of pinv --- pytensor/link/jax/dispatch/nlinalg.py | 2 +- tests/link/jax/test_nlinalg.py | 31 +++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/pytensor/link/jax/dispatch/nlinalg.py b/pytensor/link/jax/dispatch/nlinalg.py index 21d7ade849..6f6467cff7 100644 --- a/pytensor/link/jax/dispatch/nlinalg.py +++ b/pytensor/link/jax/dispatch/nlinalg.py @@ -89,7 +89,7 @@ def dot(x, y): @jax_funcify.register(MatrixPinv) def jax_funcify_Pinv(op, **kwargs): def pinv(x): - return jnp.linalg.pinv(x) + return jnp.linalg.pinv(x, hermitian=op.hermitian) return pinv diff --git a/tests/link/jax/test_nlinalg.py b/tests/link/jax/test_nlinalg.py index ff37433e22..98bfbb610c 100644 --- a/tests/link/jax/test_nlinalg.py +++ b/tests/link/jax/test_nlinalg.py @@ -143,3 +143,34 @@ def test_pinv(): fgraph = FunctionGraph([x], [x_inv]) x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) compare_jax_and_py(fgraph, [x_np]) + + +def test_pinv_hermitian(): + A = matrix("A", dtype="complex128") + A_h_test = np.c_[[3, 3 + 2j], [3 - 2j, 2]] + A_not_h_test = A_h_test + 0 + 1j + + A_inv = at_nlinalg.pinv(A, hermitian=False) + jax_fn = function([A], A_inv, mode="JAX") + + assert np.allclose(jax_fn(A_h_test), np.linalg.pinv(A_h_test, hermitian=False)) + assert np.allclose(jax_fn(A_h_test), np.linalg.pinv(A_h_test, hermitian=True)) + assert np.allclose( + jax_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=False) + ) + assert not np.allclose( + jax_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=True) + ) + + A_inv = at_nlinalg.pinv(A, hermitian=True) + jax_fn = function([A], A_inv, mode="JAX") + + assert np.allclose(jax_fn(A_h_test), np.linalg.pinv(A_h_test, hermitian=False)) + assert np.allclose(jax_fn(A_h_test), np.linalg.pinv(A_h_test, hermitian=True)) + assert not np.allclose( + jax_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=False) + ) + # Numpy fails differently than JAX when hermitian assumption is violated + assert not np.allclose( + jax_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=True) + )