Skip to content

Only do reshapes in tensordot when needed #1202

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 49 additions & 32 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2158,62 +2158,79 @@ def tensordot(
a = as_tensor_variable(a)
b = as_tensor_variable(b)
runtime_shape_a = a.shape
bcast_a = a.broadcastable
static_shape_a = a.type.shape
ndim_a = a.ndim
ndim_a = a.type.ndim
runtime_shape_b = b.shape
bcast_b = b.broadcastable
static_shape_b = b.type.shape
ndim_b = b.ndim
ndim_b = b.type.ndim
if na != nb:
raise ValueError(
"The number of axes supplied for tensordot must be equal for each tensor. "
f"Got {na} and {nb} respectively."
)
axes_a = list(normalize_axis_tuple(axes_a, ndim_a))
axes_b = list(normalize_axis_tuple(axes_b, ndim_b))

# The operation is only valid if the original dimensions match in length
# The ravelling of the dimensions to coerce the operation into a single dot
# could mask such errors, so we add an Assert if needed.
must_assert_runtime = False
for k in range(na):
ax_a = axes_a[k]
ax_b = axes_b[k]
if (bcast_a[ax_a] != bcast_b[ax_b]) or (
for ax_a, ax_b in zip(axes_a, axes_b, strict=True):
if (
static_shape_a[ax_a] is not None
and static_shape_b[ax_b] is not None
and static_shape_a[ax_a] != static_shape_b[ax_b]
):
raise ValueError(
"Input arrays have inconsistent broadcastable pattern or type shape along the axes "
"Input arrays have inconsistent type shape along the axes "
"that are to be reduced with tensordot."
)
elif static_shape_a[ax_a] is None or static_shape_b[ax_b] is None:
if must_assert_runtime:
a = Assert(
"Input array shape along reduced axes of tensordot are not equal"
)(a, eq(a.shape[ax_a], b.shape[ax_b]))
)(a, eq(runtime_shape_a[ax_a], runtime_shape_b[ax_b]))
must_assert_runtime = True

# Move the axes to sum over to the end of "a"
# and to the front of "b"
notin = [k for k in range(ndim_a) if k not in axes_a]
newaxes_a = notin + axes_a
N2 = 1
for axis in axes_a:
N2 *= runtime_shape_a[axis]
newshape_a = (-1, N2)
olda = [runtime_shape_a[axis] for axis in notin]

notin = [k for k in range(ndim_b) if k not in axes_b]
newaxes_b = axes_b + notin
N2 = 1
for axis in axes_b:
N2 *= runtime_shape_b[axis]
newshape_b = (N2, -1)
oldb = [runtime_shape_b[axis] for axis in notin]

at = a.transpose(newaxes_a).reshape(newshape_a)
bt = b.transpose(newaxes_b).reshape(newshape_b)
res = _dot(at, bt)
return res.reshape(olda + oldb)
# Convert tensordot into a stacked dot product.
# We stack the summed axes and the non-summed axes of each tensor separately,
# and place the summed axes at the end of a and the beginning of b
non_summed_axes_a = [k for k in range(ndim_a) if k not in axes_a]
non_summed_dims_a = [runtime_shape_a[axis] for axis in non_summed_axes_a]
transpose_axes_a = non_summed_axes_a + axes_a
# We only need a reshape when we need to combine summed or non-summed dims
# or introduce a new dimension (expand_dims) when doing a non-scalar outer product (len(axes) = 0)
a_needs_reshape = (ndim_a != 0) and (
(len(non_summed_axes_a) > 1) or (len(axes_a) != 1)
)

non_summed_axes_b = [k for k in range(ndim_b) if k not in axes_b]
non_summed_dims_b = [runtime_shape_b[axis] for axis in non_summed_axes_b]
transpose_axes_b = axes_b + non_summed_axes_b
b_needs_reshape = (ndim_b != 0) and (
(len(non_summed_axes_b) > 1) or (len(axes_b) != 1)
)

# summed_size_a and summed_size_b must be the same,
# but to facilitate reasoning about useless reshapes we compute both from their shapes
at = a.transpose(transpose_axes_a)
if a_needs_reshape:
non_summed_size_a = variadic_mul(*non_summed_dims_a)
summed_size_a = variadic_mul(*[runtime_shape_a[axis] for axis in axes_a])
at = at.reshape((non_summed_size_a, summed_size_a))

bt = b.transpose(transpose_axes_b)
if b_needs_reshape:
non_summed_size_b = variadic_mul(*non_summed_dims_b)
summed_size_b = variadic_mul(*[runtime_shape_b[axis] for axis in axes_b])
bt = bt.reshape((summed_size_b, non_summed_size_b))

res = dot(at, bt)

if a_needs_reshape or b_needs_reshape:
res = res.reshape(non_summed_dims_a + non_summed_dims_b)

return res


def outer(x, y):
Expand Down
39 changes: 37 additions & 2 deletions tests/tensor/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pytensor.compile.sharedvalue import shared
from pytensor.configdefaults import config
from pytensor.gradient import NullTypeGradError, grad, numeric_grad
from pytensor.graph.basic import Variable, ancestors, applys_between
from pytensor.graph.basic import Variable, ancestors, applys_between, equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import vectorize_node
from pytensor.link.c.basic import DualLinker
Expand Down Expand Up @@ -2278,7 +2278,7 @@ def test_type_shape(self):

with pytest.raises(
ValueError,
match="Input arrays have inconsistent broadcastable pattern or type shape",
match="Input arrays have inconsistent type shape",
):
tensordot(ones(shape=(7, 4)), ones(shape=(7, 4)), axes=1)

Expand Down Expand Up @@ -2323,6 +2323,41 @@ def test_shape_assert(self, axes, has_assert, values, expected_fail):
else:
assert np.allclose(np.tensordot(xv, yv, axes=axes), z.eval({x: xv, y: yv}))

def test_eager_simplification(self):
# Test that cases where tensordot isn't needed, it returns a simple graph
scl = tensor(shape=())
vec = tensor(shape=(None,))
mat = tensor(shape=(None, None))

# scalar product
out = tensordot(scl, scl, axes=[[], []])
assert equal_computations([out], [scl * scl])

# vector-vector product
out = tensordot(vec, vec, axes=[[-1], [-1]])
assert equal_computations([out], [dot(vec, vec)])

# matrix-vector product
out = tensordot(mat, vec, axes=[[-1], [-1]])
assert equal_computations([out], [dot(mat, vec)])

out = tensordot(mat, vec, axes=[[-2], [-1]])
assert equal_computations([out], [dot(mat.T, vec)])

# vector-matrix product
out = tensordot(vec, mat, axes=[[-1], [-2]])
assert equal_computations([out], [dot(vec, mat)])

out = tensordot(vec, mat, axes=[[-1], [-1]])
assert equal_computations([out], [dot(vec, mat.T)])

# matrix-matrix product
out = tensordot(mat, mat, axes=[[-1], [-2]])
assert equal_computations([out], [dot(mat, mat)])

out = tensordot(mat, mat, axes=[[-1], [-1]])
assert equal_computations([out], [dot(mat, mat.T)])


def test_smallest():
x = dvector()
Expand Down