Skip to content

Removes Function Graph, set_test_value/ get_test_value #1107

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions pytensor/link/numba/dispatch/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,17 @@

B_is_1d = B.ndim == 1

if not overwrite_b:
B_copy = _copy_to_fortran_order(B)
else:
if overwrite_b:
B_copy = B
else:
if B_is_1d:
# _copy_to_fortran_order does nothing with vectors
B_copy = np.copy(B)

Check warning on line 134 in pytensor/link/numba/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/slinalg.py#L134

Added line #L134 was not covered by tests
else:
B_copy = _copy_to_fortran_order(B)

Check warning on line 136 in pytensor/link/numba/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/slinalg.py#L136

Added line #L136 was not covered by tests

if B_is_1d:
B_copy = np.expand_dims(B, -1)
B_copy = np.expand_dims(B_copy, -1)

Check warning on line 139 in pytensor/link/numba/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/slinalg.py#L139

Added line #L139 was not covered by tests

NRHS = 1 if B_is_1d else int(B_copy.shape[-1])

Expand Down
51 changes: 24 additions & 27 deletions tests/link/jax/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function
from pytensor.compile.mode import JAX, Mode
from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.compile.sharedvalue import shared
from pytensor.configdefaults import config
from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Apply
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op, get_test_value
from pytensor.graph.op import Op
from pytensor.ifelse import ifelse
from pytensor.link.jax import JAXLinker
from pytensor.raise_op import assert_op
Expand All @@ -34,25 +34,28 @@ def set_pytensor_flags():


def compare_jax_and_py(
fgraph: FunctionGraph,
graph_inputs: Iterable[Variable],
graph_outputs: Variable | Iterable[Variable],
test_inputs: Iterable,
*,
assert_fn: Callable | None = None,
must_be_device_array: bool = True,
jax_mode=jax_mode,
py_mode=py_mode,
):
"""Function to compare python graph output and jax compiled output for testing equality
"""Function to compare python function output and jax compiled output for testing equality

In the tests below computational graphs are defined in PyTensor. These graphs are then passed to
this function which then compiles the graphs in both jax and python, runs the calculation
in both and checks if the results are the same
The inputs and outputs are then passed to this function which then compiles the given function in both
jax and python, runs the calculation in both and checks if the results are the same

Parameters
----------
fgraph: FunctionGraph
PyTensor function Graph object
graph_inputs:
Symbolic inputs to the graph
outputs:
Symbolic outputs of the graph
test_inputs: iter
Numerical inputs for testing the function graph
Numerical inputs for testing the function.
assert_fn: func, opt
Assert function used to check for equality between python and jax. If not
provided uses np.testing.assert_allclose
Expand All @@ -68,8 +71,10 @@ def compare_jax_and_py(
if assert_fn is None:
assert_fn = partial(np.testing.assert_allclose, rtol=1e-4)

fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)]
pytensor_jax_fn = function(fn_inputs, fgraph.outputs, mode=jax_mode)
if any(inp.owner is not None for inp in graph_inputs):
raise ValueError("Inputs must be root variables")

pytensor_jax_fn = function(graph_inputs, graph_outputs, mode=jax_mode)
jax_res = pytensor_jax_fn(*test_inputs)

if must_be_device_array:
Expand All @@ -78,10 +83,10 @@ def compare_jax_and_py(
else:
assert isinstance(jax_res, jax.Array)

pytensor_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode)
pytensor_py_fn = function(graph_inputs, graph_outputs, mode=py_mode)
py_res = pytensor_py_fn(*test_inputs)

if len(fgraph.outputs) > 1:
if isinstance(graph_outputs, list | tuple):
for j, p in zip(jax_res, py_res, strict=True):
assert_fn(j, p)
else:
Expand Down Expand Up @@ -187,16 +192,14 @@ def test_jax_ifelse():
false_vals = np.r_[-1, -2, -3]

x = ifelse(np.array(True), true_vals, false_vals)
x_fg = FunctionGraph([], [x])

compare_jax_and_py(x_fg, [])
compare_jax_and_py([], [x], [])

a = dscalar("a")
a.tag.test_value = np.array(0.2, dtype=config.floatX)
a_test = np.array(0.2, dtype=config.floatX)
x = ifelse(a < 0.5, true_vals, false_vals)
x_fg = FunctionGraph([a], [x]) # I.e. False

compare_jax_and_py(x_fg, [get_test_value(i) for i in x_fg.inputs])
compare_jax_and_py([a], [x], [a_test])


def test_jax_checkandraise():
Expand All @@ -209,22 +212,16 @@ def test_jax_checkandraise():
function((p,), res, mode=jax_mode)


def set_test_value(x, v):
x.tag.test_value = v
return x


def test_OpFromGraph():
x, y, z = matrices("xyz")
ofg_1 = OpFromGraph([x, y], [x + y], inline=False)
ofg_2 = OpFromGraph([x, y], [x * y, x - y], inline=False)

o1, o2 = ofg_2(y, z)
out = ofg_1(x, o1) + o2
out_fg = FunctionGraph([x, y, z], [out])

xv = np.ones((2, 2), dtype=config.floatX)
yv = np.ones((2, 2), dtype=config.floatX) * 3
zv = np.ones((2, 2), dtype=config.floatX) * 5

compare_jax_and_py(out_fg, [xv, yv, zv])
compare_jax_and_py([x, y, z], [out], [xv, yv, zv])
13 changes: 5 additions & 8 deletions tests/link/jax/test_blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from pytensor.compile.function import function
from pytensor.compile.mode import Mode
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.link.jax import JAXLinker
from pytensor.tensor import blas as pt_blas
Expand All @@ -16,21 +14,20 @@
def test_jax_BatchedDot():
# tensor3 . tensor3
a = tensor3("a")
a.tag.test_value = (
a_test_value = (
np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3))
)
b = tensor3("b")
b.tag.test_value = (
b_test_value = (
np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2))
)
out = pt_blas.BatchedDot()(a, b)
fgraph = FunctionGraph([a, b], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py([a, b], [out], [a_test_value, b_test_value])

# A dimension mismatch should raise a TypeError for compatibility
inputs = [get_test_value(a)[:-1], get_test_value(b)]
inputs = [a_test_value[:-1], b_test_value]
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts)
pytensor_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
pytensor_jax_fn = function([a, b], [out], mode=jax_mode)
with pytest.raises(TypeError):
pytensor_jax_fn(*inputs)
4 changes: 1 addition & 3 deletions tests/link/jax/test_blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import pytest

from pytensor import config
from pytensor.graph import FunctionGraph
from pytensor.tensor import tensor
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import Dot, matmul
Expand Down Expand Up @@ -32,8 +31,7 @@ def test_matmul(matmul_op):

out = matmul_op(a, b)
assert isinstance(out.owner.op, Blockwise)
fg = FunctionGraph([a, b], [out])
fn, _ = compare_jax_and_py(fg, test_values)
fn, _ = compare_jax_and_py([a, b], [out], test_values)

# Check we are not adding any unnecessary stuff
jaxpr = str(jax.make_jaxpr(fn.vm.jit_fn)(*test_values))
Expand Down
7 changes: 2 additions & 5 deletions tests/link/jax/test_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import pytest

import pytensor.tensor as pt
from pytensor.graph import FunctionGraph
from tests.link.jax.test_basic import compare_jax_and_py


Expand All @@ -22,8 +21,7 @@ def test_jax_einsum():
}
x_pt, y_pt, z_pt = (pt.tensor(name, shape=shape) for name, shape in shapes.items())
out = pt.einsum(subscripts, x_pt, y_pt, z_pt)
fg = FunctionGraph([x_pt, y_pt, z_pt], [out])
compare_jax_and_py(fg, [x, y, z])
compare_jax_and_py([x_pt, y_pt, z_pt], [out], [x, y, z])


def test_ellipsis_einsum():
Expand All @@ -34,5 +32,4 @@ def test_ellipsis_einsum():
x_pt = pt.tensor("x", shape=x.shape)
y_pt = pt.tensor("y", shape=y.shape)
out = pt.einsum(subscripts, x_pt, y_pt)
fg = FunctionGraph([x_pt, y_pt], [out])
compare_jax_and_py(fg, [x, y])
compare_jax_and_py([x_pt, y_pt], [out], [x, y])
56 changes: 23 additions & 33 deletions tests/link/jax/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import pytensor.tensor as pt
from pytensor.compile import get_mode
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.tensor import elemwise as pt_elemwise
from pytensor.tensor.math import all as pt_all
from pytensor.tensor.math import prod
Expand All @@ -26,87 +24,81 @@ def test_jax_Dimshuffle():
a_pt = matrix("a")

x = a_pt.T
x_fg = FunctionGraph([a_pt], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)])
compare_jax_and_py(
[a_pt], [x], [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]
)

x = a_pt.dimshuffle([0, 1, "x"])
x_fg = FunctionGraph([a_pt], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)])
compare_jax_and_py(
[a_pt], [x], [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]
)

a_pt = tensor(dtype=config.floatX, shape=(None, 1))
x = a_pt.dimshuffle((0,))
x_fg = FunctionGraph([a_pt], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
compare_jax_and_py([a_pt], [x], [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])

a_pt = tensor(dtype=config.floatX, shape=(None, 1))
x = pt_elemwise.DimShuffle(input_ndim=2, new_order=(0,))(a_pt)
x_fg = FunctionGraph([a_pt], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
compare_jax_and_py([a_pt], [x], [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])


def test_jax_CAReduce():
a_pt = vector("a")
a_pt.tag.test_value = np.r_[1, 2, 3].astype(config.floatX)

x = pt_sum(a_pt, axis=None)
x_fg = FunctionGraph([a_pt], [x])

compare_jax_and_py(x_fg, [np.r_[1, 2, 3].astype(config.floatX)])
compare_jax_and_py([a_pt], [x], [np.r_[1, 2, 3].astype(config.floatX)])

a_pt = matrix("a")
a_pt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)

x = pt_sum(a_pt, axis=0)
x_fg = FunctionGraph([a_pt], [x])

compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
compare_jax_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])

x = pt_sum(a_pt, axis=1)
x_fg = FunctionGraph([a_pt], [x])

compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
compare_jax_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])

a_pt = matrix("a")
a_pt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)

x = prod(a_pt, axis=0)
x_fg = FunctionGraph([a_pt], [x])

compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
compare_jax_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])

x = pt_all(a_pt)
x_fg = FunctionGraph([a_pt], [x])

compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
compare_jax_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax(axis):
x = matrix("x")
x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
x_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = softmax(x, axis=axis)
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py([x], [out], [x_test_value])


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_logsoftmax(axis):
x = matrix("x")
x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
x_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = log_softmax(x, axis=axis)
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

compare_jax_and_py([x], [out], [x_test_value])


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax_grad(axis):
dy = matrix("dy")
dy.tag.test_value = np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)
dy_test_value = np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)
sm = matrix("sm")
sm.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
sm_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = SoftmaxGrad(axis=axis)(dy, sm)
fgraph = FunctionGraph([dy, sm], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

compare_jax_and_py([dy, sm], [out], [dy_test_value, sm_test_value])


@pytest.mark.parametrize("size", [(10, 10), (1000, 1000)])
Expand Down Expand Up @@ -134,6 +126,4 @@ def test_logsumexp_benchmark(size, axis, benchmark):
def test_multiple_input_multiply():
x, y, z = vectors("xyz")
out = pt.mul(x, y, z)

fg = FunctionGraph(outputs=[out], clone=False)
compare_jax_and_py(fg, [[1.5], [2.5], [3.5]])
compare_jax_and_py([x, y, z], [out], test_inputs=[[1.5], [2.5], [3.5]])
Loading