Skip to content

Fix bug in infer_shape of Blockwise(Subtensor) #1353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,17 @@ def c_code_cache_version(self):
scalar_from_tensor = ScalarFromTensor()


@_vectorize_node.register(ScalarFromTensor)
def vectorize_scalar_from_tensor(op, node, batch_x):
if batch_x.ndim == 0:
return scalar_from_tensor(batch_x).owner
if batch_x.owner is not None:
return batch_x.owner

# Needed until we fix https://github.com/pymc-devs/pytensor/issues/902
return batch_x.copy().owner


# to be removed as we get the epydoc routine-documenting thing going
# -JB 20080924
def _conversion(real_value: Op, name: str) -> Op:
Expand Down
7 changes: 4 additions & 3 deletions pytensor/tensor/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType
from pytensor.graph import FunctionGraph
from pytensor.graph.basic import Apply, Constant, ancestors
from pytensor.graph.basic import Apply, Constant, explicit_graph_inputs
from pytensor.graph.null_type import NullType
from pytensor.graph.op import Op
from pytensor.graph.replace import (
Expand Down Expand Up @@ -190,7 +190,7 @@ def infer_shape(
core_op_infer_shape = getattr(self.core_op, "infer_shape", None)
if core_op_infer_shape is not None:
dummy_core_node = self._create_dummy_core_node(node.inputs)
dummy_core_inputs = dummy_core_node.inputs
dummy_core_inputs = tuple(explicit_graph_inputs(dummy_core_node.inputs))
dummy_fgraph = FunctionGraph(outputs=dummy_core_node.outputs, clone=False)
core_input_shapes = [
input_shape[batch_ndims:] for input_shape in input_shapes
Expand All @@ -214,7 +214,8 @@ def infer_shape(
# of the core_node as the value is not constant across batch dims of the Blockwise
core_out_dim = core_output_shapes[o][i]
if not (
set(dummy_core_inputs) & set(ancestors([core_out_dim]))
set(dummy_core_inputs)
& set(explicit_graph_inputs([core_out_dim]))
):
core_out_shape.append(core_out_dim)
continue
Expand Down
34 changes: 31 additions & 3 deletions pytensor/tensor/rewriting/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
register_stabilize,
)
from pytensor.tensor.shape import Reshape
from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor, Subtensor
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedSubtensor,
Subtensor,
indices_from_subtensor,
)


@node_rewriter([Blockwise])
Expand Down Expand Up @@ -216,9 +221,9 @@ def local_blockwise_reshape(fgraph, node):

Reshape is tricky to vectorize eagerly, because a graph like
`x.reshape([x.shape[0] * x.shape[1], -1])` has many operations
that must be vectorized before we arrize at the reshape operation.
that must be vectorized before we arrive at the reshape operation.

For the square Reshape case, we must wait for all the intemediate
For the square Reshape case, we must wait for all the intermediate
operations to be lifted as Allocs
"""
if not isinstance(node.op.core_op, Reshape):
Expand All @@ -234,6 +239,29 @@ def local_blockwise_reshape(fgraph, node):
return [new_out]


@register_stabilize
@register_specialize
@node_rewriter([Blockwise])
def local_blockwise_of_subtensor(fgraph, node):
"""Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor.

Blockwise(Subtensor{a: b})(x, a, b) -> x[:, a:b] when x has one batch dimension, and a/b none
"""
if not isinstance(node.op.core_op, Subtensor):
return

x, *idxs = node.inputs
if not all(all(idx.type.broadcastable) for idx in idxs):
return

core_idxs = indices_from_subtensor(
[idx.squeeze() for idx in idxs], node.op.core_op.idx_list
)
# Add empty slices for the batch dims
none_slices = (slice(None),) * node.op.batch_ndim(node)
return [x[(*none_slices, *core_idxs)]]


@node_rewriter(tracks=[Blockwise], inplace=True)
def blockwise_inplace(fgraph, node):
blockwise_op = node.op
Expand Down
26 changes: 24 additions & 2 deletions tests/tensor/signal/test_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import pytest
from scipy.signal import convolve as scipy_convolve

from pytensor import config, function
from pytensor import config, function, grad
from pytensor.graph import ancestors, rewrite_graph
from pytensor.tensor import matrix, vector
from pytensor.tensor.signal.conv import convolve1d
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.signal.conv import Conv1d, convolve1d
from tests import unittest_tools as utt


Expand Down Expand Up @@ -60,3 +62,23 @@ def test_convolve1d_batch_same():

res = out.eval({x: x_test, y: y_test})
assert res.shape == (2, 8)


@pytest.mark.parametrize("mode", ("full", "valid", "same"))
def test_convolve1d_batch_graph(mode):
"""Test that we don't have slow Blockwise Subtensors in graph of a batched convolve1d"""
x = matrix("x")
y = matrix("y")
out = convolve1d(x, y, mode=mode)
grads = grad(out.sum(), wrt=[x, y])
final_grads = rewrite_graph(
grads, include=("ShapeOpt", "canonicalize", "stabilize", "specialize")
)

blockwise_nodes = [
var.owner
for var in ancestors(final_grads)
if var.owner is not None and isinstance(var.owner.op, Blockwise)
]
# Check any Blockwise are just Conv1d
assert all(isinstance(node.op.core_op, Conv1d) for node in blockwise_nodes)
15 changes: 11 additions & 4 deletions tests/tensor/test_blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,13 @@ class TestOpWithInferShape(Op):
def make_node(self, a, b):
assert a.type.ndim == 1
assert b.type.ndim == 1
# Simulate make_node that introduces operations on inputs
a_identity = a.copy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does copying the input simulate an operation on the inputs? The apply takes the copies as inputs, so any intermediate operations (the f in a_identity = f(a)) would be lost right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After rewrites, but the Blockwise.infer_shape acts on the returned graph immediately.

The thing Blockwise is trying to is to figure out if the the core_shape of the Op depends on the values of the inputs, or can be guessed just based on their shapes. For that it calls infer_shape on the core op with dummy core variables and then checks if those are part of the returned shape graph. If they are, then it can't really use it, because it means the shape may vary over iterations (say a Blockwise slice Subtensor with batch start points).

If they are not used in the shape of the graph, then it means only the core shape is needed which is fine to use. This would be the case of a Blockwise(Dirichlet), where only the core shape (length of alpha), but not their value is needed. (We don't blockwise RVs but you get the idea).

Anyway, the logic to figure out if the core values are needed was to create a dummy node and then checking if the node inputs were in the graph. But this failed when the dummy node didn't really use the dummy variables (because it added extra nodes, like DimShuffle or ScalarFromTensor in the case that actually failed). Identity here is an easy way to test this without having to change anything else in the test.

b_identity = b.copy()

c = tensor(shape=(None,))
d = tensor(shape=(None,))
return Apply(self, [a, b], [c, d])
return Apply(self, [a_identity, b_identity], [c, d])

def perform(self, node, inputs, outputs):
a, b = inputs
Expand All @@ -277,9 +281,12 @@ def perform(self, node, inputs, outputs):
def infer_shape(self, fgraph, node, input_shapes):
# First output shape depends only on input_shapes
# Second output shape depends on input values
x, y = node.inputs
[(x_shape,), (y_shape,)] = input_shapes
return (x_shape + y_shape,), (x.sum() + y.sum(),)
a_identity, b_identity = node.inputs
# Simulate shape depending on original inputs, not the ones that go directly into the node
a = a_identity.owner.inputs[0]
b = b_identity.owner.inputs[0]
[(a_shape,), (b_shape,)] = input_shapes
return (a_shape + b_shape,), (a.sum() + b.sum(),)

blockwise_op = Blockwise(
core_op=TestOpWithInferShape(), signature="(a),(b)->(c),(d)"
Expand Down