Skip to content

Commit 861f95c

Browse files
ricardoV94twiecki
authored andcommitted
Add test for Blockwise of COp with params
1 parent 35f0df9 commit 861f95c

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

tests/tensor/test_blockwise.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
from pytensor.gradient import grad
1010
from pytensor.graph import Apply, Op
1111
from pytensor.graph.replace import vectorize_node
12+
from pytensor.raise_op import assert_op
1213
from pytensor.tensor import diagonal, log, tensor
1314
from pytensor.tensor.blockwise import Blockwise
1415
from pytensor.tensor.nlinalg import MatrixInverse
15-
from pytensor.tensor.shape import Shape
1616
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular
1717
from pytensor.tensor.utils import _parse_gufunc_signature
1818

@@ -362,11 +362,20 @@ def test_batched_mvnormal_logp_and_dlogp(mu_batch_shape, cov_batch_shape, benchm
362362
benchmark(fn, *test_values)
363363

364364

365-
def test_op_with_params():
366-
matrix_shape_blockwise = Blockwise(core_op=Shape(), signature="(x1,x2)->(s)")
365+
def test_cop_with_params():
366+
matrix_assert = Blockwise(core_op=assert_op, signature="(x1,x2),()->(x1,x2)")
367+
367368
x = tensor("x", shape=(5, None, None), dtype="float64")
368-
x_shape = matrix_shape_blockwise(x)
369+
x_shape = matrix_assert(x, (x >= 0).all())
370+
369371
fn = pytensor.function([x], x_shape)
370-
pytensor.dprint(fn)
371-
# Assert blockwise
372-
print(fn(np.zeros((5, 3, 2))))
372+
[fn_out] = fn.maker.fgraph.outputs
373+
assert fn_out.owner.op == matrix_assert, "Blockwise should be in final graph"
374+
375+
np.testing.assert_allclose(
376+
fn(np.zeros((5, 3, 2))),
377+
np.zeros((5, 3, 2)),
378+
)
379+
380+
with pytest.raises(AssertionError):
381+
fn(np.zeros((5, 3, 2)) - 1)

0 commit comments

Comments
 (0)