diff --git a/cirq/ops/common_gates.py b/cirq/ops/common_gates.py index 811ca46d4ad..f6c18bc1937 100644 --- a/cirq/ops/common_gates.py +++ b/cirq/ops/common_gates.py @@ -103,6 +103,17 @@ def _act_on_(self, args: Any): tableau.xs[:, q] ^= tableau.zs[:, q] return True + if isinstance(args, clifford.ActOnStabilizerCHFormArgs): + if protocols.is_parameterized(self) or self.exponent % 0.5 != 0: + return NotImplemented + assert all( + gate._act_on_(args) for gate in # type: ignore + [H, ZPowGate(exponent=self._exponent), H]) + # Adjust the global phase based on the global_shift parameter. + args.state.omega *= np.exp(1j * np.pi * self.global_shift * + self.exponent) + return True + return NotImplemented def in_su2(self) -> 'XPowGate': @@ -322,6 +333,30 @@ def _act_on_(self, args: Any): tableau.xs[:, q].copy()) return True + if isinstance(args, clifford.ActOnStabilizerCHFormArgs): + if protocols.is_parameterized(self) or self.exponent % 0.5 != 0: + return NotImplemented + effective_exponent = self._exponent % 2 + state = args.state + if effective_exponent == 0.5: + assert all( + gate._act_on_(args) # type: ignore + for gate in [ZPowGate(), H]) + state.omega *= (1 + 1j) / (2**0.5) # type: ignore + elif effective_exponent == 1: + assert all( + gate._act_on_(args) for gate in # type: ignore + [ZPowGate(), H, ZPowGate(), H]) + state.omega *= 1j # type: ignore + elif effective_exponent == 1.5: + assert all( + gate._act_on_(args) # type: ignore + for gate in [H, ZPowGate()]) + state.omega *= (1 - 1j) / (2**0.5) # type: ignore + # Adjust the global phase based on the global_shift parameter. + args.state.omega *= np.exp(1j * np.pi * self.global_shift * + self.exponent) + return True return NotImplemented def in_su2(self) -> 'YPowGate': @@ -490,6 +525,22 @@ def _act_on_(self, args: Any): tableau.zs[:, q] ^= tableau.xs[:, q] return True + if isinstance(args, clifford.ActOnStabilizerCHFormArgs): + if protocols.is_parameterized(self) or self.exponent % 0.5 != 0: + return NotImplemented + q = args.axes[0] + effective_exponent = self._exponent % 2 + state = args.state + for _ in range(int(effective_exponent * 2)): + # Prescription for S left multiplication. + # Reference: https://arxiv.org/abs/1808.00128 Proposition 4 end + state.M[q, :] ^= state.G[q, :] + state.gamma[q] = (state.gamma[q] - 1) % 4 + # Adjust the global phase based on the global_shift parameter. + args.state.omega *= np.exp(1j * np.pi * self.global_shift * + self.exponent) + return True + return NotImplemented def _decompose_into_clifford_with_qubits_(self, qubits): @@ -756,18 +807,43 @@ def _act_on_(self, args: Any): from cirq.sim import clifford if isinstance(args, clifford.ActOnCliffordTableauArgs): - if protocols.is_parameterized(self) or self.exponent % 0.5 != 0: + if protocols.is_parameterized(self) or self.exponent % 1 != 0: return NotImplemented tableau = args.tableau q = args.axes[0] - if self._exponent % 1 != 0: - return NotImplemented if self._exponent % 2 == 1: (tableau.xs[:, q], tableau.zs[:, q]) = (tableau.zs[:, q].copy(), tableau.xs[:, q].copy()) tableau.rs[:] ^= (tableau.xs[:, q] & tableau.zs[:, q]) return True + if isinstance(args, clifford.ActOnStabilizerCHFormArgs): + if protocols.is_parameterized(self) or self.exponent % 1 != 0: + return NotImplemented + q = args.axes[0] + state = args.state + if self._exponent % 2 == 1: + # Prescription for H left multiplication + # Reference: https://arxiv.org/abs/1808.00128 + # Equations 48, 49 and Proposition 4 + t = state.s ^ (state.G[q, :] & state.v) + u = state.s ^ (state.F[q, :] & + (~state.v)) ^ (state.M[q, :] & state.v) + + alpha = sum(state.G[q, :] & (~state.v) & state.s) % 2 + beta = sum(state.M[q, :] & (~state.v) & state.s) + beta += sum(state.F[q, :] & state.v & state.M[q, :]) + beta += sum(state.F[q, :] & state.v & state.s) + beta %= 2 + + delta = (state.gamma[q] + 2 * (alpha + beta)) % 4 + + state.update_sum(t, u, delta=delta, alpha=alpha) + # Adjust the global phase based on the global_shift parameter. + args.state.omega *= np.exp(1j * np.pi * self.global_shift * + self.exponent) + return True + return NotImplemented def _decompose_(self, qubits): @@ -900,6 +976,22 @@ def _act_on_(self, args: Any): tableau.rs[:] ^= (tableau.xs[:, q2] & tableau.zs[:, q2]) return True + if isinstance(args, clifford.ActOnStabilizerCHFormArgs): + if protocols.is_parameterized(self) or self.exponent % 1 != 0: + return NotImplemented + q1 = args.axes[0] + q2 = args.axes[1] + state = args.state + if self._exponent % 2 == 1: + # Prescription for CZ left multiplication. + # Reference: https://arxiv.org/abs/1808.00128 Proposition 4 end + state.M[q1, :] ^= state.G[q2, :] + state.M[q2, :] ^= state.G[q1, :] + # Adjust the global phase based on the global_shift parameter. + args.state.omega *= np.exp(1j * np.pi * self.global_shift * + self.exponent) + return True + return NotImplemented def _pauli_expansion_(self) -> value.LinearDict[str]: @@ -1098,6 +1190,26 @@ def _act_on_(self, args: Any): tableau.zs[:, q1] ^= tableau.zs[:, q2] return True + if isinstance(args, clifford.ActOnStabilizerCHFormArgs): + if protocols.is_parameterized(self) or self.exponent % 1 != 0: + return NotImplemented + q1 = args.axes[0] + q2 = args.axes[1] + state = args.state + if self._exponent % 2 == 1: + # Prescription for CX left multiplication. + # Reference: https://arxiv.org/abs/1808.00128 Proposition 4 end + state.gamma[q1] = ( + state.gamma[q1] + state.gamma[q2] + 2 * + (sum(state.M[q1, :] & state.F[q2, :]) % 2)) % 4 + state.G[q2, :] ^= state.G[q1, :] + state.F[q1, :] ^= state.F[q2, :] + state.M[q1, :] ^= state.M[q2, :] + # Adjust the global phase based on the global_shift parameter. + args.state.omega *= np.exp(1j * np.pi * self.global_shift * + self.exponent) + return True + return NotImplemented def _pauli_expansion_(self) -> value.LinearDict[str]: diff --git a/cirq/ops/common_gates_test.py b/cirq/ops/common_gates_test.py index 11c6cfd5c43..80cd16ff498 100644 --- a/cirq/ops/common_gates_test.py +++ b/cirq/ops/common_gates_test.py @@ -290,7 +290,7 @@ def test_h_str(): assert str(cirq.H**0.5) == 'H^0.5' -def test_x_act_on(): +def test_x_act_on_tableau(): with pytest.raises(TypeError, match="Failed to act"): cirq.act_on(cirq.X, object()) original_tableau = cirq.CliffordTableau(num_qubits=5, initial_state=31) @@ -326,14 +326,21 @@ def test_x_act_on(): cirq.act_on(cirq.X**foo, args) -class PhaserGate(cirq.SingleQubitGate): +class iZGate(cirq.SingleQubitGate): """Equivalent to an iZ gate without _act_on_ defined on it.""" def _unitary_(self): return np.array([[1j, 0], [0, -1j]]) -def test_y_act_on(): +class MinusOnePhaseGate(cirq.SingleQubitGate): + """Equivalent to a -1 global phase without _act_on_ defined on it.""" + + def _unitary_(self): + return np.array([[-1, 0], [0, -1]]) + + +def test_y_act_on_tableau(): with pytest.raises(TypeError, match="Failed to act"): cirq.act_on(cirq.Y, object()) original_tableau = cirq.CliffordTableau(num_qubits=5, initial_state=31) @@ -348,18 +355,18 @@ def test_y_act_on(): cirq.act_on(cirq.Y**0.5, args, allow_decompose=False) cirq.act_on(cirq.Y**0.5, args, allow_decompose=False) - cirq.act_on(PhaserGate(), args) + cirq.act_on(iZGate(), args) assert args.log_of_measurement_results == {} assert args.tableau == flipped_tableau cirq.act_on(cirq.Y, args, allow_decompose=False) - cirq.act_on(PhaserGate(), args, allow_decompose=True) + cirq.act_on(iZGate(), args, allow_decompose=True) assert args.log_of_measurement_results == {} assert args.tableau == original_tableau cirq.act_on(cirq.Y**3.5, args, allow_decompose=False) cirq.act_on(cirq.Y**3.5, args, allow_decompose=False) - cirq.act_on(PhaserGate(), args) + cirq.act_on(iZGate(), args) assert args.log_of_measurement_results == {} assert args.tableau == flipped_tableau @@ -372,9 +379,11 @@ def test_y_act_on(): cirq.act_on(cirq.Y**foo, args) -def test_z_h_act_on(): +def test_z_h_act_on_tableau(): with pytest.raises(TypeError, match="Failed to act"): - cirq.act_on(cirq.Y, object()) + cirq.act_on(cirq.Z, object()) + with pytest.raises(TypeError, match="Failed to act"): + cirq.act_on(cirq.H, object()) original_tableau = cirq.CliffordTableau(num_qubits=5, initial_state=31) flipped_tableau = cirq.CliffordTableau(num_qubits=5, initial_state=23) @@ -417,18 +426,16 @@ def test_z_h_act_on(): with pytest.raises(TypeError, match="Failed to act action on state"): cirq.act_on(cirq.Z**foo, args) - foo = sympy.Symbol('foo') with pytest.raises(TypeError, match="Failed to act action on state"): cirq.act_on(cirq.H**foo, args) - foo = sympy.Symbol('foo') with pytest.raises(TypeError, match="Failed to act action on state"): cirq.act_on(cirq.H**1.5, args) -def test_cx_act_on(): +def test_cx_act_on_tableau(): with pytest.raises(TypeError, match="Failed to act"): - cirq.act_on(cirq.Y, object()) + cirq.act_on(cirq.CX, object()) original_tableau = cirq.CliffordTableau(num_qubits=5, initial_state=31) args = cirq.ActOnCliffordTableauArgs( @@ -471,7 +478,7 @@ def test_cx_act_on(): cirq.act_on(cirq.CX**1.5, args) -def test_cz_act_on(): +def test_cz_act_on_tableau(): with pytest.raises(TypeError, match="Failed to act"): cirq.act_on(cirq.Y, object()) original_tableau = cirq.CliffordTableau(num_qubits=5, initial_state=31) @@ -516,38 +523,102 @@ def test_cz_act_on(): cirq.act_on(cirq.CZ**1.5, args) +foo = sympy.Symbol('foo') + + +@pytest.mark.parametrize('input_gate_sequence, outcome', [ + ([cirq.X**foo], 'Error'), + ([cirq.X**0.25], 'Error'), + ([cirq.X**4], 'Original'), + ([cirq.X**0.5, cirq.X**0.5], 'Flipped'), + ([cirq.X], 'Flipped'), + ([cirq.X**3.5, cirq.X**3.5], 'Flipped'), + ([cirq.Y**foo], 'Error'), + ([cirq.Y**0.25], 'Error'), + ([cirq.Y**4], 'Original'), + ([cirq.Y**0.5, cirq.Y**0.5, iZGate()], 'Flipped'), + ([cirq.Y, iZGate()], 'Flipped'), + ([cirq.Y**3.5, cirq.Y**3.5, iZGate()], 'Flipped'), + ([cirq.Z**foo], 'Error'), + ([cirq.H**foo], 'Error'), + ([cirq.H**1.5], 'Error'), + ([cirq.Z**4], 'Original'), + ([cirq.H**4], 'Original'), + ([cirq.H, cirq.S, cirq.S, cirq.H], 'Flipped'), + ([cirq.H, cirq.Z, cirq.H], 'Flipped'), + ([cirq.H, cirq.Z**3.5, cirq.Z**3.5, cirq.H], 'Flipped'), + ([cirq.CX**foo], 'Error'), + ([cirq.CX**1.5], 'Error'), + ([cirq.CX**4], 'Original'), + ([cirq.CX], 'Flipped'), + ([cirq.CZ**foo], 'Error'), + ([cirq.CZ**1.5], 'Error'), + ([cirq.CZ**4], 'Original'), + ([cirq.CZ, MinusOnePhaseGate()], 'Original'), +]) +def test_act_on_ch_form(input_gate_sequence, outcome): + original_state = cirq.StabilizerStateChForm(num_qubits=5, initial_state=31) + num_qubits = cirq.num_qubits(input_gate_sequence[0]) + if num_qubits == 1: + axes = [1] + else: + assert num_qubits == 2 + axes = [0, 1] + args = cirq.ActOnStabilizerCHFormArgs(state=original_state.copy(), + axes=axes) + + flipped_state = cirq.StabilizerStateChForm(num_qubits=5, initial_state=23) + + if outcome == 'Error': + with pytest.raises(TypeError, match="Failed to act action on state"): + for input_gate in input_gate_sequence: + cirq.act_on(input_gate, args) + return + + for input_gate in input_gate_sequence: + cirq.act_on(input_gate, args) + + if outcome == 'Original': + np.testing.assert_allclose(args.state.state_vector(), + original_state.state_vector()) + + if outcome == 'Flipped': + np.testing.assert_allclose(args.state.state_vector(), + flipped_state.state_vector()) + + @pytest.mark.parametrize( - 'input_gate', + 'input_gate, assert_implemented', [ - cirq.X, - cirq.Y, - cirq.Z, - cirq.X**0.5, - cirq.Y**0.5, - cirq.Z**0.5, - cirq.X**3.5, - cirq.Y**3.5, - cirq.Z**3.5, - cirq.X**4, - cirq.Y**4, - cirq.Z**4, - cirq.H, - cirq.CX, - cirq.CZ, - cirq.H**4, - cirq.CX**4, - cirq.CZ**4, - # Gates not supported by CliffordTableau should not fail too. - cirq.X**0.25, - cirq.Y**0.25, - cirq.Z**0.25, - cirq.H**0.5, - cirq.CX**0.5, - cirq.CZ**0.5 + (cirq.X, True), + (cirq.Y, True), + (cirq.Z, True), + (cirq.X**0.5, True), + (cirq.Y**0.5, True), + (cirq.Z**0.5, True), + (cirq.X**3.5, True), + (cirq.Y**3.5, True), + (cirq.Z**3.5, True), + (cirq.X**4, True), + (cirq.Y**4, True), + (cirq.Z**4, True), + (cirq.H, True), + (cirq.CX, True), + (cirq.CZ, True), + (cirq.H**4, True), + (cirq.CX**4, True), + (cirq.CZ**4, True), + # Unsupported gates should not fail too. + (cirq.X**0.25, False), + (cirq.Y**0.25, False), + (cirq.Z**0.25, False), + (cirq.H**0.5, False), + (cirq.CX**0.5, False), + (cirq.CZ**0.5, False), ]) -def test_act_on_clifford_tableau(input_gate): - cirq.testing.assert_act_on_clifford_tableau_effect_matches_unitary( - input_gate) +def test_act_on_consistency(input_gate, assert_implemented): + cirq.testing.assert_all_implemented_act_on_effects_match_unitary( + input_gate, assert_implemented, assert_implemented) def test_runtime_types_of_rot_gates(): diff --git a/cirq/ops/measurement_gate_test.py b/cirq/ops/measurement_gate_test.py index dc14ade53ca..7bbd1fed3f9 100644 --- a/cirq/ops/measurement_gate_test.py +++ b/cirq/ops/measurement_gate_test.py @@ -278,7 +278,7 @@ def test_act_on_clifford_tableau(): a, b = cirq.LineQubit.range(2) m = cirq.measure(a, b, key='out', invert_mask=(True,)) # The below assertion does not fail since it ignores non-unitary operations - cirq.testing.assert_act_on_clifford_tableau_effect_matches_unitary(m) + cirq.testing.assert_all_implemented_act_on_effects_match_unitary(m) with pytest.raises(TypeError, match="Failed to act"): cirq.act_on(m, object()) diff --git a/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py b/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py index 6fd1007ea06..328bd0f5007 100644 --- a/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py +++ b/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py @@ -12,10 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable +from typing import Any, Iterable, TYPE_CHECKING +import numpy as np + +from cirq.ops import common_gates, pauli_gates +from cirq.ops.clifford_gate import SingleQubitCliffordGate +from cirq.protocols import has_unitary, num_qubits, unitary from cirq.sim.clifford.stabilizer_state_ch_form import StabilizerStateChForm +if TYPE_CHECKING: + import cirq + from typing import Optional + class ActOnStabilizerCHFormArgs: """Wrapper around a stabilizer state in CH form for the act_on protocol. @@ -27,7 +36,6 @@ class ActOnStabilizerCHFormArgs: def __init__(self, state: StabilizerStateChForm, axes: Iterable[int]): """Initializes with the given state and the axes for the operation. - Args: state: The StabilizerStateChForm to act on. Operations are expected to perform inplace edits of this object. @@ -36,3 +44,52 @@ def __init__(self, state: StabilizerStateChForm, axes: Iterable[int]): """ self.state = state self.axes = tuple(axes) + + def _act_on_fallback_(self, action: Any, allow_decompose: bool): + strats = [] + if allow_decompose: + strats.append( + _strat_act_on_stabilizer_ch_form_from_single_qubit_decompose) + for strat in strats: + result = strat(action, self) + if result is True: + return True + assert result is NotImplemented, str(result) + + return NotImplemented + + +def _strat_act_on_stabilizer_ch_form_from_single_qubit_decompose( + val: Any, args: 'cirq.ActOnStabilizerCHFormArgs') -> bool: + if num_qubits(val) == 1: + if not has_unitary(val): + return NotImplemented + u = unitary(val) + clifford_gate = SingleQubitCliffordGate.from_unitary(u) + if clifford_gate is not None: + # Gather the effective unitary applied so as to correct for the + # global phase later. + final_unitary = np.eye(2) + for axis, quarter_turns in clifford_gate.decompose_rotation(): + gate = None # type: Optional[cirq.Gate] + if axis == pauli_gates.X: + gate = common_gates.XPowGate(exponent=quarter_turns / 2) + assert gate._act_on_(args) + elif axis == pauli_gates.Y: + gate = common_gates.YPowGate(exponent=quarter_turns / 2) + assert gate._act_on_(args) + else: + assert axis == pauli_gates.Z + gate = common_gates.ZPowGate(exponent=quarter_turns / 2) + assert gate._act_on_(args) + + final_unitary = np.matmul(unitary(gate), final_unitary) + + # Find the entry with the largest magnitude in the input unitary. + k = max(np.ndindex(*u.shape), key=lambda t: abs(u[t])) + # Correct the global phase that wasn't conserved in the above + # decomposition. + args.state.omega *= u[k] / final_unitary[k] + return True + + return NotImplemented diff --git a/cirq/sim/clifford/act_on_stabilizer_ch_form_args_test.py b/cirq/sim/clifford/act_on_stabilizer_ch_form_args_test.py index ef7ab0bbfb8..708b485ea6d 100644 --- a/cirq/sim/clifford/act_on_stabilizer_ch_form_args_test.py +++ b/cirq/sim/clifford/act_on_stabilizer_ch_form_args_test.py @@ -46,3 +46,45 @@ def _act_on_(self, args): cirq.act_on(CustomGate(), args) np.testing.assert_allclose(state.gamma, [0, 1, 0]) + + +def test_unitary_fallback_y(): + + class UnitaryYGate(cirq.Gate): + + def num_qubits(self) -> int: + return 1 + + def _unitary_(self): + return np.array([[0, -1j], [1j, 0]]) + + original_state = cirq.StabilizerStateChForm(num_qubits=3) + + args = cirq.ActOnStabilizerCHFormArgs(state=original_state.copy(), axes=[1]) + cirq.act_on(UnitaryYGate(), args) + expected_args = cirq.ActOnStabilizerCHFormArgs(state=original_state.copy(), + axes=[1]) + cirq.act_on(cirq.Y, expected_args) + np.testing.assert_allclose(args.state.state_vector(), + expected_args.state.state_vector()) + + +def test_unitary_fallback_h(): + + class UnitaryHGate(cirq.Gate): + + def num_qubits(self) -> int: + return 1 + + def _unitary_(self): + return np.array([[1, 1], [1, -1]]) / (2**0.5) + + original_state = cirq.StabilizerStateChForm(num_qubits=3) + + args = cirq.ActOnStabilizerCHFormArgs(state=original_state.copy(), axes=[1]) + cirq.act_on(UnitaryHGate(), args) + expected_args = cirq.ActOnStabilizerCHFormArgs(state=original_state.copy(), + axes=[1]) + cirq.act_on(cirq.H, expected_args) + np.testing.assert_allclose(args.state.state_vector(), + expected_args.state.state_vector()) diff --git a/cirq/testing/__init__.py b/cirq/testing/__init__.py index 348df2c79b7..e6ad58e6e18 100644 --- a/cirq/testing/__init__.py +++ b/cirq/testing/__init__.py @@ -28,7 +28,7 @@ ) from cirq.testing.consistent_act_on import ( - assert_act_on_clifford_tableau_effect_matches_unitary,) + assert_all_implemented_act_on_effects_match_unitary,) from cirq.testing.consistent_phase_by import ( assert_phase_by_is_consistent_with_unitary,) diff --git a/cirq/testing/consistent_act_on.py b/cirq/testing/consistent_act_on.py index cab9ee988e7..5d609255da1 100644 --- a/cirq/testing/consistent_act_on.py +++ b/cirq/testing/consistent_act_on.py @@ -1,4 +1,4 @@ -# Copyright 2019 The Cirq Developers +# Copyright 2020 The Cirq Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,14 +22,26 @@ from cirq.ops.dense_pauli_string import DensePauliString from cirq import protocols from cirq.sim import act_on_state_vector_args, final_state_vector -from cirq.sim.clifford import act_on_clifford_tableau_args, clifford_tableau +from cirq.sim.clifford import (act_on_clifford_tableau_args, clifford_tableau, + stabilizer_state_ch_form, + act_on_stabilizer_ch_form_args) def state_vector_has_stabilizer(state_vector: np.ndarray, stabilizer: DensePauliString) -> bool: - """Checks that the stabilizer does not modify the value of the - state_vector, including the global phase. Does not mutate the input - state_vector.""" + """Checks that the state_vector is stabilized by the given stabilizer. + + The stabilizer should not modify the value of the state_vector, up to the + global phase. + + Args: + state_vector: An input state vector. Is not mutated by this function. + stabilizer: A potential stabilizer of the above state_vector as a + DensePauliString. + + Returns: + Whether the stabilizer stabilizes the supplied state. + """ args = act_on_state_vector_args.ActOnStateVectorArgs( target_tensor=state_vector.copy(), @@ -41,10 +53,25 @@ def state_vector_has_stabilizer(state_vector: np.ndarray, return np.allclose(args.target_tensor, state_vector) -def assert_act_on_clifford_tableau_effect_matches_unitary(val: Any) -> None: - """Checks that act_on with CliffordTableau generates stabilizers that - stabilize the final state vector. Does not work with Operations or Gates - expecting non-qubit Qids.""" +def assert_all_implemented_act_on_effects_match_unitary( + val: Any, + assert_tableau_implemented: bool = False, + assert_ch_form_implemented: bool = False) -> None: + """Uses val's effect on final_state_vector to check act_on(val)'s behavior. + + Checks that act_on with CliffordTableau or StabilizerStateCHForm behaves + consistently with act_on through final state vector. Does not work with + Operations or Gates expecting non-qubit Qids. If either of the + assert_*_implmented args is true, fails if the corresponding method is not + implemented for the test circuit. + + Args: + val: A gate or operation that may be an input to protocols.act_on. + assert_tableau_implemented: asserts that protocols.act_on() works with + val and ActOnCliffordTableauArgs inputs. + assert_ch_form_implemented: asserts that protocols.act_on() works with + val and ActOnStabilizerStateChFormArgs inputs. + """ # pylint: disable=unused-variable __tracebackhide__ = True @@ -54,6 +81,11 @@ def assert_act_on_clifford_tableau_effect_matches_unitary(val: Any) -> None: if protocols.is_parameterized(val) or not protocols.has_unitary( val) or protocols.qid_shape(val) != (2,) * num_qubits_val: + if assert_tableau_implemented or assert_ch_form_implemented: + assert False, ("Could not assert if any act_on methods were " + "implemented. Operating on qudits or with a " + "non-unitary or parameterized operation is " + "unsupported.\n\nval: {!r}".format(val)) return None qubits = LineQubit.range(protocols.num_qubits(val) * 2) @@ -70,26 +102,48 @@ def assert_act_on_clifford_tableau_effect_matches_unitary(val: Any) -> None: else: circuit.append(val.with_qubits(*qubits[:num_qubits_val])) - tableau = _final_clifford_tableau(circuit, qubit_map) - if tableau is None: - return None - state_vector = np.reshape(final_state_vector(circuit, qubit_order=qubits), protocols.qid_shape(qubits)) - assert all( - state_vector_has_stabilizer(state_vector, stab) - for stab in tableau.stabilizers()), ( - "act_on clifford tableau is not consistent with " - "final_state_vector simulation.\n\nval: {!r}".format(val)) + tableau = _final_clifford_tableau(circuit, qubit_map) + if tableau is None: + assert not assert_tableau_implemented, ("Failed to generate final " + "tableau for the test circuit." + "\n\nval: {!r}".format(val)) + else: + assert all( + state_vector_has_stabilizer(state_vector, stab) + for stab in tableau.stabilizers()), ( + "act_on clifford tableau is not consistent with " + "final_state_vector simulation.\n\nval: {!r}".format(val)) + + stabilizer_ch_form = _final_stabilizer_state_ch_form(circuit, qubit_map) + if stabilizer_ch_form is None: + assert not assert_ch_form_implemented, ("Failed to generate final " + "stabilizer state CH form " + "for the test circuit." + "\n\nval: {!r}".format(val)) + else: + np.testing.assert_allclose(np.reshape(stabilizer_ch_form.state_vector(), + protocols.qid_shape(qubits)), + state_vector, + atol=1e-07) def _final_clifford_tableau(circuit: Circuit, qubit_map ) -> Optional[clifford_tableau.CliffordTableau]: - """Initializes a CliffordTableau with default args for the given qubits and - evolves it by having each operation act on the tableau. Returns None if any - of the operation can not act on a CliffordTableau, returns the tableau - otherwise.""" + """Evolves a default CliffordTableau through the input circuit. + + Initializes a CliffordTableau with default args for the given qubits and + evolves it by having each operation act on the tableau. + + Args: + circuit: An input circuit that acts on the zero state + qubit_map: A map from qid to the qubit index for the above circuit + + Returns: + None if any of the operations can not act on a CliffordTableau, returns + the tableau otherwise.""" tableau = clifford_tableau.CliffordTableau(len(qubit_map)) for op in circuit.all_operations(): @@ -104,3 +158,32 @@ def _final_clifford_tableau(circuit: Circuit, qubit_map except TypeError: return None return tableau + + +def _final_stabilizer_state_ch_form( + circuit: Circuit, + qubit_map) -> Optional[stabilizer_state_ch_form.StabilizerStateChForm]: + """Evolves a default StabilizerStateChForm through the input circuit. + + Initializes a StabilizerStateChForm with default args for the given qubits + and evolves it by having each operation act on the state. + + Args: + circuit: An input circuit that acts on the zero state + qubit_map: A map from qid to the qubit index for the above circuit + + Returns: + None if any of the operations can not act on a StabilizerStateChForm, + returns the StabilizerStateChForm otherwise.""" + + stabilizer_ch_form = stabilizer_state_ch_form.StabilizerStateChForm( + len(qubit_map)) + for op in circuit.all_operations(): + try: + args = act_on_stabilizer_ch_form_args.ActOnStabilizerCHFormArgs( + state=stabilizer_ch_form, + axes=[qubit_map[qid] for qid in op.qubits]) + protocols.act_on(op, args, allow_decompose=True) + except TypeError: + return None + return stabilizer_ch_form diff --git a/cirq/testing/consistent_act_on_test.py b/cirq/testing/consistent_act_on_test.py index 80916dd43b3..a6790a27ea4 100644 --- a/cirq/testing/consistent_act_on_test.py +++ b/cirq/testing/consistent_act_on_test.py @@ -49,13 +49,47 @@ def _act_on_(self, args: Any): return NotImplemented +class UnimplementedGate(cirq.TwoQubitGate): + pass + + +class UnimplementedUnitaryGate(cirq.TwoQubitGate): + + def _unitary_(self): + return np.array([[0, 0, 0, 1], [0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, + 0]]) + + def test_assert_act_on_clifford_tableau_effect_matches_unitary(): - cirq.testing.assert_act_on_clifford_tableau_effect_matches_unitary( - GoodGate()) - cirq.testing.assert_act_on_clifford_tableau_effect_matches_unitary( + cirq.testing.assert_all_implemented_act_on_effects_match_unitary(GoodGate()) + cirq.testing.assert_all_implemented_act_on_effects_match_unitary( GoodGate().on(cirq.LineQubit(1))) with pytest.raises(AssertionError, match='act_on clifford tableau is not consistent with ' 'final_state_vector simulation.'): - cirq.testing.assert_act_on_clifford_tableau_effect_matches_unitary( + cirq.testing.assert_all_implemented_act_on_effects_match_unitary( BadGate()) + + cirq.testing.assert_all_implemented_act_on_effects_match_unitary( + UnimplementedGate()) + with pytest.raises( + AssertionError, + match='Could not assert if any act_on methods were implemented'): + cirq.testing.assert_all_implemented_act_on_effects_match_unitary( + UnimplementedGate(), assert_tableau_implemented=True) + with pytest.raises( + AssertionError, + match='Could not assert if any act_on methods were implemented'): + cirq.testing.assert_all_implemented_act_on_effects_match_unitary( + UnimplementedGate(), assert_ch_form_implemented=True) + + cirq.testing.assert_all_implemented_act_on_effects_match_unitary( + UnimplementedUnitaryGate()) + with pytest.raises(AssertionError, + match='Failed to generate final tableau'): + cirq.testing.assert_all_implemented_act_on_effects_match_unitary( + UnimplementedUnitaryGate(), assert_tableau_implemented=True) + with pytest.raises(AssertionError, + match='Failed to generate final stabilizer state'): + cirq.testing.assert_all_implemented_act_on_effects_match_unitary( + UnimplementedUnitaryGate(), assert_ch_form_implemented=True) diff --git a/cirq/testing/consistent_protocols.py b/cirq/testing/consistent_protocols.py index 8d6f1f75362..b36e7f602fc 100644 --- a/cirq/testing/consistent_protocols.py +++ b/cirq/testing/consistent_protocols.py @@ -20,7 +20,7 @@ from cirq import ops, protocols, value from cirq.testing.consistent_act_on import ( - assert_act_on_clifford_tableau_effect_matches_unitary) + assert_all_implemented_act_on_effects_match_unitary) from cirq.testing.circuit_compare import (assert_has_consistent_apply_unitary, assert_has_consistent_qid_shape) from cirq.testing.consistent_decomposition import ( @@ -138,7 +138,7 @@ def _assert_meets_standards_helper(val: Any, *, ignoring_global_phase: bool, assert_specifies_has_unitary_if_unitary(val) assert_has_consistent_qid_shape(val) assert_has_consistent_apply_unitary(val) - assert_act_on_clifford_tableau_effect_matches_unitary(val) + assert_all_implemented_act_on_effects_match_unitary(val) assert_qasm_is_consistent_with_unitary(val) assert_has_consistent_trace_distance_bound(val) assert_decompose_is_consistent_with_unitary(val, diff --git a/rtd_docs/api.rst b/rtd_docs/api.rst index 904178514ac..7b1f446ced0 100644 --- a/rtd_docs/api.rst +++ b/rtd_docs/api.rst @@ -592,7 +592,7 @@ operation. cirq.LinearDict cirq.PeriodicValue cirq.testing.DEFAULT_GATE_DOMAIN - cirq.testing.assert_act_on_clifford_tableau_effect_matches_unitary + cirq.testing.assert_all_implemented_act_on_effects_match_unitary cirq.testing.assert_allclose_up_to_global_phase cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent cirq.testing.assert_commutes_magic_method_consistent_with_unitaries