-
Notifications
You must be signed in to change notification settings - Fork 129
/
Copy pathtest_basic.py
227 lines (165 loc) · 6.34 KB
/
test_basic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
from collections.abc import Callable, Iterable
from functools import partial
import numpy as np
import pytest
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function
from pytensor.compile.mode import JAX, Mode
from pytensor.compile.sharedvalue import shared
from pytensor.configdefaults import config
from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.ifelse import ifelse
from pytensor.link.jax import JAXLinker
from pytensor.raise_op import assert_op
from pytensor.tensor.type import dscalar, matrices, scalar, vector
@pytest.fixture(scope="module", autouse=True)
def set_pytensor_flags():
with config.change_flags(cxx="", compute_test_value="ignore"):
yield
jax = pytest.importorskip("jax")
optimizer = RewriteDatabaseQuery(include=["jax"], exclude=JAX._optimizer.exclude)
jax_mode = Mode(linker=JAXLinker(), optimizer=optimizer)
py_mode = Mode(linker="py", optimizer=None)
def compare_jax_and_py(
graph_inputs: Iterable[Variable],
graph_outputs: Variable | Iterable[Variable],
test_inputs: Iterable,
*,
assert_fn: Callable | None = None,
must_be_device_array: bool = True,
jax_mode=jax_mode,
py_mode=py_mode,
):
"""Function to compare python function output and jax compiled output for testing equality
The inputs and outputs are then passed to this function which then compiles the given function in both
jax and python, runs the calculation in both and checks if the results are the same
Parameters
----------
graph_inputs:
Symbolic inputs to the graph
outputs:
Symbolic outputs of the graph
test_inputs: iter
Numerical inputs for testing the function.
assert_fn: func, opt
Assert function used to check for equality between python and jax. If not
provided uses np.testing.assert_allclose
must_be_device_array: Bool
Checks for instance of jax.interpreters.xla.DeviceArray. For testing purposes
if this device array is found it indicates if the result was computed by jax
Returns
-------
jax_res
"""
if assert_fn is None:
assert_fn = partial(np.testing.assert_allclose, rtol=1e-4)
if any(inp.owner is not None for inp in graph_inputs):
raise ValueError("Inputs must be root variables")
pytensor_jax_fn = function(graph_inputs, graph_outputs, mode=jax_mode)
jax_res = pytensor_jax_fn(*test_inputs)
if must_be_device_array:
if isinstance(jax_res, list):
assert all(isinstance(res, jax.Array) for res in jax_res)
else:
assert isinstance(jax_res, jax.Array)
pytensor_py_fn = function(graph_inputs, graph_outputs, mode=py_mode)
py_res = pytensor_py_fn(*test_inputs)
if isinstance(graph_outputs, list | tuple):
for j, p in zip(jax_res, py_res, strict=True):
assert_fn(j, p)
else:
assert_fn(jax_res, py_res)
return pytensor_jax_fn, jax_res
def test_jax_FunctionGraph_once():
"""Make sure that an output is only computed once when it's referenced multiple times."""
from pytensor.link.jax.dispatch import jax_funcify
x = vector("x")
y = vector("y")
class TestOp(Op):
def __init__(self):
self.called = 0
def make_node(self, *args):
return Apply(self, list(args), [x.type() for x in args])
def perform(self, inputs, outputs):
for i, inp in enumerate(inputs):
outputs[i][0] = inp[0]
@jax_funcify.register(TestOp)
def jax_funcify_TestOp(op, **kwargs):
def func(*args, op=op):
op.called += 1
return list(args)
return func
op1 = TestOp()
op2 = TestOp()
q, r = op1(x, y)
outs = op2(q + r, q + r)
out_fg = FunctionGraph([x, y], outs, clone=False)
assert len(out_fg.outputs) == 2
out_jx = jax_funcify(out_fg)
x_val = np.r_[1, 2].astype(config.floatX)
y_val = np.r_[2, 3].astype(config.floatX)
res = out_jx(x_val, y_val)
assert len(res) == 2
assert op1.called == 1
assert op2.called == 1
res = out_jx(x_val, y_val)
assert len(res) == 2
assert op1.called == 2
assert op2.called == 2
def test_shared():
a = shared(np.array([1, 2, 3], dtype=config.floatX))
pytensor_jax_fn = function([], a, mode="JAX")
jax_res = pytensor_jax_fn()
assert isinstance(jax_res, jax.Array)
np.testing.assert_allclose(jax_res, a.get_value())
pytensor_jax_fn = function([], a * 2, mode="JAX")
jax_res = pytensor_jax_fn()
assert isinstance(jax_res, jax.Array)
np.testing.assert_allclose(jax_res, a.get_value() * 2)
# Changed the shared value and make sure that the JAX-compiled
# function also changes.
new_a_value = np.array([3, 4, 5], dtype=config.floatX)
a.set_value(new_a_value)
jax_res = pytensor_jax_fn()
assert isinstance(jax_res, jax.Array)
np.testing.assert_allclose(jax_res, new_a_value * 2)
def test_shared_updates():
a = shared(0)
pytensor_jax_fn = function([], a, updates={a: a + 1}, mode="JAX")
res1, res2 = pytensor_jax_fn(), pytensor_jax_fn()
assert res1 == 0
assert res2 == 1
assert a.get_value() == 2
a.set_value(5)
res1, res2 = pytensor_jax_fn(), pytensor_jax_fn()
assert res1 == 5
assert res2 == 6
assert a.get_value() == 7
def test_jax_ifelse():
true_vals = np.r_[1, 2, 3]
false_vals = np.r_[-1, -2, -3]
x = ifelse(np.array(True), true_vals, false_vals)
compare_jax_and_py([], [x], [])
a = dscalar("a")
a_test = np.array(0.2, dtype=config.floatX)
x = ifelse(a < 0.5, true_vals, false_vals)
compare_jax_and_py([a], [x], [a_test])
def test_jax_checkandraise():
p = scalar()
p.tag.test_value = 0
res = assert_op(p, p < 1.0)
with pytest.warns(UserWarning):
function((p,), res, mode=jax_mode)
def test_OpFromGraph():
x, y, z = matrices("xyz")
ofg_1 = OpFromGraph([x, y], [x + y], inline=False)
ofg_2 = OpFromGraph([x, y], [x * y, x - y], inline=False)
o1, o2 = ofg_2(y, z)
out = ofg_1(x, o1) + o2
xv = np.ones((2, 2), dtype=config.floatX)
yv = np.ones((2, 2), dtype=config.floatX) * 3
zv = np.ones((2, 2), dtype=config.floatX) * 5
compare_jax_and_py([x, y, z], [out], [xv, yv, zv])