14
14
15
15
"""Utility methods for transforming matrices or vectors."""
16
16
17
- from typing import Tuple , Optional , Sequence , List , Union , TypeVar
17
+ from typing import Tuple , Optional , Sequence , List , Union
18
18
19
19
import numpy as np
20
20
26
26
# of type np.ndarray to ensure the method has the correct type signature in that
27
27
# case. It is checked for using `is`, so it won't have a false positive if the
28
28
# user provides a different np.array([]) value.
29
- RaiseValueErrorIfNotProvided = np .array ([]) # type: np.ndarray
30
-
31
- TDefault = TypeVar ('TDefault' )
29
+ RaiseValueErrorIfNotProvided : np .ndarray = np .array ([])
32
30
33
31
34
32
def reflection_matrix_pow (reflection_matrix : np .ndarray , exponent : float ):
@@ -326,6 +324,10 @@ def partial_trace(tensor: np.ndarray, keep_indices: List[int]) -> np.ndarray:
326
324
return np .einsum (tensor , left_indices + right_indices )
327
325
328
326
327
+ class EntangledStateError (ValueError ):
328
+ """Raised when a product state is expected, but an entangled state is provided."""
329
+
330
+
329
331
def partial_trace_of_state_vector_as_mixture (
330
332
state_vector : np .ndarray , keep_indices : List [int ], * , atol : Union [int , float ] = 1e-8
331
333
) -> Tuple [Tuple [float , np .ndarray ], ...]:
@@ -357,9 +359,13 @@ def partial_trace_of_state_vector_as_mixture(
357
359
"""
358
360
359
361
# Attempt to do efficient state factoring.
360
- state = sub_state_vector (state_vector , keep_indices , default = None , atol = atol )
361
- if state is not None :
362
+ try :
363
+ state = sub_state_vector (
364
+ state_vector , keep_indices , default = RaiseValueErrorIfNotProvided , atol = atol
365
+ )
362
366
return ((1.0 , state ),)
367
+ except EntangledStateError :
368
+ pass
363
369
364
370
# Fall back to a (non-unique) mixture representation.
365
371
keep_dims = 1 << len (keep_indices )
@@ -382,7 +388,7 @@ def sub_state_vector(
382
388
state_vector : np .ndarray ,
383
389
keep_indices : List [int ],
384
390
* ,
385
- default : TDefault = RaiseValueErrorIfNotProvided ,
391
+ default : np . ndarray = RaiseValueErrorIfNotProvided ,
386
392
atol : Union [int , float ] = 1e-8 ,
387
393
) -> np .ndarray :
388
394
r"""Attempts to factor a state vector into two parts and return one of them.
@@ -424,8 +430,10 @@ def sub_state_vector(
424
430
425
431
Raises:
426
432
ValueError: if the `state_vector` is not of the correct shape or the
427
- indices are not a valid subset of the input `state_vector`'s indices, or
428
- the result of factoring is not a pure state.
433
+ indices are not a valid subset of the input `state_vector`'s indices
434
+ EntangledStateError: If the result of factoring is not a pure state and
435
+ `default` is not provided.
436
+
429
437
"""
430
438
431
439
if not np .log2 (state_vector .size ).is_integer ():
@@ -471,7 +479,7 @@ def sub_state_vector(
471
479
if default is not RaiseValueErrorIfNotProvided :
472
480
return default
473
481
474
- raise ValueError (
482
+ raise EntangledStateError (
475
483
"Input state vector could not be factored into pure state over "
476
484
"indices {}" .format (keep_indices )
477
485
)
0 commit comments