Skip to content

Commit 4205b18

Browse files
committed
Fix bug in local_reshape_to_dimshuffle
1 parent b12dc30 commit 4205b18

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

pytensor/tensor/rewriting/shape.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -966,16 +966,15 @@ def local_reshape_to_dimshuffle(fgraph, node):
966966
inp, output_shape = node.inputs
967967
[output] = node.outputs
968968

969-
# Remove any broadcastable dimensions from the input
970-
squeeze_axes = [i for i, bcast in enumerate(inp.type.broadcastable) if bcast]
971-
972969
# Trivial case, all dimensions of input/output are known to be broadcastable:
973970
# there's nothing to reshape
974971
if all(inp.type.broadcastable) or all(output.type.broadcastable):
972+
squeeze_axes = tuple(range(inp.type.ndim))
975973
new_output_shape = []
976974
expand_axes = tuple(range(output.type.ndim))
977975

978976
else:
977+
squeeze_axes = [i for i, bcast in enumerate(inp.type.broadcastable) if bcast]
979978
unpacked_shape = _unpack_shape_vector(output_shape)
980979
new_output_shape = []
981980
expand_axes = []

tests/tensor/rewriting/test_shape.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,16 @@ def test_squeeze_of_alloc(self):
445445
new_out = rewrite_graph(out, include=("canonicalize", "ShapeOpt"))
446446
assert equal_computations([new_out], [pt.alloc(x, 12, 9)], strict_dtype=False)
447447

448+
def test_reshape_implies_size_1_input(self):
449+
x = pt.matrix("x", shape=(None, None))
450+
out = pt.reshape(x, (1, 1, 1))
451+
452+
new_out = rewrite_graph(out, include=("canonicalize",))
453+
new_out.dprint(print_type=True)
454+
assert equal_computations(
455+
[new_out], [x.dimshuffle("x", "x", "x")], strict_dtype=False
456+
)
457+
448458

449459
def test_expand_dims_squeeze_reshape_fusion():
450460
x = pt.tensor("x", shape=(1, 9))

0 commit comments

Comments
 (0)