diff --git a/pytensor/tensor/einsum.py b/pytensor/tensor/einsum.py index e6bc613b00..cba40ec6f8 100644 --- a/pytensor/tensor/einsum.py +++ b/pytensor/tensor/einsum.py @@ -410,6 +410,12 @@ def _contraction_list_from_path( return contraction_list +def _right_to_left_path(n: int) -> tuple[tuple[int, int], ...]: + # Create a right to left contraction path + # if n = 5, out = ((4, 3), (3, 2), (2, 1), (1, 0)) + return tuple(pairwise(reversed(range(n)))) + + def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVariable: """ Multiplication and summation of tensors using the Einstein summation convention. @@ -546,8 +552,6 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar "If you need this functionality open an issue in https://github.com/pymc-devs/pytensor/issues to let us know. " ) - # TODO: Is this doing something clever about unknown shapes? - # contract_path = _poly_einsum_handlers.get(ty, _default_poly_einsum_handler) tensor_operands = [as_tensor(operand) for operand in operands] shapes = [operand.type.shape for operand in tensor_operands] @@ -565,7 +569,7 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar else: # By default, we try right to left because we assume that most graphs # have a lower dimensional rightmost operand - path = tuple(pairwise(reversed(range(len(tensor_operands))))) + path = _right_to_left_path(len(tensor_operands)) contraction_list = _contraction_list_from_path( subscripts, tensor_operands, path ) @@ -583,7 +587,18 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar einsum_call=True, # Not part of public API optimize="optimal", ) # type: ignore - path = tuple(contraction[0] for contraction in contraction_list) + np_path = tuple(contraction[0] for contraction in contraction_list) + + if len(np_path) == 1 and len(np_path[0]) > 2: + # When there's nothing to optimize, einsum_path reduces all entries simultaneously instead of doing + # pairwise reductions, which our implementation below demands. + path = _right_to_left_path(len(tensor_operands)) + contraction_list = _contraction_list_from_path( + subscripts, tensor_operands, path + ) + else: + path = np_path + optimized = True def removechars(s, chars): @@ -746,7 +761,7 @@ def filter_singleton_dims(operand, names, other_operand, other_names): ) else: raise ValueError( - f"Each step of einsum must have 1 or 2 operands, got {len(operand_indices)}" + f"Each step of einsum must have 1 or 2 operands, got {len(operand_indices)}, {path=}." ) # the resulting 'operand' with axis labels 'names' should be a permutation of the desired result diff --git a/tests/tensor/test_einsum.py b/tests/tensor/test_einsum.py index 426ed13dcd..ba8e354518 100644 --- a/tests/tensor/test_einsum.py +++ b/tests/tensor/test_einsum.py @@ -5,13 +5,14 @@ import pytest import pytensor -import pytensor.tensor as pt from pytensor import Mode, config, function from pytensor.graph import FunctionGraph from pytensor.graph.op import HasInnerGraph +from pytensor.tensor.basic import moveaxis from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.einsum import _delta, _general_dot, _iota, einsum from pytensor.tensor.shape import Reshape +from pytensor.tensor.type import tensor # Fail for unexpected warnings in this file @@ -80,8 +81,8 @@ def test_general_dot(): # X has two batch dims # Y has one batch dim - x = pt.tensor("x", shape=(5, 4, 2, 11, 13, 3)) - y = pt.tensor("y", shape=(4, 13, 5, 7, 11)) + x = tensor("x", shape=(5, 4, 2, 11, 13, 3)) + y = tensor("y", shape=(4, 13, 5, 7, 11)) out = _general_dot((x, y), tensordot_axes, [(0, 1), (0,)]) fn = pytensor.function([x, y], out) @@ -135,10 +136,10 @@ def test_einsum_signatures(static_shape_known, signature): static_shapes = [[None] * len(shape) for shape in shapes] operands = [ - pt.tensor(name, shape=static_shape) + tensor(name, shape=static_shape) for name, static_shape in zip(ascii_lowercase, static_shapes, strict=False) ] - out = pt.einsum(signature, *operands) + out = einsum(signature, *operands) assert out.owner.op.optimized == static_shape_known or len(operands) <= 2 rng = np.random.default_rng(37) @@ -160,8 +161,8 @@ def test_batch_dim(): "x": (7, 3, 5), "y": (5, 2), } - x, y = (pt.tensor(name, shape=shape) for name, shape in shapes.items()) - out = pt.einsum("mij,jk->mik", x, y) + x, y = (tensor(name, shape=shape) for name, shape in shapes.items()) + out = einsum("mij,jk->mik", x, y) assert out.type.shape == (7, 3, 2) @@ -195,24 +196,24 @@ def test_einsum_conv(): def test_ellipsis(): rng = np.random.default_rng(159) - x = pt.tensor("x", shape=(3, 5, 7, 11)) - y = pt.tensor("y", shape=(3, 5, 11, 13)) + x = tensor("x", shape=(3, 5, 7, 11)) + y = tensor("y", shape=(3, 5, 11, 13)) x_test = rng.normal(size=x.type.shape).astype(floatX) y_test = rng.normal(size=y.type.shape).astype(floatX) expected_out = np.matmul(x_test, y_test) with pytest.raises(ValueError): - pt.einsum("mp,pn->mn", x, y) + einsum("mp,pn->mn", x, y) - out = pt.einsum("...mp,...pn->...mn", x, y) + out = einsum("...mp,...pn->...mn", x, y) np.testing.assert_allclose( out.eval({x: x_test, y: y_test}), expected_out, atol=ATOL, rtol=RTOL ) # Put batch axes in the middle - new_x = pt.moveaxis(x, -2, 0) - new_y = pt.moveaxis(y, -2, 0) - out = pt.einsum("m...p,p...n->m...n", new_x, new_y) + new_x = moveaxis(x, -2, 0) + new_y = moveaxis(y, -2, 0) + out = einsum("m...p,p...n->m...n", new_x, new_y) np.testing.assert_allclose( out.eval({x: x_test, y: y_test}), expected_out.transpose(-2, 0, 1, -1), @@ -220,7 +221,7 @@ def test_ellipsis(): rtol=RTOL, ) - out = pt.einsum("m...p,p...n->mn", new_x, new_y) + out = einsum("m...p,p...n->mn", new_x, new_y) np.testing.assert_allclose( out.eval({x: x_test, y: y_test}), expected_out.sum((0, 1)), atol=ATOL, rtol=RTOL ) @@ -236,9 +237,9 @@ def test_broadcastable_dims(): # can lead to suboptimal paths. We check we issue a warning for the following example: # https://github.com/dgasmith/opt_einsum/issues/220 rng = np.random.default_rng(222) - a = pt.tensor("a", shape=(32, 32, 32)) - b = pt.tensor("b", shape=(1000, 32)) - c = pt.tensor("c", shape=(1, 32)) + a = tensor("a", shape=(32, 32, 32)) + b = tensor("b", shape=(1000, 32)) + c = tensor("c", shape=(1, 32)) a_test = rng.normal(size=a.type.shape).astype(floatX) b_test = rng.normal(size=b.type.shape).astype(floatX) @@ -248,11 +249,11 @@ def test_broadcastable_dims(): with pytest.warns( UserWarning, match="This can result in a suboptimal contraction path" ): - suboptimal_out = pt.einsum("ijk,bj,bk->i", a, b, c) + suboptimal_out = einsum("ijk,bj,bk->i", a, b, c) assert not [set(p) for p in suboptimal_out.owner.op.path] == [{0, 2}, {0, 1}] # If we use a distinct letter we get the optimal path - optimal_out = pt.einsum("ijk,bj,ck->i", a, b, c) + optimal_out = einsum("ijk,bj,ck->i", a, b, c) assert [set(p) for p in optimal_out.owner.op.path] == [{0, 2}, {0, 1}] suboptimal_eval = suboptimal_out.eval({a: a_test, b: b_test, c: c_test}) @@ -261,3 +262,22 @@ def test_broadcastable_dims(): atol = 1e-12 if config.floatX == "float64" else 1e-2 np.testing.assert_allclose(suboptimal_eval, np_eval, atol=atol) np.testing.assert_allclose(optimal_eval, np_eval, atol=atol) + + +@pytest.mark.parametrize("static_length", [False, True]) +def test_threeway_mul(static_length): + # Regression test for https://github.com/pymc-devs/pytensor/issues/1184 + # x, y, z = vectors("x", "y", "z") + sh = (3,) if static_length else (None,) + x = tensor("x", shape=sh) + y = tensor("y", shape=sh) + z = tensor("z", shape=sh) + out = einsum("..., ..., ... -> ...", x, y, z) + + x_test = np.ones((3,), dtype=x.dtype) + y_test = x_test + 1 + z_test = x_test + 2 + np.testing.assert_allclose( + out.eval({x: x_test, y: y_test, z: z_test}), + np.full((3,), fill_value=6), + )