3
3
"""
4
4
5
5
import itertools
6
- from typing import Any , Dict , Iterator , List , Tuple
6
+ from typing import Any , Dict , Iterator , List , Tuple , Union
7
7
8
8
import numpy as np
9
9
@@ -224,7 +224,7 @@ def convert_subscripts(old_sub: List[Any], symbol_map: Dict[Any, Any]) -> str:
224
224
return new_sub
225
225
226
226
227
- def convert_interleaved_input (operands : List [Any ]) -> Tuple [str , List [Any ]]:
227
+ def convert_interleaved_input (operands : Union [ List [Any ], Tuple [ Any ] ]) -> Tuple [str , List [Any ]]:
228
228
"""Convert 'interleaved' input to standard einsum input."""
229
229
tmp_operands = list (operands )
230
230
operand_list = []
@@ -262,10 +262,16 @@ def convert_interleaved_input(operands: List[Any]) -> Tuple[str, List[Any]]:
262
262
return subscripts , operands
263
263
264
264
265
- def parse_einsum_input (operands : Any ) -> Tuple [str , str , List [ArrayType ]]:
265
+ def parse_einsum_input (operands : Any , shapes : bool = False ) -> Tuple [str , str , List [ArrayType ]]:
266
266
"""
267
267
A reproduction of einsum c side einsum parsing in python.
268
268
269
+ **Parameters:**
270
+ Intakes the same inputs as `contract_path`, but NOT the keyword args. The only
271
+ supported keyword argument is:
272
+ - **shapes** - *(bool, optional)* Whether ``parse_einsum_input`` should assume arrays (the default) or
273
+ array shapes have been supplied.
274
+
269
275
Returns
270
276
-------
271
277
input_strings : str
@@ -293,11 +299,21 @@ def parse_einsum_input(operands: Any) -> Tuple[str, str, List[ArrayType]]:
293
299
294
300
if isinstance (operands [0 ], str ):
295
301
subscripts = operands [0 ].replace (" " , "" )
302
+ if shapes :
303
+ if any ([hasattr (o , "shape" ) for o in operands [1 :]]):
304
+ raise ValueError (
305
+ "shapes is set to True but given at least one operand looks like an array"
306
+ " (at least one operand has a shape attribute). "
307
+ )
296
308
operands = [possibly_convert_to_numpy (x ) for x in operands [1 :]]
297
-
298
309
else :
299
310
subscripts , operands = convert_interleaved_input (operands )
300
311
312
+ if shapes :
313
+ operand_shapes = operands
314
+ else :
315
+ operand_shapes = [o .shape for o in operands ]
316
+
301
317
# Check for proper "->"
302
318
if ("-" in subscripts ) or (">" in subscripts ):
303
319
invalid = (subscripts .count ("-" ) > 1 ) or (subscripts .count (">" ) > 1 )
@@ -307,7 +323,7 @@ def parse_einsum_input(operands: Any) -> Tuple[str, str, List[ArrayType]]:
307
323
# Parse ellipses
308
324
if "." in subscripts :
309
325
used = subscripts .replace ("." , "" ).replace ("," , "" ).replace ("->" , "" )
310
- ellipse_inds = "" .join (gen_unused_symbols (used , max (len (x . shape ) for x in operands )))
326
+ ellipse_inds = "" .join (gen_unused_symbols (used , max (len (x ) for x in operand_shapes )))
311
327
longest = 0
312
328
313
329
# Do we have an output to account for?
@@ -325,10 +341,10 @@ def parse_einsum_input(operands: Any) -> Tuple[str, str, List[ArrayType]]:
325
341
raise ValueError ("Invalid Ellipses." )
326
342
327
343
# Take into account numerical values
328
- if operands [num ]. shape == ():
344
+ if operand_shapes [num ] == ():
329
345
ellipse_count = 0
330
346
else :
331
- ellipse_count = max (len (operands [num ]. shape ), 1 ) - (len (sub ) - 3 )
347
+ ellipse_count = max (len (operand_shapes [num ]), 1 ) - (len (sub ) - 3 )
332
348
333
349
if ellipse_count > longest :
334
350
longest = ellipse_count
@@ -370,6 +386,9 @@ def parse_einsum_input(operands: Any) -> Tuple[str, str, List[ArrayType]]:
370
386
371
387
# Make sure number operands is equivalent to the number of terms
372
388
if len (input_subscripts .split ("," )) != len (operands ):
373
- raise ValueError ("Number of einsum subscripts must be equal to the " "number of operands." )
389
+ raise ValueError (
390
+ f"Number of einsum subscripts, { len (input_subscripts .split (',' ))} , must be equal to the "
391
+ f"number of operands, { len (operands )} ."
392
+ )
374
393
375
394
return input_subscripts , output_subscript , operands
0 commit comments