Skip to content

Commit 4c7eb61

Browse files
authored
[BUG FIX] Add shapes support to parse_einsum_input (#200)
* [BUG FIX] Add shapes support to parse_einsum_input * Realized more than just having shape is important for converting to numpy * fix lint * Address comments, add test case * Just smoosh args into a list * Make the test make sense * Raise error when non-shapes look passed in but shapes=True * Patch mypy, use pytest
1 parent acfb0ec commit 4c7eb61

File tree

3 files changed

+59
-11
lines changed

3 files changed

+59
-11
lines changed

opt_einsum/contract.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def contract_path(*operands_: Any, **kwargs: Any) -> Tuple[PathType, PathInfo]:
129129
**Parameters:**
130130
131131
- **subscripts** - *(str)* Specifies the subscripts for summation.
132-
- **\\*operands** - *(list of array_like)* hese are the arrays for the operation.
132+
- **\\*operands** - *(list of array_like)* these are the arrays for the operation.
133133
- **use_blas** - *(bool)* Do you use BLAS for valid operations, may use extra memory for more intermediates.
134134
- **optimize** - *(str, list or bool, optional (default: `auto`))* Choose the type of path.
135135
@@ -251,7 +251,7 @@ def contract_path(*operands_: Any, **kwargs: Any) -> Tuple[PathType, PathInfo]:
251251
use_blas = kwargs.pop("use_blas", True)
252252

253253
# Python side parsing
254-
input_subscripts, output_subscript, operands = parser.parse_einsum_input(operands_)
254+
input_subscripts, output_subscript, operands = parser.parse_einsum_input(operands_, shapes=shapes)
255255

256256
# Build a few useful list and sets
257257
input_list = input_subscripts.split(",")

opt_einsum/parser.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44

55
import itertools
6-
from typing import Any, Dict, Iterator, List, Tuple
6+
from typing import Any, Dict, Iterator, List, Tuple, Union
77

88
import numpy as np
99

@@ -224,7 +224,7 @@ def convert_subscripts(old_sub: List[Any], symbol_map: Dict[Any, Any]) -> str:
224224
return new_sub
225225

226226

227-
def convert_interleaved_input(operands: List[Any]) -> Tuple[str, List[Any]]:
227+
def convert_interleaved_input(operands: Union[List[Any], Tuple[Any]]) -> Tuple[str, List[Any]]:
228228
"""Convert 'interleaved' input to standard einsum input."""
229229
tmp_operands = list(operands)
230230
operand_list = []
@@ -262,10 +262,16 @@ def convert_interleaved_input(operands: List[Any]) -> Tuple[str, List[Any]]:
262262
return subscripts, operands
263263

264264

265-
def parse_einsum_input(operands: Any) -> Tuple[str, str, List[ArrayType]]:
265+
def parse_einsum_input(operands: Any, shapes: bool = False) -> Tuple[str, str, List[ArrayType]]:
266266
"""
267267
A reproduction of einsum c side einsum parsing in python.
268268
269+
**Parameters:**
270+
Intakes the same inputs as `contract_path`, but NOT the keyword args. The only
271+
supported keyword argument is:
272+
- **shapes** - *(bool, optional)* Whether ``parse_einsum_input`` should assume arrays (the default) or
273+
array shapes have been supplied.
274+
269275
Returns
270276
-------
271277
input_strings : str
@@ -293,11 +299,21 @@ def parse_einsum_input(operands: Any) -> Tuple[str, str, List[ArrayType]]:
293299

294300
if isinstance(operands[0], str):
295301
subscripts = operands[0].replace(" ", "")
302+
if shapes:
303+
if any([hasattr(o, "shape") for o in operands[1:]]):
304+
raise ValueError(
305+
"shapes is set to True but given at least one operand looks like an array"
306+
" (at least one operand has a shape attribute). "
307+
)
296308
operands = [possibly_convert_to_numpy(x) for x in operands[1:]]
297-
298309
else:
299310
subscripts, operands = convert_interleaved_input(operands)
300311

312+
if shapes:
313+
operand_shapes = operands
314+
else:
315+
operand_shapes = [o.shape for o in operands]
316+
301317
# Check for proper "->"
302318
if ("-" in subscripts) or (">" in subscripts):
303319
invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1)
@@ -307,7 +323,7 @@ def parse_einsum_input(operands: Any) -> Tuple[str, str, List[ArrayType]]:
307323
# Parse ellipses
308324
if "." in subscripts:
309325
used = subscripts.replace(".", "").replace(",", "").replace("->", "")
310-
ellipse_inds = "".join(gen_unused_symbols(used, max(len(x.shape) for x in operands)))
326+
ellipse_inds = "".join(gen_unused_symbols(used, max(len(x) for x in operand_shapes)))
311327
longest = 0
312328

313329
# Do we have an output to account for?
@@ -325,10 +341,10 @@ def parse_einsum_input(operands: Any) -> Tuple[str, str, List[ArrayType]]:
325341
raise ValueError("Invalid Ellipses.")
326342

327343
# Take into account numerical values
328-
if operands[num].shape == ():
344+
if operand_shapes[num] == ():
329345
ellipse_count = 0
330346
else:
331-
ellipse_count = max(len(operands[num].shape), 1) - (len(sub) - 3)
347+
ellipse_count = max(len(operand_shapes[num]), 1) - (len(sub) - 3)
332348

333349
if ellipse_count > longest:
334350
longest = ellipse_count
@@ -370,6 +386,9 @@ def parse_einsum_input(operands: Any) -> Tuple[str, str, List[ArrayType]]:
370386

371387
# Make sure number operands is equivalent to the number of terms
372388
if len(input_subscripts.split(",")) != len(operands):
373-
raise ValueError("Number of einsum subscripts must be equal to the " "number of operands.")
389+
raise ValueError(
390+
f"Number of einsum subscripts, {len(input_subscripts.split(','))}, must be equal to the "
391+
f"number of operands, {len(operands)}."
392+
)
374393

375394
return input_subscripts, output_subscript, operands

opt_einsum/tests/test_parser.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
Directly tests various parser utility functions.
33
"""
44

5-
from opt_einsum.parser import get_symbol
5+
from multiprocessing.sharedctypes import Value
6+
import numpy as np
7+
import pytest
8+
from opt_einsum.parser import get_symbol, parse_einsum_input, possibly_convert_to_numpy
69

710

811
def test_get_symbol():
@@ -12,3 +15,29 @@ def test_get_symbol():
1215
assert get_symbol(55295) == "\ud88b"
1316
assert get_symbol(55296) == "\ue000"
1417
assert get_symbol(57343) == "\ue7ff"
18+
19+
20+
def test_parse_einsum_input():
21+
eq = "ab,bc,cd"
22+
ops = [np.random.rand(2, 3), np.random.rand(3, 4), np.random.rand(4, 5)]
23+
input_subscripts, output_subscript, operands = parse_einsum_input([eq, *ops])
24+
assert input_subscripts == eq
25+
assert output_subscript == "ad"
26+
assert operands == ops
27+
28+
29+
def test_parse_einsum_input_shapes_error():
30+
eq = "ab,bc,cd"
31+
ops = [np.random.rand(2, 3), np.random.rand(3, 4), np.random.rand(4, 5)]
32+
33+
with pytest.raises(ValueError):
34+
_ = parse_einsum_input([eq, *ops], shapes=True)
35+
36+
37+
def test_parse_einsum_input_shapes():
38+
eq = "ab,bc,cd"
39+
shps = [(2, 3), (3, 4), (4, 5)]
40+
input_subscripts, output_subscript, operands = parse_einsum_input([eq, *shps], shapes=True)
41+
assert input_subscripts == eq
42+
assert output_subscript == "ad"
43+
assert np.allclose([possibly_convert_to_numpy(shp) for shp in shps], operands)

0 commit comments

Comments
 (0)