Skip to content

Commit 5a8f313

Browse files
committed
Fix failing float32 jax pinv test
1 parent 2ea5a54 commit 5a8f313

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/link/jax/test_nlinalg.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -141,5 +141,5 @@ def test_pinv():
141141
x_inv = at_nlinalg.pinv(x)
142142

143143
fgraph = FunctionGraph([x], [x_inv])
144-
x_np = np.array([[1.0, 2.0], [3.0, 4.0]])
144+
x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
145145
compare_jax_and_py(fgraph, [x_np])

0 commit comments

Comments
 (0)