diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index a3f5ea9491..d311a7e302 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -126,13 +126,17 @@ def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b): B_is_1d = B.ndim == 1 - if not overwrite_b: - B_copy = _copy_to_fortran_order(B) - else: + if overwrite_b: B_copy = B + else: + if B_is_1d: + # _copy_to_fortran_order does nothing with vectors + B_copy = np.copy(B) + else: + B_copy = _copy_to_fortran_order(B) if B_is_1d: - B_copy = np.expand_dims(B, -1) + B_copy = np.expand_dims(B_copy, -1) NRHS = 1 if B_is_1d else int(B_copy.shape[-1]) diff --git a/tests/link/jax/test_basic.py b/tests/link/jax/test_basic.py index d0f748f3e7..4a6eee1890 100644 --- a/tests/link/jax/test_basic.py +++ b/tests/link/jax/test_basic.py @@ -7,12 +7,12 @@ from pytensor.compile.builders import OpFromGraph from pytensor.compile.function import function from pytensor.compile.mode import JAX, Mode -from pytensor.compile.sharedvalue import SharedVariable, shared +from pytensor.compile.sharedvalue import shared from pytensor.configdefaults import config from pytensor.graph import RewriteDatabaseQuery -from pytensor.graph.basic import Apply +from pytensor.graph.basic import Apply, Variable from pytensor.graph.fg import FunctionGraph -from pytensor.graph.op import Op, get_test_value +from pytensor.graph.op import Op from pytensor.ifelse import ifelse from pytensor.link.jax import JAXLinker from pytensor.raise_op import assert_op @@ -34,25 +34,28 @@ def set_pytensor_flags(): def compare_jax_and_py( - fgraph: FunctionGraph, + graph_inputs: Iterable[Variable], + graph_outputs: Variable | Iterable[Variable], test_inputs: Iterable, + *, assert_fn: Callable | None = None, must_be_device_array: bool = True, jax_mode=jax_mode, py_mode=py_mode, ): - """Function to compare python graph output and jax compiled output for testing equality + """Function to compare python function output and jax compiled output for testing equality - In the tests below computational graphs are defined in PyTensor. These graphs are then passed to - this function which then compiles the graphs in both jax and python, runs the calculation - in both and checks if the results are the same + The inputs and outputs are then passed to this function which then compiles the given function in both + jax and python, runs the calculation in both and checks if the results are the same Parameters ---------- - fgraph: FunctionGraph - PyTensor function Graph object + graph_inputs: + Symbolic inputs to the graph + outputs: + Symbolic outputs of the graph test_inputs: iter - Numerical inputs for testing the function graph + Numerical inputs for testing the function. assert_fn: func, opt Assert function used to check for equality between python and jax. If not provided uses np.testing.assert_allclose @@ -68,8 +71,10 @@ def compare_jax_and_py( if assert_fn is None: assert_fn = partial(np.testing.assert_allclose, rtol=1e-4) - fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)] - pytensor_jax_fn = function(fn_inputs, fgraph.outputs, mode=jax_mode) + if any(inp.owner is not None for inp in graph_inputs): + raise ValueError("Inputs must be root variables") + + pytensor_jax_fn = function(graph_inputs, graph_outputs, mode=jax_mode) jax_res = pytensor_jax_fn(*test_inputs) if must_be_device_array: @@ -78,10 +83,10 @@ def compare_jax_and_py( else: assert isinstance(jax_res, jax.Array) - pytensor_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode) + pytensor_py_fn = function(graph_inputs, graph_outputs, mode=py_mode) py_res = pytensor_py_fn(*test_inputs) - if len(fgraph.outputs) > 1: + if isinstance(graph_outputs, list | tuple): for j, p in zip(jax_res, py_res, strict=True): assert_fn(j, p) else: @@ -187,16 +192,14 @@ def test_jax_ifelse(): false_vals = np.r_[-1, -2, -3] x = ifelse(np.array(True), true_vals, false_vals) - x_fg = FunctionGraph([], [x]) - compare_jax_and_py(x_fg, []) + compare_jax_and_py([], [x], []) a = dscalar("a") - a.tag.test_value = np.array(0.2, dtype=config.floatX) + a_test = np.array(0.2, dtype=config.floatX) x = ifelse(a < 0.5, true_vals, false_vals) - x_fg = FunctionGraph([a], [x]) # I.e. False - compare_jax_and_py(x_fg, [get_test_value(i) for i in x_fg.inputs]) + compare_jax_and_py([a], [x], [a_test]) def test_jax_checkandraise(): @@ -209,11 +212,6 @@ def test_jax_checkandraise(): function((p,), res, mode=jax_mode) -def set_test_value(x, v): - x.tag.test_value = v - return x - - def test_OpFromGraph(): x, y, z = matrices("xyz") ofg_1 = OpFromGraph([x, y], [x + y], inline=False) @@ -221,10 +219,9 @@ def test_OpFromGraph(): o1, o2 = ofg_2(y, z) out = ofg_1(x, o1) + o2 - out_fg = FunctionGraph([x, y, z], [out]) xv = np.ones((2, 2), dtype=config.floatX) yv = np.ones((2, 2), dtype=config.floatX) * 3 zv = np.ones((2, 2), dtype=config.floatX) * 5 - compare_jax_and_py(out_fg, [xv, yv, zv]) + compare_jax_and_py([x, y, z], [out], [xv, yv, zv]) diff --git a/tests/link/jax/test_blas.py b/tests/link/jax/test_blas.py index fe162d1d45..aedd52eca1 100644 --- a/tests/link/jax/test_blas.py +++ b/tests/link/jax/test_blas.py @@ -4,8 +4,6 @@ from pytensor.compile.function import function from pytensor.compile.mode import Mode from pytensor.configdefaults import config -from pytensor.graph.fg import FunctionGraph -from pytensor.graph.op import get_test_value from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.link.jax import JAXLinker from pytensor.tensor import blas as pt_blas @@ -16,21 +14,20 @@ def test_jax_BatchedDot(): # tensor3 . tensor3 a = tensor3("a") - a.tag.test_value = ( + a_test_value = ( np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3)) ) b = tensor3("b") - b.tag.test_value = ( + b_test_value = ( np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2)) ) out = pt_blas.BatchedDot()(a, b) - fgraph = FunctionGraph([a, b], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_jax_and_py([a, b], [out], [a_test_value, b_test_value]) # A dimension mismatch should raise a TypeError for compatibility - inputs = [get_test_value(a)[:-1], get_test_value(b)] + inputs = [a_test_value[:-1], b_test_value] opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) jax_mode = Mode(JAXLinker(), opts) - pytensor_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode) + pytensor_jax_fn = function([a, b], [out], mode=jax_mode) with pytest.raises(TypeError): pytensor_jax_fn(*inputs) diff --git a/tests/link/jax/test_blockwise.py b/tests/link/jax/test_blockwise.py index 64569b0274..74d518c891 100644 --- a/tests/link/jax/test_blockwise.py +++ b/tests/link/jax/test_blockwise.py @@ -2,7 +2,6 @@ import pytest from pytensor import config -from pytensor.graph import FunctionGraph from pytensor.tensor import tensor from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.math import Dot, matmul @@ -32,8 +31,7 @@ def test_matmul(matmul_op): out = matmul_op(a, b) assert isinstance(out.owner.op, Blockwise) - fg = FunctionGraph([a, b], [out]) - fn, _ = compare_jax_and_py(fg, test_values) + fn, _ = compare_jax_and_py([a, b], [out], test_values) # Check we are not adding any unnecessary stuff jaxpr = str(jax.make_jaxpr(fn.vm.jit_fn)(*test_values)) diff --git a/tests/link/jax/test_einsum.py b/tests/link/jax/test_einsum.py index 4f1d25acfe..18fce217be 100644 --- a/tests/link/jax/test_einsum.py +++ b/tests/link/jax/test_einsum.py @@ -2,7 +2,6 @@ import pytest import pytensor.tensor as pt -from pytensor.graph import FunctionGraph from tests.link.jax.test_basic import compare_jax_and_py @@ -22,8 +21,7 @@ def test_jax_einsum(): } x_pt, y_pt, z_pt = (pt.tensor(name, shape=shape) for name, shape in shapes.items()) out = pt.einsum(subscripts, x_pt, y_pt, z_pt) - fg = FunctionGraph([x_pt, y_pt, z_pt], [out]) - compare_jax_and_py(fg, [x, y, z]) + compare_jax_and_py([x_pt, y_pt, z_pt], [out], [x, y, z]) def test_ellipsis_einsum(): @@ -34,5 +32,4 @@ def test_ellipsis_einsum(): x_pt = pt.tensor("x", shape=x.shape) y_pt = pt.tensor("y", shape=y.shape) out = pt.einsum(subscripts, x_pt, y_pt) - fg = FunctionGraph([x_pt, y_pt], [out]) - compare_jax_and_py(fg, [x, y]) + compare_jax_and_py([x_pt, y_pt], [out], [x, y]) diff --git a/tests/link/jax/test_elemwise.py b/tests/link/jax/test_elemwise.py index 687049f7e1..796d25d07b 100644 --- a/tests/link/jax/test_elemwise.py +++ b/tests/link/jax/test_elemwise.py @@ -6,8 +6,6 @@ import pytensor.tensor as pt from pytensor.compile import get_mode from pytensor.configdefaults import config -from pytensor.graph.fg import FunctionGraph -from pytensor.graph.op import get_test_value from pytensor.tensor import elemwise as pt_elemwise from pytensor.tensor.math import all as pt_all from pytensor.tensor.math import prod @@ -26,22 +24,22 @@ def test_jax_Dimshuffle(): a_pt = matrix("a") x = a_pt.T - x_fg = FunctionGraph([a_pt], [x]) - compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]) + compare_jax_and_py( + [a_pt], [x], [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)] + ) x = a_pt.dimshuffle([0, 1, "x"]) - x_fg = FunctionGraph([a_pt], [x]) - compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]) + compare_jax_and_py( + [a_pt], [x], [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)] + ) a_pt = tensor(dtype=config.floatX, shape=(None, 1)) x = a_pt.dimshuffle((0,)) - x_fg = FunctionGraph([a_pt], [x]) - compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]) + compare_jax_and_py([a_pt], [x], [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]) a_pt = tensor(dtype=config.floatX, shape=(None, 1)) x = pt_elemwise.DimShuffle(input_ndim=2, new_order=(0,))(a_pt) - x_fg = FunctionGraph([a_pt], [x]) - compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]) + compare_jax_and_py([a_pt], [x], [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]) def test_jax_CAReduce(): @@ -49,64 +47,58 @@ def test_jax_CAReduce(): a_pt.tag.test_value = np.r_[1, 2, 3].astype(config.floatX) x = pt_sum(a_pt, axis=None) - x_fg = FunctionGraph([a_pt], [x]) - compare_jax_and_py(x_fg, [np.r_[1, 2, 3].astype(config.floatX)]) + compare_jax_and_py([a_pt], [x], [np.r_[1, 2, 3].astype(config.floatX)]) a_pt = matrix("a") a_pt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX) x = pt_sum(a_pt, axis=0) - x_fg = FunctionGraph([a_pt], [x]) - compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) + compare_jax_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) x = pt_sum(a_pt, axis=1) - x_fg = FunctionGraph([a_pt], [x]) - compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) + compare_jax_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) a_pt = matrix("a") a_pt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX) x = prod(a_pt, axis=0) - x_fg = FunctionGraph([a_pt], [x]) - compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) + compare_jax_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) x = pt_all(a_pt) - x_fg = FunctionGraph([a_pt], [x]) - compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) + compare_jax_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) @pytest.mark.parametrize("axis", [None, 0, 1]) def test_softmax(axis): x = matrix("x") - x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3) + x_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3) out = softmax(x, axis=axis) - fgraph = FunctionGraph([x], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_jax_and_py([x], [out], [x_test_value]) @pytest.mark.parametrize("axis", [None, 0, 1]) def test_logsoftmax(axis): x = matrix("x") - x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3) + x_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3) out = log_softmax(x, axis=axis) - fgraph = FunctionGraph([x], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + compare_jax_and_py([x], [out], [x_test_value]) @pytest.mark.parametrize("axis", [None, 0, 1]) def test_softmax_grad(axis): dy = matrix("dy") - dy.tag.test_value = np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX) + dy_test_value = np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX) sm = matrix("sm") - sm.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3) + sm_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3) out = SoftmaxGrad(axis=axis)(dy, sm) - fgraph = FunctionGraph([dy, sm], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + compare_jax_and_py([dy, sm], [out], [dy_test_value, sm_test_value]) @pytest.mark.parametrize("size", [(10, 10), (1000, 1000)]) @@ -134,6 +126,4 @@ def test_logsumexp_benchmark(size, axis, benchmark): def test_multiple_input_multiply(): x, y, z = vectors("xyz") out = pt.mul(x, y, z) - - fg = FunctionGraph(outputs=[out], clone=False) - compare_jax_and_py(fg, [[1.5], [2.5], [3.5]]) + compare_jax_and_py([x, y, z], [out], test_inputs=[[1.5], [2.5], [3.5]]) diff --git a/tests/link/jax/test_extra_ops.py b/tests/link/jax/test_extra_ops.py index 0c8fb92810..f1c7609a66 100644 --- a/tests/link/jax/test_extra_ops.py +++ b/tests/link/jax/test_extra_ops.py @@ -3,8 +3,6 @@ import pytensor.tensor.basic as ptb from pytensor.configdefaults import config -from pytensor.graph.fg import FunctionGraph -from pytensor.graph.op import get_test_value from pytensor.tensor import extra_ops as pt_extra_ops from pytensor.tensor.sort import argsort from pytensor.tensor.type import matrix, tensor @@ -19,57 +17,45 @@ def test_extra_ops(): a_test = np.arange(6, dtype=config.floatX).reshape((3, 2)) out = pt_extra_ops.cumsum(a, axis=0) - fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [a_test]) + compare_jax_and_py([a], [out], [a_test]) out = pt_extra_ops.cumprod(a, axis=1) - fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [a_test]) + compare_jax_and_py([a], [out], [a_test]) out = pt_extra_ops.diff(a, n=2, axis=1) - fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [a_test]) + compare_jax_and_py([a], [out], [a_test]) out = pt_extra_ops.repeat(a, (3, 3), axis=1) - fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [a_test]) + compare_jax_and_py([a], [out], [a_test]) c = ptb.as_tensor(5) out = pt_extra_ops.fill_diagonal(a, c) - fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [a_test]) + compare_jax_and_py([a], [out], [a_test]) with pytest.raises(NotImplementedError): out = pt_extra_ops.fill_diagonal_offset(a, c, c) - fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [a_test]) + compare_jax_and_py([a], [out], [a_test]) with pytest.raises(NotImplementedError): out = pt_extra_ops.Unique(axis=1)(a) - fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [a_test]) + compare_jax_and_py([a], [out], [a_test]) indices = np.arange(np.prod((3, 4))) out = pt_extra_ops.unravel_index(indices, (3, 4), order="C") - fgraph = FunctionGraph([], out) - compare_jax_and_py( - fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False - ) + compare_jax_and_py([], out, [], must_be_device_array=False) v = ptb.as_tensor_variable(6.0) sorted_idx = argsort(a.ravel()) out = pt_extra_ops.searchsorted(a.ravel()[sorted_idx], v) - fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [a_test]) + compare_jax_and_py([a], [out], [a_test]) @pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes") def test_bartlett_dynamic_shape(): c = tensor(shape=(), dtype=int) out = pt_extra_ops.bartlett(c) - fgraph = FunctionGraph([], [out]) - compare_jax_and_py(fgraph, [np.array(5)]) + compare_jax_and_py([], [out], [np.array(5)]) @pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes") @@ -79,8 +65,7 @@ def test_ravel_multi_index_dynamic_shape(): x = tensor(shape=(None,), dtype=int) y = tensor(shape=(None,), dtype=int) out = pt_extra_ops.ravel_multi_index((x, y), (3, 4)) - fgraph = FunctionGraph([], [out]) - compare_jax_and_py(fgraph, [x_test, y_test]) + compare_jax_and_py([], [out], [x_test, y_test]) @pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes") @@ -89,5 +74,4 @@ def test_unique_dynamic_shape(): a_test = np.arange(6, dtype=config.floatX).reshape((3, 2)) out = pt_extra_ops.Unique()(a) - fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [a_test]) + compare_jax_and_py([a], [out], [a_test]) diff --git a/tests/link/jax/test_math.py b/tests/link/jax/test_math.py index 0a1e91b4da..9f0172675a 100644 --- a/tests/link/jax/test_math.py +++ b/tests/link/jax/test_math.py @@ -2,8 +2,6 @@ import pytest from pytensor.configdefaults import config -from pytensor.graph.fg import FunctionGraph -from pytensor.graph.op import get_test_value from pytensor.tensor.math import Argmax, Max, maximum from pytensor.tensor.math import max as pt_max from pytensor.tensor.type import dvector, matrix, scalar, vector @@ -20,33 +18,39 @@ def test_jax_max_and_argmax(): mx = Max([0])(x) amx = Argmax([0])(x) out = mx * amx - out_fg = FunctionGraph([x], [out]) - compare_jax_and_py(out_fg, [np.r_[1, 2]]) + compare_jax_and_py([x], [out], [np.r_[1, 2]]) def test_dot(): y = vector("y") - y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) + y_test_value = np.r_[1.0, 2.0].astype(config.floatX) x = vector("x") - x.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX) + x_test_value = np.r_[3.0, 4.0].astype(config.floatX) A = matrix("A") - A.tag.test_value = np.empty((2, 2), dtype=config.floatX) + A_test_value = np.empty((2, 2), dtype=config.floatX) alpha = scalar("alpha") - alpha.tag.test_value = np.array(3.0, dtype=config.floatX) + alpha_test_value = np.array(3.0, dtype=config.floatX) beta = scalar("beta") - beta.tag.test_value = np.array(5.0, dtype=config.floatX) + beta_test_value = np.array(5.0, dtype=config.floatX) # This should be converted into a `Gemv` `Op` when the non-JAX compatible # optimizations are turned on; however, when using JAX mode, it should # leave the expression alone. out = y.dot(alpha * A).dot(x) + beta * y - fgraph = FunctionGraph([y, x, A, alpha, beta], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_jax_and_py( + [y, x, A, alpha, beta], + out, + [ + y_test_value, + x_test_value, + A_test_value, + alpha_test_value, + beta_test_value, + ], + ) out = maximum(y, x) - fgraph = FunctionGraph([y, x], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_jax_and_py([y, x], [out], [y_test_value, x_test_value]) out = pt_max(y) - fgraph = FunctionGraph([y], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_jax_and_py([y], [out], [y_test_value]) diff --git a/tests/link/jax/test_nlinalg.py b/tests/link/jax/test_nlinalg.py index cd6ca2ac71..866d99ce71 100644 --- a/tests/link/jax/test_nlinalg.py +++ b/tests/link/jax/test_nlinalg.py @@ -3,7 +3,6 @@ from pytensor.compile.function import function from pytensor.configdefaults import config -from pytensor.graph.fg import FunctionGraph from pytensor.tensor import nlinalg as pt_nlinalg from pytensor.tensor.type import matrix from tests.link.jax.test_basic import compare_jax_and_py @@ -21,41 +20,34 @@ def test_jax_basic_multiout(): x = matrix("x") outs = pt_nlinalg.eig(x) - out_fg = FunctionGraph([x], outs) def assert_fn(x, y): np.testing.assert_allclose(x.astype(config.floatX), y, rtol=1e-3) - compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) + compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn) outs = pt_nlinalg.eigh(x) - out_fg = FunctionGraph([x], outs) - compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) + compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn) outs = pt_nlinalg.qr(x, mode="full") - out_fg = FunctionGraph([x], outs) - compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) + compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn) outs = pt_nlinalg.qr(x, mode="reduced") - out_fg = FunctionGraph([x], outs) - compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) + compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn) outs = pt_nlinalg.svd(x) - out_fg = FunctionGraph([x], outs) - compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) + compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn) outs = pt_nlinalg.slogdet(x) - out_fg = FunctionGraph([x], outs) - compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) + compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn) def test_pinv(): x = matrix("x") x_inv = pt_nlinalg.pinv(x) - fgraph = FunctionGraph([x], [x_inv]) x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) - compare_jax_and_py(fgraph, [x_np]) + compare_jax_and_py([x], [x_inv], [x_np]) def test_pinv_hermitian(): @@ -94,8 +86,7 @@ def test_kron(): y = matrix("y") z = pt_nlinalg.kron(x, y) - fgraph = FunctionGraph([x, y], [z]) x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) y_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) - compare_jax_and_py(fgraph, [x_np, y_np]) + compare_jax_and_py([x, y], [z], [x_np, y_np]) diff --git a/tests/link/jax/test_pad.py b/tests/link/jax/test_pad.py index 2321645741..8ecb460ace 100644 --- a/tests/link/jax/test_pad.py +++ b/tests/link/jax/test_pad.py @@ -3,7 +3,6 @@ import pytensor.tensor as pt from pytensor import config -from pytensor.graph import FunctionGraph from pytensor.tensor.pad import PadMode from tests.link.jax.test_basic import compare_jax_and_py @@ -53,10 +52,10 @@ def test_jax_pad(mode: PadMode, kwargs): x = np.random.normal(size=(3, 3)) res = pt.pad(x_pt, mode=mode, pad_width=3, **kwargs) - res_fg = FunctionGraph([x_pt], [res]) compare_jax_and_py( - res_fg, + [x_pt], + [res], [x], assert_fn=lambda x, y: np.testing.assert_allclose(x, y, rtol=RTOL, atol=ATOL), py_mode="FAST_RUN", diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index fa25f3aac0..2a6ebca0af 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -7,13 +7,11 @@ import pytensor.tensor.random.basic as ptr from pytensor import clone_replace from pytensor.compile.function import function -from pytensor.compile.sharedvalue import SharedVariable, shared -from pytensor.graph.basic import Constant -from pytensor.graph.fg import FunctionGraph +from pytensor.compile.sharedvalue import shared from pytensor.tensor.random.basic import RandomVariable from pytensor.tensor.random.type import RandomType from pytensor.tensor.random.utils import RandomStream -from tests.link.jax.test_basic import compare_jax_and_py, jax_mode, set_test_value +from tests.link.jax.test_basic import compare_jax_and_py, jax_mode from tests.tensor.random.test_basic import ( batched_permutation_tester, batched_unweighted_choice_without_replacement_tester, @@ -147,11 +145,11 @@ def test_replaced_shared_rng_storage_ordering_equality(): ( ptr.beta, [ - set_test_value( + ( pt.dvector(), np.array([1.0, 2.0], dtype=np.float64), ), - set_test_value( + ( pt.dscalar(), np.array(1.0, dtype=np.float64), ), @@ -163,11 +161,11 @@ def test_replaced_shared_rng_storage_ordering_equality(): ( ptr.cauchy, [ - set_test_value( + ( pt.dvector(), np.array([1.0, 2.0], dtype=np.float64), ), - set_test_value( + ( pt.dscalar(), np.array(1.0, dtype=np.float64), ), @@ -179,7 +177,7 @@ def test_replaced_shared_rng_storage_ordering_equality(): ( ptr.exponential, [ - set_test_value( + ( pt.dvector(), np.array([1.0, 2.0], dtype=np.float64), ), @@ -191,11 +189,11 @@ def test_replaced_shared_rng_storage_ordering_equality(): ( ptr._gamma, [ - set_test_value( + ( pt.dvector(), np.array([1.0, 2.0], dtype=np.float64), ), - set_test_value( + ( pt.dvector(), np.array([0.5, 3.0], dtype=np.float64), ), @@ -207,11 +205,11 @@ def test_replaced_shared_rng_storage_ordering_equality(): ( ptr.gumbel, [ - set_test_value( + ( pt.lvector(), np.array([1, 2], dtype=np.int64), ), - set_test_value( + ( pt.dscalar(), np.array(1.0, dtype=np.float64), ), @@ -223,8 +221,8 @@ def test_replaced_shared_rng_storage_ordering_equality(): ( ptr.laplace, [ - set_test_value(pt.dvector(), np.array([1.0, 2.0], dtype=np.float64)), - set_test_value(pt.dscalar(), np.array(1.0, dtype=np.float64)), + (pt.dvector(), np.array([1.0, 2.0], dtype=np.float64)), + (pt.dscalar(), np.array(1.0, dtype=np.float64)), ], (2,), "laplace", @@ -233,11 +231,11 @@ def test_replaced_shared_rng_storage_ordering_equality(): ( ptr.logistic, [ - set_test_value( + ( pt.dvector(), np.array([1.0, 2.0], dtype=np.float64), ), - set_test_value( + ( pt.dscalar(), np.array(1.0, dtype=np.float64), ), @@ -249,11 +247,11 @@ def test_replaced_shared_rng_storage_ordering_equality(): ( ptr.lognormal, [ - set_test_value( + ( pt.lvector(), np.array([0, 0], dtype=np.int64), ), - set_test_value( + ( pt.dscalar(), np.array(1.0, dtype=np.float64), ), @@ -265,11 +263,11 @@ def test_replaced_shared_rng_storage_ordering_equality(): ( ptr.normal, [ - set_test_value( + ( pt.lvector(), np.array([1, 2], dtype=np.int64), ), - set_test_value( + ( pt.dscalar(), np.array(1.0, dtype=np.float64), ), @@ -281,11 +279,11 @@ def test_replaced_shared_rng_storage_ordering_equality(): ( ptr.pareto, [ - set_test_value( + ( pt.dvector(), np.array([1.0, 2.0], dtype=np.float64), ), - set_test_value( + ( pt.dvector(), np.array([2.0, 10.0], dtype=np.float64), ), @@ -297,7 +295,7 @@ def test_replaced_shared_rng_storage_ordering_equality(): ( ptr.poisson, [ - set_test_value( + ( pt.dvector(), np.array([100000.0, 200000.0], dtype=np.float64), ), @@ -309,11 +307,11 @@ def test_replaced_shared_rng_storage_ordering_equality(): ( ptr.integers, [ - set_test_value( + ( pt.lscalar(), np.array(0, dtype=np.int64), ), - set_test_value( # high-value necessary since test on cdf + ( # high-value necessary since test on cdf pt.lscalar(), np.array(1000, dtype=np.int64), ), @@ -332,15 +330,15 @@ def test_replaced_shared_rng_storage_ordering_equality(): ( ptr.t, [ - set_test_value( + ( pt.dscalar(), np.array(2.0, dtype=np.float64), ), - set_test_value( + ( pt.dvector(), np.array([1.0, 2.0], dtype=np.float64), ), - set_test_value( + ( pt.dscalar(), np.array(1.0, dtype=np.float64), ), @@ -352,11 +350,11 @@ def test_replaced_shared_rng_storage_ordering_equality(): ( ptr.uniform, [ - set_test_value( + ( pt.dvector(), np.array([1.0, 2.0], dtype=np.float64), ), - set_test_value( + ( pt.dscalar(), np.array(1000.0, dtype=np.float64), ), @@ -368,11 +366,11 @@ def test_replaced_shared_rng_storage_ordering_equality(): ( ptr.halfnormal, [ - set_test_value( + ( pt.dvector(), np.array([-1.0, 200.0], dtype=np.float64), ), - set_test_value( + ( pt.dscalar(), np.array(1000.0, dtype=np.float64), ), @@ -384,11 +382,11 @@ def test_replaced_shared_rng_storage_ordering_equality(): ( ptr.invgamma, [ - set_test_value( + ( pt.dvector(), np.array([10.4, 2.8], dtype=np.float64), ), - set_test_value( + ( pt.dvector(), np.array([3.4, 7.3], dtype=np.float64), ), @@ -400,7 +398,7 @@ def test_replaced_shared_rng_storage_ordering_equality(): ( ptr.chisquare, [ - set_test_value( + ( pt.dvector(), np.array([2.4, 4.9], dtype=np.float64), ), @@ -412,15 +410,15 @@ def test_replaced_shared_rng_storage_ordering_equality(): ( ptr.gengamma, [ - set_test_value( + ( pt.dvector(), np.array([10.4, 2.8], dtype=np.float64), ), - set_test_value( + ( pt.dvector(), np.array([3.4, 7.3], dtype=np.float64), ), - set_test_value( + ( pt.dvector(), np.array([0.9, 2.0], dtype=np.float64), ), @@ -432,11 +430,11 @@ def test_replaced_shared_rng_storage_ordering_equality(): ( ptr.wald, [ - set_test_value( + ( pt.dvector(), np.array([10.4, 2.8], dtype=np.float64), ), - set_test_value( + ( pt.dvector(), np.array([4.5, 2.0], dtype=np.float64), ), @@ -449,11 +447,11 @@ def test_replaced_shared_rng_storage_ordering_equality(): pytest.param( ptr.vonmises, [ - set_test_value( + ( pt.dvector(), np.array([-0.5, 1.3], dtype=np.float64), ), - set_test_value( + ( pt.dvector(), np.array([5.5, 13.0], dtype=np.float64), ), @@ -478,20 +476,16 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c The transpiled `RandomVariable` `Op`. dist_params The parameters passed to the op. - """ + dist_params, test_values = ( + zip(*dist_params, strict=True) if dist_params else ([], []) + ) rng = shared(np.random.default_rng(29403)) g = rv_op(*dist_params, size=(10000, *base_size), rng=rng) g_fn = compile_random_function(dist_params, g, mode=jax_mode) - samples = g_fn( - *[ - i.tag.test_value - for i in g_fn.maker.fgraph.inputs - if not isinstance(i, SharedVariable | Constant) - ] - ) + samples = g_fn(*test_values) - bcast_dist_args = np.broadcast_arrays(*[i.tag.test_value for i in dist_params]) + bcast_dist_args = np.broadcast_arrays(*test_values) for idx in np.ndindex(*base_size): cdf_params = params_conv(*(arg[idx] for arg in bcast_dist_args)) @@ -775,13 +769,12 @@ def rng_fn(cls, rng, size): nonexistentrv = NonExistentRV() rng = shared(np.random.default_rng(123)) out = nonexistentrv(rng=rng) - fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False) with pytest.raises(NotImplementedError): with pytest.warns( UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used" ): - compare_jax_and_py(fgraph, []) + compare_jax_and_py([], [out], []) def test_random_custom_implementation(): @@ -810,11 +803,10 @@ def sample_fn(rng, size, dtype, *parameters): nonexistentrv = CustomRV() rng = shared(np.random.default_rng(123)) out = nonexistentrv(rng=rng) - fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False) with pytest.warns( UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used" ): - compare_jax_and_py(fgraph, []) + compare_jax_and_py([], [out], []) def test_random_concrete_shape(): diff --git a/tests/link/jax/test_scalar.py b/tests/link/jax/test_scalar.py index 475062e86c..463405fff4 100644 --- a/tests/link/jax/test_scalar.py +++ b/tests/link/jax/test_scalar.py @@ -5,7 +5,6 @@ import pytensor.tensor as pt from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph -from pytensor.graph.op import get_test_value from pytensor.scalar.basic import Composite from pytensor.tensor import as_tensor from pytensor.tensor.elemwise import Elemwise @@ -51,20 +50,19 @@ def test_second(): b = scalar("b") out = ps.second(a0, b) - fgraph = FunctionGraph([a0, b], [out]) - compare_jax_and_py(fgraph, [10.0, 5.0]) + compare_jax_and_py([a0, b], [out], [10.0, 5.0]) a1 = vector("a1") out = pt.second(a1, b) - fgraph = FunctionGraph([a1, b], [out]) - compare_jax_and_py(fgraph, [np.zeros([5], dtype=config.floatX), 5.0]) + compare_jax_and_py([a1, b], [out], [np.zeros([5], dtype=config.floatX), 5.0]) a2 = matrix("a2", shape=(1, None), dtype="float64") b2 = matrix("b2", shape=(None, 1), dtype="int32") out = pt.second(a2, b2) - fgraph = FunctionGraph([a2, b2], [out]) compare_jax_and_py( - fgraph, [np.zeros((1, 3), dtype="float64"), np.ones((5, 1), dtype="int32")] + [a2, b2], + [out], + [np.zeros((1, 3), dtype="float64"), np.ones((5, 1), dtype="int32")], ) @@ -81,11 +79,10 @@ def test_second_constant_scalar(): def test_identity(): a = scalar("a") - a.tag.test_value = 10 + a_test_value = 10 out = ps.identity(a) - fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_jax_and_py([a], [out], [a_test_value]) @pytest.mark.parametrize( @@ -109,13 +106,11 @@ def test_jax_Composite_singe_output(x, y, x_val, y_val): out = comp_op(x, y) - out_fg = FunctionGraph([x, y], [out]) - test_input_vals = [ x_val.astype(config.floatX), y_val.astype(config.floatX), ] - _ = compare_jax_and_py(out_fg, test_input_vals) + _ = compare_jax_and_py([x, y], [out], test_input_vals) def test_jax_Composite_multi_output(): @@ -124,32 +119,28 @@ def test_jax_Composite_multi_output(): x_s = ps.float64("xs") outs = Elemwise(Composite(inputs=[x_s], outputs=[x_s + 1, x_s - 1]))(x) - fgraph = FunctionGraph([x], outs) - compare_jax_and_py(fgraph, [np.arange(10, dtype=config.floatX)]) + compare_jax_and_py([x], outs, [np.arange(10, dtype=config.floatX)]) def test_erf(): x = scalar("x") out = erf(x) - fg = FunctionGraph([x], [out]) - compare_jax_and_py(fg, [1.0]) + compare_jax_and_py([x], [out], [1.0]) def test_erfc(): x = scalar("x") out = erfc(x) - fg = FunctionGraph([x], [out]) - compare_jax_and_py(fg, [1.0]) + compare_jax_and_py([x], [out], [1.0]) def test_erfinv(): x = scalar("x") out = erfinv(x) - fg = FunctionGraph([x], [out]) - compare_jax_and_py(fg, [0.95]) + compare_jax_and_py([x], [out], [0.95]) @pytest.mark.parametrize( @@ -166,8 +157,7 @@ def test_tfp_ops(op, test_values): inputs = [as_tensor(test_value).type() for test_value in test_values] output = op(*inputs) - fg = FunctionGraph(inputs, [output]) - compare_jax_and_py(fg, test_values) + compare_jax_and_py(inputs, [output], test_values) def test_betaincinv(): @@ -175,9 +165,10 @@ def test_betaincinv(): b = vector("b", dtype="float64") x = vector("x", dtype="float64") out = betaincinv(a, b, x) - fg = FunctionGraph([a, b, x], [out]) + compare_jax_and_py( - fg, + [a, b, x], + [out], [ np.array([5.5, 7.0]), np.array([5.5, 7.0]), @@ -190,39 +181,40 @@ def test_gammaincinv(): k = vector("k", dtype="float64") x = vector("x", dtype="float64") out = gammaincinv(k, x) - fg = FunctionGraph([k, x], [out]) - compare_jax_and_py(fg, [np.array([5.5, 7.0]), np.array([0.25, 0.7])]) + + compare_jax_and_py([k, x], [out], [np.array([5.5, 7.0]), np.array([0.25, 0.7])]) def test_gammainccinv(): k = vector("k", dtype="float64") x = vector("x", dtype="float64") out = gammainccinv(k, x) - fg = FunctionGraph([k, x], [out]) - compare_jax_and_py(fg, [np.array([5.5, 7.0]), np.array([0.25, 0.7])]) + + compare_jax_and_py([k, x], [out], [np.array([5.5, 7.0]), np.array([0.25, 0.7])]) def test_psi(): x = scalar("x") out = psi(x) - fg = FunctionGraph([x], [out]) - compare_jax_and_py(fg, [3.0]) + + compare_jax_and_py([x], [out], [3.0]) def test_tri_gamma(): x = vector("x", dtype="float64") out = tri_gamma(x) - fg = FunctionGraph([x], [out]) - compare_jax_and_py(fg, [np.array([3.0, 5.0])]) + + compare_jax_and_py([x], [out], [np.array([3.0, 5.0])]) def test_polygamma(): n = vector("n", dtype="int32") x = vector("x", dtype="float32") out = polygamma(n, x) - fg = FunctionGraph([n, x], [out]) + compare_jax_and_py( - fg, + [n, x], + [out], [ np.array([0, 1, 2]).astype("int32"), np.array([0.5, 0.9, 2.5]).astype("float32"), @@ -233,41 +225,34 @@ def test_polygamma(): def test_log1mexp(): x = vector("x") out = log1mexp(x) - fg = FunctionGraph([x], [out]) - compare_jax_and_py(fg, [[-1.0, -0.75, -0.5, -0.25]]) + compare_jax_and_py([x], [out], [[-1.0, -0.75, -0.5, -0.25]]) def test_nnet(): x = vector("x") - x.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) + x_test_value = np.r_[1.0, 2.0].astype(config.floatX) out = sigmoid(x) - fgraph = FunctionGraph([x], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_jax_and_py([x], [out], [x_test_value]) out = softplus(x) - fgraph = FunctionGraph([x], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_jax_and_py([x], [out], [x_test_value]) def test_jax_variadic_Scalar(): mu = vector("mu", dtype=config.floatX) - mu.tag.test_value = np.r_[0.1, 1.1].astype(config.floatX) + mu_test_value = np.r_[0.1, 1.1].astype(config.floatX) tau = vector("tau", dtype=config.floatX) - tau.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) + tau_test_value = np.r_[1.0, 2.0].astype(config.floatX) res = -tau * mu - fgraph = FunctionGraph([mu, tau], [res]) - - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_jax_and_py([mu, tau], [res], [mu_test_value, tau_test_value]) res = -tau * (tau - mu) ** 2 - fgraph = FunctionGraph([mu, tau], [res]) - - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_jax_and_py([mu, tau], [res], [mu_test_value, tau_test_value]) def test_add_scalars(): @@ -275,8 +260,7 @@ def test_add_scalars(): size = x.shape[0] + x.shape[0] + x.shape[1] out = pt.ones(size).astype(config.floatX) - out_fg = FunctionGraph([x], [out]) - compare_jax_and_py(out_fg, [np.ones((2, 3)).astype(config.floatX)]) + compare_jax_and_py([x], [out], [np.ones((2, 3)).astype(config.floatX)]) def test_mul_scalars(): @@ -284,8 +268,7 @@ def test_mul_scalars(): size = x.shape[0] * x.shape[0] * x.shape[1] out = pt.ones(size).astype(config.floatX) - out_fg = FunctionGraph([x], [out]) - compare_jax_and_py(out_fg, [np.ones((2, 3)).astype(config.floatX)]) + compare_jax_and_py([x], [out], [np.ones((2, 3)).astype(config.floatX)]) def test_div_scalars(): @@ -293,8 +276,7 @@ def test_div_scalars(): size = x.shape[0] // x.shape[1] out = pt.ones(size).astype(config.floatX) - out_fg = FunctionGraph([x], [out]) - compare_jax_and_py(out_fg, [np.ones((12, 3)).astype(config.floatX)]) + compare_jax_and_py([x], [out], [np.ones((12, 3)).astype(config.floatX)]) def test_mod_scalars(): @@ -302,39 +284,43 @@ def test_mod_scalars(): size = x.shape[0] % x.shape[1] out = pt.ones(size).astype(config.floatX) - out_fg = FunctionGraph([x], [out]) - compare_jax_and_py(out_fg, [np.ones((12, 3)).astype(config.floatX)]) + compare_jax_and_py([x], [out], [np.ones((12, 3)).astype(config.floatX)]) def test_jax_multioutput(): x = vector("x") - x.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) + x_test_value = np.r_[1.0, 2.0].astype(config.floatX) y = vector("y") - y.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX) + y_test_value = np.r_[3.0, 4.0].astype(config.floatX) w = cosh(x**2 + y / 3.0) v = cosh(x / 3.0 + y**2) - fgraph = FunctionGraph([x, y], [w, v]) - - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_jax_and_py([x, y], [w, v], [x_test_value, y_test_value]) def test_jax_logp(): mu = vector("mu") - mu.tag.test_value = np.r_[0.0, 0.0].astype(config.floatX) + mu_test_value = np.r_[0.0, 0.0].astype(config.floatX) tau = vector("tau") - tau.tag.test_value = np.r_[1.0, 1.0].astype(config.floatX) + tau_test_value = np.r_[1.0, 1.0].astype(config.floatX) sigma = vector("sigma") - sigma.tag.test_value = (1.0 / get_test_value(tau)).astype(config.floatX) + sigma_test_value = (1.0 / tau_test_value).astype(config.floatX) value = vector("value") - value.tag.test_value = np.r_[0.1, -10].astype(config.floatX) + value_test_value = np.r_[0.1, -10].astype(config.floatX) logp = (-tau * (value - mu) ** 2 + log(tau / np.pi / 2.0)) / 2.0 conditions = [sigma > 0] alltrue = pt_all([pt_all(1 * val) for val in conditions]) normal_logp = pt.switch(alltrue, logp, -np.inf) - fgraph = FunctionGraph([mu, tau, sigma, value], [normal_logp]) - - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_jax_and_py( + [mu, tau, sigma, value], + [normal_logp], + [ + mu_test_value, + tau_test_value, + sigma_test_value, + value_test_value, + ], + ) diff --git a/tests/link/jax/test_scan.py b/tests/link/jax/test_scan.py index ae64cad4c0..4ee95ab527 100644 --- a/tests/link/jax/test_scan.py +++ b/tests/link/jax/test_scan.py @@ -7,7 +7,6 @@ from pytensor import function, shared from pytensor.compile import get_mode from pytensor.configdefaults import config -from pytensor.graph.fg import FunctionGraph from pytensor.scan import until from pytensor.scan.basic import scan from pytensor.scan.op import Scan @@ -30,9 +29,8 @@ def test_scan_sit_sot(view): ) if view: xs = xs[view] - fg = FunctionGraph([x0], [xs]) test_input_vals = [np.e] - compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") + compare_jax_and_py([x0], [xs], test_input_vals, jax_mode="JAX") @pytest.mark.parametrize("view", [None, (-1,), slice(-4, -1, None)]) @@ -45,9 +43,8 @@ def test_scan_mit_sot(view): ) if view: xs = xs[view] - fg = FunctionGraph([x0], [xs]) test_input_vals = [np.full((3,), np.e)] - compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") + compare_jax_and_py([x0], [xs], test_input_vals, jax_mode="JAX") @pytest.mark.parametrize("view_x", [None, (-1,), slice(-4, -1, None)]) @@ -72,9 +69,8 @@ def step(xtm3, xtm1, ytm4, ytm2): if view_y: ys = ys[view_y] - fg = FunctionGraph([x0, y0], [xs, ys]) test_input_vals = [np.full((3,), np.e), np.full((4,), np.pi)] - compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") + compare_jax_and_py([x0, y0], [xs, ys], test_input_vals, jax_mode="JAX") @pytest.mark.parametrize("view", [None, (-2,), slice(None, None, 2)]) @@ -90,12 +86,11 @@ def test_scan_nit_sot(view): ) if view: ys = ys[view] - fg = FunctionGraph([xs], [ys]) test_input_vals = [rng.normal(size=10)] # We need to remove pushout rewrites, or the whole scan would just be # converted to an Elemwise on xs jax_fn, _ = compare_jax_and_py( - fg, test_input_vals, jax_mode=get_mode("JAX").excluding("scan_pushout") + [xs], [ys], test_input_vals, jax_mode=get_mode("JAX").excluding("scan_pushout") ) scan_nodes = [ node for node in jax_fn.maker.fgraph.apply_nodes if isinstance(node.op, Scan) @@ -112,8 +107,7 @@ def test_scan_mit_mot(): n_steps=10, ) grads_wrt_xs = pt.grad(ys.sum(), wrt=xs) - fg = FunctionGraph([xs], [grads_wrt_xs]) - compare_jax_and_py(fg, [np.arange(10)]) + compare_jax_and_py([xs], [grads_wrt_xs], [np.arange(10)]) def test_scan_update(): @@ -192,8 +186,7 @@ def test_scan_while(): n_steps=100, ) - fg = FunctionGraph([], [xs]) - compare_jax_and_py(fg, []) + compare_jax_and_py([], [xs], []) def test_scan_SEIR(): @@ -257,11 +250,6 @@ def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta): logp_c_all.name = "C_t_logp" logp_d_all.name = "D_t_logp" - out_fg = FunctionGraph( - [at_C, at_D, st0, et0, it0, logp_c, logp_d, beta, gamma, delta], - [st, et, it, logp_c_all, logp_d_all], - ) - s0, e0, i0 = 100, 50, 25 logp_c0 = np.array(0.0, dtype=config.floatX) logp_d0 = np.array(0.0, dtype=config.floatX) @@ -283,7 +271,12 @@ def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta): gamma_val, delta_val, ] - compare_jax_and_py(out_fg, test_input_vals, jax_mode="JAX") + compare_jax_and_py( + [at_C, at_D, st0, et0, it0, logp_c, logp_d, beta, gamma, delta], + [st, et, it, logp_c_all, logp_d_all], + test_input_vals, + jax_mode="JAX", + ) def test_scan_mitsot_with_nonseq(): @@ -313,10 +306,8 @@ def input_step_fn(y_tm1, y_tm3, a): y_scan_pt.name = "y" y_scan_pt.owner.inputs[0].name = "y_all" - out_fg = FunctionGraph([a_pt], [y_scan_pt]) - test_input_vals = [np.array(10.0).astype(config.floatX)] - compare_jax_and_py(out_fg, test_input_vals, jax_mode="JAX") + compare_jax_and_py([a_pt], [y_scan_pt], test_input_vals, jax_mode="JAX") @pytest.mark.parametrize("x0_func", [dvector, dmatrix]) @@ -343,9 +334,8 @@ def test_nd_scan_sit_sot(x0_func, A_func): ) A_val = np.eye(k, dtype=config.floatX) - fg = FunctionGraph([x0, A], [xs]) test_input_vals = [x0_val, A_val] - compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") + compare_jax_and_py([x0, A], [xs], test_input_vals, jax_mode="JAX") def test_nd_scan_sit_sot_with_seq(): @@ -366,9 +356,8 @@ def test_nd_scan_sit_sot_with_seq(): x_val = np.arange(n_steps * k, dtype=config.floatX).reshape(n_steps, k) A_val = np.eye(k, dtype=config.floatX) - fg = FunctionGraph([x, A], [xs]) test_input_vals = [x_val, A_val] - compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") + compare_jax_and_py([x, A], [xs], test_input_vals, jax_mode="JAX") def test_nd_scan_mit_sot(): @@ -384,13 +373,12 @@ def test_nd_scan_mit_sot(): n_steps=10, ) - fg = FunctionGraph([x0, A, B], [xs]) x0_val = np.arange(9, dtype=config.floatX).reshape(3, 3) A_val = np.eye(3, dtype=config.floatX) B_val = np.eye(3, dtype=config.floatX) test_input_vals = [x0_val, A_val, B_val] - compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") + compare_jax_and_py([x0, A, B], [xs], test_input_vals, jax_mode="JAX") def test_nd_scan_sit_sot_with_carry(): @@ -409,12 +397,11 @@ def step(x, A): mode=get_mode("JAX"), ) - fg = FunctionGraph([x0, A], xs) x0_val = np.arange(3, dtype=config.floatX) A_val = np.eye(3, dtype=config.floatX) test_input_vals = [x0_val, A_val] - compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") + compare_jax_and_py([x0, A], xs, test_input_vals, jax_mode="JAX") def test_default_mode_excludes_incompatible_rewrites(): @@ -422,8 +409,7 @@ def test_default_mode_excludes_incompatible_rewrites(): A = matrix("A") B = matrix("B") out, _ = scan(lambda a, b: a @ b, outputs_info=[A], non_sequences=[B], n_steps=2) - fg = FunctionGraph([A, B], [out]) - compare_jax_and_py(fg, [np.eye(3), np.eye(3)], jax_mode="JAX") + compare_jax_and_py([A, B], [out], [np.eye(3), np.eye(3)], jax_mode="JAX") def test_dynamic_sequence_length(): diff --git a/tests/link/jax/test_shape.py b/tests/link/jax/test_shape.py index 6eec401578..085f67f411 100644 --- a/tests/link/jax/test_shape.py +++ b/tests/link/jax/test_shape.py @@ -4,7 +4,6 @@ import pytensor.tensor as pt from pytensor.compile.ops import DeepCopyOp, ViewOp from pytensor.configdefaults import config -from pytensor.graph.fg import FunctionGraph from pytensor.tensor.shape import Shape, Shape_i, Unbroadcast, reshape from pytensor.tensor.type import iscalar, vector from tests.link.jax.test_basic import compare_jax_and_py @@ -13,29 +12,27 @@ def test_jax_shape_ops(): x_np = np.zeros((20, 3)) x = Shape()(pt.as_tensor_variable(x_np)) - x_fg = FunctionGraph([], [x]) - compare_jax_and_py(x_fg, [], must_be_device_array=False) + compare_jax_and_py([], [x], [], must_be_device_array=False) x = Shape_i(1)(pt.as_tensor_variable(x_np)) - x_fg = FunctionGraph([], [x]) - compare_jax_and_py(x_fg, [], must_be_device_array=False) + compare_jax_and_py([], [x], [], must_be_device_array=False) def test_jax_specify_shape(): in_pt = pt.matrix("in") x = pt.specify_shape(in_pt, (4, None)) - x_fg = FunctionGraph([in_pt], [x]) - compare_jax_and_py(x_fg, [np.ones((4, 5)).astype(config.floatX)]) + compare_jax_and_py([in_pt], [x], [np.ones((4, 5)).astype(config.floatX)]) # When used to assert two arrays have similar shapes in_pt = pt.matrix("in") shape_pt = pt.matrix("shape") x = pt.specify_shape(in_pt, shape_pt.shape) - x_fg = FunctionGraph([in_pt, shape_pt], [x]) + compare_jax_and_py( - x_fg, + [in_pt, shape_pt], + [x], [np.ones((4, 5)).astype(config.floatX), np.ones((4, 5)).astype(config.floatX)], ) @@ -43,20 +40,17 @@ def test_jax_specify_shape(): def test_jax_Reshape_constant(): a = vector("a") x = reshape(a, (2, 2)) - x_fg = FunctionGraph([a], [x]) - compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) + compare_jax_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) def test_jax_Reshape_concrete_shape(): """JAX should compile when a concrete value is passed for the `shape` parameter.""" a = vector("a") x = reshape(a, a.shape) - x_fg = FunctionGraph([a], [x]) - compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) + compare_jax_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) x = reshape(a, (a.shape[0] // 2, a.shape[0] // 2)) - x_fg = FunctionGraph([a], [x]) - compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) + compare_jax_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) @pytest.mark.xfail( @@ -66,23 +60,20 @@ def test_jax_Reshape_shape_graph_input(): a = vector("a") shape_pt = iscalar("b") x = reshape(a, (shape_pt, shape_pt)) - x_fg = FunctionGraph([a, shape_pt], [x]) - compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2]) + compare_jax_and_py( + [a, shape_pt], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2] + ) def test_jax_compile_ops(): x = DeepCopyOp()(pt.as_tensor_variable(1.1)) - x_fg = FunctionGraph([], [x]) - - compare_jax_and_py(x_fg, []) + compare_jax_and_py([], [x], []) x_np = np.zeros((20, 1, 1)) x = Unbroadcast(0, 2)(pt.as_tensor_variable(x_np)) - x_fg = FunctionGraph([], [x]) - compare_jax_and_py(x_fg, []) + compare_jax_and_py([], [x], []) x = ViewOp()(pt.as_tensor_variable(x_np)) - x_fg = FunctionGraph([], [x]) - compare_jax_and_py(x_fg, []) + compare_jax_and_py([], [x], []) diff --git a/tests/link/jax/test_slinalg.py b/tests/link/jax/test_slinalg.py index 3320eb9e73..2656b0fd04 100644 --- a/tests/link/jax/test_slinalg.py +++ b/tests/link/jax/test_slinalg.py @@ -6,7 +6,6 @@ import pytensor.tensor as pt from pytensor.configdefaults import config -from pytensor.graph.fg import FunctionGraph from pytensor.tensor import nlinalg as pt_nlinalg from pytensor.tensor import slinalg as pt_slinalg from pytensor.tensor import subtensor as pt_subtensor @@ -30,13 +29,11 @@ def test_jax_basic(): out = pt_subtensor.inc_subtensor(out[0, 1], 2.0) out = out[:5, :3] - out_fg = FunctionGraph([x, y], [out]) - test_input_vals = [ np.tile(np.arange(10), (10, 1)).astype(config.floatX), np.tile(np.arange(10, 20), (10, 1)).astype(config.floatX), ] - _, [jax_res] = compare_jax_and_py(out_fg, test_input_vals) + _, [jax_res] = compare_jax_and_py([x, y], [out], test_input_vals) # Confirm that the `Subtensor` slice operations are correct assert jax_res.shape == (5, 3) @@ -46,19 +43,17 @@ def test_jax_basic(): assert jax_res[0, 1] == -8.0 out = clip(x, y, 5) - out_fg = FunctionGraph([x, y], [out]) - compare_jax_and_py(out_fg, test_input_vals) + compare_jax_and_py([x, y], [out], test_input_vals) out = pt.diagonal(x, 0) - out_fg = FunctionGraph([x], [out]) compare_jax_and_py( - out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)] + [x], [out], [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)] ) out = pt_slinalg.cholesky(x) - out_fg = FunctionGraph([x], [out]) compare_jax_and_py( - out_fg, + [x], + [out], [ (np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( config.floatX @@ -68,9 +63,9 @@ def test_jax_basic(): # not sure why this isn't working yet with lower=False out = pt_slinalg.Cholesky(lower=False)(x) - out_fg = FunctionGraph([x], [out]) compare_jax_and_py( - out_fg, + [x], + [out], [ (np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( config.floatX @@ -79,9 +74,9 @@ def test_jax_basic(): ) out = pt_slinalg.solve(x, b) - out_fg = FunctionGraph([x, b], [out]) compare_jax_and_py( - out_fg, + [x, b], + [out], [ np.eye(10).astype(config.floatX), np.arange(10).astype(config.floatX), @@ -89,19 +84,17 @@ def test_jax_basic(): ) out = pt.diag(b) - out_fg = FunctionGraph([b], [out]) - compare_jax_and_py(out_fg, [np.arange(10).astype(config.floatX)]) + compare_jax_and_py([b], [out], [np.arange(10).astype(config.floatX)]) out = pt_nlinalg.det(x) - out_fg = FunctionGraph([x], [out]) compare_jax_and_py( - out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)] + [x], [out], [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)] ) out = pt_nlinalg.matrix_inverse(x) - out_fg = FunctionGraph([x], [out]) compare_jax_and_py( - out_fg, + [x], + [out], [ (np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( config.floatX @@ -124,9 +117,9 @@ def test_jax_SolveTriangular(trans, lower, check_finite): lower=lower, check_finite=check_finite, ) - out_fg = FunctionGraph([x, b], [out]) compare_jax_and_py( - out_fg, + [x, b], + [out], [ np.eye(10).astype(config.floatX), np.arange(10).astype(config.floatX), @@ -141,10 +134,10 @@ def test_jax_block_diag(): D = matrix("D") out = pt_slinalg.block_diag(A, B, C, D) - out_fg = FunctionGraph([A, B, C, D], [out]) compare_jax_and_py( - out_fg, + [A, B, C, D], + [out], [ np.random.normal(size=(5, 5)).astype(config.floatX), np.random.normal(size=(3, 3)).astype(config.floatX), @@ -158,9 +151,10 @@ def test_jax_block_diag_blockwise(): A = pt.tensor3("A") B = pt.tensor3("B") out = pt_slinalg.block_diag(A, B) - out_fg = FunctionGraph([A, B], [out]) + compare_jax_and_py( - out_fg, + [A, B], + [out], [ np.random.normal(size=(5, 5, 5)).astype(config.floatX), np.random.normal(size=(5, 3, 3)).astype(config.floatX), @@ -174,11 +168,11 @@ def test_jax_eigvalsh(lower): B = matrix("B") out = pt_slinalg.eigvalsh(A, B, lower=lower) - out_fg = FunctionGraph([A, B], [out]) with pytest.raises(NotImplementedError): compare_jax_and_py( - out_fg, + [A, B], + [out], [ np.array( [[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]] @@ -189,7 +183,8 @@ def test_jax_eigvalsh(lower): ], ) compare_jax_and_py( - out_fg, + [A, B], + [out], [ np.array([[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]]).astype( config.floatX @@ -207,11 +202,11 @@ def test_jax_solve_discrete_lyapunov( A = pt.tensor(name="A", shape=shape) B = pt.tensor(name="B", shape=shape) out = pt_slinalg.solve_discrete_lyapunov(A, B, method=method) - out_fg = FunctionGraph([A, B], [out]) atol = rtol = 1e-8 if config.floatX == "float64" else 1e-3 compare_jax_and_py( - out_fg, + [A, B], + [out], [ np.random.normal(size=shape).astype(config.floatX), np.random.normal(size=shape).astype(config.floatX), diff --git a/tests/link/jax/test_sort.py b/tests/link/jax/test_sort.py index c0eb4ff06e..5f6362be14 100644 --- a/tests/link/jax/test_sort.py +++ b/tests/link/jax/test_sort.py @@ -1,7 +1,6 @@ import numpy as np import pytest -from pytensor.graph import FunctionGraph from pytensor.tensor import matrix from pytensor.tensor.sort import argsort, sort from tests.link.jax.test_basic import compare_jax_and_py @@ -12,6 +11,5 @@ def test_sort(func, axis): x = matrix("x", shape=(2, 2), dtype="float64") out = func(x, axis=axis) - fgraph = FunctionGraph([x], [out]) arr = np.array([[1.0, 4.0], [5.0, 2.0]]) - compare_jax_and_py(fgraph, [arr]) + compare_jax_and_py([x], [out], [arr]) diff --git a/tests/link/jax/test_sparse.py b/tests/link/jax/test_sparse.py index c53aa301af..f5e4da84c5 100644 --- a/tests/link/jax/test_sparse.py +++ b/tests/link/jax/test_sparse.py @@ -5,7 +5,6 @@ import pytensor.sparse as ps import pytensor.tensor as pt from pytensor import function -from pytensor.graph import FunctionGraph from tests.link.jax.test_basic import compare_jax_and_py @@ -50,8 +49,7 @@ def test_sparse_dot_constant_sparse(x_type, y_type, op): test_values.append(y_test) dot_pt = op(x_pt, y_pt) - fgraph = FunctionGraph(inputs, [dot_pt]) - compare_jax_and_py(fgraph, test_values, jax_mode="JAX") + compare_jax_and_py(inputs, [dot_pt], test_values, jax_mode="JAX") def test_sparse_dot_non_const_raises(): diff --git a/tests/link/jax/test_subtensor.py b/tests/link/jax/test_subtensor.py index 489fbb010e..9e326102cd 100644 --- a/tests/link/jax/test_subtensor.py +++ b/tests/link/jax/test_subtensor.py @@ -21,55 +21,55 @@ def test_jax_Subtensor_constant(): # Basic indices out_pt = x_pt[1, 2, 0] assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_jax_and_py(out_fg, [x_np]) + + compare_jax_and_py([x_pt], [out_pt], [x_np]) out_pt = x_pt[1:, 1, :] assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_jax_and_py(out_fg, [x_np]) + + compare_jax_and_py([x_pt], [out_pt], [x_np]) out_pt = x_pt[:2, 1, :] assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_jax_and_py(out_fg, [x_np]) + + compare_jax_and_py([x_pt], [out_pt], [x_np]) out_pt = x_pt[1:2, 1, :] assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_jax_and_py(out_fg, [x_np]) + + compare_jax_and_py([x_pt], [out_pt], [x_np]) # Advanced indexing out_pt = pt_subtensor.advanced_subtensor1(x_pt, [1, 2]) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor1) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_jax_and_py(out_fg, [x_np]) + + compare_jax_and_py([x_pt], [out_pt], [x_np]) out_pt = x_pt[[1, 2], [2, 3]] assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_jax_and_py(out_fg, [x_np]) + + compare_jax_and_py([x_pt], [out_pt], [x_np]) # Advanced and basic indexing out_pt = x_pt[[1, 2], :] assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_jax_and_py(out_fg, [x_np]) + + compare_jax_and_py([x_pt], [out_pt], [x_np]) out_pt = x_pt[[1, 2], :, [3, 4]] assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_jax_and_py(out_fg, [x_np]) + + compare_jax_and_py([x_pt], [out_pt], [x_np]) # Flipping out_pt = x_pt[::-1] - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_jax_and_py(out_fg, [x_np]) + + compare_jax_and_py([x_pt], [out_pt], [x_np]) # Boolean indexing should work if indexes are constant out_pt = x_pt[np.random.binomial(1, 0.5, size=(3, 4, 5)).astype(bool)] - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_jax_and_py(out_fg, [x_np]) + + compare_jax_and_py([x_pt], [out_pt], [x_np]) @pytest.mark.xfail(reason="`a` should be specified as static when JIT-compiling") @@ -78,8 +78,8 @@ def test_jax_Subtensor_dynamic(): x = pt.arange(3) out_pt = x[:a] assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) - out_fg = FunctionGraph([a], [out_pt]) - compare_jax_and_py(out_fg, [1]) + + compare_jax_and_py([a], [out_pt], [1]) def test_jax_Subtensor_dynamic_boolean_mask(): @@ -90,11 +90,9 @@ def test_jax_Subtensor_dynamic_boolean_mask(): out_pt = x_pt[x_pt < 0] assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - x_pt_test = np.arange(-5, 5) with pytest.raises(NonConcreteBooleanIndexError): - compare_jax_and_py(out_fg, [x_pt_test]) + compare_jax_and_py([x_pt], [out_pt], [x_pt_test]) def test_jax_Subtensor_boolean_mask_reexpressible(): @@ -110,8 +108,10 @@ def test_jax_Subtensor_boolean_mask_reexpressible(): """ x_pt = pt.matrix("x") out_pt = x_pt[x_pt < 0].sum() - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_jax_and_py(out_fg, [np.arange(25).reshape(5, 5).astype(config.floatX)]) + + compare_jax_and_py( + [x_pt], [out_pt], [np.arange(25).reshape(5, 5).astype(config.floatX)] + ) def test_boolean_indexing_sum_not_applicable(): @@ -136,19 +136,19 @@ def test_jax_IncSubtensor(): st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=config.floatX)) out_pt = pt_subtensor.set_subtensor(x_pt[1, 2, 3], st_pt) assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) - out_fg = FunctionGraph([], [out_pt]) - compare_jax_and_py(out_fg, []) + + compare_jax_and_py([], [out_pt], []) st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) out_pt = pt_subtensor.set_subtensor(x_pt[:2, 0, 0], st_pt) assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) - out_fg = FunctionGraph([], [out_pt]) - compare_jax_and_py(out_fg, []) + + compare_jax_and_py([], [out_pt], []) out_pt = pt_subtensor.set_subtensor(x_pt[0, 1:3, 0], st_pt) assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) - out_fg = FunctionGraph([], [out_pt]) - compare_jax_and_py(out_fg, []) + + compare_jax_and_py([], [out_pt], []) # "Set" advanced indices st_pt = pt.as_tensor_variable( @@ -156,39 +156,39 @@ def test_jax_IncSubtensor(): ) out_pt = pt_subtensor.set_subtensor(x_pt[np.r_[0, 2]], st_pt) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([], [out_pt]) - compare_jax_and_py(out_fg, []) + + compare_jax_and_py([], [out_pt], []) st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) out_pt = pt_subtensor.set_subtensor(x_pt[[0, 2], 0, 0], st_pt) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([], [out_pt]) - compare_jax_and_py(out_fg, []) + + compare_jax_and_py([], [out_pt], []) # "Set" boolean indices mask_pt = pt.constant(x_np > 0) out_pt = pt_subtensor.set_subtensor(x_pt[mask_pt], 0.0) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([], [out_pt]) - compare_jax_and_py(out_fg, []) + + compare_jax_and_py([], [out_pt], []) # "Increment" basic indices st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=config.floatX)) out_pt = pt_subtensor.inc_subtensor(x_pt[1, 2, 3], st_pt) assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) - out_fg = FunctionGraph([], [out_pt]) - compare_jax_and_py(out_fg, []) + + compare_jax_and_py([], [out_pt], []) st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) out_pt = pt_subtensor.inc_subtensor(x_pt[:2, 0, 0], st_pt) assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) - out_fg = FunctionGraph([], [out_pt]) - compare_jax_and_py(out_fg, []) + + compare_jax_and_py([], [out_pt], []) out_pt = pt_subtensor.set_subtensor(x_pt[0, 1:3, 0], st_pt) assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) - out_fg = FunctionGraph([], [out_pt]) - compare_jax_and_py(out_fg, []) + + compare_jax_and_py([], [out_pt], []) # "Increment" advanced indices st_pt = pt.as_tensor_variable( @@ -196,33 +196,33 @@ def test_jax_IncSubtensor(): ) out_pt = pt_subtensor.inc_subtensor(x_pt[np.r_[0, 2]], st_pt) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([], [out_pt]) - compare_jax_and_py(out_fg, []) + + compare_jax_and_py([], [out_pt], []) st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) out_pt = pt_subtensor.inc_subtensor(x_pt[[0, 2], 0, 0], st_pt) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([], [out_pt]) - compare_jax_and_py(out_fg, []) + + compare_jax_and_py([], [out_pt], []) # "Increment" boolean indices mask_pt = pt.constant(x_np > 0) out_pt = pt_subtensor.set_subtensor(x_pt[mask_pt], 1.0) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([], [out_pt]) - compare_jax_and_py(out_fg, []) + + compare_jax_and_py([], [out_pt], []) st_pt = pt.as_tensor_variable(x_np[[0, 2], 0, :3]) out_pt = pt_subtensor.set_subtensor(x_pt[[0, 2], 0, :3], st_pt) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([], [out_pt]) - compare_jax_and_py(out_fg, []) + + compare_jax_and_py([], [out_pt], []) st_pt = pt.as_tensor_variable(x_np[[0, 2], 0, :3]) out_pt = pt_subtensor.inc_subtensor(x_pt[[0, 2], 0, :3], st_pt) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([], [out_pt]) - compare_jax_and_py(out_fg, []) + + compare_jax_and_py([], [out_pt], []) def test_jax_IncSubtensor_boolean_indexing_reexpressible(): @@ -243,14 +243,14 @@ def test_jax_IncSubtensor_boolean_indexing_reexpressible(): mask_pt = pt.as_tensor(x_pt) > 0 out_pt = pt_subtensor.set_subtensor(x_pt[mask_pt], 0.0) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_jax_and_py(out_fg, [x_np]) + + compare_jax_and_py([x_pt], [out_pt], [x_np]) mask_pt = pt.as_tensor(x_pt) > 0 out_pt = pt_subtensor.inc_subtensor(x_pt[mask_pt], 1.0) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_jax_and_py(out_fg, [x_np]) + + compare_jax_and_py([x_pt], [out_pt], [x_np]) def test_boolean_indexing_set_or_inc_not_applicable(): diff --git a/tests/link/jax/test_tensor_basic.py b/tests/link/jax/test_tensor_basic.py index 75ca673d78..46f7fd7375 100644 --- a/tests/link/jax/test_tensor_basic.py +++ b/tests/link/jax/test_tensor_basic.py @@ -10,8 +10,6 @@ import pytensor import pytensor.tensor.basic as ptb from pytensor.configdefaults import config -from pytensor.graph.fg import FunctionGraph -from pytensor.graph.op import get_test_value from pytensor.tensor.type import iscalar, matrix, scalar, vector from tests.link.jax.test_basic import compare_jax_and_py from tests.tensor.test_basic import check_alloc_runtime_broadcast @@ -19,38 +17,31 @@ def test_jax_Alloc(): x = ptb.alloc(0.0, 2, 3) - x_fg = FunctionGraph([], [x]) - _, [jax_res] = compare_jax_and_py(x_fg, []) + _, [jax_res] = compare_jax_and_py([], [x], []) assert jax_res.shape == (2, 3) x = ptb.alloc(1.1, 2, 3) - x_fg = FunctionGraph([], [x]) - compare_jax_and_py(x_fg, []) + compare_jax_and_py([], [x], []) x = ptb.AllocEmpty("float32")(2, 3) - x_fg = FunctionGraph([], [x]) def compare_shape_dtype(x, y): - (x,) = x - (y,) = y - return x.shape == y.shape and x.dtype == y.dtype + np.testing.assert_array_equal(x, y, strict=True) - compare_jax_and_py(x_fg, [], assert_fn=compare_shape_dtype) + compare_jax_and_py([], [x], [], assert_fn=compare_shape_dtype) a = scalar("a") x = ptb.alloc(a, 20) - x_fg = FunctionGraph([a], [x]) - compare_jax_and_py(x_fg, [10.0]) + compare_jax_and_py([a], [x], [10.0]) a = vector("a") x = ptb.alloc(a, 20, 10) - x_fg = FunctionGraph([a], [x]) - compare_jax_and_py(x_fg, [np.ones(10, dtype=config.floatX)]) + compare_jax_and_py([a], [x], [np.ones(10, dtype=config.floatX)]) def test_alloc_runtime_broadcast(): @@ -59,34 +50,31 @@ def test_alloc_runtime_broadcast(): def test_jax_MakeVector(): x = ptb.make_vector(1, 2, 3) - x_fg = FunctionGraph([], [x]) - compare_jax_and_py(x_fg, []) + compare_jax_and_py([], [x], []) def test_arange(): out = ptb.arange(1, 10, 2) - fgraph = FunctionGraph([], [out]) - compare_jax_and_py(fgraph, []) + + compare_jax_and_py([], [out], []) def test_arange_of_shape(): x = vector("x") out = ptb.arange(1, x.shape[-1], 2) - fgraph = FunctionGraph([x], [out]) - compare_jax_and_py(fgraph, [np.zeros((5,))], jax_mode="JAX") + compare_jax_and_py([x], [out], [np.zeros((5,))], jax_mode="JAX") def test_arange_nonconcrete(): """JAX cannot JIT-compile `jax.numpy.arange` when arguments are not concrete values.""" a = scalar("a") - a.tag.test_value = 10 + a_test_value = 10 out = ptb.arange(a) with pytest.raises(NotImplementedError): - fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_jax_and_py([a], [out], [a_test_value]) def test_jax_Join(): @@ -94,16 +82,17 @@ def test_jax_Join(): b = matrix("b") x = ptb.join(0, a, b) - x_fg = FunctionGraph([a, b], [x]) compare_jax_and_py( - x_fg, + [a, b], + [x], [ np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), np.c_[[4.0, 5.0, 6.0]].astype(config.floatX), ], ) compare_jax_and_py( - x_fg, + [a, b], + [x], [ np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), np.c_[[4.0, 5.0]].astype(config.floatX), @@ -111,16 +100,17 @@ def test_jax_Join(): ) x = ptb.join(1, a, b) - x_fg = FunctionGraph([a, b], [x]) compare_jax_and_py( - x_fg, + [a, b], + [x], [ np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), np.c_[[4.0, 5.0, 6.0]].astype(config.floatX), ], ) compare_jax_and_py( - x_fg, + [a, b], + [x], [ np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX), np.c_[[5.0, 6.0]].astype(config.floatX), @@ -132,9 +122,9 @@ class TestJaxSplit: def test_basic(self): a = matrix("a") a_splits = ptb.split(a, splits_size=[1, 2, 3], n_splits=3, axis=0) - fg = FunctionGraph([a], a_splits) compare_jax_and_py( - fg, + [a], + a_splits, [ np.zeros((6, 4)).astype(config.floatX), ], @@ -142,9 +132,9 @@ def test_basic(self): a = matrix("a", shape=(6, None)) a_splits = ptb.split(a, splits_size=[2, a.shape[0] - 2], n_splits=2, axis=0) - fg = FunctionGraph([a], a_splits) compare_jax_and_py( - fg, + [a], + a_splits, [ np.zeros((6, 4)).astype(config.floatX), ], @@ -207,15 +197,14 @@ def test_jax_split_not_supported(self): def test_jax_eye(): """Tests jaxification of the Eye operator""" out = ptb.eye(3) - out_fg = FunctionGraph([], [out]) - compare_jax_and_py(out_fg, []) + compare_jax_and_py([], [out], []) def test_tri(): out = ptb.tri(10, 10, 0) - fgraph = FunctionGraph([], [out]) - compare_jax_and_py(fgraph, []) + + compare_jax_and_py([], [out], []) @pytest.mark.skipif( @@ -230,14 +219,13 @@ def test_tri_nonconcrete(): scalar("n", dtype="int64"), scalar("k", dtype="int64"), ) - m.tag.test_value = 10 - n.tag.test_value = 10 - k.tag.test_value = 0 + m_test_value = 10 + n_test_value = 10 + k_test_value = 0 out = ptb.tri(m, n, k) # The actual error the user will see should be jax.errors.ConcretizationTypeError, but # the error handler raises an Attribute error first, so that's what this test needs to pass with pytest.raises(AttributeError): - fgraph = FunctionGraph([m, n, k], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_jax_and_py([m, n, k], [out], [m_test_value, n_test_value, k_test_value]) diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index f0f73ca74d..4857d2f932 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -1,6 +1,6 @@ import contextlib import inspect -from collections.abc import Callable, Sequence +from collections.abc import Callable, Iterable from typing import TYPE_CHECKING, Any from unittest import mock @@ -21,10 +21,8 @@ from pytensor.compile.function import function from pytensor.compile.mode import Mode from pytensor.compile.ops import ViewOp -from pytensor.compile.sharedvalue import SharedVariable -from pytensor.graph.basic import Apply, Constant -from pytensor.graph.fg import FunctionGraph -from pytensor.graph.op import Op, get_test_value +from pytensor.graph.basic import Apply, Variable +from pytensor.graph.op import Op from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.type import Type from pytensor.ifelse import ifelse @@ -39,7 +37,6 @@ if TYPE_CHECKING: from pytensor.graph.basic import Variable - from pytensor.tensor import TensorLike class MyType(Type): @@ -128,11 +125,6 @@ def perform(self, node, inputs, outputs): rng = np.random.default_rng(42849) -def set_test_value(x, v): - x.tag.test_value = v - return x - - def compare_shape_dtype(x, y): return x.shape == y.shape and x.dtype == y.dtype @@ -225,28 +217,30 @@ def py_global_numba_func(func): def compare_numba_and_py( - fgraph: FunctionGraph | tuple[Sequence["Variable"], Sequence["Variable"]], - inputs: Sequence["TensorLike"], - assert_fn: Callable | None = None, + graph_inputs: Iterable[Variable], + graph_outputs: Variable | Iterable[Variable], + test_inputs: Iterable, *, + assert_fn: Callable | None = None, numba_mode=numba_mode, py_mode=py_mode, updates=None, inplace: bool = False, eval_obj_mode: bool = True, ) -> tuple[Callable, Any]: - """Function to compare python graph output and Numba compiled output for testing equality + """Function to compare python function output and Numba compiled output for testing equality - In the tests below computational graphs are defined in PyTensor. These graphs are then passed to - this function which then compiles the graphs in both Numba and python, runs the calculation - in both and checks if the results are the same + The inputs and outputs are then passed to this function which then compiles the given function in both + numba and python, runs the calculation in both and checks if the results are the same Parameters ---------- - fgraph - `FunctionGraph` or tuple(inputs, outputs) to compare. - inputs - Numeric inputs to be passed to the compiled graphs. + graph_inputs: + Symbolic inputs to the graph + graph_outputs: + Symbolic outputs of the graph + test_inputs + Numerical inputs with which to evaluate the graph. assert_fn Assert function used to check for equality between python and Numba. If not provided uses `np.testing.assert_allclose`. @@ -267,42 +261,38 @@ def assert_fn(x, y): x, y ) - if isinstance(fgraph, FunctionGraph): - fn_inputs = fgraph.inputs - fn_outputs = fgraph.outputs - else: - fn_inputs, fn_outputs = fgraph - - fn_inputs = [i for i in fn_inputs if not isinstance(i, SharedVariable)] + if any(inp.owner is not None for inp in graph_inputs): + raise ValueError("Inputs must be root variables") pytensor_py_fn = function( - fn_inputs, fn_outputs, mode=py_mode, accept_inplace=True, updates=updates + graph_inputs, graph_outputs, mode=py_mode, accept_inplace=True, updates=updates ) - test_inputs = (inp.copy() for inp in inputs) if inplace else inputs - py_res = pytensor_py_fn(*test_inputs) + test_inputs_copy = (inp.copy() for inp in test_inputs) if inplace else test_inputs + py_res = pytensor_py_fn(*test_inputs_copy) # Get some coverage (and catch errors in python mode before unreadable numba ones) if eval_obj_mode: - test_inputs = (inp.copy() for inp in inputs) if inplace else inputs - eval_python_only(fn_inputs, fn_outputs, test_inputs, mode=numba_mode) + test_inputs_copy = ( + (inp.copy() for inp in test_inputs) if inplace else test_inputs + ) + eval_python_only(graph_inputs, graph_outputs, test_inputs_copy, mode=numba_mode) pytensor_numba_fn = function( - fn_inputs, - fn_outputs, + graph_inputs, + graph_outputs, mode=numba_mode, accept_inplace=True, updates=updates, ) + test_inputs_copy = (inp.copy() for inp in test_inputs) if inplace else test_inputs + numba_res = pytensor_numba_fn(*test_inputs_copy) - test_inputs = (inp.copy() for inp in inputs) if inplace else inputs - numba_res = pytensor_numba_fn(*test_inputs) - - if len(fn_outputs) > 1: + if isinstance(graph_outputs, tuple | list): for j, p in zip(numba_res, py_res, strict=True): assert_fn(j, p) else: - assert_fn(numba_res[0], py_res[0]) + assert_fn(numba_res, py_res) return pytensor_numba_fn, numba_res @@ -380,53 +370,53 @@ def test_create_numba_signature(v, expected, force_scalar): ) def test_Shape(x, i): g = Shape()(pt.as_tensor_variable(x)) - g_fg = FunctionGraph([], [g]) - compare_numba_and_py(g_fg, []) + compare_numba_and_py([], [g], []) g = Shape_i(i)(pt.as_tensor_variable(x)) - g_fg = FunctionGraph([], [g]) - compare_numba_and_py(g_fg, []) + compare_numba_and_py([], [g], []) @pytest.mark.parametrize( "v, shape, ndim", [ - (set_test_value(pt.vector(), np.array([4], dtype=config.floatX)), (), 0), - (set_test_value(pt.vector(), np.arange(4, dtype=config.floatX)), (2, 2), 2), + ((pt.vector(), np.array([4], dtype=config.floatX)), ((), None), 0), + ((pt.vector(), np.arange(4, dtype=config.floatX)), ((2, 2), None), 2), ( - set_test_value(pt.vector(), np.arange(4, dtype=config.floatX)), - set_test_value(pt.lvector(), np.array([2, 2], dtype="int64")), + (pt.vector(), np.arange(4, dtype=config.floatX)), + (pt.lvector(), np.array([2, 2], dtype="int64")), 2, ), ], ) def test_Reshape(v, shape, ndim): + v, v_test_value = v + shape, shape_test_value = shape + g = Reshape(ndim)(v, shape) - g_fg = FunctionGraph(outputs=[g]) + inputs = [v] if not isinstance(shape, Variable) else [v, shape] + test_values = ( + [v_test_value] + if not isinstance(shape, Variable) + else [v_test_value, shape_test_value] + ) compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + inputs, + [g], + test_values, ) def test_Reshape_scalar(): v = pt.vector() - v.tag.test_value = np.array([1.0], dtype=config.floatX) + v_test_value = np.array([1.0], dtype=config.floatX) g = Reshape(1)(v[0], (1,)) - g_fg = FunctionGraph(outputs=[g]) + compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [v], + g, + [v_test_value], ) @@ -434,53 +424,44 @@ def test_Reshape_scalar(): "v, shape, fails", [ ( - set_test_value(pt.matrix(), np.array([[1.0]], dtype=config.floatX)), + (pt.matrix(), np.array([[1.0]], dtype=config.floatX)), (1, 1), False, ), ( - set_test_value(pt.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)), + (pt.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)), (1, 1), True, ), ( - set_test_value(pt.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)), + (pt.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)), (1, None), False, ), ], ) def test_SpecifyShape(v, shape, fails): + v, v_test_value = v g = SpecifyShape()(v, *shape) - g_fg = FunctionGraph(outputs=[g]) cm = contextlib.suppress() if not fails else pytest.raises(AssertionError) + with cm: compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [v], + [g], + [v_test_value], ) -@pytest.mark.parametrize( - "v", - [ - set_test_value(pt.vector(), np.arange(4, dtype=config.floatX)), - ], -) -def test_ViewOp(v): +def test_ViewOp(): + v = pt.vector() + v_test_value = np.arange(4, dtype=config.floatX) g = ViewOp()(v) - g_fg = FunctionGraph(outputs=[g]) + compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [v], + [g], + [v_test_value], ) @@ -489,20 +470,16 @@ def test_ViewOp(v): [ ( [ - set_test_value( - pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX) - ), - set_test_value(pt.lmatrix(), rng.poisson(size=(2, 3))), + (pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), + (pt.lmatrix(), rng.poisson(size=(2, 3))), ], MySingleOut, UserWarning, ), ( [ - set_test_value( - pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX) - ), - set_test_value(pt.lmatrix(), rng.poisson(size=(2, 3))), + (pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), + (pt.lmatrix(), rng.poisson(size=(2, 3))), ], MyMultiOut, UserWarning, @@ -510,38 +487,32 @@ def test_ViewOp(v): ], ) def test_perform(inputs, op, exc): + inputs, test_values = zip(*inputs, strict=True) g = op()(*inputs) if isinstance(g, list): - g_fg = FunctionGraph(outputs=g) + outputs = g else: - g_fg = FunctionGraph(outputs=[g]) + outputs = [g] cm = contextlib.suppress() if exc is None else pytest.warns(exc) with cm: compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + inputs, + outputs, + test_values, ) def test_perform_params(): """This tests for `Op.perform` implementations that require the `params` arguments.""" - x = pt.vector() - x.tag.test_value = np.array([1.0, 2.0], dtype=config.floatX) + x = pt.vector(shape=(2,)) + x_test_value = np.array([1.0, 2.0], dtype=config.floatX) out = assert_op(x, np.array(True)) - if not isinstance(out, list | tuple): - out = [out] - - out_fg = FunctionGraph([x], out) - compare_numba_and_py(out_fg, [get_test_value(i) for i in out_fg.inputs]) + compare_numba_and_py([x], out, [x_test_value]) def test_perform_type_convert(): @@ -552,59 +523,50 @@ def test_perform_type_convert(): """ x = pt.vector() - x.tag.test_value = np.array([1.0, 2.0], dtype=config.floatX) + x_test_value = np.array([1.0, 2.0], dtype=config.floatX) out = assert_op(x.sum(), np.array(True)) - if not isinstance(out, list | tuple): - out = [out] - - out_fg = FunctionGraph([x], out) - compare_numba_and_py(out_fg, [get_test_value(i) for i in out_fg.inputs]) + compare_numba_and_py([x], out, [x_test_value]) @pytest.mark.parametrize( "x, y, exc", [ ( - set_test_value(pt.matrix(), rng.random(size=(3, 2)).astype(config.floatX)), - set_test_value(pt.vector(), rng.random(size=(2,)).astype(config.floatX)), + (pt.matrix(), rng.random(size=(3, 2)).astype(config.floatX)), + (pt.vector(), rng.random(size=(2,)).astype(config.floatX)), None, ), ( - set_test_value( - pt.matrix(dtype="float64"), rng.random(size=(3, 2)).astype("float64") - ), - set_test_value( - pt.vector(dtype="float32"), rng.random(size=(2,)).astype("float32") - ), + (pt.matrix(dtype="float64"), rng.random(size=(3, 2)).astype("float64")), + (pt.vector(dtype="float32"), rng.random(size=(2,)).astype("float32")), None, ), ( - set_test_value(pt.lmatrix(), rng.poisson(size=(3, 2))), - set_test_value(pt.fvector(), rng.random(size=(2,)).astype("float32")), + (pt.lmatrix(), rng.poisson(size=(3, 2))), + (pt.fvector(), rng.random(size=(2,)).astype("float32")), None, ), ( - set_test_value(pt.lvector(), rng.random(size=(2,)).astype(np.int64)), - set_test_value(pt.lvector(), rng.random(size=(2,)).astype(np.int64)), + (pt.lvector(), rng.random(size=(2,)).astype(np.int64)), + (pt.lvector(), rng.random(size=(2,)).astype(np.int64)), None, ), ], ) def test_Dot(x, y, exc): + x, x_test_value = x + y, y_test_value = y + g = ptm.Dot()(x, y) - g_fg = FunctionGraph(outputs=[g]) cm = contextlib.suppress() if exc is None else pytest.warns(exc) with cm: compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [x, y], + [g], + [x_test_value, y_test_value], ) @@ -612,44 +574,41 @@ def test_Dot(x, y, exc): "x, exc", [ ( - set_test_value(ps.float64(), np.array(0.0, dtype="float64")), + (ps.float64(), np.array(0.0, dtype="float64")), None, ), ( - set_test_value(ps.float64(), np.array(-32.0, dtype="float64")), + (ps.float64(), np.array(-32.0, dtype="float64")), None, ), ( - set_test_value(ps.float64(), np.array(-40.0, dtype="float64")), + (ps.float64(), np.array(-40.0, dtype="float64")), None, ), ( - set_test_value(ps.float64(), np.array(32.0, dtype="float64")), + (ps.float64(), np.array(32.0, dtype="float64")), None, ), ( - set_test_value(ps.float64(), np.array(40.0, dtype="float64")), + (ps.float64(), np.array(40.0, dtype="float64")), None, ), ( - set_test_value(ps.int64(), np.array(32, dtype="int64")), + (ps.int64(), np.array(32, dtype="int64")), None, ), ], ) def test_Softplus(x, exc): + x, x_test_value = x g = psm.Softplus(ps.upgrade_to_float)(x) - g_fg = FunctionGraph(outputs=[g]) cm = contextlib.suppress() if exc is None else pytest.warns(exc) with cm: compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [x], + [g], + [x_test_value], ) @@ -657,22 +616,22 @@ def test_Softplus(x, exc): "x, y, exc", [ ( - set_test_value( + ( pt.dtensor3(), rng.random(size=(2, 3, 3)).astype("float64"), ), - set_test_value( + ( pt.dtensor3(), rng.random(size=(2, 3, 3)).astype("float64"), ), None, ), ( - set_test_value( + ( pt.dtensor3(), rng.random(size=(2, 3, 3)).astype("float64"), ), - set_test_value( + ( pt.ltensor3(), rng.poisson(size=(2, 3, 3)).astype("int64"), ), @@ -681,22 +640,17 @@ def test_Softplus(x, exc): ], ) def test_BatchedDot(x, y, exc): - g = blas.BatchedDot()(x, y) + x, x_test_value = x + y, y_test_value = y - if isinstance(g, list): - g_fg = FunctionGraph(outputs=g) - else: - g_fg = FunctionGraph(outputs=[g]) + g = blas.BatchedDot()(x, y) cm = contextlib.suppress() if exc is None else pytest.warns(exc) with cm: compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [x, y], + g, + [x_test_value, y_test_value], ) @@ -767,15 +721,15 @@ def test_shared_updates(): [ ([], lambda: np.array(True), np.r_[1, 2, 3], np.r_[-1, -2, -3]), ( - [set_test_value(pt.dscalar(), np.array(0.2, dtype=np.float64))], + [(pt.dscalar(), np.array(0.2, dtype=np.float64))], lambda x: x < 0.5, np.r_[1, 2, 3], np.r_[-1, -2, -3], ), ( [ - set_test_value(pt.dscalar(), np.array(0.3, dtype=np.float64)), - set_test_value(pt.dscalar(), np.array(0.5, dtype=np.float64)), + (pt.dscalar(), np.array(0.3, dtype=np.float64)), + (pt.dscalar(), np.array(0.5, dtype=np.float64)), ], lambda x, y: x > y, x, @@ -783,8 +737,8 @@ def test_shared_updates(): ), ( [ - set_test_value(pt.dvector(), np.array([0.3, 0.1], dtype=np.float64)), - set_test_value(pt.dvector(), np.array([0.5, 0.9], dtype=np.float64)), + (pt.dvector(), np.array([0.3, 0.1], dtype=np.float64)), + (pt.dvector(), np.array([0.5, 0.9], dtype=np.float64)), ], lambda x, y: pt.all(x > y), x, @@ -792,8 +746,8 @@ def test_shared_updates(): ), ( [ - set_test_value(pt.dvector(), np.array([0.3, 0.1], dtype=np.float64)), - set_test_value(pt.dvector(), np.array([0.5, 0.9], dtype=np.float64)), + (pt.dvector(), np.array([0.3, 0.1], dtype=np.float64)), + (pt.dvector(), np.array([0.5, 0.9], dtype=np.float64)), ], lambda x, y: pt.all(x > y), [x, 2 * x], @@ -801,8 +755,8 @@ def test_shared_updates(): ), ( [ - set_test_value(pt.dvector(), np.array([0.5, 0.9], dtype=np.float64)), - set_test_value(pt.dvector(), np.array([0.3, 0.1], dtype=np.float64)), + (pt.dvector(), np.array([0.5, 0.9], dtype=np.float64)), + (pt.dvector(), np.array([0.3, 0.1], dtype=np.float64)), ], lambda x, y: pt.all(x > y), [x, 2 * x], @@ -811,14 +765,9 @@ def test_shared_updates(): ], ) def test_IfElse(inputs, cond_fn, true_vals, false_vals): + inputs, test_values = zip(*inputs, strict=True) if inputs else ([], []) out = ifelse(cond_fn(*inputs), true_vals, false_vals) - - if not isinstance(out, list): - out = [out] - - out_fg = FunctionGraph(inputs, out) - - compare_numba_and_py(out_fg, [get_test_value(i) for i in out_fg.inputs]) + compare_numba_and_py(inputs, out, test_values) @pytest.mark.xfail(reason="https://github.com/numba/numba/issues/7409") @@ -883,7 +832,7 @@ def test_OpFromGraph(): yv = np.ones((2, 2), dtype=config.floatX) * 3 zv = np.ones((2, 2), dtype=config.floatX) * 5 - compare_numba_and_py(((x, y, z), (out,)), [xv, yv, zv]) + compare_numba_and_py([x, y, z], [out], [xv, yv, zv]) @pytest.mark.filterwarnings("error") diff --git a/tests/link/numba/test_blockwise.py b/tests/link/numba/test_blockwise.py index ced4185e14..43056f9f56 100644 --- a/tests/link/numba/test_blockwise.py +++ b/tests/link/numba/test_blockwise.py @@ -27,7 +27,8 @@ def test_blockwise(core_op, shape_opt): ) x_test = np.eye(3) * np.arange(1, 6)[:, None, None] compare_numba_and_py( - ([x], outs), + [x], + outs, [x_test], numba_mode=mode, eval_obj_mode=False, diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index b2ccc1ef1e..eaa0fa951d 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -11,10 +11,7 @@ from pytensor import config, function from pytensor.compile import get_mode from pytensor.compile.ops import deep_copy_op -from pytensor.compile.sharedvalue import SharedVariable from pytensor.gradient import grad -from pytensor.graph.basic import Constant -from pytensor.graph.fg import FunctionGraph from pytensor.scalar import float64 from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum @@ -22,7 +19,6 @@ from tests.link.numba.test_basic import ( compare_numba_and_py, scalar_my_multi_out, - set_test_value, ) from tests.tensor.test_elemwise import ( careduce_benchmark_tester, @@ -116,13 +112,13 @@ def test_Elemwise(inputs, input_vals, output_fn, exc): outputs = output_fn(*inputs) - out_fg = FunctionGraph( - outputs=[outputs] if not isinstance(outputs, list) else outputs - ) - cm = contextlib.suppress() if exc is None else pytest.raises(exc) with cm: - compare_numba_and_py(out_fg, input_vals) + compare_numba_and_py( + inputs, + outputs, + input_vals, + ) @pytest.mark.xfail(reason="Logic had to be reversed due to surprising segfaults") @@ -135,7 +131,7 @@ def test_elemwise_runtime_broadcast(): [ # `{'drop': [], 'shuffle': [], 'augment': [0, 1]}` ( - set_test_value( + ( pt.lscalar(name="a"), np.array(1, dtype=np.int64), ), @@ -144,21 +140,17 @@ def test_elemwise_runtime_broadcast(): # I.e. `a_pt.T` # `{'drop': [], 'shuffle': [1, 0], 'augment': []}` ( - set_test_value( - pt.matrix("a"), np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) - ), + (pt.matrix("a"), np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)), (1, 0), ), # `{'drop': [], 'shuffle': [0, 1], 'augment': [2]}` ( - set_test_value( - pt.matrix("a"), np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) - ), + (pt.matrix("a"), np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)), (1, 0, "x"), ), # `{'drop': [1], 'shuffle': [2, 0], 'augment': [0, 2, 4]}` ( - set_test_value( + ( pt.tensor(dtype=config.floatX, shape=(None, 1, None), name="a"), np.array([[[1.0, 2.0]], [[3.0, 4.0]]], dtype=config.floatX), ), @@ -167,21 +159,21 @@ def test_elemwise_runtime_broadcast(): # I.e. `a_pt.dimshuffle((0,))` # `{'drop': [1], 'shuffle': [0], 'augment': []}` ( - set_test_value( + ( pt.tensor(dtype=config.floatX, shape=(None, 1), name="a"), np.array([[1.0], [2.0], [3.0], [4.0]], dtype=config.floatX), ), (0,), ), ( - set_test_value( + ( pt.tensor(dtype=config.floatX, shape=(None, 1), name="a"), np.array([[1.0], [2.0], [3.0], [4.0]], dtype=config.floatX), ), (0,), ), ( - set_test_value( + ( pt.tensor(dtype=config.floatX, shape=(1, 1, 1), name="a"), np.array([[[1.0]]], dtype=config.floatX), ), @@ -190,15 +182,12 @@ def test_elemwise_runtime_broadcast(): ], ) def test_Dimshuffle(v, new_order): + v, v_test_value = v g = v.dimshuffle(new_order) - g_fg = FunctionGraph(outputs=[g]) compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [v], + [g], + [v_test_value], ) @@ -229,79 +218,68 @@ def test_Dimshuffle_non_contiguous(): axis=axis, dtype=dtype, acc_dtype=acc_dtype )(x), 0, - set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)), + (pt.vector(), np.arange(3, dtype=config.floatX)), ), ( lambda x, axis=None, dtype=None, acc_dtype=None: All(axis)(x), 0, - set_test_value(pt.vector(dtype="bool"), np.array([False, True, False])), + (pt.vector(dtype="bool"), np.array([False, True, False])), ), ( lambda x, axis=None, dtype=None, acc_dtype=None: Any(axis)(x), 0, - set_test_value(pt.vector(dtype="bool"), np.array([False, True, False])), + (pt.vector(dtype="bool"), np.array([False, True, False])), ), ( lambda x, axis=None, dtype=None, acc_dtype=None: Sum( axis=axis, dtype=dtype, acc_dtype=acc_dtype )(x), 0, - set_test_value( - pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), + (pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))), ), ( lambda x, axis=None, dtype=None, acc_dtype=None: Sum( axis=axis, dtype=dtype, acc_dtype=acc_dtype )(x), (0, 1), - set_test_value( - pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), + (pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))), ), ( lambda x, axis=None, dtype=None, acc_dtype=None: Sum( axis=axis, dtype=dtype, acc_dtype=acc_dtype )(x), (1, 0), - set_test_value( - pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), + (pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))), ), ( lambda x, axis=None, dtype=None, acc_dtype=None: Sum( axis=axis, dtype=dtype, acc_dtype=acc_dtype )(x), None, - set_test_value( - pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), + (pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))), ), ( lambda x, axis=None, dtype=None, acc_dtype=None: Sum( axis=axis, dtype=dtype, acc_dtype=acc_dtype )(x), 1, - set_test_value( - pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), + (pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))), ), ( lambda x, axis=None, dtype=None, acc_dtype=None: Prod( axis=axis, dtype=dtype, acc_dtype=acc_dtype )(x), (), # Empty axes would normally be rewritten away, but we want to test it still works - set_test_value( - pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), + (pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))), ), ( lambda x, axis=None, dtype=None, acc_dtype=None: Prod( axis=axis, dtype=dtype, acc_dtype=acc_dtype )(x), None, - set_test_value( - pt.scalar(), np.array(99.0, dtype=config.floatX) + ( + pt.scalar(), + np.array(99.0, dtype=config.floatX), ), # Scalar input would normally be rewritten away, but we want to test it still works ), ( @@ -309,77 +287,62 @@ def test_Dimshuffle_non_contiguous(): axis=axis, dtype=dtype, acc_dtype=acc_dtype )(x), 0, - set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)), + (pt.vector(), np.arange(3, dtype=config.floatX)), ), ( lambda x, axis=None, dtype=None, acc_dtype=None: ProdWithoutZeros( axis=axis, dtype=dtype, acc_dtype=acc_dtype )(x), 0, - set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)), + (pt.vector(), np.arange(3, dtype=config.floatX)), ), ( lambda x, axis=None, dtype=None, acc_dtype=None: Prod( axis=axis, dtype=dtype, acc_dtype=acc_dtype )(x), 0, - set_test_value( - pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), + (pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))), ), ( lambda x, axis=None, dtype=None, acc_dtype=None: Prod( axis=axis, dtype=dtype, acc_dtype=acc_dtype )(x), 1, - set_test_value( - pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), + (pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))), ), ( lambda x, axis=None, dtype=None, acc_dtype=None: Max(axis)(x), None, - set_test_value( - pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), + (pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))), ), ( lambda x, axis=None, dtype=None, acc_dtype=None: Max(axis)(x), None, - set_test_value( - pt.lmatrix(), np.arange(3 * 2, dtype=np.int64).reshape((3, 2)) - ), + (pt.lmatrix(), np.arange(3 * 2, dtype=np.int64).reshape((3, 2))), ), ( lambda x, axis=None, dtype=None, acc_dtype=None: Min(axis)(x), None, - set_test_value( - pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), + (pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))), ), ( lambda x, axis=None, dtype=None, acc_dtype=None: Min(axis)(x), None, - set_test_value( - pt.lmatrix(), np.arange(3 * 2, dtype=np.int64).reshape((3, 2)) - ), + (pt.lmatrix(), np.arange(3 * 2, dtype=np.int64).reshape((3, 2))), ), ], ) def test_CAReduce(careduce_fn, axis, v): + v, v_test_value = v g = careduce_fn(v, axis=axis) - g_fg = FunctionGraph(outputs=[g]) fn, _ = compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [v], + [g], + [v_test_value], ) # Confirm CAReduce is in the compiled function - fn.dprint() + # fn.dprint() [node] = fn.maker.fgraph.apply_nodes assert isinstance(node.op, CAReduce) @@ -387,102 +350,91 @@ def test_CAReduce(careduce_fn, axis, v): def test_scalar_Elemwise_Clip(): a = pt.scalar("a") b = pt.scalar("b") + inputs = [a, b] z = pt.switch(1, a, b) c = pt.clip(z, 1, 3) - c_fg = FunctionGraph(outputs=[c]) - compare_numba_and_py(c_fg, [1, 1]) + compare_numba_and_py(inputs, [c], [1, 1]) @pytest.mark.parametrize( "dy, sm, axis, exc", [ ( - set_test_value( - pt.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX) - ), - set_test_value(pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), + (pt.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)), + (pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), None, None, ), ( - set_test_value( - pt.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX) - ), - set_test_value(pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), + (pt.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)), + (pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), 0, None, ), ( - set_test_value( - pt.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX) - ), - set_test_value(pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), + (pt.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)), + (pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), 1, None, ), ], ) def test_SoftmaxGrad(dy, sm, axis, exc): + dy, dy_test_value = dy + sm, sm_test_value = sm g = SoftmaxGrad(axis=axis)(dy, sm) - g_fg = FunctionGraph(outputs=[g]) cm = contextlib.suppress() if exc is None else pytest.warns(exc) with cm: compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [dy, sm], + [g], + [dy_test_value, sm_test_value], ) def test_SoftMaxGrad_constant_dy(): dy = pt.constant(np.zeros((3,), dtype=config.floatX)) sm = pt.vector(shape=(3,)) + inputs = [sm] g = SoftmaxGrad(axis=None)(dy, sm) - g_fg = FunctionGraph(outputs=[g]) - compare_numba_and_py(g_fg, [np.ones((3,), dtype=config.floatX)]) + compare_numba_and_py(inputs, [g], [np.ones((3,), dtype=config.floatX)]) @pytest.mark.parametrize( "x, axis, exc", [ ( - set_test_value(pt.vector(), rng.random(size=(2,)).astype(config.floatX)), + (pt.vector(), rng.random(size=(2,)).astype(config.floatX)), None, None, ), ( - set_test_value(pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), + (pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), None, None, ), ( - set_test_value(pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), + (pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), 0, None, ), ], ) def test_Softmax(x, axis, exc): + x, x_test_value = x g = Softmax(axis=axis)(x) - g_fg = FunctionGraph(outputs=[g]) cm = contextlib.suppress() if exc is None else pytest.warns(exc) with cm: compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [x], + [g], + [x_test_value], ) @@ -490,35 +442,32 @@ def test_Softmax(x, axis, exc): "x, axis, exc", [ ( - set_test_value(pt.vector(), rng.random(size=(2,)).astype(config.floatX)), + (pt.vector(), rng.random(size=(2,)).astype(config.floatX)), None, None, ), ( - set_test_value(pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), + (pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), 0, None, ), ( - set_test_value(pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), + (pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), 1, None, ), ], ) def test_LogSoftmax(x, axis, exc): + x, x_test_value = x g = LogSoftmax(axis=axis)(x) - g_fg = FunctionGraph(outputs=[g]) cm = contextlib.suppress() if exc is None else pytest.warns(exc) with cm: compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [x], + [g], + [x_test_value], ) @@ -526,44 +475,37 @@ def test_LogSoftmax(x, axis, exc): "x, axes, exc", [ ( - set_test_value(pt.dscalar(), np.array(0.0, dtype="float64")), + (pt.dscalar(), np.array(0.0, dtype="float64")), [], None, ), ( - set_test_value(pt.dvector(), rng.random(size=(3,)).astype("float64")), + (pt.dvector(), rng.random(size=(3,)).astype("float64")), [0], None, ), ( - set_test_value(pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")), + (pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")), [0], None, ), ( - set_test_value(pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")), + (pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")), [0, 1], None, ), ], ) def test_Max(x, axes, exc): + x, x_test_value = x g = ptm.Max(axes)(x) - if isinstance(g, list): - g_fg = FunctionGraph(outputs=g) - else: - g_fg = FunctionGraph(outputs=[g]) - cm = contextlib.suppress() if exc is None else pytest.warns(exc) with cm: compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [x], + [g], + [x_test_value], ) @@ -571,44 +513,37 @@ def test_Max(x, axes, exc): "x, axes, exc", [ ( - set_test_value(pt.dscalar(), np.array(0.0, dtype="float64")), + (pt.dscalar(), np.array(0.0, dtype="float64")), [], None, ), ( - set_test_value(pt.dvector(), rng.random(size=(3,)).astype("float64")), + (pt.dvector(), rng.random(size=(3,)).astype("float64")), [0], None, ), ( - set_test_value(pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")), + (pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")), [0], None, ), ( - set_test_value(pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")), + (pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")), [0, 1], None, ), ], ) def test_Argmax(x, axes, exc): + x, x_test_value = x g = ptm.Argmax(axes)(x) - if isinstance(g, list): - g_fg = FunctionGraph(outputs=g) - else: - g_fg = FunctionGraph(outputs=[g]) - cm = contextlib.suppress() if exc is None else pytest.warns(exc) with cm: compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [x], + [g], + [x_test_value], ) @@ -636,7 +571,8 @@ def test_scalar_loop(): with pytest.warns(UserWarning, match="object mode"): compare_numba_and_py( - ([x], [elemwise_loop]), + [x], + [elemwise_loop], (np.array([1, 2, 3], dtype="float64"),), ) diff --git a/tests/link/numba/test_extra_ops.py b/tests/link/numba/test_extra_ops.py index e61862ffdf..e9b6700c63 100644 --- a/tests/link/numba/test_extra_ops.py +++ b/tests/link/numba/test_extra_ops.py @@ -5,11 +5,8 @@ import pytensor.tensor as pt from pytensor import config -from pytensor.compile.sharedvalue import SharedVariable -from pytensor.graph.basic import Constant -from pytensor.graph.fg import FunctionGraph from pytensor.tensor import extra_ops -from tests.link.numba.test_basic import compare_numba_and_py, set_test_value +from tests.link.numba.test_basic import compare_numba_and_py rng = np.random.default_rng(42849) @@ -18,20 +15,17 @@ @pytest.mark.parametrize( "val", [ - set_test_value(pt.lscalar(), np.array(6, dtype="int64")), + (pt.lscalar(), np.array(6, dtype="int64")), ], ) def test_Bartlett(val): + val, test_val = val g = extra_ops.bartlett(val) - g_fg = FunctionGraph(outputs=[g]) compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [val], + g, + [test_val], assert_fn=lambda x, y: np.testing.assert_allclose(x, y, atol=1e-15), ) @@ -40,97 +34,71 @@ def test_Bartlett(val): "val, axis, mode", [ ( - set_test_value( - pt.matrix(), np.arange(3, dtype=config.floatX).reshape((3, 1)) - ), + (pt.matrix(), np.arange(3, dtype=config.floatX).reshape((3, 1))), 1, "add", ), ( - set_test_value( - pt.dtensor3(), np.arange(30, dtype=config.floatX).reshape((2, 3, 5)) - ), + (pt.dtensor3(), np.arange(30, dtype=config.floatX).reshape((2, 3, 5))), -1, "add", ), ( - set_test_value( - pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2)) - ), + (pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))), 0, "add", ), ( - set_test_value( - pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2)) - ), + (pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))), 1, "add", ), ( - set_test_value( - pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2)) - ), + (pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))), None, "add", ), ( - set_test_value( - pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2)) - ), + (pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))), 0, "mul", ), ( - set_test_value( - pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2)) - ), + (pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))), 1, "mul", ), ( - set_test_value( - pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2)) - ), + (pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))), None, "mul", ), ], ) def test_CumOp(val, axis, mode): + val, test_val = val g = extra_ops.CumOp(axis=axis, mode=mode)(val) - g_fg = FunctionGraph(outputs=[g]) compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [val], + g, + [test_val], ) -@pytest.mark.parametrize( - "a, val", - [ - ( - set_test_value(pt.lmatrix(), np.zeros((10, 2), dtype="int64")), - set_test_value(pt.lscalar(), np.array(1, dtype="int64")), - ) - ], -) -def test_FillDiagonal(a, val): +def test_FillDiagonal(): + a = pt.lmatrix("a") + test_a = np.zeros((10, 2), dtype="int64") + + val = pt.lscalar("val") + test_val = np.array(1, dtype="int64") + g = extra_ops.FillDiagonal()(a, val) - g_fg = FunctionGraph(outputs=[g]) compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [a, val], + g, + [test_a, test_val], ) @@ -138,33 +106,32 @@ def test_FillDiagonal(a, val): "a, val, offset", [ ( - set_test_value(pt.lmatrix(), np.zeros((10, 2), dtype="int64")), - set_test_value(pt.lscalar(), np.array(1, dtype="int64")), - set_test_value(pt.lscalar(), np.array(-1, dtype="int64")), + (pt.lmatrix(), np.zeros((10, 2), dtype="int64")), + (pt.lscalar(), np.array(1, dtype="int64")), + (pt.lscalar(), np.array(-1, dtype="int64")), ), ( - set_test_value(pt.lmatrix(), np.zeros((10, 2), dtype="int64")), - set_test_value(pt.lscalar(), np.array(1, dtype="int64")), - set_test_value(pt.lscalar(), np.array(0, dtype="int64")), + (pt.lmatrix(), np.zeros((10, 2), dtype="int64")), + (pt.lscalar(), np.array(1, dtype="int64")), + (pt.lscalar(), np.array(0, dtype="int64")), ), ( - set_test_value(pt.lmatrix(), np.zeros((10, 3), dtype="int64")), - set_test_value(pt.lscalar(), np.array(1, dtype="int64")), - set_test_value(pt.lscalar(), np.array(1, dtype="int64")), + (pt.lmatrix(), np.zeros((10, 3), dtype="int64")), + (pt.lscalar(), np.array(1, dtype="int64")), + (pt.lscalar(), np.array(1, dtype="int64")), ), ], ) def test_FillDiagonalOffset(a, val, offset): + a, test_a = a + val, test_val = val + offset, test_offset = offset g = extra_ops.FillDiagonalOffset()(a, val, offset) - g_fg = FunctionGraph(outputs=[g]) compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [a, val, offset], + g, + [test_a, test_val, test_offset], ) @@ -172,65 +139,56 @@ def test_FillDiagonalOffset(a, val, offset): "arr, shape, mode, order, exc", [ ( - tuple(set_test_value(pt.lscalar(), v) for v in np.array([0])), - set_test_value(pt.lvector(), np.array([2])), + tuple((pt.lscalar(), v) for v in np.array([0])), + (pt.lvector(), np.array([2])), "raise", "C", None, ), ( - tuple(set_test_value(pt.lscalar(), v) for v in np.array([0, 0, 3])), - set_test_value(pt.lvector(), np.array([2, 3, 4])), + tuple((pt.lscalar(), v) for v in np.array([0, 0, 3])), + (pt.lvector(), np.array([2, 3, 4])), "raise", "C", None, ), ( - tuple( - set_test_value(pt.lvector(), v) - for v in np.array([[0, 1], [2, 0], [1, 3]]) - ), - set_test_value(pt.lvector(), np.array([2, 3, 4])), + tuple((pt.lvector(), v) for v in np.array([[0, 1], [2, 0], [1, 3]])), + (pt.lvector(), np.array([2, 3, 4])), "raise", "C", None, ), ( - tuple( - set_test_value(pt.lvector(), v) - for v in np.array([[0, 1], [2, 0], [1, 3]]) - ), - set_test_value(pt.lvector(), np.array([2, 3, 4])), + tuple((pt.lvector(), v) for v in np.array([[0, 1], [2, 0], [1, 3]])), + (pt.lvector(), np.array([2, 3, 4])), "raise", "F", NotImplementedError, ), ( tuple( - set_test_value(pt.lvector(), v) - for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]]) + (pt.lvector(), v) for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]]) ), - set_test_value(pt.lvector(), np.array([2, 3, 4])), + (pt.lvector(), np.array([2, 3, 4])), "raise", "C", ValueError, ), ( tuple( - set_test_value(pt.lvector(), v) - for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]]) + (pt.lvector(), v) for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]]) ), - set_test_value(pt.lvector(), np.array([2, 3, 4])), + (pt.lvector(), np.array([2, 3, 4])), "wrap", "C", None, ), ( tuple( - set_test_value(pt.lvector(), v) - for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]]) + (pt.lvector(), v) for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]]) ), - set_test_value(pt.lvector(), np.array([2, 3, 4])), + (pt.lvector(), np.array([2, 3, 4])), "clip", "C", None, @@ -238,18 +196,16 @@ def test_FillDiagonalOffset(a, val, offset): ], ) def test_RavelMultiIndex(arr, shape, mode, order, exc): - g = extra_ops.RavelMultiIndex(mode, order)(*((*arr, shape))) - g_fg = FunctionGraph(outputs=[g]) + arr, test_arr = zip(*arr, strict=True) + shape, test_shape = shape + g = extra_ops.RavelMultiIndex(mode, order)(*arr, shape) cm = contextlib.suppress() if exc is None else pytest.raises(exc) with cm: compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [*arr, shape], + g, + [*test_arr, test_shape], ) @@ -257,44 +213,42 @@ def test_RavelMultiIndex(arr, shape, mode, order, exc): "x, repeats, axis, exc", [ ( - set_test_value(pt.lscalar(), np.array(1, dtype="int64")), - set_test_value(pt.lscalar(), np.array(0, dtype="int64")), + (pt.lscalar(), np.array(1, dtype="int64")), + (pt.lscalar(), np.array(0, dtype="int64")), None, None, ), ( - set_test_value(pt.lmatrix(), np.zeros((2, 2), dtype="int64")), - set_test_value(pt.lscalar(), np.array(1, dtype="int64")), + (pt.lmatrix(), np.zeros((2, 2), dtype="int64")), + (pt.lscalar(), np.array(1, dtype="int64")), None, None, ), ( - set_test_value(pt.lvector(), np.arange(2, dtype="int64")), - set_test_value(pt.lvector(), np.array([1, 1], dtype="int64")), + (pt.lvector(), np.arange(2, dtype="int64")), + (pt.lvector(), np.array([1, 1], dtype="int64")), None, None, ), ( - set_test_value(pt.lmatrix(), np.zeros((2, 2), dtype="int64")), - set_test_value(pt.lscalar(), np.array(1, dtype="int64")), + (pt.lmatrix(), np.zeros((2, 2), dtype="int64")), + (pt.lscalar(), np.array(1, dtype="int64")), 0, UserWarning, ), ], ) def test_Repeat(x, repeats, axis, exc): + x, test_x = x + repeats, test_repeats = repeats g = extra_ops.Repeat(axis)(x, repeats) - g_fg = FunctionGraph(outputs=[g]) cm = contextlib.suppress() if exc is None else pytest.warns(exc) with cm: compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [x, repeats], + g, + [test_x, test_repeats], ) @@ -302,7 +256,7 @@ def test_Repeat(x, repeats, axis, exc): "x, axis, return_index, return_inverse, return_counts, exc", [ ( - set_test_value(pt.lscalar(), np.array(1, dtype="int64")), + (pt.lscalar(), np.array(1, dtype="int64")), None, False, False, @@ -310,7 +264,7 @@ def test_Repeat(x, repeats, axis, exc): None, ), ( - set_test_value(pt.lvector(), np.array([1, 1, 2], dtype="int64")), + (pt.lvector(), np.array([1, 1, 2], dtype="int64")), None, False, False, @@ -318,7 +272,7 @@ def test_Repeat(x, repeats, axis, exc): None, ), ( - set_test_value(pt.lmatrix(), np.array([[1, 1], [2, 2]], dtype="int64")), + (pt.lmatrix(), np.array([[1, 1], [2, 2]], dtype="int64")), None, False, False, @@ -326,9 +280,7 @@ def test_Repeat(x, repeats, axis, exc): None, ), ( - set_test_value( - pt.lmatrix(), np.array([[1, 1], [1, 1], [2, 2]], dtype="int64") - ), + (pt.lmatrix(), np.array([[1, 1], [1, 1], [2, 2]], dtype="int64")), 0, False, False, @@ -336,9 +288,7 @@ def test_Repeat(x, repeats, axis, exc): UserWarning, ), ( - set_test_value( - pt.lmatrix(), np.array([[1, 1], [1, 1], [2, 2]], dtype="int64") - ), + (pt.lmatrix(), np.array([[1, 1], [1, 1], [2, 2]], dtype="int64")), 0, True, True, @@ -348,22 +298,15 @@ def test_Repeat(x, repeats, axis, exc): ], ) def test_Unique(x, axis, return_index, return_inverse, return_counts, exc): + x, test_x = x g = extra_ops.Unique(return_index, return_inverse, return_counts, axis)(x) - if isinstance(g, list): - g_fg = FunctionGraph(outputs=g) - else: - g_fg = FunctionGraph(outputs=[g]) - cm = contextlib.suppress() if exc is None else pytest.warns(exc) with cm: compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [x], + g, + [test_x], ) @@ -371,19 +314,19 @@ def test_Unique(x, axis, return_index, return_inverse, return_counts, exc): "arr, shape, order, exc", [ ( - set_test_value(pt.lvector(), np.array([9, 15, 1], dtype="int64")), + (pt.lvector(), np.array([9, 15, 1], dtype="int64")), pt.as_tensor([2, 3, 4]), "C", None, ), ( - set_test_value(pt.lvector(), np.array([1, 0], dtype="int64")), + (pt.lvector(), np.array([1, 0], dtype="int64")), pt.as_tensor([2]), "C", None, ), ( - set_test_value(pt.lvector(), np.array([9, 15, 1], dtype="int64")), + (pt.lvector(), np.array([9, 15, 1], dtype="int64")), pt.as_tensor([2, 3, 4]), "F", NotImplementedError, @@ -391,22 +334,15 @@ def test_Unique(x, axis, return_index, return_inverse, return_counts, exc): ], ) def test_UnravelIndex(arr, shape, order, exc): + arr, test_arr = arr g = extra_ops.UnravelIndex(order)(arr, shape) - if isinstance(g, list): - g_fg = FunctionGraph(outputs=g) - else: - g_fg = FunctionGraph(outputs=[g]) - cm = contextlib.suppress() if exc is None else pytest.raises(exc) with cm: compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [arr], + g, + [test_arr], ) @@ -414,18 +350,18 @@ def test_UnravelIndex(arr, shape, order, exc): "a, v, side, sorter, exc", [ ( - set_test_value(pt.vector(), np.array([1.0, 2.0, 3.0], dtype=config.floatX)), - set_test_value(pt.matrix(), rng.random((3, 2)).astype(config.floatX)), + (pt.vector(), np.array([1.0, 2.0, 3.0], dtype=config.floatX)), + (pt.matrix(), rng.random((3, 2)).astype(config.floatX)), "left", None, None, ), pytest.param( - set_test_value( + ( pt.vector(), np.array([0.29769574, 0.71649186, 0.20475563]).astype(config.floatX), ), - set_test_value( + ( pt.matrix(), np.array( [ @@ -440,25 +376,26 @@ def test_UnravelIndex(arr, shape, order, exc): None, ), ( - set_test_value(pt.vector(), np.array([1.0, 2.0, 3.0], dtype=config.floatX)), - set_test_value(pt.matrix(), rng.random((3, 2)).astype(config.floatX)), + (pt.vector(), np.array([1.0, 2.0, 3.0], dtype=config.floatX)), + (pt.matrix(), rng.random((3, 2)).astype(config.floatX)), "right", - set_test_value(pt.lvector(), np.array([0, 2, 1])), + (pt.lvector(), np.array([0, 2, 1])), UserWarning, ), ], ) def test_Searchsorted(a, v, side, sorter, exc): + a, test_a = a + v, test_v = v + if sorter is not None: + sorter, test_sorter = sorter + g = extra_ops.SearchsortedOp(side)(a, v, sorter) - g_fg = FunctionGraph(outputs=[g]) cm = contextlib.suppress() if exc is None else pytest.warns(exc) with cm: compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [a, v] if sorter is None else [a, v, sorter], + g, + [test_a, test_v] if sorter is None else [test_a, test_v, test_sorter], ) diff --git a/tests/link/numba/test_nlinalg.py b/tests/link/numba/test_nlinalg.py index 3dc427cd9c..67bdc6f1a0 100644 --- a/tests/link/numba/test_nlinalg.py +++ b/tests/link/numba/test_nlinalg.py @@ -4,11 +4,8 @@ import pytest import pytensor.tensor as pt -from pytensor.compile.sharedvalue import SharedVariable -from pytensor.graph.basic import Constant -from pytensor.graph.fg import FunctionGraph from pytensor.tensor import nlinalg -from tests.link.numba.test_basic import compare_numba_and_py, set_test_value +from tests.link.numba.test_basic import compare_numba_and_py rng = np.random.default_rng(42849) @@ -18,14 +15,14 @@ "x, exc", [ ( - set_test_value( + ( pt.dmatrix(), (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), ), None, ), ( - set_test_value( + ( pt.lmatrix(), (lambda x: x.T.dot(x))(rng.poisson(size=(3, 3)).astype("int64")), ), @@ -34,18 +31,15 @@ ], ) def test_Det(x, exc): + x, test_x = x g = nlinalg.Det()(x) - g_fg = FunctionGraph(outputs=[g]) cm = contextlib.suppress() if exc is None else pytest.warns(exc) with cm: compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [x], + g, + [test_x], ) @@ -53,14 +47,14 @@ def test_Det(x, exc): "x, exc", [ ( - set_test_value( + ( pt.dmatrix(), (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), ), None, ), ( - set_test_value( + ( pt.lmatrix(), (lambda x: x.T.dot(x))(rng.poisson(size=(3, 3)).astype("int64")), ), @@ -69,18 +63,15 @@ def test_Det(x, exc): ], ) def test_SLogDet(x, exc): + x, test_x = x g = nlinalg.SLogDet()(x) - g_fg = FunctionGraph(outputs=g) cm = contextlib.suppress() if exc is None else pytest.warns(exc) with cm: compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [x], + g, + [test_x], ) @@ -112,21 +103,21 @@ def test_SLogDet(x, exc): "x, exc", [ ( - set_test_value( + ( pt.dmatrix(), (lambda x: x.T.dot(x))(x), ), None, ), ( - set_test_value( + ( pt.dmatrix(), (lambda x: x.T.dot(x))(y), ), None, ), ( - set_test_value( + ( pt.lmatrix(), (lambda x: x.T.dot(x))( rng.integers(1, 10, size=(3, 3)).astype("int64") @@ -137,22 +128,15 @@ def test_SLogDet(x, exc): ], ) def test_Eig(x, exc): + x, test_x = x g = nlinalg.Eig()(x) - if isinstance(g, list): - g_fg = FunctionGraph(outputs=g) - else: - g_fg = FunctionGraph(outputs=[g]) - cm = contextlib.suppress() if exc is None else pytest.warns(exc) with cm: compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [x], + g, + [test_x], ) @@ -160,7 +144,7 @@ def test_Eig(x, exc): "x, uplo, exc", [ ( - set_test_value( + ( pt.dmatrix(), (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), ), @@ -168,7 +152,7 @@ def test_Eig(x, exc): None, ), ( - set_test_value( + ( pt.lmatrix(), (lambda x: x.T.dot(x))( rng.integers(1, 10, size=(3, 3)).astype("int64") @@ -180,22 +164,15 @@ def test_Eig(x, exc): ], ) def test_Eigh(x, uplo, exc): + x, test_x = x g = nlinalg.Eigh(uplo)(x) - if isinstance(g, list): - g_fg = FunctionGraph(outputs=g) - else: - g_fg = FunctionGraph(outputs=[g]) - cm = contextlib.suppress() if exc is None else pytest.warns(exc) with cm: compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [x], + g, + [test_x], ) @@ -204,7 +181,7 @@ def test_Eigh(x, uplo, exc): [ ( nlinalg.MatrixInverse, - set_test_value( + ( pt.dmatrix(), (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), ), @@ -213,7 +190,7 @@ def test_Eigh(x, uplo, exc): ), ( nlinalg.MatrixInverse, - set_test_value( + ( pt.lmatrix(), (lambda x: x.T.dot(x))( rng.integers(1, 10, size=(3, 3)).astype("int64") @@ -224,7 +201,7 @@ def test_Eigh(x, uplo, exc): ), ( nlinalg.MatrixPinv, - set_test_value( + ( pt.dmatrix(), (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), ), @@ -233,7 +210,7 @@ def test_Eigh(x, uplo, exc): ), ( nlinalg.MatrixPinv, - set_test_value( + ( pt.lmatrix(), (lambda x: x.T.dot(x))( rng.integers(1, 10, size=(3, 3)).astype("int64") @@ -245,18 +222,15 @@ def test_Eigh(x, uplo, exc): ], ) def test_matrix_inverses(op, x, exc, op_args): + x, test_x = x g = op(*op_args)(x) - g_fg = FunctionGraph(outputs=[g]) cm = contextlib.suppress() if exc is None else pytest.warns(exc) with cm: compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [x], + g, + [test_x], ) @@ -264,7 +238,7 @@ def test_matrix_inverses(op, x, exc, op_args): "x, mode, exc", [ ( - set_test_value( + ( pt.dmatrix(), (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), ), @@ -272,7 +246,7 @@ def test_matrix_inverses(op, x, exc, op_args): None, ), ( - set_test_value( + ( pt.dmatrix(), (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), ), @@ -280,7 +254,7 @@ def test_matrix_inverses(op, x, exc, op_args): None, ), ( - set_test_value( + ( pt.lmatrix(), (lambda x: x.T.dot(x))( rng.integers(1, 10, size=(3, 3)).astype("int64") @@ -290,7 +264,7 @@ def test_matrix_inverses(op, x, exc, op_args): None, ), ( - set_test_value( + ( pt.lmatrix(), (lambda x: x.T.dot(x))( rng.integers(1, 10, size=(3, 3)).astype("int64") @@ -302,22 +276,15 @@ def test_matrix_inverses(op, x, exc, op_args): ], ) def test_QRFull(x, mode, exc): + x, test_x = x g = nlinalg.QRFull(mode)(x) - if isinstance(g, list): - g_fg = FunctionGraph(outputs=g) - else: - g_fg = FunctionGraph(outputs=[g]) - cm = contextlib.suppress() if exc is None else pytest.warns(exc) with cm: compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [x], + g, + [test_x], ) @@ -325,7 +292,7 @@ def test_QRFull(x, mode, exc): "x, full_matrices, compute_uv, exc", [ ( - set_test_value( + ( pt.dmatrix(), (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), ), @@ -334,7 +301,7 @@ def test_QRFull(x, mode, exc): None, ), ( - set_test_value( + ( pt.dmatrix(), (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), ), @@ -343,7 +310,7 @@ def test_QRFull(x, mode, exc): None, ), ( - set_test_value( + ( pt.lmatrix(), (lambda x: x.T.dot(x))( rng.integers(1, 10, size=(3, 3)).astype("int64") @@ -354,7 +321,7 @@ def test_QRFull(x, mode, exc): None, ), ( - set_test_value( + ( pt.lmatrix(), (lambda x: x.T.dot(x))( rng.integers(1, 10, size=(3, 3)).astype("int64") @@ -367,20 +334,13 @@ def test_QRFull(x, mode, exc): ], ) def test_SVD(x, full_matrices, compute_uv, exc): + x, test_x = x g = nlinalg.SVD(full_matrices, compute_uv)(x) - if isinstance(g, list): - g_fg = FunctionGraph(outputs=g) - else: - g_fg = FunctionGraph(outputs=[g]) - cm = contextlib.suppress() if exc is None else pytest.warns(exc) with cm: compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [x], + g, + [test_x], ) diff --git a/tests/link/numba/test_pad.py b/tests/link/numba/test_pad.py index 11877594d7..437c325d6c 100644 --- a/tests/link/numba/test_pad.py +++ b/tests/link/numba/test_pad.py @@ -3,7 +3,6 @@ import pytensor.tensor as pt from pytensor import config -from pytensor.graph import FunctionGraph from pytensor.tensor.pad import PadMode from tests.link.numba.test_basic import compare_numba_and_py @@ -58,10 +57,10 @@ def test_numba_pad(mode: PadMode, kwargs): x = np.random.normal(size=(3, 3)) res = pt.pad(x_pt, mode=mode, pad_width=3, **kwargs) - res_fg = FunctionGraph([x_pt], [res]) compare_numba_and_py( - res_fg, + [x_pt], + [res], [x], assert_fn=lambda x, y: np.testing.assert_allclose(x, y, rtol=RTOL, atol=ATOL), py_mode="FAST_RUN", diff --git a/tests/link/numba/test_random.py b/tests/link/numba/test_random.py index 1569ea8ae8..f52b1e2800 100644 --- a/tests/link/numba/test_random.py +++ b/tests/link/numba/test_random.py @@ -10,13 +10,9 @@ from pytensor import shared from pytensor.compile.builders import OpFromGraph from pytensor.compile.function import function -from pytensor.compile.sharedvalue import SharedVariable -from pytensor.graph.basic import Constant -from pytensor.graph.fg import FunctionGraph from tests.link.numba.test_basic import ( compare_numba_and_py, numba_mode, - set_test_value, ) from tests.tensor.random.test_basic import ( batched_permutation_tester, @@ -159,11 +155,11 @@ def test_multivariate_normal(): ( ptr.uniform, [ - set_test_value( + ( pt.dscalar(), np.array(1.0, dtype=np.float64), ), - set_test_value( + ( pt.dvector(), np.array([1.0, 2.0], dtype=np.float64), ), @@ -173,15 +169,15 @@ def test_multivariate_normal(): ( ptr.triangular, [ - set_test_value( + ( pt.dscalar(), np.array(-5.0, dtype=np.float64), ), - set_test_value( + ( pt.dscalar(), np.array(1.0, dtype=np.float64), ), - set_test_value( + ( pt.dscalar(), np.array(5.0, dtype=np.float64), ), @@ -191,11 +187,11 @@ def test_multivariate_normal(): ( ptr.lognormal, [ - set_test_value( + ( pt.dvector(), np.array([1.0, 2.0], dtype=np.float64), ), - set_test_value( + ( pt.dscalar(), np.array(1.0, dtype=np.float64), ), @@ -205,11 +201,11 @@ def test_multivariate_normal(): ( ptr.pareto, [ - set_test_value( + ( pt.dvector(), np.array([1.0, 2.0], dtype=np.float64), ), - set_test_value( + ( pt.dvector(), np.array([2.0, 10.0], dtype=np.float64), ), @@ -219,7 +215,7 @@ def test_multivariate_normal(): ( ptr.exponential, [ - set_test_value( + ( pt.dvector(), np.array([1.0, 2.0], dtype=np.float64), ), @@ -229,7 +225,7 @@ def test_multivariate_normal(): ( ptr.weibull, [ - set_test_value( + ( pt.dvector(), np.array([1.0, 2.0], dtype=np.float64), ), @@ -239,11 +235,11 @@ def test_multivariate_normal(): ( ptr.logistic, [ - set_test_value( + ( pt.dvector(), np.array([1.0, 2.0], dtype=np.float64), ), - set_test_value( + ( pt.dscalar(), np.array(1.0, dtype=np.float64), ), @@ -253,7 +249,7 @@ def test_multivariate_normal(): ( ptr.geometric, [ - set_test_value( + ( pt.dvector(), np.array([0.3, 0.4], dtype=np.float64), ), @@ -263,15 +259,15 @@ def test_multivariate_normal(): pytest.param( ptr.hypergeometric, [ - set_test_value( + ( pt.lscalar(), np.array(7, dtype=np.int64), ), - set_test_value( + ( pt.lscalar(), np.array(8, dtype=np.int64), ), - set_test_value( + ( pt.lscalar(), np.array(15, dtype=np.int64), ), @@ -282,11 +278,11 @@ def test_multivariate_normal(): ( ptr.wald, [ - set_test_value( + ( pt.dvector(), np.array([1.0, 2.0], dtype=np.float64), ), - set_test_value( + ( pt.dscalar(), np.array(1.0, dtype=np.float64), ), @@ -296,11 +292,11 @@ def test_multivariate_normal(): ( ptr.laplace, [ - set_test_value( + ( pt.dvector(), np.array([1.0, 2.0], dtype=np.float64), ), - set_test_value( + ( pt.dscalar(), np.array(1.0, dtype=np.float64), ), @@ -310,11 +306,11 @@ def test_multivariate_normal(): ( ptr.binomial, [ - set_test_value( + ( pt.lvector(), np.array([1, 2], dtype=np.int64), ), - set_test_value( + ( pt.dscalar(), np.array(0.9, dtype=np.float64), ), @@ -324,21 +320,21 @@ def test_multivariate_normal(): ( ptr.normal, [ - set_test_value( + ( pt.lvector(), np.array([1, 2], dtype=np.int64), ), - set_test_value( + ( pt.dscalar(), np.array(1.0, dtype=np.float64), ), ], - pt.as_tensor(tuple(set_test_value(pt.lscalar(), v) for v in [3, 2])), + pt.as_tensor([3, 2]), ), ( ptr.poisson, [ - set_test_value( + ( pt.dvector(), np.array([1.0, 2.0], dtype=np.float64), ), @@ -348,11 +344,11 @@ def test_multivariate_normal(): ( ptr.halfnormal, [ - set_test_value( + ( pt.lvector(), np.array([1, 2], dtype=np.int64), ), - set_test_value( + ( pt.dscalar(), np.array(1.0, dtype=np.float64), ), @@ -362,7 +358,7 @@ def test_multivariate_normal(): ( ptr.bernoulli, [ - set_test_value( + ( pt.dvector(), np.array([0.1, 0.9], dtype=np.float64), ), @@ -372,11 +368,11 @@ def test_multivariate_normal(): ( ptr.beta, [ - set_test_value( + ( pt.dvector(), np.array([1.0, 2.0], dtype=np.float64), ), - set_test_value( + ( pt.dscalar(), np.array(1.0, dtype=np.float64), ), @@ -386,11 +382,11 @@ def test_multivariate_normal(): ( ptr._gamma, [ - set_test_value( + ( pt.dvector(), np.array([1.0, 2.0], dtype=np.float64), ), - set_test_value( + ( pt.dvector(), np.array([0.5, 3.0], dtype=np.float64), ), @@ -400,7 +396,7 @@ def test_multivariate_normal(): ( ptr.chisquare, [ - set_test_value( + ( pt.dvector(), np.array([1.0, 2.0], dtype=np.float64), ) @@ -410,11 +406,11 @@ def test_multivariate_normal(): ( ptr.negative_binomial, [ - set_test_value( + ( pt.lvector(), np.array([100, 200], dtype=np.int64), ), - set_test_value( + ( pt.dscalar(), np.array(0.09, dtype=np.float64), ), @@ -424,11 +420,11 @@ def test_multivariate_normal(): ( ptr.vonmises, [ - set_test_value( + ( pt.dvector(), np.array([-0.5, 0.5], dtype=np.float64), ), - set_test_value( + ( pt.dscalar(), np.array(1.0, dtype=np.float64), ), @@ -438,14 +434,14 @@ def test_multivariate_normal(): ( ptr.permutation, [ - set_test_value(pt.dmatrix(), np.eye(5, dtype=np.float64)), + (pt.dmatrix(), np.eye(5, dtype=np.float64)), ], (), ), ( partial(ptr.choice, replace=True), [ - set_test_value(pt.dmatrix(), np.eye(5, dtype=np.float64)), + (pt.dmatrix(), np.eye(5, dtype=np.float64)), ], pt.as_tensor([2]), ), @@ -455,17 +451,15 @@ def test_multivariate_normal(): a, p=p, size=size, replace=True, rng=rng ), [ - set_test_value(pt.dmatrix(), np.eye(3, dtype=np.float64)), - set_test_value( - pt.dvector(), np.array([0.25, 0.5, 0.25], dtype=np.float64) - ), + (pt.dmatrix(), np.eye(3, dtype=np.float64)), + (pt.dvector(), np.array([0.25, 0.5, 0.25], dtype=np.float64)), ], (pt.as_tensor([2, 3])), ), pytest.param( partial(ptr.choice, replace=False), [ - set_test_value(pt.dvector(), np.arange(5, dtype=np.float64)), + (pt.dvector(), np.arange(5, dtype=np.float64)), ], pt.as_tensor([2]), marks=pytest.mark.xfail( @@ -476,7 +470,7 @@ def test_multivariate_normal(): pytest.param( partial(ptr.choice, replace=False), [ - set_test_value(pt.dmatrix(), np.eye(5, dtype=np.float64)), + (pt.dmatrix(), np.eye(5, dtype=np.float64)), ], pt.as_tensor([2]), marks=pytest.mark.xfail( @@ -490,8 +484,8 @@ def test_multivariate_normal(): a, p=p, size=size, replace=False, rng=rng ), [ - set_test_value(pt.vector(), np.arange(5, dtype=np.float64)), - set_test_value( + (pt.vector(), np.arange(5, dtype=np.float64)), + ( pt.dvector(), np.array([0.5, 0.0, 0.25, 0.0, 0.25], dtype=np.float64), ), @@ -504,10 +498,8 @@ def test_multivariate_normal(): a, p=p, size=size, replace=False, rng=rng ), [ - set_test_value(pt.dmatrix(), np.eye(3, dtype=np.float64)), - set_test_value( - pt.dvector(), np.array([0.25, 0.5, 0.25], dtype=np.float64) - ), + (pt.dmatrix(), np.eye(3, dtype=np.float64)), + (pt.dvector(), np.array([0.25, 0.5, 0.25], dtype=np.float64)), ], (), ), @@ -517,10 +509,8 @@ def test_multivariate_normal(): a, p=p, size=size, replace=False, rng=rng ), [ - set_test_value(pt.dmatrix(), np.eye(3, dtype=np.float64)), - set_test_value( - pt.dvector(), np.array([0.25, 0.5, 0.25], dtype=np.float64) - ), + (pt.dmatrix(), np.eye(3, dtype=np.float64)), + (pt.dvector(), np.array([0.25, 0.5, 0.25], dtype=np.float64)), ], (pt.as_tensor([2, 1])), ), @@ -529,17 +519,14 @@ def test_multivariate_normal(): ) def test_aligned_RandomVariable(rv_op, dist_args, size): """Tests for Numba samplers that are one-to-one with PyTensor's/NumPy's samplers.""" + dist_args, test_dist_args = zip(*dist_args, strict=True) rng = shared(np.random.default_rng(29402)) g = rv_op(*dist_args, size=size, rng=rng) - g_fg = FunctionGraph(outputs=[g]) compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + dist_args, + [g], + test_dist_args, eval_obj_mode=False, # No python impl ) @@ -550,11 +537,11 @@ def test_aligned_RandomVariable(rv_op, dist_args, size): ( ptr.cauchy, [ - set_test_value( + ( pt.dvector(), np.array([1.0, 2.0], dtype=np.float64), ), - set_test_value( + ( pt.dscalar(), np.array(1.0, dtype=np.float64), ), @@ -566,11 +553,11 @@ def test_aligned_RandomVariable(rv_op, dist_args, size): ( ptr.gumbel, [ - set_test_value( + ( pt.dvector(), np.array([1.0, 2.0], dtype=np.float64), ), - set_test_value( + ( pt.dscalar(), np.array(1.0, dtype=np.float64), ), @@ -583,18 +570,13 @@ def test_aligned_RandomVariable(rv_op, dist_args, size): ) def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_conv): """Tests for Numba samplers that are not one-to-one with PyTensor's/NumPy's samplers.""" + dist_args, test_dist_args = zip(*dist_args, strict=True) rng = shared(np.random.default_rng(29402)) g = rv_op(*dist_args, size=(2000, *base_size), rng=rng) g_fn = function(dist_args, g, mode=numba_mode) - samples = g_fn( - *[ - i.tag.test_value - for i in g_fn.maker.fgraph.inputs - if not isinstance(i, SharedVariable | Constant) - ] - ) + samples = g_fn(*test_dist_args) - bcast_dist_args = np.broadcast_arrays(*[i.tag.test_value for i in dist_args]) + bcast_dist_args = np.broadcast_arrays(*test_dist_args) for idx in np.ndindex(*base_size): cdf_params = params_conv(*(arg[idx] for arg in bcast_dist_args)) @@ -608,7 +590,7 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_ "a, size, cm", [ pytest.param( - set_test_value( + ( pt.dvector(), np.array([100000, 1, 1], dtype=np.float64), ), @@ -616,7 +598,7 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_ contextlib.suppress(), ), pytest.param( - set_test_value( + ( pt.dmatrix(), np.array( [[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]], @@ -627,7 +609,7 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_ contextlib.suppress(), ), pytest.param( - set_test_value( + ( pt.dmatrix(), np.array( [[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]], @@ -643,13 +625,12 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_ ], ) def test_DirichletRV(a, size, cm): + a, a_val = a rng = shared(np.random.default_rng(29402)) g = ptr.dirichlet(a, size=size, rng=rng) g_fn = function([a], g, mode=numba_mode) with cm: - a_val = a.tag.test_value - all_samples = [] for i in range(1000): samples = g_fn(a_val) diff --git a/tests/link/numba/test_scalar.py b/tests/link/numba/test_scalar.py index 655e507da6..504d2a163c 100644 --- a/tests/link/numba/test_scalar.py +++ b/tests/link/numba/test_scalar.py @@ -5,13 +5,10 @@ import pytensor.scalar.basic as psb import pytensor.tensor as pt from pytensor import config -from pytensor.compile.sharedvalue import SharedVariable -from pytensor.graph.basic import Constant -from pytensor.graph.fg import FunctionGraph from pytensor.scalar.basic import Composite from pytensor.tensor import tensor from pytensor.tensor.elemwise import Elemwise -from tests.link.numba.test_basic import compare_numba_and_py, set_test_value +from tests.link.numba.test_basic import compare_numba_and_py rng = np.random.default_rng(42849) @@ -21,48 +18,43 @@ "x, y", [ ( - set_test_value(pt.lvector(), np.arange(4, dtype="int64")), - set_test_value(pt.dvector(), np.arange(4, dtype="float64")), + (pt.lvector(), np.arange(4, dtype="int64")), + (pt.dvector(), np.arange(4, dtype="float64")), ), ( - set_test_value(pt.dmatrix(), np.arange(4, dtype="float64").reshape((2, 2))), - set_test_value(pt.lscalar(), np.array(4, dtype="int64")), + (pt.dmatrix(), np.arange(4, dtype="float64").reshape((2, 2))), + (pt.lscalar(), np.array(4, dtype="int64")), ), ], ) def test_Second(x, y): + x, x_test = x + y, y_test = y # We use the `Elemwise`-wrapped version of `Second` g = pt.second(x, y) - g_fg = FunctionGraph(outputs=[g]) compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [x, y], + g, + [x_test, y_test], ) @pytest.mark.parametrize( "v, min, max", [ - (set_test_value(pt.scalar(), np.array(10, dtype=config.floatX)), 3.0, 7.0), - (set_test_value(pt.scalar(), np.array(1, dtype=config.floatX)), 3.0, 7.0), - (set_test_value(pt.scalar(), np.array(10, dtype=config.floatX)), 7.0, 3.0), + ((pt.scalar(), np.array(10, dtype=config.floatX)), 3.0, 7.0), + ((pt.scalar(), np.array(1, dtype=config.floatX)), 3.0, 7.0), + ((pt.scalar(), np.array(10, dtype=config.floatX)), 7.0, 3.0), ], ) def test_Clip(v, min, max): + v, v_test = v g = ps.clip(v, min, max) - g_fg = FunctionGraph(outputs=[g]) compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [v], + [g], + [v_test], ) @@ -100,46 +92,39 @@ def test_Clip(v, min, max): def test_Composite(inputs, input_values, scalar_fn): composite_inputs = [ps.ScalarType(config.floatX)(name=i.name) for i in inputs] comp_op = Elemwise(Composite(composite_inputs, [scalar_fn(*composite_inputs)])) - out_fg = FunctionGraph(inputs, [comp_op(*inputs)]) - compare_numba_and_py(out_fg, input_values) + compare_numba_and_py(inputs, [comp_op(*inputs)], input_values) @pytest.mark.parametrize( "v, dtype", [ - (set_test_value(pt.fscalar(), np.array(1.0, dtype="float32")), psb.float64), - (set_test_value(pt.dscalar(), np.array(1.0, dtype="float64")), psb.float32), + ((pt.fscalar(), np.array(1.0, dtype="float32")), psb.float64), + ((pt.dscalar(), np.array(1.0, dtype="float64")), psb.float32), ], ) def test_Cast(v, dtype): + v, v_test = v g = psb.Cast(dtype)(v) - g_fg = FunctionGraph(outputs=[g]) compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [v], + [g], + [v_test], ) @pytest.mark.parametrize( "v, dtype", [ - (set_test_value(pt.iscalar(), np.array(10, dtype="int32")), psb.float64), + ((pt.iscalar(), np.array(10, dtype="int32")), psb.float64), ], ) def test_reciprocal(v, dtype): + v, v_test = v g = psb.reciprocal(v) - g_fg = FunctionGraph(outputs=[g]) compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [v], + [g], + [v_test], ) @@ -156,6 +141,7 @@ def test_isnan(composite): out = pt.isnan(x) compare_numba_and_py( - ([x], [out]), + [x], + [out], [np.array([1, 0], dtype="float64")], ) diff --git a/tests/link/numba/test_scan.py b/tests/link/numba/test_scan.py index 5b9436688b..037155880e 100644 --- a/tests/link/numba/test_scan.py +++ b/tests/link/numba/test_scan.py @@ -5,7 +5,6 @@ import pytensor.tensor as pt from pytensor import config, function, grad from pytensor.compile.mode import Mode, get_mode -from pytensor.graph.fg import FunctionGraph from pytensor.scalar import Log1p from pytensor.scan.basic import scan from pytensor.scan.op import Scan @@ -147,7 +146,7 @@ def test_xit_xot_types( if output_vals is None: compare_numba_and_py( - (sequences + non_sequences, res), input_vals, updates=updates + sequences + non_sequences, res, input_vals, updates=updates ) else: numba_mode = get_mode("NUMBA") @@ -217,10 +216,7 @@ def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta): logp_c_all.name = "C_t_logp" logp_d_all.name = "D_t_logp" - out_fg = FunctionGraph( - [pt_C, pt_D, st0, et0, it0, logp_c, logp_d, beta, gamma, delta], - [st, et, it, logp_c_all, logp_d_all], - ) + out = [st, et, it, logp_c_all, logp_d_all] s0, e0, i0 = 100, 50, 25 logp_c0 = np.array(0.0, dtype=config.floatX) @@ -243,21 +239,21 @@ def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta): gamma_val, delta_val, ] - scan_fn, _ = compare_numba_and_py(out_fg, test_input_vals) + scan_fn, _ = compare_numba_and_py( + [pt_C, pt_D, st0, et0, it0, logp_c, logp_d, beta, gamma, delta], + out, + test_input_vals, + ) benchmark(scan_fn, *test_input_vals) -@config.change_flags(compute_test_value="raise") def test_scan_tap_output(): a_pt = pt.scalar("a") - a_pt.tag.test_value = 10.0 - b_pt = pt.arange(11).astype(config.floatX) - b_pt.name = "b" + b_pt = pt.vector("b") - c_pt = pt.arange(20, 31, dtype=config.floatX) - c_pt.name = "c" + c_pt = pt.vector("c") def input_step_fn(b, b2, c, x_tm1, y_tm1, y_tm3, a): x_tm1.name = "x_tm1" @@ -301,14 +297,12 @@ def input_step_fn(b, b2, c, x_tm1, y_tm1, y_tm3, a): strict=True, ) - out_fg = FunctionGraph([a_pt, b_pt, c_pt], scan_res) - test_input_vals = [ np.array(10.0).astype(config.floatX), np.arange(11, dtype=config.floatX), np.arange(20, 31, dtype=config.floatX), ] - compare_numba_and_py(out_fg, test_input_vals) + compare_numba_and_py([a_pt, b_pt, c_pt], scan_res, test_input_vals) def test_scan_while(): @@ -323,12 +317,10 @@ def power_of_2(previous_power, max_value): n_steps=1024, ) - out_fg = FunctionGraph([max_value], [values]) - test_input_vals = [ np.array(45).astype(config.floatX), ] - compare_numba_and_py(out_fg, test_input_vals) + compare_numba_and_py([max_value], [values], test_input_vals) def test_scan_multiple_none_output(): @@ -343,11 +335,8 @@ def power_step(prior_result, x): outputs_info=[pt.ones_like(A), None, None], n_steps=3, ) - - out_fg = FunctionGraph([A], result) test_input_vals = (np.array([1.0, 2.0]),) - - compare_numba_and_py(out_fg, test_input_vals) + compare_numba_and_py([A], result, test_input_vals) @pytest.mark.parametrize("n_steps_val", [1, 5]) @@ -372,11 +361,14 @@ def f_pow2(x_tm2, x_tm1): numba_mode = get_mode("NUMBA").including("scan_save_mem") py_mode = Mode("py").including("scan_save_mem") - out_fg = FunctionGraph([init_x, n_steps], [output]) test_input_vals = (state_val, n_steps_val) compare_numba_and_py( - out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode + [init_x, n_steps], + [output], + test_input_vals, + numba_mode=numba_mode, + py_mode=py_mode, ) @@ -410,14 +402,12 @@ def inner_fct(seq, state_old, state_current): numba_mode = get_mode("NUMBA").including("scan_save_mem") py_mode = Mode("py").including("scan_save_mem") - out_fg = FunctionGraph([seq, init_x], g_outs) - seq_val = np.arange(3) init_x_val = np.r_[-2, -1] test_input_vals = (seq_val, init_x_val) compare_numba_and_py( - out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode + [seq, init_x], g_outs, test_input_vals, numba_mode=numba_mode, py_mode=py_mode ) diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index 8e49627361..67ddc1daff 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -9,14 +9,14 @@ import pytensor import pytensor.tensor as pt -from pytensor.graph import FunctionGraph +from pytensor import config from tests import unittest_tools as utt from tests.link.numba.test_basic import compare_numba_and_py numba = pytest.importorskip("numba") -floatX = pytensor.config.floatX +floatX = config.floatX rng = np.random.default_rng(42849) @@ -79,16 +79,21 @@ def A_func(x): A_val = A_val + np.random.normal(size=(5, 5)) * 1j b_val = b_val + np.random.normal(size=b_shape) * 1j - X_np = f(A_func(A_val.copy()), b_val.copy()) + X_np = f(A_func(A_val), b_val) - test_input = transpose_func(A_func(A_val.copy()), trans) + test_input = transpose_func(A_func(A_val), trans) ATOL = 1e-8 if floatX.endswith("64") else 1e-4 RTOL = 1e-8 if floatX.endswith("64") else 1e-4 np.testing.assert_allclose(test_input @ X_np, b_val, atol=ATOL, rtol=RTOL) - compare_numba_and_py(f.maker.fgraph, [A_func(A_val.copy()), b_val.copy()]) + compiled_fgraph = f.maker.fgraph + compare_numba_and_py( + compiled_fgraph.inputs, + compiled_fgraph.outputs, + [A_func(A_val), b_val], + ) @pytest.mark.parametrize( @@ -159,12 +164,10 @@ def test_numba_Cholesky(lower, trans): cov_ = cov chol = pt.linalg.cholesky(cov_, lower=lower) - fg = FunctionGraph(outputs=[chol]) - x = np.array([0.1, 0.2, 0.3]).astype(floatX) val = np.eye(3).astype(floatX) + x[None, :] * x[:, None] - compare_numba_and_py(fg, [val]) + compare_numba_and_py([cov], [chol], [val]) def test_numba_Cholesky_raises_on_nan_input(): @@ -218,8 +221,7 @@ def test_block_diag(): B_val = np.random.normal(size=(3, 3)).astype(floatX) C_val = np.random.normal(size=(2, 2)).astype(floatX) D_val = np.random.normal(size=(4, 4)).astype(floatX) - out_fg = pytensor.graph.FunctionGraph([A, B, C, D], [X]) - compare_numba_and_py(out_fg, [A_val, B_val, C_val, D_val]) + compare_numba_and_py([A, B, C, D], [X], [A_val, B_val, C_val, D_val]) def test_lamch(): @@ -390,7 +392,7 @@ def A_func(x): ) op = f.maker.fgraph.outputs[0].owner.op - compare_numba_and_py(([A, b], [X]), inputs=[A_val, b_val], inplace=True) + compare_numba_and_py([A, b], [X], test_inputs=[A_val, b_val], inplace=True) # Calling this is destructive and will rewrite b_val to be the answer. Store copies of the inputs first. A_val_copy = A_val.copy() diff --git a/tests/link/numba/test_sparse.py b/tests/link/numba/test_sparse.py index 6a01a5db76..3d91ca13a8 100644 --- a/tests/link/numba/test_sparse.py +++ b/tests/link/numba/test_sparse.py @@ -100,4 +100,4 @@ def test_sparse_objmode(): UserWarning, match="Numba will use object mode to run SparseDot's perform method", ): - compare_numba_and_py(((x, y), (out,)), [x_val, y_val]) + compare_numba_and_py([x, y], out, [x_val, y_val]) diff --git a/tests/link/numba/test_subtensor.py b/tests/link/numba/test_subtensor.py index d63445bf77..d28c94f5b5 100644 --- a/tests/link/numba/test_subtensor.py +++ b/tests/link/numba/test_subtensor.py @@ -4,7 +4,6 @@ import pytest import pytensor.tensor as pt -from pytensor.graph import FunctionGraph from pytensor.tensor import as_tensor from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, @@ -44,8 +43,7 @@ def test_Subtensor(x, indices): """Test NumPy's basic indexing.""" out_pt = x[indices] assert isinstance(out_pt.owner.op, Subtensor) - out_fg = FunctionGraph([], [out_pt]) - compare_numba_and_py(out_fg, []) + compare_numba_and_py([], [out_pt], []) @pytest.mark.parametrize( @@ -59,16 +57,14 @@ def test_AdvancedSubtensor1(x, indices): """Test NumPy's advanced indexing in one dimension.""" out_pt = advanced_subtensor1(x, *indices) assert isinstance(out_pt.owner.op, AdvancedSubtensor1) - out_fg = FunctionGraph([], [out_pt]) - compare_numba_and_py(out_fg, []) + compare_numba_and_py([], [out_pt], []) def test_AdvancedSubtensor1_out_of_bounds(): out_pt = advanced_subtensor1(np.arange(3), [4]) assert isinstance(out_pt.owner.op, AdvancedSubtensor1) - out_fg = FunctionGraph([], [out_pt]) with pytest.raises(IndexError): - compare_numba_and_py(out_fg, []) + compare_numba_and_py([], [out_pt], []) @pytest.mark.parametrize( @@ -151,7 +147,6 @@ def test_AdvancedSubtensor(x, indices, objmode_needed): x_pt = x.type() out_pt = x_pt[indices] assert isinstance(out_pt.owner.op, AdvancedSubtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) with ( pytest.warns( UserWarning, @@ -161,7 +156,8 @@ def test_AdvancedSubtensor(x, indices, objmode_needed): else contextlib.nullcontext() ): compare_numba_and_py( - out_fg, + [x_pt], + [out_pt], [x.data], numba_mode=numba_mode.including("specialize"), ) @@ -195,19 +191,16 @@ def test_AdvancedSubtensor(x, indices, objmode_needed): def test_IncSubtensor(x, y, indices): out_pt = set_subtensor(x[indices], y) assert isinstance(out_pt.owner.op, IncSubtensor) - out_fg = FunctionGraph([], [out_pt]) - compare_numba_and_py(out_fg, []) + compare_numba_and_py([], [out_pt], []) out_pt = inc_subtensor(x[indices], y) assert isinstance(out_pt.owner.op, IncSubtensor) - out_fg = FunctionGraph([], [out_pt]) - compare_numba_and_py(out_fg, []) + compare_numba_and_py([], [out_pt], []) x_pt = x.type() out_pt = set_subtensor(x_pt[indices], y, inplace=True) assert isinstance(out_pt.owner.op, IncSubtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_numba_and_py(out_fg, [x.data]) + compare_numba_and_py([x_pt], [out_pt], [x.data]) @pytest.mark.parametrize( @@ -249,13 +242,11 @@ def test_IncSubtensor(x, y, indices): def test_AdvancedIncSubtensor1(x, y, indices): out_pt = advanced_set_subtensor1(x, y, *indices) assert isinstance(out_pt.owner.op, AdvancedIncSubtensor1) - out_fg = FunctionGraph([], [out_pt]) - compare_numba_and_py(out_fg, []) + compare_numba_and_py([], [out_pt], []) out_pt = advanced_inc_subtensor1(x, y, *indices) assert isinstance(out_pt.owner.op, AdvancedIncSubtensor1) - out_fg = FunctionGraph([], [out_pt]) - compare_numba_and_py(out_fg, []) + compare_numba_and_py([], [out_pt], []) # With symbolic inputs x_pt = x.type() @@ -263,15 +254,13 @@ def test_AdvancedIncSubtensor1(x, y, indices): out_pt = AdvancedIncSubtensor1(inplace=True)(x_pt, y_pt, *indices) assert isinstance(out_pt.owner.op, AdvancedIncSubtensor1) - out_fg = FunctionGraph([x_pt, y_pt], [out_pt]) - compare_numba_and_py(out_fg, [x.data, y.data]) + compare_numba_and_py([x_pt, y_pt], [out_pt], [x.data, y.data]) out_pt = AdvancedIncSubtensor1(set_instead_of_inc=True, inplace=True)( x_pt, y_pt, *indices ) assert isinstance(out_pt.owner.op, AdvancedIncSubtensor1) - out_fg = FunctionGraph([x_pt, y_pt], [out_pt]) - compare_numba_and_py(out_fg, [x.data, y.data]) + compare_numba_and_py([x_pt, y_pt], [out_pt], [x.data, y.data]) @pytest.mark.parametrize( @@ -454,7 +443,7 @@ def test_AdvancedIncSubtensor( if set_requires_objmode else contextlib.nullcontext() ): - fn, _ = compare_numba_and_py(([x_pt, y_pt], [out_pt]), [x, y], numba_mode=mode) + fn, _ = compare_numba_and_py([x_pt, y_pt], out_pt, [x, y], numba_mode=mode) if inplace: # Test updates inplace @@ -474,7 +463,7 @@ def test_AdvancedIncSubtensor( if inc_requires_objmode else contextlib.nullcontext() ): - fn, _ = compare_numba_and_py(([x_pt, y_pt], [out_pt]), [x, y], numba_mode=mode) + fn, _ = compare_numba_and_py([x_pt, y_pt], out_pt, [x, y], numba_mode=mode) if inplace: # Test updates inplace x_orig = x.copy() diff --git a/tests/link/numba/test_tensor_basic.py b/tests/link/numba/test_tensor_basic.py index 95ab5799c1..0eebe115e9 100644 --- a/tests/link/numba/test_tensor_basic.py +++ b/tests/link/numba/test_tensor_basic.py @@ -6,15 +6,11 @@ import pytensor.tensor.basic as ptb from pytensor import config, function from pytensor.compile import get_mode -from pytensor.compile.sharedvalue import SharedVariable -from pytensor.graph.basic import Constant -from pytensor.graph.fg import FunctionGraph from pytensor.scalar import Add from pytensor.tensor.shape import Unbroadcast from tests.link.numba.test_basic import ( compare_numba_and_py, compare_shape_dtype, - set_test_value, ) from tests.tensor.test_basic import check_alloc_runtime_broadcast @@ -31,21 +27,18 @@ [ (0.0, (2, 3)), (1.1, (2, 3)), - (set_test_value(pt.scalar("a"), np.array(10.0, dtype=config.floatX)), (20,)), - (set_test_value(pt.vector("a"), np.ones(10, dtype=config.floatX)), (20, 10)), + ((pt.scalar("a"), np.array(10.0, dtype=config.floatX)), (20,)), + ((pt.vector("a"), np.ones(10, dtype=config.floatX)), (20, 10)), ], ) def test_Alloc(v, shape): + v, v_test = v if isinstance(v, tuple) else (v, None) g = pt.alloc(v, *shape) - g_fg = FunctionGraph(outputs=[g]) _, (numba_res,) = compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [v] if v_test is not None else [], + [g], + [v_test] if v_test is not None else [], ) assert numba_res.shape == shape @@ -57,58 +50,38 @@ def test_alloc_runtime_broadcast(): def test_AllocEmpty(): x = pt.empty((2, 3), dtype="float32") - x_fg = FunctionGraph([], [x]) # We cannot compare the values in the arrays, only the shapes and dtypes - compare_numba_and_py(x_fg, [], assert_fn=compare_shape_dtype) + compare_numba_and_py([], x, [], assert_fn=compare_shape_dtype) -@pytest.mark.parametrize( - "v", [set_test_value(ps.float64(), np.array(1.0, dtype="float64"))] -) -def test_TensorFromScalar(v): +def test_TensorFromScalar(): + v, v_test = ps.float64(), np.array(1.0, dtype="float64") g = ptb.TensorFromScalar()(v) - g_fg = FunctionGraph(outputs=[g]) compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [v], + g, + [v_test], ) -@pytest.mark.parametrize( - "v", - [ - set_test_value(pt.scalar(), np.array(1.0, dtype=config.floatX)), - ], -) -def test_ScalarFromTensor(v): +def test_ScalarFromTensor(): + v, v_test = pt.scalar(), np.array(1.0, dtype=config.floatX) g = ptb.ScalarFromTensor()(v) - g_fg = FunctionGraph(outputs=[g]) compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [v], + g, + [v_test], ) def test_Unbroadcast(): - v = set_test_value(pt.row(), np.array([[1.0, 2.0]], dtype=config.floatX)) + v, v_test = pt.row(), np.array([[1.0, 2.0]], dtype=config.floatX) g = Unbroadcast(0)(v) - g_fg = FunctionGraph(outputs=[g]) compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [v], + g, + [v_test], ) @@ -117,65 +90,52 @@ def test_Unbroadcast(): [ ( ( - set_test_value(pt.scalar(), np.array(1, dtype=config.floatX)), - set_test_value(pt.scalar(), np.array(2, dtype=config.floatX)), - set_test_value(pt.scalar(), np.array(3, dtype=config.floatX)), + (pt.scalar(), np.array(1, dtype=config.floatX)), + (pt.scalar(), np.array(2, dtype=config.floatX)), + (pt.scalar(), np.array(3, dtype=config.floatX)), ), config.floatX, ), ( ( - set_test_value(pt.dscalar(), np.array(1, dtype=np.float64)), - set_test_value(pt.lscalar(), np.array(3, dtype=np.int32)), + (pt.dscalar(), np.array(1, dtype=np.float64)), + (pt.lscalar(), np.array(3, dtype=np.int32)), ), "float64", ), ( - (set_test_value(pt.iscalar(), np.array(1, dtype=np.int32)),), + ((pt.iscalar(), np.array(1, dtype=np.int32)),), "float64", ), ( - (set_test_value(pt.scalar(dtype=bool), True),), + ((pt.scalar(dtype=bool), True),), bool, ), ], ) def test_MakeVector(vals, dtype): + vals, vals_test = zip(*vals, strict=True) g = ptb.MakeVector(dtype)(*vals) - g_fg = FunctionGraph(outputs=[g]) compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + vals, + [g], + vals_test, ) -@pytest.mark.parametrize( - "start, stop, step, dtype", - [ - ( - set_test_value(pt.lscalar(), np.array(1)), - set_test_value(pt.lscalar(), np.array(10)), - set_test_value(pt.lscalar(), np.array(3)), - config.floatX, - ), - ], -) -def test_ARange(start, stop, step, dtype): +def test_ARange(): + start, start_test = pt.lscalar(), np.array(1) + stop, stop_tset = pt.lscalar(), np.array(10) + step, step_test = pt.lscalar(), np.array(3) + dtype = config.floatX + g = ptb.ARange(dtype)(start, stop, step) - g_fg = FunctionGraph(outputs=[g]) compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [start, stop, step], + g, + [start_test, stop_tset, step_test], ) @@ -184,80 +144,60 @@ def test_ARange(start, stop, step, dtype): [ ( ( - set_test_value( - pt.matrix(), rng.normal(size=(1, 2)).astype(config.floatX) - ), - set_test_value( - pt.matrix(), rng.normal(size=(1, 2)).astype(config.floatX) - ), + (pt.matrix(), rng.normal(size=(1, 2)).astype(config.floatX)), + (pt.matrix(), rng.normal(size=(1, 2)).astype(config.floatX)), ), 0, ), ( ( - set_test_value( - pt.matrix(), rng.normal(size=(2, 1)).astype(config.floatX) - ), - set_test_value( - pt.matrix(), rng.normal(size=(3, 1)).astype(config.floatX) - ), + (pt.matrix(), rng.normal(size=(2, 1)).astype(config.floatX)), + (pt.matrix(), rng.normal(size=(3, 1)).astype(config.floatX)), ), 0, ), ( ( - set_test_value( - pt.matrix(), rng.normal(size=(1, 2)).astype(config.floatX) - ), - set_test_value( - pt.matrix(), rng.normal(size=(1, 2)).astype(config.floatX) - ), + (pt.matrix(), rng.normal(size=(1, 2)).astype(config.floatX)), + (pt.matrix(), rng.normal(size=(1, 2)).astype(config.floatX)), ), 1, ), ( ( - set_test_value( - pt.matrix(), rng.normal(size=(2, 2)).astype(config.floatX) - ), - set_test_value( - pt.matrix(), rng.normal(size=(2, 1)).astype(config.floatX) - ), + (pt.matrix(), rng.normal(size=(2, 2)).astype(config.floatX)), + (pt.matrix(), rng.normal(size=(2, 1)).astype(config.floatX)), ), 1, ), ], ) def test_Join(vals, axis): + vals, vals_test = zip(*vals, strict=True) g = pt.join(axis, *vals) - g_fg = FunctionGraph(outputs=[g]) compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + vals, + g, + vals_test, ) def test_Join_view(): - vals = ( - set_test_value(pt.matrix(), rng.normal(size=(2, 2)).astype(config.floatX)), - set_test_value(pt.matrix(), rng.normal(size=(2, 2)).astype(config.floatX)), + vals, vals_test = zip( + *( + (pt.matrix(), rng.normal(size=(2, 2)).astype(config.floatX)), + (pt.matrix(), rng.normal(size=(2, 2)).astype(config.floatX)), + ), + strict=True, ) g = ptb.Join(view=1)(1, *vals) - g_fg = FunctionGraph(outputs=[g]) with pytest.raises(NotImplementedError): compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + vals, + g, + vals_test, ) @@ -267,57 +207,47 @@ def test_Join_view(): ( 0, 0, - set_test_value(pt.vector(), rng.normal(size=20).astype(config.floatX)), - set_test_value(pt.vector(dtype="int64"), []), + (pt.vector(), rng.normal(size=20).astype(config.floatX)), + (pt.vector(dtype="int64"), []), ), ( 5, 0, - set_test_value(pt.vector(), rng.normal(size=5).astype(config.floatX)), - set_test_value( - pt.vector(dtype="int64"), rng.multinomial(5, np.ones(5) / 5) - ), + (pt.vector(), rng.normal(size=5).astype(config.floatX)), + (pt.vector(dtype="int64"), rng.multinomial(5, np.ones(5) / 5)), ), ( 5, 0, - set_test_value(pt.vector(), rng.normal(size=10).astype(config.floatX)), - set_test_value( - pt.vector(dtype="int64"), rng.multinomial(10, np.ones(5) / 5) - ), + (pt.vector(), rng.normal(size=10).astype(config.floatX)), + (pt.vector(dtype="int64"), rng.multinomial(10, np.ones(5) / 5)), ), ( 5, -1, - set_test_value(pt.matrix(), rng.normal(size=(11, 7)).astype(config.floatX)), - set_test_value( - pt.vector(dtype="int64"), rng.multinomial(7, np.ones(5) / 5) - ), + (pt.matrix(), rng.normal(size=(11, 7)).astype(config.floatX)), + (pt.vector(dtype="int64"), rng.multinomial(7, np.ones(5) / 5)), ), ( 5, -2, - set_test_value(pt.matrix(), rng.normal(size=(11, 7)).astype(config.floatX)), - set_test_value( - pt.vector(dtype="int64"), rng.multinomial(11, np.ones(5) / 5) - ), + (pt.matrix(), rng.normal(size=(11, 7)).astype(config.floatX)), + (pt.vector(dtype="int64"), rng.multinomial(11, np.ones(5) / 5)), ), ], ) def test_Split(n_splits, axis, values, sizes): + values, values_test = values + sizes, sizes_test = sizes g = pt.split(values, sizes, n_splits, axis=axis) assert len(g) == n_splits if n_splits == 0: return - g_fg = FunctionGraph(outputs=[g] if n_splits == 1 else g) compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [values, sizes], + g, + [values_test, sizes_test], ) @@ -349,34 +279,27 @@ def test_Split_view(): "val, offset", [ ( - set_test_value( - pt.matrix(), np.arange(10 * 10, dtype=config.floatX).reshape((10, 10)) - ), + (pt.matrix(), np.arange(10 * 10, dtype=config.floatX).reshape((10, 10))), 0, ), ( - set_test_value( - pt.matrix(), np.arange(10 * 10, dtype=config.floatX).reshape((10, 10)) - ), + (pt.matrix(), np.arange(10 * 10, dtype=config.floatX).reshape((10, 10))), -1, ), ( - set_test_value(pt.vector(), np.arange(10, dtype=config.floatX)), + (pt.vector(), np.arange(10, dtype=config.floatX)), 0, ), ], ) def test_ExtractDiag(val, offset): + val, val_test = val g = pt.diag(val, offset) - g_fg = FunctionGraph(outputs=[g]) compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [val], + g, + [val_test], ) @@ -407,30 +330,28 @@ def wrap(x): @pytest.mark.parametrize( "n, m, k, dtype", [ - (set_test_value(pt.lscalar(), np.array(1, dtype=np.int64)), None, 0, None), + ((pt.lscalar(), np.array(1, dtype=np.int64)), None, 0, None), ( - set_test_value(pt.lscalar(), np.array(1, dtype=np.int64)), - set_test_value(pt.lscalar(), np.array(2, dtype=np.int64)), + (pt.lscalar(), np.array(1, dtype=np.int64)), + (pt.lscalar(), np.array(2, dtype=np.int64)), 0, "float32", ), ( - set_test_value(pt.lscalar(), np.array(1, dtype=np.int64)), - set_test_value(pt.lscalar(), np.array(2, dtype=np.int64)), + (pt.lscalar(), np.array(1, dtype=np.int64)), + (pt.lscalar(), np.array(2, dtype=np.int64)), 1, "int64", ), ], ) def test_Eye(n, m, k, dtype): + n, n_test = n + m, m_test = m if m is not None else (None, None) g = pt.eye(n, m, k, dtype=dtype) - g_fg = FunctionGraph(outputs=[g]) compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], + [n, m] if m is not None else [n], + g, + [n_test, m_test] if m is not None else [n_test], ) diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index d5c23c83e4..f080fe70df 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -9,10 +9,10 @@ from pytensor.compile.builders import OpFromGraph from pytensor.compile.function import function from pytensor.compile.mode import PYTORCH, Mode -from pytensor.compile.sharedvalue import SharedVariable, shared +from pytensor.compile.sharedvalue import shared from pytensor.configdefaults import config from pytensor.graph import RewriteDatabaseQuery -from pytensor.graph.basic import Apply +from pytensor.graph.basic import Apply, Variable from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op from pytensor.ifelse import ifelse @@ -39,10 +39,10 @@ def compare_pytorch_and_py( - fgraph: FunctionGraph, + graph_inputs: Iterable[Variable], + graph_outputs: Variable | Iterable[Variable], test_inputs: Iterable, assert_fn: Callable | None = None, - must_be_device_array: bool = True, pytorch_mode=pytorch_mode, py_mode=py_mode, ): @@ -50,8 +50,10 @@ def compare_pytorch_and_py( Parameters ---------- - fgraph: FunctionGraph - PyTensor function Graph object + graph_inputs + Symbolic inputs to the graph + graph_outputs: + Symbolic outputs of the graph test_inputs: iter Numerical inputs for testing the function graph assert_fn: func, opt @@ -63,24 +65,22 @@ def compare_pytorch_and_py( if assert_fn is None: assert_fn = partial(np.testing.assert_allclose) - fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)] + if any(inp.owner is not None for inp in graph_inputs): + raise ValueError("Inputs must be root variables") - pytensor_torch_fn = function(fn_inputs, fgraph.outputs, mode=pytorch_mode) + pytensor_torch_fn = function(graph_inputs, graph_outputs, mode=pytorch_mode) pytorch_res = pytensor_torch_fn(*test_inputs) - if isinstance(pytorch_res, list): - assert all(isinstance(res, np.ndarray) for res in pytorch_res) - else: - assert isinstance(pytorch_res, np.ndarray) - - pytensor_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode) + pytensor_py_fn = function(graph_inputs, graph_outputs, mode=py_mode) py_res = pytensor_py_fn(*test_inputs) - if len(fgraph.outputs) > 1: + if isinstance(graph_outputs, list | tuple): for pytorch_res_i, py_res_i in zip(pytorch_res, py_res, strict=True): + assert not isinstance(pytorch_res_i, torch.Tensor) assert_fn(pytorch_res_i, py_res_i) else: - assert_fn(pytorch_res[0], py_res[0]) + assert not isinstance(pytorch_res, torch.Tensor) + assert_fn(pytorch_res, py_res) return pytensor_torch_fn, pytorch_res @@ -231,7 +231,8 @@ def test_alloc_and_empty(): v = vector("v", shape=(3,), dtype="float64") out = alloc(v, dim0, dim1, 3) compare_pytorch_and_py( - FunctionGraph([v, dim1], [out]), + [v, dim1], + [out], [np.array([1, 2, 3]), np.array(7)], ) @@ -244,7 +245,8 @@ def test_arange(): out = arange(start, stop, step, dtype="int16") compare_pytorch_and_py( - FunctionGraph([start, stop, step], [out]), + [start, stop, step], + [out], [np.array(1), np.array(10), np.array(2)], ) @@ -254,16 +256,18 @@ def test_pytorch_Join(): b = matrix("b") x = ptb.join(0, a, b) - x_fg = FunctionGraph([a, b], [x]) + compare_pytorch_and_py( - x_fg, + [a, b], + [x], [ np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), np.c_[[4.0, 5.0, 6.0]].astype(config.floatX), ], ) compare_pytorch_and_py( - x_fg, + [a, b], + [x], [ np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), np.c_[[4.0, 5.0]].astype(config.floatX), @@ -271,16 +275,18 @@ def test_pytorch_Join(): ) x = ptb.join(1, a, b) - x_fg = FunctionGraph([a, b], [x]) + compare_pytorch_and_py( - x_fg, + [a, b], + [x], [ np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), np.c_[[4.0, 5.0, 6.0]].astype(config.floatX), ], ) compare_pytorch_and_py( - x_fg, + [a, b], + [x], [ np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX), np.c_[[5.0, 6.0]].astype(config.floatX), @@ -309,9 +315,8 @@ def test_eye(dtype): def test_pytorch_MakeVector(): x = ptb.make_vector(1, 2, 3) - x_fg = FunctionGraph([], [x]) - compare_pytorch_and_py(x_fg, []) + compare_pytorch_and_py([], [x], []) def test_pytorch_ifelse(): @@ -320,15 +325,13 @@ def test_pytorch_ifelse(): a = scalar("a") x = ifelse(a < 0.5, tuple(np.r_[p1_vals, p2_vals]), tuple(np.r_[p2_vals, p1_vals])) - x_fg = FunctionGraph([a], x) - compare_pytorch_and_py(x_fg, np.array([0.2], dtype=config.floatX)) + compare_pytorch_and_py([a], x, np.array([0.2], dtype=config.floatX)) a = scalar("a") x = ifelse(a < 0.4, tuple(np.r_[p1_vals, p2_vals]), tuple(np.r_[p2_vals, p1_vals])) - x_fg = FunctionGraph([a], x) - compare_pytorch_and_py(x_fg, np.array([0.5], dtype=config.floatX)) + compare_pytorch_and_py([a], x, np.array([0.5], dtype=config.floatX)) def test_pytorch_OpFromGraph(): @@ -343,8 +346,7 @@ def test_pytorch_OpFromGraph(): yv = np.ones((2, 2), dtype=config.floatX) * 3 zv = np.ones((2, 2), dtype=config.floatX) * 5 - f = FunctionGraph([x, y, z], [out]) - compare_pytorch_and_py(f, [xv, yv, zv]) + compare_pytorch_and_py([x, y, z], [out], [xv, yv, zv]) def test_pytorch_link_references(): @@ -380,15 +382,13 @@ def inner_fn(x): def test_pytorch_scipy(): x = vector("a", shape=(3,)) out = expit(x) - f = FunctionGraph([x], [out]) - compare_pytorch_and_py(f, [np.random.rand(3)]) + compare_pytorch_and_py([x], [out], [np.random.rand(3)]) def test_pytorch_softplus(): x = vector("a", shape=(3,)) out = softplus(x) - f = FunctionGraph([x], [out]) - compare_pytorch_and_py(f, [np.random.rand(3)]) + compare_pytorch_and_py([x], [out], [np.random.rand(3)]) def test_ScalarLoop(): @@ -436,13 +436,15 @@ def test_ScalarLoop_Elemwise_single_carries(): x0 = pt.vector("x0", dtype="float32") state, done = op(n_steps, x0) - f = FunctionGraph([n_steps, x0], [state, done]) args = [ np.array(10).astype("int32"), np.arange(0, 5).astype("float32"), ] compare_pytorch_and_py( - f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6) + [n_steps, x0], + [state, done], + args, + assert_fn=partial(np.testing.assert_allclose, rtol=1e-6), ) @@ -462,14 +464,16 @@ def test_ScalarLoop_Elemwise_multi_carries(): x1 = pt.tensor("c0", dtype="float32", shape=(7, 3, 1)) *states, done = op(n_steps, x0, x1) - f = FunctionGraph([n_steps, x0, x1], [*states, done]) args = [ np.array(10).astype("int32"), np.arange(0, 5).astype("float32"), np.random.rand(7, 3, 1).astype("float32"), ] compare_pytorch_and_py( - f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6) + [n_steps, x0, x1], + [*states, done], + args, + assert_fn=partial(np.testing.assert_allclose, rtol=1e-6), ) @@ -518,6 +522,5 @@ def test_Split(n_splits, axis, values, sizes): assert len(g) == n_splits if n_splits == 0: return - g_fg = FunctionGraph(inputs=[i, s], outputs=[g] if n_splits == 1 else g) - compare_pytorch_and_py(g_fg, [values, sizes]) + compare_pytorch_and_py([i, s], g, [values, sizes]) diff --git a/tests/link/pytorch/test_blas.py b/tests/link/pytorch/test_blas.py index 35f7dd7b6a..4b9fc4d55f 100644 --- a/tests/link/pytorch/test_blas.py +++ b/tests/link/pytorch/test_blas.py @@ -2,7 +2,6 @@ import pytest from pytensor.configdefaults import config -from pytensor.graph.fg import FunctionGraph from pytensor.tensor import blas as pt_blas from pytensor.tensor.type import tensor3 from tests.link.pytorch.test_basic import compare_pytorch_and_py @@ -15,8 +14,8 @@ def test_pytorch_BatchedDot(): b = tensor3("b") b_test = np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2)) out = pt_blas.BatchedDot()(a, b) - fgraph = FunctionGraph([a, b], [out]) - pytensor_pytorch_fn, _ = compare_pytorch_and_py(fgraph, [a_test, b_test]) + + pytensor_pytorch_fn, _ = compare_pytorch_and_py([a, b], [out], [a_test, b_test]) # A dimension mismatch should raise a TypeError for compatibility inputs = [a_test[:-1], b_test] diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py index 2a9cf39c99..152b235074 100644 --- a/tests/link/pytorch/test_elemwise.py +++ b/tests/link/pytorch/test_elemwise.py @@ -5,7 +5,6 @@ import pytensor.tensor as pt import pytensor.tensor.math as ptm from pytensor.configdefaults import config -from pytensor.graph.fg import FunctionGraph from pytensor.scalar.basic import ScalarOp, get_scalar_type from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax @@ -20,17 +19,23 @@ def test_pytorch_Dimshuffle(): a_pt = matrix("a") x = a_pt.T - x_fg = FunctionGraph([a_pt], [x]) - compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]) + + compare_pytorch_and_py( + [a_pt], [x], [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)] + ) x = a_pt.dimshuffle([0, 1, "x"]) - x_fg = FunctionGraph([a_pt], [x]) - compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]) + + compare_pytorch_and_py( + [a_pt], [x], [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)] + ) a_pt = tensor(dtype=config.floatX, shape=(None, 1)) x = a_pt.dimshuffle((0,)) - x_fg = FunctionGraph([a_pt], [x]) - compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]) + + compare_pytorch_and_py( + [a_pt], [x], [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)] + ) def test_multiple_input_output(): @@ -38,24 +43,21 @@ def test_multiple_input_output(): y = vector("y") out = pt.mul(x, y) - fg = FunctionGraph(outputs=[out], clone=False) - compare_pytorch_and_py(fg, [[1.5], [2.5]]) + compare_pytorch_and_py([x, y], [out], [[1.5], [2.5]]) x = vector("x") y = vector("y") div = pt.int_div(x, y) pt_sum = pt.add(y, x) - fg = FunctionGraph(outputs=[div, pt_sum], clone=False) - compare_pytorch_and_py(fg, [[1.5], [2.5]]) + compare_pytorch_and_py([x, y], [div, pt_sum], [[1.5], [2.5]]) def test_pytorch_elemwise(): x = pt.vector("x") out = pt.log(1 - x) - fg = FunctionGraph([x], [out]) - compare_pytorch_and_py(fg, [[0.9, 0.9]]) + compare_pytorch_and_py([x], [out], [[0.9, 0.9]]) @pytest.mark.parametrize("fn", [ptm.sum, ptm.prod, ptm.max, ptm.min]) @@ -81,9 +83,8 @@ def test_pytorch_careduce(fn, axis): ).astype(config.floatX) x = fn(a_pt, axis=axis) - x_fg = FunctionGraph([a_pt], [x]) - compare_pytorch_and_py(x_fg, [test_value]) + compare_pytorch_and_py([a_pt], [x], [test_value]) @pytest.mark.parametrize("fn", [ptm.any, ptm.all]) @@ -93,9 +94,8 @@ def test_pytorch_any_all(fn, axis): test_value = np.array([[True, False, True], [False, True, True]]) x = fn(a_pt, axis=axis) - x_fg = FunctionGraph([a_pt], [x]) - compare_pytorch_and_py(x_fg, [test_value]) + compare_pytorch_and_py([a_pt], [x], [test_value]) @pytest.mark.parametrize("dtype", ["float64", "int64"]) @@ -103,7 +103,6 @@ def test_pytorch_any_all(fn, axis): def test_softmax(axis, dtype): x = matrix("x", dtype=dtype) out = softmax(x, axis=axis) - fgraph = FunctionGraph([x], [out]) test_input = np.arange(6, dtype=config.floatX).reshape(2, 3) if dtype == "int64": @@ -111,9 +110,9 @@ def test_softmax(axis, dtype): NotImplementedError, match="Pytorch Softmax is not currently implemented for non-float types.", ): - compare_pytorch_and_py(fgraph, [test_input]) + compare_pytorch_and_py([x], [out], [test_input]) else: - compare_pytorch_and_py(fgraph, [test_input]) + compare_pytorch_and_py([x], [out], [test_input]) @pytest.mark.parametrize("dtype", ["float64", "int64"]) @@ -121,7 +120,6 @@ def test_softmax(axis, dtype): def test_logsoftmax(axis, dtype): x = matrix("x", dtype=dtype) out = log_softmax(x, axis=axis) - fgraph = FunctionGraph([x], [out]) test_input = np.arange(6, dtype=config.floatX).reshape(2, 3) if dtype == "int64": @@ -129,9 +127,9 @@ def test_logsoftmax(axis, dtype): NotImplementedError, match="Pytorch LogSoftmax is not currently implemented for non-float types.", ): - compare_pytorch_and_py(fgraph, [test_input]) + compare_pytorch_and_py([x], [out], [test_input]) else: - compare_pytorch_and_py(fgraph, [test_input]) + compare_pytorch_and_py([x], [out], [test_input]) @pytest.mark.parametrize("axis", [None, 0, 1]) @@ -141,16 +139,14 @@ def test_softmax_grad(axis): sm = matrix("sm") sm_value = np.arange(6, dtype=config.floatX).reshape(2, 3) out = SoftmaxGrad(axis=axis)(dy, sm) - fgraph = FunctionGraph([dy, sm], [out]) - compare_pytorch_and_py(fgraph, [dy_value, sm_value]) + compare_pytorch_and_py([dy, sm], [out], [dy_value, sm_value]) def test_cast(): x = matrix("x", dtype="float32") out = pt.cast(x, "int32") - fgraph = FunctionGraph([x], [out]) _, [res] = compare_pytorch_and_py( - fgraph, [np.arange(6, dtype="float32").reshape(2, 3)] + [x], [out], [np.arange(6, dtype="float32").reshape(2, 3)] ) assert res.dtype == np.int32 diff --git a/tests/link/pytorch/test_extra_ops.py b/tests/link/pytorch/test_extra_ops.py index c615176a45..2f72f7a908 100644 --- a/tests/link/pytorch/test_extra_ops.py +++ b/tests/link/pytorch/test_extra_ops.py @@ -2,7 +2,6 @@ import pytest import pytensor.tensor as pt -from pytensor.graph import FunctionGraph from tests.link.pytorch.test_basic import compare_pytorch_and_py @@ -31,16 +30,14 @@ def test_pytorch_CumOp(axis, dtype): out = pt.cumprod(a, axis=axis) else: out = pt.cumsum(a, axis=axis) - # Create a PyTensor `FunctionGraph` - fgraph = FunctionGraph([a], [out]) - # Pass the graph and inputs to the testing function - compare_pytorch_and_py(fgraph, [test_value]) + # Pass the inputs and outputs to the testing function + compare_pytorch_and_py([a], [out], [test_value]) # For the second mode of CumOp out = pt.cumprod(a, axis=axis) - fgraph = FunctionGraph([a], [out]) - compare_pytorch_and_py(fgraph, [test_value]) + + compare_pytorch_and_py([a], [out], [test_value]) @pytest.mark.parametrize("axis, repeats", [(0, (1, 2, 3)), (1, (3, 3)), (None, 3)]) @@ -50,8 +47,8 @@ def test_pytorch_Repeat(axis, repeats): test_value = np.arange(6, dtype="float64").reshape((3, 2)) out = pt.repeat(a, repeats, axis=axis) - fgraph = FunctionGraph([a], [out]) - compare_pytorch_and_py(fgraph, [test_value]) + + compare_pytorch_and_py([a], [out], [test_value]) @pytest.mark.parametrize("axis", [None, 0, 1]) @@ -63,8 +60,8 @@ def test_pytorch_Unique_axis(axis): ) out = pt.unique(a, axis=axis) - fgraph = FunctionGraph([a], [out]) - compare_pytorch_and_py(fgraph, [test_value]) + + compare_pytorch_and_py([a], [out], [test_value]) @pytest.mark.parametrize("return_inverse", [False, True]) @@ -86,5 +83,7 @@ def test_pytorch_Unique_params(return_index, return_inverse, return_counts): return_counts=return_counts, axis=0, ) - fgraph = FunctionGraph([a], [out[0] if isinstance(out, list) else out]) - compare_pytorch_and_py(fgraph, [test_value]) + + compare_pytorch_and_py( + [a], [out[0] if isinstance(out, list) else out], [test_value] + ) diff --git a/tests/link/pytorch/test_math.py b/tests/link/pytorch/test_math.py index affca4ad32..9d9f9318a8 100644 --- a/tests/link/pytorch/test_math.py +++ b/tests/link/pytorch/test_math.py @@ -1,7 +1,6 @@ import numpy as np from pytensor.configdefaults import config -from pytensor.graph.fg import FunctionGraph from pytensor.tensor.type import matrix, scalar, vector from tests.link.pytorch.test_basic import compare_pytorch_and_py @@ -20,10 +19,12 @@ def test_pytorch_dot(): # 2D * 2D out = A.dot(A * alpha) + beta * A - fgraph = FunctionGraph([A, alpha, beta], [out]) - compare_pytorch_and_py(fgraph, [A_test, alpha_test, beta_test]) + + compare_pytorch_and_py([A, alpha, beta], [out], [A_test, alpha_test, beta_test]) # 1D * 2D and 1D * 1D out = y.dot(alpha * A).dot(x) + beta * y - fgraph = FunctionGraph([y, x, A, alpha, beta], [out]) - compare_pytorch_and_py(fgraph, [y_test, x_test, A_test, alpha_test, beta_test]) + + compare_pytorch_and_py( + [y, x, A, alpha, beta], [out], [y_test, x_test, A_test, alpha_test, beta_test] + ) diff --git a/tests/link/pytorch/test_nlinalg.py b/tests/link/pytorch/test_nlinalg.py index 55e7c447e3..7e061f7cfc 100644 --- a/tests/link/pytorch/test_nlinalg.py +++ b/tests/link/pytorch/test_nlinalg.py @@ -1,11 +1,8 @@ -from collections.abc import Sequence - import numpy as np import pytest from pytensor.compile.function import function from pytensor.configdefaults import config -from pytensor.graph.fg import FunctionGraph from pytensor.tensor import nlinalg as pt_nla from pytensor.tensor.type import matrix from tests.link.pytorch.test_basic import compare_pytorch_and_py @@ -29,13 +26,12 @@ def matrix_test(): def test_lin_alg_no_params(func, matrix_test): x, test_value = matrix_test - out = func(x) - out_fg = FunctionGraph([x], out if isinstance(out, Sequence) else [out]) + outs = func(x) def assert_fn(x, y): np.testing.assert_allclose(x, y, rtol=1e-3) - compare_pytorch_and_py(out_fg, [test_value], assert_fn=assert_fn) + compare_pytorch_and_py([x], outs, [test_value], assert_fn=assert_fn) @pytest.mark.parametrize( @@ -50,8 +46,8 @@ def assert_fn(x, y): def test_qr(mode, matrix_test): x, test_value = matrix_test outs = pt_nla.qr(x, mode=mode) - out_fg = FunctionGraph([x], outs if isinstance(outs, list) else [outs]) - compare_pytorch_and_py(out_fg, [test_value]) + + compare_pytorch_and_py([x], outs, [test_value]) @pytest.mark.parametrize("compute_uv", [True, False]) @@ -60,18 +56,16 @@ def test_svd(compute_uv, full_matrices, matrix_test): x, test_value = matrix_test out = pt_nla.svd(x, full_matrices=full_matrices, compute_uv=compute_uv) - out_fg = FunctionGraph([x], out if isinstance(out, list) else [out]) - compare_pytorch_and_py(out_fg, [test_value]) + compare_pytorch_and_py([x], out, [test_value]) def test_pinv(): x = matrix("x") x_inv = pt_nla.pinv(x) - fgraph = FunctionGraph([x], [x_inv]) x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) - compare_pytorch_and_py(fgraph, [x_np]) + compare_pytorch_and_py([x], [x_inv], [x_np]) @pytest.mark.parametrize("hermitian", [False, True]) @@ -106,8 +100,7 @@ def test_kron(): y = matrix("y") z = pt_nla.kron(x, y) - fgraph = FunctionGraph([x, y], [z]) x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) y_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) - compare_pytorch_and_py(fgraph, [x_np, y_np]) + compare_pytorch_and_py([x, y], [z], [x_np, y_np]) diff --git a/tests/link/pytorch/test_shape.py b/tests/link/pytorch/test_shape.py index 152aa8ddf3..4bfe6e1a2b 100644 --- a/tests/link/pytorch/test_shape.py +++ b/tests/link/pytorch/test_shape.py @@ -2,7 +2,6 @@ import pytensor.tensor as pt from pytensor.configdefaults import config -from pytensor.graph.fg import FunctionGraph from pytensor.tensor.shape import Shape, Shape_i, Unbroadcast, reshape from pytensor.tensor.type import iscalar, vector from tests.link.pytorch.test_basic import compare_pytorch_and_py @@ -11,29 +10,27 @@ def test_pytorch_shape_ops(): x_np = np.zeros((20, 3)) x = Shape()(pt.as_tensor_variable(x_np)) - x_fg = FunctionGraph([], [x]) - compare_pytorch_and_py(x_fg, [], must_be_device_array=False) + compare_pytorch_and_py([], [x], []) x = Shape_i(1)(pt.as_tensor_variable(x_np)) - x_fg = FunctionGraph([], [x]) - compare_pytorch_and_py(x_fg, [], must_be_device_array=False) + compare_pytorch_and_py([], [x], []) def test_pytorch_specify_shape(): in_pt = pt.matrix("in") x = pt.specify_shape(in_pt, (4, None)) - x_fg = FunctionGraph([in_pt], [x]) - compare_pytorch_and_py(x_fg, [np.ones((4, 5)).astype(config.floatX)]) + compare_pytorch_and_py([in_pt], [x], [np.ones((4, 5)).astype(config.floatX)]) # When used to assert two arrays have similar shapes in_pt = pt.matrix("in") shape_pt = pt.matrix("shape") x = pt.specify_shape(in_pt, shape_pt.shape) - x_fg = FunctionGraph([in_pt, shape_pt], [x]) + compare_pytorch_and_py( - x_fg, + [in_pt, shape_pt], + [x], [np.ones((4, 5)).astype(config.floatX), np.ones((4, 5)).astype(config.floatX)], ) @@ -41,21 +38,22 @@ def test_pytorch_specify_shape(): def test_pytorch_Reshape_constant(): a = vector("a") x = reshape(a, (2, 2)) - x_fg = FunctionGraph([a], [x]) - compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) + + compare_pytorch_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) def test_pytorch_Reshape_dynamic(): a = vector("a") shape_pt = iscalar("b") x = reshape(a, (shape_pt, shape_pt)) - x_fg = FunctionGraph([a, shape_pt], [x]) - compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2]) + + compare_pytorch_and_py( + [a, shape_pt], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2] + ) def test_pytorch_unbroadcast(): x_np = np.zeros((20, 1, 1)) x = Unbroadcast(0, 2)(pt.as_tensor_variable(x_np)) - x_fg = FunctionGraph([], [x]) - compare_pytorch_and_py(x_fg, []) + compare_pytorch_and_py([], [x], []) diff --git a/tests/link/pytorch/test_sort.py b/tests/link/pytorch/test_sort.py index 7912dd4a03..686a455409 100644 --- a/tests/link/pytorch/test_sort.py +++ b/tests/link/pytorch/test_sort.py @@ -1,7 +1,6 @@ import numpy as np import pytest -from pytensor.graph import FunctionGraph from pytensor.tensor import matrix from pytensor.tensor.sort import argsort, sort from tests.link.pytorch.test_basic import compare_pytorch_and_py @@ -12,6 +11,5 @@ def test_sort(func, axis): x = matrix("x", shape=(2, 2), dtype="float64") out = func(x, axis=axis) - fgraph = FunctionGraph([x], [out]) arr = np.array([[1.0, 4.0], [5.0, 2.0]]) - compare_pytorch_and_py(fgraph, [arr]) + compare_pytorch_and_py([x], [out], [arr]) diff --git a/tests/link/pytorch/test_subtensor.py b/tests/link/pytorch/test_subtensor.py index fb2b3390d3..15c32c2824 100644 --- a/tests/link/pytorch/test_subtensor.py +++ b/tests/link/pytorch/test_subtensor.py @@ -6,7 +6,6 @@ import pytensor.scalar as ps import pytensor.tensor as pt from pytensor.configdefaults import config -from pytensor.graph.fg import FunctionGraph from pytensor.tensor import inc_subtensor, set_subtensor from pytensor.tensor import subtensor as pt_subtensor from tests.link.pytorch.test_basic import compare_pytorch_and_py @@ -19,38 +18,33 @@ def test_pytorch_Subtensor(): out_pt = x_pt[1, 2, 0] assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_pytorch_and_py(out_fg, [x_np]) + + compare_pytorch_and_py([x_pt], [out_pt], [x_np]) out_pt = x_pt[1:, 1, :] assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_pytorch_and_py(out_fg, [x_np]) + compare_pytorch_and_py([x_pt], [out_pt], [x_np]) out_pt = x_pt[:2, 1, :] assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_pytorch_and_py(out_fg, [x_np]) + compare_pytorch_and_py([x_pt], [out_pt], [x_np]) out_pt = x_pt[1:2, 1, :] assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_pytorch_and_py(out_fg, [x_np]) + compare_pytorch_and_py([x_pt], [out_pt], [x_np]) # symbolic index a_pt = ps.int64("a") a_np = 1 out_pt = x_pt[a_pt, 2, a_pt:2] assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) - out_fg = FunctionGraph([x_pt, a_pt], [out_pt]) - compare_pytorch_and_py(out_fg, [x_np, a_np]) + compare_pytorch_and_py([x_pt, a_pt], [out_pt], [x_np, a_np]) with pytest.raises( NotImplementedError, match="Negative step sizes are not supported in Pytorch" ): out_pt = x_pt[::-1] - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_pytorch_and_py(out_fg, [x_np]) + compare_pytorch_and_py([x_pt], [out_pt], [x_np]) def test_pytorch_AdvSubtensor(): @@ -60,52 +54,43 @@ def test_pytorch_AdvSubtensor(): out_pt = pt_subtensor.advanced_subtensor1(x_pt, [1, 2]) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor1) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_pytorch_and_py(out_fg, [x_np]) + compare_pytorch_and_py([x_pt], [out_pt], [x_np]) out_pt = x_pt[[1, 2], [2, 3]] assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_pytorch_and_py(out_fg, [x_np]) + compare_pytorch_and_py([x_pt], [out_pt], [x_np]) out_pt = x_pt[[1, 2], 1:] assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_pytorch_and_py(out_fg, [x_np]) + compare_pytorch_and_py([x_pt], [out_pt], [x_np]) out_pt = x_pt[[1, 2], :, [3, 4]] assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_pytorch_and_py(out_fg, [x_np]) + compare_pytorch_and_py([x_pt], [out_pt], [x_np]) out_pt = x_pt[[1, 2], None] - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_pytorch_and_py(out_fg, [x_np]) + compare_pytorch_and_py([x_pt], [out_pt], [x_np]) a_pt = ps.int64("a") a_np = 2 out_pt = x_pt[[1, a_pt], a_pt] - out_fg = FunctionGraph([x_pt, a_pt], [out_pt]) - compare_pytorch_and_py(out_fg, [x_np, a_np]) + compare_pytorch_and_py([x_pt, a_pt], [out_pt], [x_np, a_np]) # boolean indices out_pt = x_pt[np.random.binomial(1, 0.5, size=(3, 4, 5)).astype(bool)] - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_pytorch_and_py(out_fg, [x_np]) + compare_pytorch_and_py([x_pt], [out_pt], [x_np]) a_pt = pt.tensor3("a", dtype="bool") a_np = np.random.binomial(1, 0.5, size=(3, 4, 5)).astype(bool) out_pt = x_pt[a_pt] - out_fg = FunctionGraph([x_pt, a_pt], [out_pt]) - compare_pytorch_and_py(out_fg, [x_np, a_np]) + compare_pytorch_and_py([x_pt, a_pt], [out_pt], [x_np, a_np]) with pytest.raises( NotImplementedError, match="Negative step sizes are not supported in Pytorch" ): out_pt = x_pt[[1, 2], ::-1] - out_fg = FunctionGraph([x_pt], [out_pt]) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) - compare_pytorch_and_py(out_fg, [x_np]) + compare_pytorch_and_py([x_pt], [out_pt], [x_np]) @pytest.mark.parametrize("subtensor_op", [set_subtensor, inc_subtensor]) @@ -116,20 +101,17 @@ def test_pytorch_IncSubtensor(subtensor_op): st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=config.floatX)) out_pt = subtensor_op(x_pt[1, 2, 3], st_pt) assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_pytorch_and_py(out_fg, [x_test]) + compare_pytorch_and_py([x_pt], [out_pt], [x_test]) # Test different type update st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype("float32")) out_pt = subtensor_op(x_pt[:2, 0, 0], st_pt) assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_pytorch_and_py(out_fg, [x_test]) + compare_pytorch_and_py([x_pt], [out_pt], [x_test]) out_pt = subtensor_op(x_pt[0, 1:3, 0], st_pt) assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_pytorch_and_py(out_fg, [x_test]) + compare_pytorch_and_py([x_pt], [out_pt], [x_test]) def inc_subtensor_ignore_duplicates(x, y): @@ -150,14 +132,12 @@ def test_pytorch_AvdancedIncSubtensor(advsubtensor_op): ) out_pt = advsubtensor_op(x_pt[np.r_[0, 2]], st_pt) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_pytorch_and_py(out_fg, [x_test]) + compare_pytorch_and_py([x_pt], [out_pt], [x_test]) # Repeated indices out_pt = advsubtensor_op(x_pt[np.r_[0, 0]], st_pt) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_pytorch_and_py(out_fg, [x_test]) + compare_pytorch_and_py([x_pt], [out_pt], [x_test]) # Mixing advanced and basic indexing if advsubtensor_op is inc_subtensor: @@ -168,19 +148,16 @@ def test_pytorch_AvdancedIncSubtensor(advsubtensor_op): st_pt = pt.as_tensor_variable(x_test[[0, 2], 0, :3]) out_pt = advsubtensor_op(x_pt[[0, 0], 0, :3], st_pt) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) with expectation: - compare_pytorch_and_py(out_fg, [x_test]) + compare_pytorch_and_py([x_pt], [out_pt], [x_test]) # Test different dtype update st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype("float32")) out_pt = advsubtensor_op(x_pt[[0, 2], 0, 0], st_pt) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_pytorch_and_py(out_fg, [x_test]) + compare_pytorch_and_py([x_pt], [out_pt], [x_test]) # Boolean indices out_pt = advsubtensor_op(x_pt[x_pt > 5], 1.0) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_pytorch_and_py(out_fg, [x_test]) + compare_pytorch_and_py([x_pt], [out_pt], [x_test]) diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index 6a93f3c7fd..c387152757 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -63,11 +63,6 @@ from tests import unittest_tools as utt -def set_test_value(x, v): - x.tag.test_value = v - return x - - def test_cpu_contiguous(): a = fmatrix("a") i = iscalar("i")