Skip to content

Commit 67017a6

Browse files
committed
Fix bug in local_reshape_to_dimshuffle
1 parent b12dc30 commit 67017a6

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

pytensor/tensor/rewriting/shape.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -966,16 +966,15 @@ def local_reshape_to_dimshuffle(fgraph, node):
966966
inp, output_shape = node.inputs
967967
[output] = node.outputs
968968

969-
# Remove any broadcastable dimensions from the input
970-
squeeze_axes = [i for i, bcast in enumerate(inp.type.broadcastable) if bcast]
971-
972969
# Trivial case, all dimensions of input/output are known to be broadcastable:
973970
# there's nothing to reshape
974971
if all(inp.type.broadcastable) or all(output.type.broadcastable):
972+
squeeze_axes = tuple(range(inp.type.ndim))
975973
new_output_shape = []
976974
expand_axes = tuple(range(output.type.ndim))
977975

978976
else:
977+
squeeze_axes = [i for i, bcast in enumerate(inp.type.broadcastable) if bcast]
979978
unpacked_shape = _unpack_shape_vector(output_shape)
980979
new_output_shape = []
981980
expand_axes = []

tests/tensor/rewriting/test_shape.py

+9
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,15 @@ def test_squeeze_of_alloc(self):
445445
new_out = rewrite_graph(out, include=("canonicalize", "ShapeOpt"))
446446
assert equal_computations([new_out], [pt.alloc(x, 12, 9)], strict_dtype=False)
447447

448+
def test_reshape_implies_size_1_input(self):
449+
x = pt.matrix("x", shape=(None, None))
450+
out = pt.reshape(x, (1, 1, 1))
451+
452+
new_out = rewrite_graph(out, include=("canonicalize",))
453+
assert equal_computations(
454+
[new_out], [x.dimshuffle("x", "x", "x")], strict_dtype=False
455+
)
456+
448457

449458
def test_expand_dims_squeeze_reshape_fusion():
450459
x = pt.tensor("x", shape=(1, 9))

0 commit comments

Comments
 (0)