Skip to content

Commit 0c398e3

Browse files
committed
Rewrite away blockwise Subtensor in gradient of Blockwise(Conv1d)
1 parent a0a494a commit 0c398e3

File tree

2 files changed

+55
-5
lines changed

2 files changed

+55
-5
lines changed

Diff for: pytensor/tensor/rewriting/blockwise.py

+31-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414
register_stabilize,
1515
)
1616
from pytensor.tensor.shape import Reshape
17-
from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor, Subtensor
17+
from pytensor.tensor.subtensor import (
18+
AdvancedIncSubtensor,
19+
AdvancedSubtensor,
20+
Subtensor,
21+
indices_from_subtensor,
22+
)
1823

1924

2025
@node_rewriter([Blockwise])
@@ -216,9 +221,9 @@ def local_blockwise_reshape(fgraph, node):
216221
217222
Reshape is tricky to vectorize eagerly, because a graph like
218223
`x.reshape([x.shape[0] * x.shape[1], -1])` has many operations
219-
that must be vectorized before we arrize at the reshape operation.
224+
that must be vectorized before we arrive at the reshape operation.
220225
221-
For the square Reshape case, we must wait for all the intemediate
226+
For the square Reshape case, we must wait for all the intermediate
222227
operations to be lifted as Allocs
223228
"""
224229
if not isinstance(node.op.core_op, Reshape):
@@ -234,6 +239,29 @@ def local_blockwise_reshape(fgraph, node):
234239
return [new_out]
235240

236241

242+
@register_stabilize
243+
@register_specialize
244+
@node_rewriter([Blockwise])
245+
def local_blockwise_of_subtensor(fgraph, node):
246+
"""Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor.
247+
248+
Blockwise(Subtensor{a: b})(x, a, b) -> x[:, a:b] when x has one batch dimension, and a/b none
249+
"""
250+
if not isinstance(node.op.core_op, Subtensor):
251+
return
252+
253+
x, *idxs = node.inputs
254+
if not all(all(idx.type.broadcastable) for idx in idxs):
255+
return
256+
257+
core_idxs = indices_from_subtensor(
258+
[idx.squeeze() for idx in idxs], node.op.core_op.idx_list
259+
)
260+
# Add empty slices for the batch dims
261+
none_slices = (slice(None),) * node.op.batch_ndim(node)
262+
return [x[(*none_slices, *core_idxs)]]
263+
264+
237265
@node_rewriter(tracks=[Blockwise], inplace=True)
238266
def blockwise_inplace(fgraph, node):
239267
blockwise_op = node.op

Diff for: tests/tensor/signal/test_conv.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
import pytest
55
from scipy.signal import convolve as scipy_convolve
66

7-
from pytensor import config, function
7+
from pytensor import config, function, grad
8+
from pytensor.graph import ancestors, rewrite_graph
89
from pytensor.tensor import matrix, vector
9-
from pytensor.tensor.signal.conv import convolve1d
10+
from pytensor.tensor.blockwise import Blockwise
11+
from pytensor.tensor.signal.conv import Conv1d, convolve1d
1012
from tests import unittest_tools as utt
1113

1214

@@ -60,3 +62,23 @@ def test_convolve1d_batch_same():
6062

6163
res = out.eval({x: x_test, y: y_test})
6264
assert res.shape == (2, 8)
65+
66+
67+
@pytest.mark.parametrize("mode", ("full", "valid", "same"))
68+
def test_convolve1d_batch_graph(mode):
69+
"""Test that we don't have slow Blockwise Subtensors in graph of a batched convolve1d"""
70+
x = matrix("x")
71+
y = matrix("y")
72+
out = convolve1d(x, y, mode=mode)
73+
grads = grad(out.sum(), wrt=[x, y])
74+
final_grads = rewrite_graph(
75+
grads, include=("ShapeOpt", "canonicalize", "stabilize", "specialize")
76+
)
77+
78+
blockwise_nodes = [
79+
var.owner
80+
for var in ancestors(final_grads)
81+
if var.owner is not None and isinstance(var.owner.op, Blockwise)
82+
]
83+
# Check any Blockwise are just Conv1d
84+
assert all(isinstance(node.op.core_op, Conv1d) for node in blockwise_nodes)

0 commit comments

Comments
 (0)