@@ -2158,62 +2158,79 @@ def tensordot(
2158
2158
a = as_tensor_variable (a )
2159
2159
b = as_tensor_variable (b )
2160
2160
runtime_shape_a = a .shape
2161
- bcast_a = a .broadcastable
2162
2161
static_shape_a = a .type .shape
2163
- ndim_a = a .ndim
2162
+ ndim_a = a .type . ndim
2164
2163
runtime_shape_b = b .shape
2165
- bcast_b = b .broadcastable
2166
2164
static_shape_b = b .type .shape
2167
- ndim_b = b .ndim
2165
+ ndim_b = b .type . ndim
2168
2166
if na != nb :
2169
2167
raise ValueError (
2170
2168
"The number of axes supplied for tensordot must be equal for each tensor. "
2171
2169
f"Got { na } and { nb } respectively."
2172
2170
)
2173
2171
axes_a = list (normalize_axis_tuple (axes_a , ndim_a ))
2174
2172
axes_b = list (normalize_axis_tuple (axes_b , ndim_b ))
2173
+
2174
+ # The operation is only valid if the original dimensions match in length
2175
+ # The ravelling of the dimensions to coerce the operation into a single dot
2176
+ # could mask such errors, so we add an Assert if needed.
2175
2177
must_assert_runtime = False
2176
- for k in range (na ):
2177
- ax_a = axes_a [k ]
2178
- ax_b = axes_b [k ]
2179
- if (bcast_a [ax_a ] != bcast_b [ax_b ]) or (
2178
+ for ax_a , ax_b in zip (axes_a , axes_b , strict = True ):
2179
+ if (
2180
2180
static_shape_a [ax_a ] is not None
2181
2181
and static_shape_b [ax_b ] is not None
2182
2182
and static_shape_a [ax_a ] != static_shape_b [ax_b ]
2183
2183
):
2184
2184
raise ValueError (
2185
- "Input arrays have inconsistent broadcastable pattern or type shape along the axes "
2185
+ "Input arrays have inconsistent type shape along the axes "
2186
2186
"that are to be reduced with tensordot."
2187
2187
)
2188
2188
elif static_shape_a [ax_a ] is None or static_shape_b [ax_b ] is None :
2189
2189
if must_assert_runtime :
2190
2190
a = Assert (
2191
2191
"Input array shape along reduced axes of tensordot are not equal"
2192
- )(a , eq (a . shape [ax_a ], b . shape [ax_b ]))
2192
+ )(a , eq (runtime_shape_a [ax_a ], runtime_shape_b [ax_b ]))
2193
2193
must_assert_runtime = True
2194
2194
2195
- # Move the axes to sum over to the end of "a"
2196
- # and to the front of "b"
2197
- notin = [k for k in range (ndim_a ) if k not in axes_a ]
2198
- newaxes_a = notin + axes_a
2199
- N2 = 1
2200
- for axis in axes_a :
2201
- N2 *= runtime_shape_a [axis ]
2202
- newshape_a = (- 1 , N2 )
2203
- olda = [runtime_shape_a [axis ] for axis in notin ]
2204
-
2205
- notin = [k for k in range (ndim_b ) if k not in axes_b ]
2206
- newaxes_b = axes_b + notin
2207
- N2 = 1
2208
- for axis in axes_b :
2209
- N2 *= runtime_shape_b [axis ]
2210
- newshape_b = (N2 , - 1 )
2211
- oldb = [runtime_shape_b [axis ] for axis in notin ]
2212
-
2213
- at = a .transpose (newaxes_a ).reshape (newshape_a )
2214
- bt = b .transpose (newaxes_b ).reshape (newshape_b )
2215
- res = _dot (at , bt )
2216
- return res .reshape (olda + oldb )
2195
+ # Convert tensordot into a stacked dot product.
2196
+ # We stack the summed axes and the non-summed axes of each tensor separately,
2197
+ # and place the summed axes at the end of a and the beginning of b
2198
+ non_summed_axes_a = [k for k in range (ndim_a ) if k not in axes_a ]
2199
+ non_summed_dims_a = [runtime_shape_a [axis ] for axis in non_summed_axes_a ]
2200
+ transpose_axes_a = non_summed_axes_a + axes_a
2201
+ # We only need a reshape when we need to combine summed or non-summed dims
2202
+ # or introduce a new dimension (expand_dims) when doing a non-scalar outer product (len(axes) = 0)
2203
+ a_needs_reshape = (ndim_a != 0 ) and (
2204
+ (len (non_summed_axes_a ) > 1 ) or (len (axes_a ) != 1 )
2205
+ )
2206
+
2207
+ non_summed_axes_b = [k for k in range (ndim_b ) if k not in axes_b ]
2208
+ non_summed_dims_b = [runtime_shape_b [axis ] for axis in non_summed_axes_b ]
2209
+ transpose_axes_b = axes_b + non_summed_axes_b
2210
+ b_needs_reshape = (ndim_b != 0 ) and (
2211
+ (len (non_summed_axes_b ) > 1 ) or (len (axes_b ) != 1 )
2212
+ )
2213
+
2214
+ # summed_size_a and summed_size_b must be the same,
2215
+ # but to facilitate reasoning about useless reshapes we compute both from their shapes
2216
+ at = a .transpose (transpose_axes_a )
2217
+ if a_needs_reshape :
2218
+ non_summed_size_a = variadic_mul (* non_summed_dims_a )
2219
+ summed_size_a = variadic_mul (* [runtime_shape_a [axis ] for axis in axes_a ])
2220
+ at = at .reshape ((non_summed_size_a , summed_size_a ))
2221
+
2222
+ bt = b .transpose (transpose_axes_b )
2223
+ if b_needs_reshape :
2224
+ non_summed_size_b = variadic_mul (* non_summed_dims_b )
2225
+ summed_size_b = variadic_mul (* [runtime_shape_b [axis ] for axis in axes_b ])
2226
+ bt = bt .reshape ((summed_size_b , non_summed_size_b ))
2227
+
2228
+ res = dot (at , bt )
2229
+
2230
+ if a_needs_reshape or b_needs_reshape :
2231
+ res = res .reshape (non_summed_dims_a + non_summed_dims_b )
2232
+
2233
+ return res
2217
2234
2218
2235
2219
2236
def outer (x , y ):
0 commit comments