@@ -2152,62 +2152,79 @@ def tensordot(
2152
2152
a = as_tensor_variable (a )
2153
2153
b = as_tensor_variable (b )
2154
2154
runtime_shape_a = a .shape
2155
- bcast_a = a .broadcastable
2156
2155
static_shape_a = a .type .shape
2157
- ndim_a = a .ndim
2156
+ ndim_a = a .type . ndim
2158
2157
runtime_shape_b = b .shape
2159
- bcast_b = b .broadcastable
2160
2158
static_shape_b = b .type .shape
2161
- ndim_b = b .ndim
2159
+ ndim_b = b .type . ndim
2162
2160
if na != nb :
2163
2161
raise ValueError (
2164
2162
"The number of axes supplied for tensordot must be equal for each tensor. "
2165
2163
f"Got { na } and { nb } respectively."
2166
2164
)
2167
2165
axes_a = list (normalize_axis_tuple (axes_a , ndim_a ))
2168
2166
axes_b = list (normalize_axis_tuple (axes_b , ndim_b ))
2167
+
2168
+ # The operation is only valid if the original dimensions match in length
2169
+ # The ravelling of the dimensions to coerce the operation into a single dot
2170
+ # could mask such errors, so we add an Assert if needed.
2169
2171
must_assert_runtime = False
2170
- for k in range (na ):
2171
- ax_a = axes_a [k ]
2172
- ax_b = axes_b [k ]
2173
- if (bcast_a [ax_a ] != bcast_b [ax_b ]) or (
2172
+ for ax_a , ax_b in zip (axes_a , axes_b , strict = True ):
2173
+ if (
2174
2174
static_shape_a [ax_a ] is not None
2175
2175
and static_shape_b [ax_b ] is not None
2176
2176
and static_shape_a [ax_a ] != static_shape_b [ax_b ]
2177
2177
):
2178
2178
raise ValueError (
2179
- "Input arrays have inconsistent broadcastable pattern or type shape along the axes "
2179
+ "Input arrays have inconsistent type shape along the axes "
2180
2180
"that are to be reduced with tensordot."
2181
2181
)
2182
2182
elif static_shape_a [ax_a ] is None or static_shape_b [ax_b ] is None :
2183
2183
if must_assert_runtime :
2184
2184
a = Assert (
2185
2185
"Input array shape along reduced axes of tensordot are not equal"
2186
- )(a , eq (a . shape [ax_a ], b . shape [ax_b ]))
2186
+ )(a , eq (runtime_shape_a [ax_a ], runtime_shape_b [ax_b ]))
2187
2187
must_assert_runtime = True
2188
2188
2189
- # Move the axes to sum over to the end of "a"
2190
- # and to the front of "b"
2191
- notin = [k for k in range (ndim_a ) if k not in axes_a ]
2192
- newaxes_a = notin + axes_a
2193
- N2 = 1
2194
- for axis in axes_a :
2195
- N2 *= runtime_shape_a [axis ]
2196
- newshape_a = (- 1 , N2 )
2197
- olda = [runtime_shape_a [axis ] for axis in notin ]
2198
-
2199
- notin = [k for k in range (ndim_b ) if k not in axes_b ]
2200
- newaxes_b = axes_b + notin
2201
- N2 = 1
2202
- for axis in axes_b :
2203
- N2 *= runtime_shape_b [axis ]
2204
- newshape_b = (N2 , - 1 )
2205
- oldb = [runtime_shape_b [axis ] for axis in notin ]
2206
-
2207
- at = a .transpose (newaxes_a ).reshape (newshape_a )
2208
- bt = b .transpose (newaxes_b ).reshape (newshape_b )
2209
- res = _dot (at , bt )
2210
- return res .reshape (olda + oldb )
2189
+ # Convert tensordot into a stacked dot product.
2190
+ # We stack the summed axes and the non-summed axes of each tensor separately,
2191
+ # and place the summed axes at the end of a and the beginning of b
2192
+ non_summed_axes_a = [k for k in range (ndim_a ) if k not in axes_a ]
2193
+ non_summed_dims_a = [runtime_shape_a [axis ] for axis in non_summed_axes_a ]
2194
+ transpose_axes_a = non_summed_axes_a + axes_a
2195
+ # We only need a reshape when we need to combine summed or non-summed dims
2196
+ # or introduce a new dimension (expand_dims), when doing a non-scalar outer product (axes = 0)
2197
+ a_needs_reshape = (ndim_a != 0 ) and (
2198
+ (len (non_summed_axes_a ) > 1 ) or (len (axes_a ) != 1 )
2199
+ )
2200
+
2201
+ non_summed_axes_b = [k for k in range (ndim_b ) if k not in axes_b ]
2202
+ non_summed_dims_b = [runtime_shape_b [axis ] for axis in non_summed_axes_b ]
2203
+ transpose_axes_b = axes_b + non_summed_axes_b
2204
+ b_needs_reshape = (ndim_b != 0 ) and (
2205
+ (len (non_summed_axes_b ) > 1 ) or (len (axes_b ) != 1 )
2206
+ )
2207
+
2208
+ # summed_size_a and summed_size_b must be the same,
2209
+ # but to facilitate reasoning about useless reshapes we compute both from their shapes
2210
+ at = a .transpose (transpose_axes_a )
2211
+ if a_needs_reshape :
2212
+ non_summed_size_a = variadic_mul (* non_summed_dims_a )
2213
+ summed_size_a = variadic_mul (* [runtime_shape_a [axis ] for axis in axes_a ])
2214
+ at = at .reshape ((non_summed_size_a , summed_size_a ))
2215
+
2216
+ bt = b .transpose (transpose_axes_b )
2217
+ if b_needs_reshape :
2218
+ non_summed_size_b = variadic_mul (* non_summed_dims_b )
2219
+ summed_size_b = variadic_mul (* [runtime_shape_b [axis ] for axis in axes_b ])
2220
+ bt = bt .reshape ((summed_size_b , non_summed_size_b ))
2221
+
2222
+ res = dot (at , bt )
2223
+
2224
+ if a_needs_reshape or b_needs_reshape :
2225
+ res = res .reshape (non_summed_dims_a + non_summed_dims_b )
2226
+
2227
+ return res
2211
2228
2212
2229
2213
2230
def outer (x , y ):
0 commit comments