Skip to content

Commit 72ca41e

Browse files
authored
Fix empty tensordot. (#256)
1 parent ff0ed8e commit 72ca41e

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

sparse/coo/common.py

+8
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from functools import reduce, wraps
2+
from itertools import chain
23
import operator
34
import warnings
45
from collections.abc import Iterable
@@ -148,6 +149,13 @@ def tensordot(a, b, axes=2):
148149
newshape_b = (N2, -1)
149150
oldb = [bs[axis] for axis in notin]
150151

152+
if any(dim == 0 for dim in chain(newshape_a, newshape_b)):
153+
res = asCOO(np.empty(olda + oldb), check=False)
154+
if isinstance(a, np.ndarray) or isinstance(b, np.ndarray):
155+
res = res.todense()
156+
157+
return res
158+
151159
at = a.transpose(newaxes_a).reshape(newshape_a)
152160
bt = b.transpose(newaxes_b).reshape(newshape_b)
153161
res = _dot(at, bt)

tests/test_coo.py

+9
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,15 @@ def test_tensordot(a_shape, b_shape, axes):
324324
# assert isinstance(sparse.tensordot(a, sb, axes), COO)
325325

326326

327+
def test_tensordot_empty():
328+
x1 = np.empty((0, 0, 0))
329+
x2 = np.empty((0, 0, 0))
330+
s1 = sparse.COO.from_numpy(x1)
331+
s2 = sparse.COO.from_numpy(x2)
332+
333+
assert_eq(np.tensordot(x1, x2), sparse.tensordot(s1, s2))
334+
335+
327336
@pytest.mark.parametrize('a_shape, b_shape', [
328337
((3, 1, 6, 5), (2, 1, 4, 5, 6)),
329338
((2, 1, 4, 5, 6), (3, 1, 6, 5)),

0 commit comments

Comments
 (0)