diff --git a/cirq-core/cirq/linalg/transformations.py b/cirq-core/cirq/linalg/transformations.py index 12322fe60cf..44b7977de01 100644 --- a/cirq-core/cirq/linalg/transformations.py +++ b/cirq-core/cirq/linalg/transformations.py @@ -396,7 +396,7 @@ def sub_state_vector( keep_indices: List[int], *, default: np.ndarray = RaiseValueErrorIfNotProvided, - atol: Union[int, float] = 1e-8, + atol: Union[int, float] = 1e-6, ) -> np.ndarray: r"""Attempts to factor a state vector into two parts and return one of them. diff --git a/cirq-core/cirq/linalg/transformations_test.py b/cirq-core/cirq/linalg/transformations_test.py index db13af73c12..d01e34cb898 100644 --- a/cirq-core/cirq/linalg/transformations_test.py +++ b/cirq-core/cirq/linalg/transformations_test.py @@ -609,3 +609,22 @@ def test_to_special(): su = cirq.to_special(u) assert not cirq.is_special_unitary(u) assert cirq.is_special_unitary(su) + + +def test_default_tolerance(): + a, b = cirq.LineQubit.range(2) + final_state_vector = ( + cirq.Simulator() + .simulate( + cirq.Circuit( + cirq.H(a), + cirq.H(b), + cirq.CZ(a, b), + cirq.measure(a), + ) + ) + .final_state_vector.reshape((2, 2)) + ) + # Here, we do NOT specify the default tolerance. It is merely to check that the default value + # is reasonable. + cirq.sub_state_vector(final_state_vector, [0])