Skip to content

Commit 3b425ec

Browse files
committed
Implement JAX dot with sparse constants
Non-constant sparse inputs can't be handled because JAX does not allow Scipy sparse matrices as inputs. We could implement a BCOO type explicitly but this would be JAX exclusive, and the user would need to use it from the get go, meaning such graphs would not be compatible with other backends.
1 parent 98de246 commit 3b425ec

File tree

3 files changed

+114
-0
lines changed

3 files changed

+114
-0
lines changed

Diff for: pytensor/link/jax/dispatch/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@
1212
import pytensor.link.jax.dispatch.random
1313
import pytensor.link.jax.dispatch.elemwise
1414
import pytensor.link.jax.dispatch.scan
15+
import pytensor.link.jax.dispatch.sparse
1516

1617
# isort: on

Diff for: pytensor/link/jax/dispatch/sparse.py

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import jax.experimental.sparse as jsp
2+
from scipy.sparse import spmatrix
3+
4+
from pytensor.graph.basic import Constant
5+
from pytensor.link.jax.dispatch import jax_funcify, jax_typify
6+
from pytensor.sparse.basic import Dot, StructuredDot
7+
from pytensor.sparse.type import SparseTensorType
8+
9+
10+
@jax_typify.register(spmatrix)
11+
def jax_typify_spmatrix(matrix, dtype=None, **kwargs):
12+
# Note: This changes the type of the constants from CSR/CSC to BCOO
13+
# We could add BCOO as a PyTensor type but this would only be useful for JAX graphs
14+
# and it would break the premise of one graph -> multiple backends.
15+
# The same situation happens with RandomGenerators...
16+
return jsp.BCOO.from_scipy_sparse(matrix)
17+
18+
19+
@jax_funcify.register(Dot)
20+
@jax_funcify.register(StructuredDot)
21+
def jax_funcify_sparse_dot(op, node, **kwargs):
22+
for input in node.inputs:
23+
if isinstance(input.type, SparseTensorType) and not isinstance(input, Constant):
24+
raise NotImplementedError(
25+
"JAX sparse dot only implemented for constant sparse inputs"
26+
)
27+
28+
if isinstance(node.outputs[0].type, SparseTensorType):
29+
raise NotImplementedError("JAX sparse dot only implemented for dense outputs")
30+
31+
@jsp.sparsify
32+
def sparse_dot(x, y):
33+
out = x @ y
34+
if isinstance(out, jsp.BCOO):
35+
out = out.todense()
36+
return out
37+
38+
return sparse_dot

Diff for: tests/link/jax/test_sparse.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import numpy as np
2+
import pytest
3+
import scipy.sparse
4+
5+
import pytensor.sparse as ps
6+
import pytensor.tensor as pt
7+
from pytensor import function
8+
from pytensor.graph import FunctionGraph
9+
from tests.link.jax.test_basic import compare_jax_and_py
10+
11+
12+
@pytest.mark.parametrize(
13+
"op, x_type, y_type",
14+
[
15+
(ps.dot, pt.vector, ps.matrix),
16+
(ps.dot, pt.matrix, ps.matrix),
17+
(ps.dot, ps.matrix, pt.vector),
18+
(ps.dot, ps.matrix, pt.matrix),
19+
# structured_dot only allows matrix @ matrix
20+
(ps.structured_dot, pt.matrix, ps.matrix),
21+
(ps.structured_dot, ps.matrix, pt.matrix),
22+
],
23+
)
24+
def test_sparse_dot_constant_sparse(x_type, y_type, op):
25+
inputs = []
26+
test_values = []
27+
28+
if x_type is ps.matrix:
29+
x_sp = scipy.sparse.random(5, 40, density=0.25, format="csr", dtype="float32")
30+
x_pt = ps.as_sparse_variable(x_sp, name="x")
31+
else:
32+
x_pt = x_type("x", dtype="float32")
33+
if x_pt.ndim == 1:
34+
x_test = np.arange(40, dtype="float32")
35+
else:
36+
x_test = np.arange(5 * 40, dtype="float32").reshape(5, 40)
37+
inputs.append(x_pt)
38+
test_values.append(x_test)
39+
40+
if y_type is ps.matrix:
41+
y_sp = scipy.sparse.random(40, 3, density=0.25, format="csc", dtype="float32")
42+
y_pt = ps.as_sparse_variable(y_sp, name="y")
43+
else:
44+
y_pt = y_type("y", dtype="float32")
45+
if y_pt.ndim == 1:
46+
y_test = np.arange(40, dtype="float32")
47+
else:
48+
y_test = np.arange(40 * 3, dtype="float32").reshape(40, 3)
49+
inputs.append(y_pt)
50+
test_values.append(y_test)
51+
52+
dot_pt = op(x_pt, y_pt)
53+
fgraph = FunctionGraph(inputs, [dot_pt])
54+
compare_jax_and_py(fgraph, test_values)
55+
56+
57+
def test_sparse_dot_non_const_raises():
58+
x_pt = pt.vector("x")
59+
60+
y_sp = scipy.sparse.random(40, 3, density=0.25, format="csc", dtype="float32")
61+
y_pt = ps.as_sparse_variable(y_sp, name="y").type()
62+
63+
out = ps.dot(x_pt, y_pt)
64+
65+
msg = "JAX sparse dot only implemented for constant sparse inputs"
66+
67+
with pytest.raises(NotImplementedError, match=msg):
68+
function([x_pt, y_pt], out, mode="JAX")
69+
70+
y_pt_shared = ps.shared(y_sp, name="y")
71+
72+
out = ps.dot(x_pt, y_pt_shared)
73+
74+
with pytest.raises(NotImplementedError, match=msg):
75+
function([x_pt], out, mode="JAX")

0 commit comments

Comments
 (0)