6
6
import pytensor .tensor as pt
7
7
from pytensor .compile import get_mode
8
8
from pytensor .configdefaults import config
9
- from pytensor .graph .fg import FunctionGraph
10
- from pytensor .graph .op import get_test_value
11
9
from pytensor .tensor import elemwise as pt_elemwise
12
10
from pytensor .tensor .math import all as pt_all
13
11
from pytensor .tensor .math import prod
@@ -26,87 +24,81 @@ def test_jax_Dimshuffle():
26
24
a_pt = matrix ("a" )
27
25
28
26
x = a_pt .T
29
- x_fg = FunctionGraph ([a_pt ], [x ])
30
- compare_jax_and_py (x_fg , [np .c_ [[1.0 , 2.0 ], [3.0 , 4.0 ]].astype (config .floatX )])
27
+ compare_jax_and_py (
28
+ [a_pt ], [x ], [np .c_ [[1.0 , 2.0 ], [3.0 , 4.0 ]].astype (config .floatX )]
29
+ )
31
30
32
31
x = a_pt .dimshuffle ([0 , 1 , "x" ])
33
- x_fg = FunctionGraph ([a_pt ], [x ])
34
- compare_jax_and_py (x_fg , [np .c_ [[1.0 , 2.0 ], [3.0 , 4.0 ]].astype (config .floatX )])
32
+ compare_jax_and_py (
33
+ [a_pt ], [x ], [np .c_ [[1.0 , 2.0 ], [3.0 , 4.0 ]].astype (config .floatX )]
34
+ )
35
35
36
36
a_pt = tensor (dtype = config .floatX , shape = (None , 1 ))
37
37
x = a_pt .dimshuffle ((0 ,))
38
- x_fg = FunctionGraph ([a_pt ], [x ])
39
- compare_jax_and_py (x_fg , [np .c_ [[1.0 , 2.0 , 3.0 , 4.0 ]].astype (config .floatX )])
38
+ compare_jax_and_py ([a_pt ], [x ], [np .c_ [[1.0 , 2.0 , 3.0 , 4.0 ]].astype (config .floatX )])
40
39
41
40
a_pt = tensor (dtype = config .floatX , shape = (None , 1 ))
42
41
x = pt_elemwise .DimShuffle (input_ndim = 2 , new_order = (0 ,))(a_pt )
43
- x_fg = FunctionGraph ([a_pt ], [x ])
44
- compare_jax_and_py (x_fg , [np .c_ [[1.0 , 2.0 , 3.0 , 4.0 ]].astype (config .floatX )])
42
+ compare_jax_and_py ([a_pt ], [x ], [np .c_ [[1.0 , 2.0 , 3.0 , 4.0 ]].astype (config .floatX )])
45
43
46
44
47
45
def test_jax_CAReduce ():
48
46
a_pt = vector ("a" )
49
47
a_pt .tag .test_value = np .r_ [1 , 2 , 3 ].astype (config .floatX )
50
48
51
49
x = pt_sum (a_pt , axis = None )
52
- x_fg = FunctionGraph ([a_pt ], [x ])
53
50
54
- compare_jax_and_py (x_fg , [np .r_ [1 , 2 , 3 ].astype (config .floatX )])
51
+ compare_jax_and_py ([ a_pt ], [ x ] , [np .r_ [1 , 2 , 3 ].astype (config .floatX )])
55
52
56
53
a_pt = matrix ("a" )
57
54
a_pt .tag .test_value = np .c_ [[1 , 2 , 3 ], [1 , 2 , 3 ]].astype (config .floatX )
58
55
59
56
x = pt_sum (a_pt , axis = 0 )
60
- x_fg = FunctionGraph ([a_pt ], [x ])
61
57
62
- compare_jax_and_py (x_fg , [np .c_ [[1 , 2 , 3 ], [1 , 2 , 3 ]].astype (config .floatX )])
58
+ compare_jax_and_py ([ a_pt ], [ x ] , [np .c_ [[1 , 2 , 3 ], [1 , 2 , 3 ]].astype (config .floatX )])
63
59
64
60
x = pt_sum (a_pt , axis = 1 )
65
- x_fg = FunctionGraph ([a_pt ], [x ])
66
61
67
- compare_jax_and_py (x_fg , [np .c_ [[1 , 2 , 3 ], [1 , 2 , 3 ]].astype (config .floatX )])
62
+ compare_jax_and_py ([ a_pt ], [ x ] , [np .c_ [[1 , 2 , 3 ], [1 , 2 , 3 ]].astype (config .floatX )])
68
63
69
64
a_pt = matrix ("a" )
70
65
a_pt .tag .test_value = np .c_ [[1 , 2 , 3 ], [1 , 2 , 3 ]].astype (config .floatX )
71
66
72
67
x = prod (a_pt , axis = 0 )
73
- x_fg = FunctionGraph ([a_pt ], [x ])
74
68
75
- compare_jax_and_py (x_fg , [np .c_ [[1 , 2 , 3 ], [1 , 2 , 3 ]].astype (config .floatX )])
69
+ compare_jax_and_py ([ a_pt ], [ x ] , [np .c_ [[1 , 2 , 3 ], [1 , 2 , 3 ]].astype (config .floatX )])
76
70
77
71
x = pt_all (a_pt )
78
- x_fg = FunctionGraph ([a_pt ], [x ])
79
72
80
- compare_jax_and_py (x_fg , [np .c_ [[1 , 2 , 3 ], [1 , 2 , 3 ]].astype (config .floatX )])
73
+ compare_jax_and_py ([ a_pt ], [ x ] , [np .c_ [[1 , 2 , 3 ], [1 , 2 , 3 ]].astype (config .floatX )])
81
74
82
75
83
76
@pytest .mark .parametrize ("axis" , [None , 0 , 1 ])
84
77
def test_softmax (axis ):
85
78
x = matrix ("x" )
86
- x . tag . test_value = np .arange (6 , dtype = config .floatX ).reshape (2 , 3 )
79
+ x_test_value = np .arange (6 , dtype = config .floatX ).reshape (2 , 3 )
87
80
out = softmax (x , axis = axis )
88
- fgraph = FunctionGraph ([x ], [out ])
89
- compare_jax_and_py (fgraph , [get_test_value (i ) for i in fgraph .inputs ])
81
+ compare_jax_and_py ([x ], [out ], [x_test_value ])
90
82
91
83
92
84
@pytest .mark .parametrize ("axis" , [None , 0 , 1 ])
93
85
def test_logsoftmax (axis ):
94
86
x = matrix ("x" )
95
- x . tag . test_value = np .arange (6 , dtype = config .floatX ).reshape (2 , 3 )
87
+ x_test_value = np .arange (6 , dtype = config .floatX ).reshape (2 , 3 )
96
88
out = log_softmax (x , axis = axis )
97
- fgraph = FunctionGraph ([ x ], [ out ])
98
- compare_jax_and_py (fgraph , [get_test_value ( i ) for i in fgraph . inputs ])
89
+
90
+ compare_jax_and_py ([ x ] , [out ], [ x_test_value ])
99
91
100
92
101
93
@pytest .mark .parametrize ("axis" , [None , 0 , 1 ])
102
94
def test_softmax_grad (axis ):
103
95
dy = matrix ("dy" )
104
- dy . tag . test_value = np .array ([[1 , 1 , 1 ], [0 , 0 , 0 ]], dtype = config .floatX )
96
+ dy_test_value = np .array ([[1 , 1 , 1 ], [0 , 0 , 0 ]], dtype = config .floatX )
105
97
sm = matrix ("sm" )
106
- sm . tag . test_value = np .arange (6 , dtype = config .floatX ).reshape (2 , 3 )
98
+ sm_test_value = np .arange (6 , dtype = config .floatX ).reshape (2 , 3 )
107
99
out = SoftmaxGrad (axis = axis )(dy , sm )
108
- fgraph = FunctionGraph ([ dy , sm ], [ out ])
109
- compare_jax_and_py (fgraph , [ get_test_value ( i ) for i in fgraph . inputs ])
100
+
101
+ compare_jax_and_py ([ dy , sm ], [ out ], [ dy_test_value , sm_test_value ])
110
102
111
103
112
104
@pytest .mark .parametrize ("size" , [(10 , 10 ), (1000 , 1000 )])
@@ -134,6 +126,4 @@ def test_logsumexp_benchmark(size, axis, benchmark):
134
126
def test_multiple_input_multiply ():
135
127
x , y , z = vectors ("xyz" )
136
128
out = pt .mul (x , y , z )
137
-
138
- fg = FunctionGraph (outputs = [out ], clone = False )
139
- compare_jax_and_py (fg , [[1.5 ], [2.5 ], [3.5 ]])
129
+ compare_jax_and_py ([x , y , z ], [out ], test_inputs = [[1.5 ], [2.5 ], [3.5 ]])
0 commit comments