Skip to content

Commit 7fcaf62

Browse files
committed
Only do reshapes in tensordot when needed
1 parent 3cff4f5 commit 7fcaf62

File tree

2 files changed

+86
-34
lines changed

2 files changed

+86
-34
lines changed

Diff for: pytensor/tensor/math.py

+49-32
Original file line numberDiff line numberDiff line change
@@ -2152,62 +2152,79 @@ def tensordot(
21522152
a = as_tensor_variable(a)
21532153
b = as_tensor_variable(b)
21542154
runtime_shape_a = a.shape
2155-
bcast_a = a.broadcastable
21562155
static_shape_a = a.type.shape
2157-
ndim_a = a.ndim
2156+
ndim_a = a.type.ndim
21582157
runtime_shape_b = b.shape
2159-
bcast_b = b.broadcastable
21602158
static_shape_b = b.type.shape
2161-
ndim_b = b.ndim
2159+
ndim_b = b.type.ndim
21622160
if na != nb:
21632161
raise ValueError(
21642162
"The number of axes supplied for tensordot must be equal for each tensor. "
21652163
f"Got {na} and {nb} respectively."
21662164
)
21672165
axes_a = list(normalize_axis_tuple(axes_a, ndim_a))
21682166
axes_b = list(normalize_axis_tuple(axes_b, ndim_b))
2167+
2168+
# The operation is only valid if the original dimensions match in length
2169+
# The ravelling of the dimensions to coerce the operation into a single dot
2170+
# could mask such errors, so we add an Assert if needed.
21692171
must_assert_runtime = False
2170-
for k in range(na):
2171-
ax_a = axes_a[k]
2172-
ax_b = axes_b[k]
2173-
if (bcast_a[ax_a] != bcast_b[ax_b]) or (
2172+
for ax_a, ax_b in zip(axes_a, axes_b, strict=True):
2173+
if (
21742174
static_shape_a[ax_a] is not None
21752175
and static_shape_b[ax_b] is not None
21762176
and static_shape_a[ax_a] != static_shape_b[ax_b]
21772177
):
21782178
raise ValueError(
2179-
"Input arrays have inconsistent broadcastable pattern or type shape along the axes "
2179+
"Input arrays have inconsistent type shape along the axes "
21802180
"that are to be reduced with tensordot."
21812181
)
21822182
elif static_shape_a[ax_a] is None or static_shape_b[ax_b] is None:
21832183
if must_assert_runtime:
21842184
a = Assert(
21852185
"Input array shape along reduced axes of tensordot are not equal"
2186-
)(a, eq(a.shape[ax_a], b.shape[ax_b]))
2186+
)(a, eq(runtime_shape_a[ax_a], runtime_shape_b[ax_b]))
21872187
must_assert_runtime = True
21882188

2189-
# Move the axes to sum over to the end of "a"
2190-
# and to the front of "b"
2191-
notin = [k for k in range(ndim_a) if k not in axes_a]
2192-
newaxes_a = notin + axes_a
2193-
N2 = 1
2194-
for axis in axes_a:
2195-
N2 *= runtime_shape_a[axis]
2196-
newshape_a = (-1, N2)
2197-
olda = [runtime_shape_a[axis] for axis in notin]
2198-
2199-
notin = [k for k in range(ndim_b) if k not in axes_b]
2200-
newaxes_b = axes_b + notin
2201-
N2 = 1
2202-
for axis in axes_b:
2203-
N2 *= runtime_shape_b[axis]
2204-
newshape_b = (N2, -1)
2205-
oldb = [runtime_shape_b[axis] for axis in notin]
2206-
2207-
at = a.transpose(newaxes_a).reshape(newshape_a)
2208-
bt = b.transpose(newaxes_b).reshape(newshape_b)
2209-
res = _dot(at, bt)
2210-
return res.reshape(olda + oldb)
2189+
# Convert tensordot into a stacked dot product.
2190+
# We stack the summed axes and the non-summed axes of each tensor separately,
2191+
# and place the summed axes at the end of a and the beginning of b
2192+
non_summed_axes_a = [k for k in range(ndim_a) if k not in axes_a]
2193+
non_summed_dims_a = [runtime_shape_a[axis] for axis in non_summed_axes_a]
2194+
transpose_axes_a = non_summed_axes_a + axes_a
2195+
# We only need a reshape when we need to combine summed or non-summed dims
2196+
# or introduce a new dimension (expand_dims), when doing a non-scalar outer product (axes = 0)
2197+
a_needs_reshape = (ndim_a != 0) and (
2198+
(len(non_summed_axes_a) > 1) or (len(axes_a) != 1)
2199+
)
2200+
2201+
non_summed_axes_b = [k for k in range(ndim_b) if k not in axes_b]
2202+
non_summed_dims_b = [runtime_shape_b[axis] for axis in non_summed_axes_b]
2203+
transpose_axes_b = axes_b + non_summed_axes_b
2204+
b_needs_reshape = (ndim_b != 0) and (
2205+
(len(non_summed_axes_b) > 1) or (len(axes_b) != 1)
2206+
)
2207+
2208+
# summed_size_a and summed_size_b must be the same,
2209+
# but to facilitate reasoning about useless reshapes we compute both from their shapes
2210+
at = a.transpose(transpose_axes_a)
2211+
if a_needs_reshape:
2212+
non_summed_size_a = variadic_mul(*non_summed_dims_a)
2213+
summed_size_a = variadic_mul(*[runtime_shape_a[axis] for axis in axes_a])
2214+
at = at.reshape((non_summed_size_a, summed_size_a))
2215+
2216+
bt = b.transpose(transpose_axes_b)
2217+
if b_needs_reshape:
2218+
non_summed_size_b = variadic_mul(*non_summed_dims_b)
2219+
summed_size_b = variadic_mul(*[runtime_shape_b[axis] for axis in axes_b])
2220+
bt = bt.reshape((summed_size_b, non_summed_size_b))
2221+
2222+
res = dot(at, bt)
2223+
2224+
if a_needs_reshape or b_needs_reshape:
2225+
res = res.reshape(non_summed_dims_a + non_summed_dims_b)
2226+
2227+
return res
22112228

22122229

22132230
def outer(x, y):

Diff for: tests/tensor/test_math.py

+37-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pytensor.compile.sharedvalue import shared
2020
from pytensor.configdefaults import config
2121
from pytensor.gradient import NullTypeGradError, grad, numeric_grad
22-
from pytensor.graph.basic import Variable, ancestors, applys_between
22+
from pytensor.graph.basic import Variable, ancestors, applys_between, equal_computations
2323
from pytensor.graph.fg import FunctionGraph
2424
from pytensor.graph.replace import vectorize_node
2525
from pytensor.link.c.basic import DualLinker
@@ -2278,7 +2278,7 @@ def test_type_shape(self):
22782278

22792279
with pytest.raises(
22802280
ValueError,
2281-
match="Input arrays have inconsistent broadcastable pattern or type shape",
2281+
match="Input arrays have inconsistent type shape",
22822282
):
22832283
tensordot(ones(shape=(7, 4)), ones(shape=(7, 4)), axes=1)
22842284

@@ -2323,6 +2323,41 @@ def test_shape_assert(self, axes, has_assert, values, expected_fail):
23232323
else:
23242324
assert np.allclose(np.tensordot(xv, yv, axes=axes), z.eval({x: xv, y: yv}))
23252325

2326+
def test_eager_simplification(self):
2327+
# Test that cases where tensordot isn't needed, it returns a simple graph
2328+
scl = tensor(shape=())
2329+
vec = tensor(shape=(None,))
2330+
mat = tensor(shape=(None, None))
2331+
2332+
# scalar product
2333+
out = tensordot(scl, scl, axes=[[], []])
2334+
assert equal_computations([out], [scl * scl])
2335+
2336+
# vector-vector product
2337+
out = tensordot(vec, vec, axes=[[-1], [-1]])
2338+
assert equal_computations([out], [dot(vec, vec)])
2339+
2340+
# matrix-vector product
2341+
out = tensordot(mat, vec, axes=[[-1], [-1]])
2342+
assert equal_computations([out], [dot(mat, vec)])
2343+
2344+
out = tensordot(mat, vec, axes=[[-2], [-1]])
2345+
assert equal_computations([out], [dot(mat.T, vec)])
2346+
2347+
# vector-matrix product
2348+
out = tensordot(vec, mat, axes=[[-1], [-2]])
2349+
assert equal_computations([out], [dot(vec, mat)])
2350+
2351+
out = tensordot(vec, mat, axes=[[-1], [-1]])
2352+
assert equal_computations([out], [dot(vec, mat.T)])
2353+
2354+
# matrix-matrix product
2355+
out = tensordot(mat, mat, axes=[[-1], [-2]])
2356+
assert equal_computations([out], [dot(mat, mat)])
2357+
2358+
out = tensordot(mat, mat, axes=[[-1], [-1]])
2359+
assert equal_computations([out], [dot(mat, mat.T)])
2360+
23262361

23272362
def test_smallest():
23282363
x = dvector()

0 commit comments

Comments
 (0)