@@ -534,8 +534,6 @@ def test_wavefunction_partial_trace_as_mixture_invalid_input():
534
534
535
535
536
536
def mixtures_equal (m1 , m2 , atol = 1e-7 ):
537
- if len (m1 ) != len (m2 ):
538
- return False
539
537
for (p1 , v1 ), (p2 , v2 ) in zip (m1 , m2 ):
540
538
if not (cirq .approx_eq (p1 , p2 , atol = atol ) and
541
539
cirq .equal_up_to_global_phase (v1 , v2 , atol = atol )):
@@ -582,25 +580,25 @@ def test_wavefunction_partial_trace_as_mixture_pure_result():
582
580
((1.0 , b ),))
583
581
assert mixtures_equal (
584
582
cirq .wavefunction_partial_trace_as_mixture (state , [5 , 6 , 7 , 8 ],
585
- atol = 1e-8 ),
586
- ((1.0 , c ),))
583
+ atol = 1e-8 ), ((1.0 , c ),))
587
584
588
585
# Return mixture will defer to numpy.linalg.eigh's builtin tolerance.
589
586
state = np .array ([1 , 0 , 0 , 1 ]) / np .sqrt (2 )
587
+ truth = ((0.5 , np .array ([1 , 0 ])), (0.5 , np .array ([0 , 1 ])))
590
588
assert mixtures_equal (
591
589
cirq .wavefunction_partial_trace_as_mixture (state , [1 ], atol = 1e-20 ),
592
- (( 0.5 , np . array ([ 1 , 0 ])), ( 0.5 , np . array ([ 0 , 1 ]))) , atol = 1e-15 )
590
+ truth , atol = 1e-15 )
593
591
assert not mixtures_equal (
594
592
cirq .wavefunction_partial_trace_as_mixture (state , [1 ], atol = 1e-20 ),
595
- (( 0.5 , np . array ([ 1 , 0 ])), ( 0.5 , np . array ([ 0 , 1 ]))) , atol = 1e-16 )
593
+ truth , atol = 1e-16 )
596
594
597
595
598
596
def test_wavefunction_partial_trace_as_mixture_mixed_result ():
599
597
state = np .array ([1 , 0 , 0 , 1 ]) / np .sqrt (2 )
600
598
truth = ((0.5 , np .array ([1 , 0 ])), (0.5 , np .array ([0 , 1 ])))
601
599
for q1 in [0 , 1 ]:
602
- mixture = cirq .wavefunction_partial_trace_as_mixture (
603
- state . reshape ( 2 , 2 ), [ q1 ], atol = 1e-8 )
600
+ mixture = cirq .wavefunction_partial_trace_as_mixture (state , [ q1 ],
601
+ atol = 1e-8 )
604
602
assert mixtures_equal (mixture , truth )
605
603
606
604
state = np .array ([0 , 1 , 1 , 0 , 1 , 0 , 0 , 0 ]).reshape (2 , 2 , 2 ) / np .sqrt (3 )
0 commit comments