Skip to content

Commit 7ecb9f8

Browse files
committed
Fix vectorized shape
1 parent 861f95c commit 7ecb9f8

File tree

2 files changed

+66
-16
lines changed

2 files changed

+66
-16
lines changed

pytensor/tensor/shape.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pytensor
99
from pytensor.gradient import DisconnectedType
1010
from pytensor.graph.basic import Apply, Variable
11-
from pytensor.graph.replace import _vectorize_node, _vectorize_not_needed
11+
from pytensor.graph.replace import _vectorize_node
1212
from pytensor.graph.type import HasShape
1313
from pytensor.link.c.op import COp
1414
from pytensor.link.c.params_type import ParamsType
@@ -155,7 +155,20 @@ def _get_vector_length_Shape(op, var):
155155
return var.owner.inputs[0].type.ndim
156156

157157

158-
_vectorize_node.register(Shape, _vectorize_not_needed)
158+
@_vectorize_node.register(Shape)
159+
def vectorize_shape(op, node, batched_x):
160+
from pytensor.tensor.extra_ops import broadcast_to
161+
162+
[old_x] = node.inputs
163+
core_ndims = old_x.type.ndim
164+
batch_ndims = batched_x.type.ndim - core_ndims
165+
batched_x_shape = shape(batched_x)
166+
if not batch_ndims:
167+
return batched_x_shape.owner
168+
else:
169+
batch_shape = batched_x_shape[:batch_ndims]
170+
core_shape = batched_x_shape[batch_ndims:]
171+
return broadcast_to(core_shape, (*batch_shape, core_ndims)).owner
159172

160173

161174
def shape_tuple(x: TensorVariable) -> tuple[Variable, ...]:

tests/tensor/test_shape.py

+51-14
Original file line numberDiff line numberDiff line change
@@ -709,29 +709,66 @@ def test_shape_tuple():
709709

710710

711711
class TestVectorize:
712+
@pytensor.config.change_flags(cxx="") # For faster eval
712713
def test_shape(self):
713-
vec = tensor(shape=(None,))
714-
mat = tensor(shape=(None, None))
715-
714+
vec = tensor(shape=(None,), dtype="float64")
715+
mat = tensor(shape=(None, None), dtype="float64")
716716
node = shape(vec).owner
717-
vect_node = vectorize_node(node, mat)
718-
assert equal_computations(vect_node.outputs, [shape(mat)])
719717

718+
[vect_out] = vectorize_node(node, mat).outputs
719+
assert equal_computations(
720+
[vect_out], [broadcast_to(mat.shape[1:], (*mat.shape[:1], 1))]
721+
)
722+
723+
mat_test_value = np.ones((5, 3))
724+
ref_fn = np.vectorize(lambda vec: np.asarray(vec.shape), signature="(vec)->(1)")
725+
np.testing.assert_array_equal(
726+
vect_out.eval({mat: mat_test_value}),
727+
ref_fn(mat_test_value),
728+
)
729+
730+
mat = tensor(shape=(None, None), dtype="float64")
731+
tns = tensor(shape=(None, None, None, None), dtype="float64")
732+
node = shape(mat).owner
733+
[vect_out] = vectorize_node(node, tns).outputs
734+
assert equal_computations(
735+
[vect_out], [broadcast_to(tns.shape[2:], (*tns.shape[:2], 2))]
736+
)
737+
738+
tns_test_value = np.ones((4, 6, 5, 3))
739+
ref_fn = np.vectorize(
740+
lambda vec: np.asarray(vec.shape), signature="(m1,m2)->(2)"
741+
)
742+
np.testing.assert_array_equal(
743+
vect_out.eval({tns: tns_test_value}),
744+
ref_fn(tns_test_value),
745+
)
746+
747+
@pytensor.config.change_flags(cxx="") # For faster eval
720748
def test_reshape(self):
721749
x = scalar("x", dtype=int)
722-
vec = tensor(shape=(None,))
723-
mat = tensor(shape=(None, None))
750+
vec = tensor(shape=(None,), dtype="float64")
751+
mat = tensor(shape=(None, None), dtype="float64")
724752

725-
shape = (2, x)
753+
shape = (-1, x)
726754
node = reshape(vec, shape).owner
727-
vect_node = vectorize_node(node, mat, shape)
728-
assert equal_computations(
729-
vect_node.outputs, [reshape(mat, (*mat.shape[:1], 2, x))]
755+
756+
[vect_out] = vectorize_node(node, mat, shape).outputs
757+
assert equal_computations([vect_out], [reshape(mat, (*mat.shape[:1], -1, x))])
758+
759+
x_test_value = 2
760+
mat_test_value = np.ones((5, 6))
761+
ref_fn = np.vectorize(
762+
lambda x, vec: vec.reshape(-1, x), signature="(),(vec1)->(mat1,mat2)"
763+
)
764+
np.testing.assert_array_equal(
765+
vect_out.eval({x: x_test_value, mat: mat_test_value}),
766+
ref_fn(x_test_value, mat_test_value),
730767
)
731768

732-
new_shape = (5, 2, x)
733-
vect_node = vectorize_node(node, mat, new_shape)
734-
assert equal_computations(vect_node.outputs, [reshape(mat, new_shape)])
769+
new_shape = (5, -1, x)
770+
[vect_out] = vectorize_node(node, mat, new_shape).outputs
771+
assert equal_computations([vect_out], [reshape(mat, new_shape)])
735772

736773
with pytest.raises(NotImplementedError):
737774
vectorize_node(node, vec, broadcast_to(as_tensor([5, 2, x]), (2, 3)))

0 commit comments

Comments
 (0)