Skip to content

Commit 8b8b4ca

Browse files
committed
Fix bug in JAX test
1 parent c572c38 commit 8b8b4ca

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

Diff for: tests/link/jax/test_tensor_basic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def test_jax_Alloc():
2929
x = ptb.AllocEmpty("float32")(2, 3)
3030

3131
def compare_shape_dtype(x, y):
32-
np.testing.assert_array_equal(x, y, strict=True)
32+
assert x.shape == y.shape and x.dtype == y.dtype
3333

3434
compare_jax_and_py([], [x], [], assert_fn=compare_shape_dtype)
3535

0 commit comments

Comments
 (0)