@@ -143,3 +143,34 @@ def test_pinv():
143
143
fgraph = FunctionGraph ([x ], [x_inv ])
144
144
x_np = np .array ([[1.0 , 2.0 ], [3.0 , 4.0 ]], dtype = config .floatX )
145
145
compare_jax_and_py (fgraph , [x_np ])
146
+
147
+
148
+ def test_pinv_hermitian ():
149
+ A = matrix ("A" , dtype = "complex128" )
150
+ A_h_test = np .c_ [[3 , 3 + 2j ], [3 - 2j , 2 ]]
151
+ A_not_h_test = A_h_test + 0 + 1j
152
+
153
+ A_inv = at_nlinalg .pinv (A , hermitian = False )
154
+ jax_fn = function ([A ], A_inv , mode = "JAX" )
155
+
156
+ assert np .allclose (jax_fn (A_h_test ), np .linalg .pinv (A_h_test , hermitian = False ))
157
+ assert np .allclose (jax_fn (A_h_test ), np .linalg .pinv (A_h_test , hermitian = True ))
158
+ assert np .allclose (
159
+ jax_fn (A_not_h_test ), np .linalg .pinv (A_not_h_test , hermitian = False )
160
+ )
161
+ assert not np .allclose (
162
+ jax_fn (A_not_h_test ), np .linalg .pinv (A_not_h_test , hermitian = True )
163
+ )
164
+
165
+ A_inv = at_nlinalg .pinv (A , hermitian = True )
166
+ jax_fn = function ([A ], A_inv , mode = "JAX" )
167
+
168
+ assert np .allclose (jax_fn (A_h_test ), np .linalg .pinv (A_h_test , hermitian = False ))
169
+ assert np .allclose (jax_fn (A_h_test ), np .linalg .pinv (A_h_test , hermitian = True ))
170
+ assert not np .allclose (
171
+ jax_fn (A_not_h_test ), np .linalg .pinv (A_not_h_test , hermitian = False )
172
+ )
173
+ # Numpy fails differently than JAX when hermitian assumption is violated
174
+ assert not np .allclose (
175
+ jax_fn (A_not_h_test ), np .linalg .pinv (A_not_h_test , hermitian = True )
176
+ )
0 commit comments