Skip to content

Fix einsum bug #1185

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions pytensor/tensor/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,12 @@ def _contraction_list_from_path(
return contraction_list


def _right_to_left_path(n: int) -> tuple[tuple[int, int], ...]:
# Create a right to left contraction path
# if n = 5, out = ((4, 3), (3, 2), (2, 1), (1, 0))
return tuple(pairwise(reversed(range(n))))


def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVariable:
"""
Multiplication and summation of tensors using the Einstein summation convention.
Expand Down Expand Up @@ -546,8 +552,6 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
"If you need this functionality open an issue in https://github.com/pymc-devs/pytensor/issues to let us know. "
)

# TODO: Is this doing something clever about unknown shapes?
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked, they are not doing anything clever, simply set unknown dims to 8 and use whatever comes out of it

# contract_path = _poly_einsum_handlers.get(ty, _default_poly_einsum_handler)
tensor_operands = [as_tensor(operand) for operand in operands]
shapes = [operand.type.shape for operand in tensor_operands]

Expand All @@ -565,7 +569,7 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
else:
# By default, we try right to left because we assume that most graphs
# have a lower dimensional rightmost operand
path = tuple(pairwise(reversed(range(len(tensor_operands)))))
path = _right_to_left_path(len(tensor_operands))
contraction_list = _contraction_list_from_path(
subscripts, tensor_operands, path
)
Expand All @@ -583,7 +587,18 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
einsum_call=True, # Not part of public API
optimize="optimal",
) # type: ignore
path = tuple(contraction[0] for contraction in contraction_list)
np_path = tuple(contraction[0] for contraction in contraction_list)

if len(np_path) == 1 and len(np_path[0]) > 2:
# When there's nothing to optimize, einsum_path reduces all entries simultaneously instead of doing
# pairwise reductions, which our implementation below demands.
path = _right_to_left_path(len(tensor_operands))
contraction_list = _contraction_list_from_path(
subscripts, tensor_operands, path
)
else:
path = np_path

optimized = True

def removechars(s, chars):
Expand Down Expand Up @@ -746,7 +761,7 @@ def filter_singleton_dims(operand, names, other_operand, other_names):
)
else:
raise ValueError(
f"Each step of einsum must have 1 or 2 operands, got {len(operand_indices)}"
f"Each step of einsum must have 1 or 2 operands, got {len(operand_indices)}, {path=}."
)

# the resulting 'operand' with axis labels 'names' should be a permutation of the desired result
Expand Down
60 changes: 40 additions & 20 deletions tests/tensor/test_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import pytest

import pytensor
import pytensor.tensor as pt
from pytensor import Mode, config, function
from pytensor.graph import FunctionGraph
from pytensor.graph.op import HasInnerGraph
from pytensor.tensor.basic import moveaxis
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.einsum import _delta, _general_dot, _iota, einsum
from pytensor.tensor.shape import Reshape
from pytensor.tensor.type import tensor


# Fail for unexpected warnings in this file
Expand Down Expand Up @@ -80,8 +81,8 @@ def test_general_dot():

# X has two batch dims
# Y has one batch dim
x = pt.tensor("x", shape=(5, 4, 2, 11, 13, 3))
y = pt.tensor("y", shape=(4, 13, 5, 7, 11))
x = tensor("x", shape=(5, 4, 2, 11, 13, 3))
y = tensor("y", shape=(4, 13, 5, 7, 11))
out = _general_dot((x, y), tensordot_axes, [(0, 1), (0,)])

fn = pytensor.function([x, y], out)
Expand Down Expand Up @@ -135,10 +136,10 @@ def test_einsum_signatures(static_shape_known, signature):
static_shapes = [[None] * len(shape) for shape in shapes]

operands = [
pt.tensor(name, shape=static_shape)
tensor(name, shape=static_shape)
for name, static_shape in zip(ascii_lowercase, static_shapes, strict=False)
]
out = pt.einsum(signature, *operands)
out = einsum(signature, *operands)
assert out.owner.op.optimized == static_shape_known or len(operands) <= 2

rng = np.random.default_rng(37)
Expand All @@ -160,8 +161,8 @@ def test_batch_dim():
"x": (7, 3, 5),
"y": (5, 2),
}
x, y = (pt.tensor(name, shape=shape) for name, shape in shapes.items())
out = pt.einsum("mij,jk->mik", x, y)
x, y = (tensor(name, shape=shape) for name, shape in shapes.items())
out = einsum("mij,jk->mik", x, y)

assert out.type.shape == (7, 3, 2)

Expand Down Expand Up @@ -195,32 +196,32 @@ def test_einsum_conv():

def test_ellipsis():
rng = np.random.default_rng(159)
x = pt.tensor("x", shape=(3, 5, 7, 11))
y = pt.tensor("y", shape=(3, 5, 11, 13))
x = tensor("x", shape=(3, 5, 7, 11))
y = tensor("y", shape=(3, 5, 11, 13))
x_test = rng.normal(size=x.type.shape).astype(floatX)
y_test = rng.normal(size=y.type.shape).astype(floatX)
expected_out = np.matmul(x_test, y_test)

with pytest.raises(ValueError):
pt.einsum("mp,pn->mn", x, y)
einsum("mp,pn->mn", x, y)

out = pt.einsum("...mp,...pn->...mn", x, y)
out = einsum("...mp,...pn->...mn", x, y)
np.testing.assert_allclose(
out.eval({x: x_test, y: y_test}), expected_out, atol=ATOL, rtol=RTOL
)

# Put batch axes in the middle
new_x = pt.moveaxis(x, -2, 0)
new_y = pt.moveaxis(y, -2, 0)
out = pt.einsum("m...p,p...n->m...n", new_x, new_y)
new_x = moveaxis(x, -2, 0)
new_y = moveaxis(y, -2, 0)
out = einsum("m...p,p...n->m...n", new_x, new_y)
np.testing.assert_allclose(
out.eval({x: x_test, y: y_test}),
expected_out.transpose(-2, 0, 1, -1),
atol=ATOL,
rtol=RTOL,
)

out = pt.einsum("m...p,p...n->mn", new_x, new_y)
out = einsum("m...p,p...n->mn", new_x, new_y)
np.testing.assert_allclose(
out.eval({x: x_test, y: y_test}), expected_out.sum((0, 1)), atol=ATOL, rtol=RTOL
)
Expand All @@ -236,9 +237,9 @@ def test_broadcastable_dims():
# can lead to suboptimal paths. We check we issue a warning for the following example:
# https://github.com/dgasmith/opt_einsum/issues/220
rng = np.random.default_rng(222)
a = pt.tensor("a", shape=(32, 32, 32))
b = pt.tensor("b", shape=(1000, 32))
c = pt.tensor("c", shape=(1, 32))
a = tensor("a", shape=(32, 32, 32))
b = tensor("b", shape=(1000, 32))
c = tensor("c", shape=(1, 32))

a_test = rng.normal(size=a.type.shape).astype(floatX)
b_test = rng.normal(size=b.type.shape).astype(floatX)
Expand All @@ -248,11 +249,11 @@ def test_broadcastable_dims():
with pytest.warns(
UserWarning, match="This can result in a suboptimal contraction path"
):
suboptimal_out = pt.einsum("ijk,bj,bk->i", a, b, c)
suboptimal_out = einsum("ijk,bj,bk->i", a, b, c)
assert not [set(p) for p in suboptimal_out.owner.op.path] == [{0, 2}, {0, 1}]

# If we use a distinct letter we get the optimal path
optimal_out = pt.einsum("ijk,bj,ck->i", a, b, c)
optimal_out = einsum("ijk,bj,ck->i", a, b, c)
assert [set(p) for p in optimal_out.owner.op.path] == [{0, 2}, {0, 1}]

suboptimal_eval = suboptimal_out.eval({a: a_test, b: b_test, c: c_test})
Expand All @@ -261,3 +262,22 @@ def test_broadcastable_dims():
atol = 1e-12 if config.floatX == "float64" else 1e-2
np.testing.assert_allclose(suboptimal_eval, np_eval, atol=atol)
np.testing.assert_allclose(optimal_eval, np_eval, atol=atol)


@pytest.mark.parametrize("static_length", [False, True])
def test_threeway_mul(static_length):
# Regression test for https://github.com/pymc-devs/pytensor/issues/1184
# x, y, z = vectors("x", "y", "z")
sh = (3,) if static_length else (None,)
x = tensor("x", shape=sh)
y = tensor("y", shape=sh)
z = tensor("z", shape=sh)
out = einsum("..., ..., ... -> ...", x, y, z)

x_test = np.ones((3,), dtype=x.dtype)
y_test = x_test + 1
z_test = x_test + 2
np.testing.assert_allclose(
out.eval({x: x_test, y: y_test, z: z_test}),
np.full((3,), fill_value=6),
)