Skip to content

Commit 9eeb83d

Browse files
daxfohlrht
authored andcommitted
Fix state vector factorization validation (quantumlib#5076)
The `validate` block here did not account for the possibility of a phase shift. It was just luck that `H(q0), I(q1)` factorization in the test did not cause one. Also there's no need to transpose the axes here, just compare `t1` and `t2` directly.
1 parent 05c2db2 commit 9eeb83d

File tree

2 files changed

+3
-5
lines changed

2 files changed

+3
-5
lines changed

cirq-core/cirq/linalg/transformations.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -594,9 +594,7 @@ def factor_state_vector(
594594
remainder = remainder / np.sum(abs(remainder) ** 2) ** 0.5
595595
if validate:
596596
t2 = state_vector_kronecker_product(extracted, remainder)
597-
axes2 = list(axes) + [i for i in range(t1.ndim) if i not in axes]
598-
t3 = transpose_state_vector_to_axis_order(t2, axes2)
599-
if not np.allclose(t3, t, atol=atol):
597+
if not predicates.allclose_up_to_global_phase(t2, t1, atol=atol):
600598
raise ValueError('The tensor cannot be factored by the requested axes')
601599
return extracted, remainder
602600

cirq-core/cirq/sim/state_vector_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,10 +379,10 @@ def test_step_result_bloch_vector():
379379

380380
def test_factor_validation():
381381
args = cirq.Simulator()._create_act_on_args(0, qubits=cirq.LineQubit.range(2))
382-
args.apply_operation(cirq.H(cirq.LineQubit(0)))
382+
args.apply_operation(cirq.H(cirq.LineQubit(0)) ** 0.7)
383383
t = args.create_merged_state().target_tensor
384384
cirq.linalg.transformations.factor_state_vector(t, [0])
385-
cirq.linalg.transformations.factor_state_vector(t, [1], atol=1e-2)
385+
cirq.linalg.transformations.factor_state_vector(t, [1])
386386
args.apply_operation(cirq.CNOT(cirq.LineQubit(0), cirq.LineQubit(1)))
387387
t = args.create_merged_state().target_tensor
388388
with pytest.raises(ValueError, match='factor'):

0 commit comments

Comments
 (0)