|
9 | 9 | from pytensor.gradient import grad
|
10 | 10 | from pytensor.graph import Apply, Op
|
11 | 11 | from pytensor.graph.replace import vectorize_node
|
| 12 | +from pytensor.raise_op import assert_op |
12 | 13 | from pytensor.tensor import diagonal, log, tensor
|
13 | 14 | from pytensor.tensor.blockwise import Blockwise
|
14 | 15 | from pytensor.tensor.nlinalg import MatrixInverse
|
15 |
| -from pytensor.tensor.shape import Shape |
16 | 16 | from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular
|
17 | 17 | from pytensor.tensor.utils import _parse_gufunc_signature
|
18 | 18 |
|
@@ -362,11 +362,20 @@ def test_batched_mvnormal_logp_and_dlogp(mu_batch_shape, cov_batch_shape, benchm
|
362 | 362 | benchmark(fn, *test_values)
|
363 | 363 |
|
364 | 364 |
|
365 |
| -def test_op_with_params(): |
366 |
| - matrix_shape_blockwise = Blockwise(core_op=Shape(), signature="(x1,x2)->(s)") |
| 365 | +def test_cop_with_params(): |
| 366 | + matrix_assert = Blockwise(core_op=assert_op, signature="(x1,x2),()->(x1,x2)") |
| 367 | + |
367 | 368 | x = tensor("x", shape=(5, None, None), dtype="float64")
|
368 |
| - x_shape = matrix_shape_blockwise(x) |
| 369 | + x_shape = matrix_assert(x, (x >= 0).all()) |
| 370 | + |
369 | 371 | fn = pytensor.function([x], x_shape)
|
370 |
| - pytensor.dprint(fn) |
371 |
| - # Assert blockwise |
372 |
| - print(fn(np.zeros((5, 3, 2)))) |
| 372 | + [fn_out] = fn.maker.fgraph.outputs |
| 373 | + assert fn_out.owner.op == matrix_assert, "Blockwise should be in final graph" |
| 374 | + |
| 375 | + np.testing.assert_allclose( |
| 376 | + fn(np.zeros((5, 3, 2))), |
| 377 | + np.zeros((5, 3, 2)), |
| 378 | + ) |
| 379 | + |
| 380 | + with pytest.raises(AssertionError): |
| 381 | + fn(np.zeros((5, 3, 2)) - 1) |
0 commit comments