Skip to content

Commit 8e5e8a4

Browse files
committed
Only do reshapes in tensordot when needed
1 parent 65b96c1 commit 8e5e8a4

File tree

2 files changed

+86
-34
lines changed

2 files changed

+86
-34
lines changed

Diff for: pytensor/tensor/math.py

+49-32
Original file line numberDiff line numberDiff line change
@@ -2158,62 +2158,79 @@ def tensordot(
21582158
a = as_tensor_variable(a)
21592159
b = as_tensor_variable(b)
21602160
runtime_shape_a = a.shape
2161-
bcast_a = a.broadcastable
21622161
static_shape_a = a.type.shape
2163-
ndim_a = a.ndim
2162+
ndim_a = a.type.ndim
21642163
runtime_shape_b = b.shape
2165-
bcast_b = b.broadcastable
21662164
static_shape_b = b.type.shape
2167-
ndim_b = b.ndim
2165+
ndim_b = b.type.ndim
21682166
if na != nb:
21692167
raise ValueError(
21702168
"The number of axes supplied for tensordot must be equal for each tensor. "
21712169
f"Got {na} and {nb} respectively."
21722170
)
21732171
axes_a = list(normalize_axis_tuple(axes_a, ndim_a))
21742172
axes_b = list(normalize_axis_tuple(axes_b, ndim_b))
2173+
2174+
# The operation is only valid if the original dimensions match in length
2175+
# The ravelling of the dimensions to coerce the operation into a single dot
2176+
# could mask such errors, so we add an Assert if needed.
21752177
must_assert_runtime = False
2176-
for k in range(na):
2177-
ax_a = axes_a[k]
2178-
ax_b = axes_b[k]
2179-
if (bcast_a[ax_a] != bcast_b[ax_b]) or (
2178+
for ax_a, ax_b in zip(axes_a, axes_b, strict=True):
2179+
if (
21802180
static_shape_a[ax_a] is not None
21812181
and static_shape_b[ax_b] is not None
21822182
and static_shape_a[ax_a] != static_shape_b[ax_b]
21832183
):
21842184
raise ValueError(
2185-
"Input arrays have inconsistent broadcastable pattern or type shape along the axes "
2185+
"Input arrays have inconsistent type shape along the axes "
21862186
"that are to be reduced with tensordot."
21872187
)
21882188
elif static_shape_a[ax_a] is None or static_shape_b[ax_b] is None:
21892189
if must_assert_runtime:
21902190
a = Assert(
21912191
"Input array shape along reduced axes of tensordot are not equal"
2192-
)(a, eq(a.shape[ax_a], b.shape[ax_b]))
2192+
)(a, eq(runtime_shape_a[ax_a], runtime_shape_b[ax_b]))
21932193
must_assert_runtime = True
21942194

2195-
# Move the axes to sum over to the end of "a"
2196-
# and to the front of "b"
2197-
notin = [k for k in range(ndim_a) if k not in axes_a]
2198-
newaxes_a = notin + axes_a
2199-
N2 = 1
2200-
for axis in axes_a:
2201-
N2 *= runtime_shape_a[axis]
2202-
newshape_a = (-1, N2)
2203-
olda = [runtime_shape_a[axis] for axis in notin]
2204-
2205-
notin = [k for k in range(ndim_b) if k not in axes_b]
2206-
newaxes_b = axes_b + notin
2207-
N2 = 1
2208-
for axis in axes_b:
2209-
N2 *= runtime_shape_b[axis]
2210-
newshape_b = (N2, -1)
2211-
oldb = [runtime_shape_b[axis] for axis in notin]
2212-
2213-
at = a.transpose(newaxes_a).reshape(newshape_a)
2214-
bt = b.transpose(newaxes_b).reshape(newshape_b)
2215-
res = _dot(at, bt)
2216-
return res.reshape(olda + oldb)
2195+
# Convert tensordot into a stacked dot product.
2196+
# We stack the summed axes and the non-summed axes of each tensor separately,
2197+
# and place the summed axes at the end of a and the beginning of b
2198+
non_summed_axes_a = [k for k in range(ndim_a) if k not in axes_a]
2199+
non_summed_dims_a = [runtime_shape_a[axis] for axis in non_summed_axes_a]
2200+
transpose_axes_a = non_summed_axes_a + axes_a
2201+
# We only need a reshape when we need to combine summed or non-summed dims
2202+
# or introduce a new dimension (expand_dims) when doing a non-scalar outer product (len(axes) = 0)
2203+
a_needs_reshape = (ndim_a != 0) and (
2204+
(len(non_summed_axes_a) > 1) or (len(axes_a) != 1)
2205+
)
2206+
2207+
non_summed_axes_b = [k for k in range(ndim_b) if k not in axes_b]
2208+
non_summed_dims_b = [runtime_shape_b[axis] for axis in non_summed_axes_b]
2209+
transpose_axes_b = axes_b + non_summed_axes_b
2210+
b_needs_reshape = (ndim_b != 0) and (
2211+
(len(non_summed_axes_b) > 1) or (len(axes_b) != 1)
2212+
)
2213+
2214+
# summed_size_a and summed_size_b must be the same,
2215+
# but to facilitate reasoning about useless reshapes we compute both from their shapes
2216+
at = a.transpose(transpose_axes_a)
2217+
if a_needs_reshape:
2218+
non_summed_size_a = variadic_mul(*non_summed_dims_a)
2219+
summed_size_a = variadic_mul(*[runtime_shape_a[axis] for axis in axes_a])
2220+
at = at.reshape((non_summed_size_a, summed_size_a))
2221+
2222+
bt = b.transpose(transpose_axes_b)
2223+
if b_needs_reshape:
2224+
non_summed_size_b = variadic_mul(*non_summed_dims_b)
2225+
summed_size_b = variadic_mul(*[runtime_shape_b[axis] for axis in axes_b])
2226+
bt = bt.reshape((summed_size_b, non_summed_size_b))
2227+
2228+
res = dot(at, bt)
2229+
2230+
if a_needs_reshape or b_needs_reshape:
2231+
res = res.reshape(non_summed_dims_a + non_summed_dims_b)
2232+
2233+
return res
22172234

22182235

22192236
def outer(x, y):

Diff for: tests/tensor/test_math.py

+37-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pytensor.compile.sharedvalue import shared
2020
from pytensor.configdefaults import config
2121
from pytensor.gradient import NullTypeGradError, grad, numeric_grad
22-
from pytensor.graph.basic import Variable, ancestors, applys_between
22+
from pytensor.graph.basic import Variable, ancestors, applys_between, equal_computations
2323
from pytensor.graph.fg import FunctionGraph
2424
from pytensor.graph.replace import vectorize_node
2525
from pytensor.link.c.basic import DualLinker
@@ -2278,7 +2278,7 @@ def test_type_shape(self):
22782278

22792279
with pytest.raises(
22802280
ValueError,
2281-
match="Input arrays have inconsistent broadcastable pattern or type shape",
2281+
match="Input arrays have inconsistent type shape",
22822282
):
22832283
tensordot(ones(shape=(7, 4)), ones(shape=(7, 4)), axes=1)
22842284

@@ -2323,6 +2323,41 @@ def test_shape_assert(self, axes, has_assert, values, expected_fail):
23232323
else:
23242324
assert np.allclose(np.tensordot(xv, yv, axes=axes), z.eval({x: xv, y: yv}))
23252325

2326+
def test_eager_simplification(self):
2327+
# Test that cases where tensordot isn't needed, it returns a simple graph
2328+
scl = tensor(shape=())
2329+
vec = tensor(shape=(None,))
2330+
mat = tensor(shape=(None, None))
2331+
2332+
# scalar product
2333+
out = tensordot(scl, scl, axes=[[], []])
2334+
assert equal_computations([out], [scl * scl])
2335+
2336+
# vector-vector product
2337+
out = tensordot(vec, vec, axes=[[-1], [-1]])
2338+
assert equal_computations([out], [dot(vec, vec)])
2339+
2340+
# matrix-vector product
2341+
out = tensordot(mat, vec, axes=[[-1], [-1]])
2342+
assert equal_computations([out], [dot(mat, vec)])
2343+
2344+
out = tensordot(mat, vec, axes=[[-2], [-1]])
2345+
assert equal_computations([out], [dot(mat.T, vec)])
2346+
2347+
# vector-matrix product
2348+
out = tensordot(vec, mat, axes=[[-1], [-2]])
2349+
assert equal_computations([out], [dot(vec, mat)])
2350+
2351+
out = tensordot(vec, mat, axes=[[-1], [-1]])
2352+
assert equal_computations([out], [dot(vec, mat.T)])
2353+
2354+
# matrix-matrix product
2355+
out = tensordot(mat, mat, axes=[[-1], [-2]])
2356+
assert equal_computations([out], [dot(mat, mat)])
2357+
2358+
out = tensordot(mat, mat, axes=[[-1], [-1]])
2359+
assert equal_computations([out], [dot(mat, mat.T)])
2360+
23262361

23272362
def test_smallest():
23282363
x = dvector()

0 commit comments

Comments
 (0)