From f673c088d4e2f35ce9d142ee3f4b63cb83b5f2a6 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 8 Apr 2025 19:57:47 +0200 Subject: [PATCH 1/3] Fix Blockwise infer shape from core Op Sometimes `_create_dummy_core_node` can create a multi-node graph, where the root inputs are not `node.inputs`. Then infer_shape may bypass the intermediate nodes. This was the case with Subtensor, which introduces `ScalarFromTensor` nodes, but ignores them in the shape graph (for a cleaner graph) --- pytensor/tensor/blockwise.py | 7 ++++--- tests/tensor/test_blockwise.py | 15 +++++++++++---- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index be5e048c77..fe7fe155af 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -7,7 +7,7 @@ from pytensor.compile.builders import OpFromGraph from pytensor.gradient import DisconnectedType from pytensor.graph import FunctionGraph -from pytensor.graph.basic import Apply, Constant, ancestors +from pytensor.graph.basic import Apply, Constant, explicit_graph_inputs from pytensor.graph.null_type import NullType from pytensor.graph.op import Op from pytensor.graph.replace import ( @@ -190,7 +190,7 @@ def infer_shape( core_op_infer_shape = getattr(self.core_op, "infer_shape", None) if core_op_infer_shape is not None: dummy_core_node = self._create_dummy_core_node(node.inputs) - dummy_core_inputs = dummy_core_node.inputs + dummy_core_inputs = tuple(explicit_graph_inputs(dummy_core_node.inputs)) dummy_fgraph = FunctionGraph(outputs=dummy_core_node.outputs, clone=False) core_input_shapes = [ input_shape[batch_ndims:] for input_shape in input_shapes @@ -214,7 +214,8 @@ def infer_shape( # of the core_node as the value is not constant across batch dims of the Blockwise core_out_dim = core_output_shapes[o][i] if not ( - set(dummy_core_inputs) & set(ancestors([core_out_dim])) + set(dummy_core_inputs) + & set(explicit_graph_inputs([core_out_dim])) ): core_out_shape.append(core_out_dim) continue diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index 51862562ac..771ff11ba7 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -264,9 +264,13 @@ class TestOpWithInferShape(Op): def make_node(self, a, b): assert a.type.ndim == 1 assert b.type.ndim == 1 + # Simulate make_node that introduces operations on inputs + a_identity = a.copy() + b_identity = b.copy() + c = tensor(shape=(None,)) d = tensor(shape=(None,)) - return Apply(self, [a, b], [c, d]) + return Apply(self, [a_identity, b_identity], [c, d]) def perform(self, node, inputs, outputs): a, b = inputs @@ -277,9 +281,12 @@ def perform(self, node, inputs, outputs): def infer_shape(self, fgraph, node, input_shapes): # First output shape depends only on input_shapes # Second output shape depends on input values - x, y = node.inputs - [(x_shape,), (y_shape,)] = input_shapes - return (x_shape + y_shape,), (x.sum() + y.sum(),) + a_identity, b_identity = node.inputs + # Simulate shape depending on original inputs, not the ones that go directly into the node + a = a_identity.owner.inputs[0] + b = b_identity.owner.inputs[0] + [(a_shape,), (b_shape,)] = input_shapes + return (a_shape + b_shape,), (a.sum() + b.sum(),) blockwise_op = Blockwise( core_op=TestOpWithInferShape(), signature="(a),(b)->(c),(d)" From 42ca4038babfd4bfdddc493712b99b5293254a78 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 9 Apr 2025 12:33:36 +0200 Subject: [PATCH 2/3] Vectorize ScalarFromTensor --- pytensor/tensor/basic.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 5d6c059c53..e0752f14ea 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -710,6 +710,17 @@ def c_code_cache_version(self): scalar_from_tensor = ScalarFromTensor() +@_vectorize_node.register(ScalarFromTensor) +def vectorize_scalar_from_tensor(op, node, batch_x): + if batch_x.ndim == 0: + return scalar_from_tensor(batch_x).owner + if batch_x.owner is not None: + return batch_x.owner + + # Needed until we fix https://github.com/pymc-devs/pytensor/issues/902 + return batch_x.copy().owner + + # to be removed as we get the epydoc routine-documenting thing going # -JB 20080924 def _conversion(real_value: Op, name: str) -> Op: From 0aeda2107b1cc41a65e1e3e0aea1100be05dabd6 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 8 Apr 2025 20:24:50 +0200 Subject: [PATCH 3/3] Rewrite away blockwise Subtensor in gradient of Blockwise(Conv1d) --- pytensor/tensor/rewriting/blockwise.py | 34 +++++++++++++++++++++++--- tests/tensor/signal/test_conv.py | 26 ++++++++++++++++++-- 2 files changed, 55 insertions(+), 5 deletions(-) diff --git a/pytensor/tensor/rewriting/blockwise.py b/pytensor/tensor/rewriting/blockwise.py index 49bd5510ae..4d2a3715c3 100644 --- a/pytensor/tensor/rewriting/blockwise.py +++ b/pytensor/tensor/rewriting/blockwise.py @@ -14,7 +14,12 @@ register_stabilize, ) from pytensor.tensor.shape import Reshape -from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor, Subtensor +from pytensor.tensor.subtensor import ( + AdvancedIncSubtensor, + AdvancedSubtensor, + Subtensor, + indices_from_subtensor, +) @node_rewriter([Blockwise]) @@ -216,9 +221,9 @@ def local_blockwise_reshape(fgraph, node): Reshape is tricky to vectorize eagerly, because a graph like `x.reshape([x.shape[0] * x.shape[1], -1])` has many operations - that must be vectorized before we arrize at the reshape operation. + that must be vectorized before we arrive at the reshape operation. - For the square Reshape case, we must wait for all the intemediate + For the square Reshape case, we must wait for all the intermediate operations to be lifted as Allocs """ if not isinstance(node.op.core_op, Reshape): @@ -234,6 +239,29 @@ def local_blockwise_reshape(fgraph, node): return [new_out] +@register_stabilize +@register_specialize +@node_rewriter([Blockwise]) +def local_blockwise_of_subtensor(fgraph, node): + """Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor. + + Blockwise(Subtensor{a: b})(x, a, b) -> x[:, a:b] when x has one batch dimension, and a/b none + """ + if not isinstance(node.op.core_op, Subtensor): + return + + x, *idxs = node.inputs + if not all(all(idx.type.broadcastable) for idx in idxs): + return + + core_idxs = indices_from_subtensor( + [idx.squeeze() for idx in idxs], node.op.core_op.idx_list + ) + # Add empty slices for the batch dims + none_slices = (slice(None),) * node.op.batch_ndim(node) + return [x[(*none_slices, *core_idxs)]] + + @node_rewriter(tracks=[Blockwise], inplace=True) def blockwise_inplace(fgraph, node): blockwise_op = node.op diff --git a/tests/tensor/signal/test_conv.py b/tests/tensor/signal/test_conv.py index fe353b18fb..d56d365193 100644 --- a/tests/tensor/signal/test_conv.py +++ b/tests/tensor/signal/test_conv.py @@ -4,9 +4,11 @@ import pytest from scipy.signal import convolve as scipy_convolve -from pytensor import config, function +from pytensor import config, function, grad +from pytensor.graph import ancestors, rewrite_graph from pytensor.tensor import matrix, vector -from pytensor.tensor.signal.conv import convolve1d +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.signal.conv import Conv1d, convolve1d from tests import unittest_tools as utt @@ -60,3 +62,23 @@ def test_convolve1d_batch_same(): res = out.eval({x: x_test, y: y_test}) assert res.shape == (2, 8) + + +@pytest.mark.parametrize("mode", ("full", "valid", "same")) +def test_convolve1d_batch_graph(mode): + """Test that we don't have slow Blockwise Subtensors in graph of a batched convolve1d""" + x = matrix("x") + y = matrix("y") + out = convolve1d(x, y, mode=mode) + grads = grad(out.sum(), wrt=[x, y]) + final_grads = rewrite_graph( + grads, include=("ShapeOpt", "canonicalize", "stabilize", "specialize") + ) + + blockwise_nodes = [ + var.owner + for var in ancestors(final_grads) + if var.owner is not None and isinstance(var.owner.op, Blockwise) + ] + # Check any Blockwise are just Conv1d + assert all(isinstance(node.op.core_op, Conv1d) for node in blockwise_nodes)