diff --git a/cirq/__init__.py b/cirq/__init__.py index e2a0bc31ced..067d7f26cfa 100644 --- a/cirq/__init__.py +++ b/cirq/__init__.py @@ -412,6 +412,7 @@ has_channel, has_mixture, has_mixture_channel, + has_stabilizer_effect, has_unitary, inverse, is_measurement, diff --git a/cirq/ops/common_gates.py b/cirq/ops/common_gates.py index e9bf7a9e5e0..5b2380d1b57 100644 --- a/cirq/ops/common_gates.py +++ b/cirq/ops/common_gates.py @@ -191,6 +191,11 @@ def _phase_by_(self, phase_turns, qubit_index): exponent=self._exponent, phase_exponent=phase_turns * 2) + def _has_stabilizer_effect_(self) -> Optional[bool]: + if self._is_parameterized_(): + return None + return self.exponent % 1 == 0 + def __str__(self) -> str: if self._global_shift == -0.5: if self._exponent == 1: @@ -330,6 +335,11 @@ def _phase_by_(self, phase_turns, qubit_index): exponent=self._exponent, phase_exponent=0.5 + phase_turns * 2) + def _has_stabilizer_effect_(self) -> Optional[bool]: + if self._is_parameterized_(): + return None + return self.exponent % 1 == 0 + def __str__(self) -> str: if self._global_shift == -0.5: if self._exponent == 1: @@ -474,6 +484,11 @@ def _pauli_expansion_(self) -> value.LinearDict[str]: def _phase_by_(self, phase_turns: float, qubit_index: int): return self + def _has_stabilizer_effect_(self) -> Optional[bool]: + if self._is_parameterized_(): + return None + return self.exponent % 0.5 == 0 + def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs' ) -> Union[str, 'protocols.CircuitDiagramInfo']: if self._global_shift == -0.5: @@ -662,6 +677,11 @@ def _qasm_(self, args: 'cirq.QasmArgs', 'rx({1:half_turns}) {3};\n' 'ry({2:half_turns}) {3};\n', 0.25, self._exponent, -0.25, qubits[0]) + def _has_stabilizer_effect_(self) -> Optional[bool]: + if self._is_parameterized_(): + return None + return self.exponent % 1 == 0 + def __str__(self): if self._exponent == 1: return 'H' @@ -794,6 +814,11 @@ def _qasm_(self, args: 'cirq.QasmArgs', args.validate_version('2.0') return args.format('cz {0},{1};\n', qubits[0], qubits[1]) + def _has_stabilizer_effect_(self) -> Optional[bool]: + if self._is_parameterized_(): + return None + return self.exponent % 1 == 0 + def __str__(self) -> str: if self._exponent == 1: return 'CZ' @@ -954,6 +979,11 @@ def _qasm_(self, args: 'cirq.QasmArgs', args.validate_version('2.0') return args.format('cx {0},{1};\n', qubits[0], qubits[1]) + def _has_stabilizer_effect_(self) -> Optional[bool]: + if self._is_parameterized_(): + return None + return self.exponent % 1 == 0 + def __str__(self) -> str: if self._exponent == 1: return 'CNOT' diff --git a/cirq/ops/common_gates_test.py b/cirq/ops/common_gates_test.py index d15cec294b3..76ade394cf8 100644 --- a/cirq/ops/common_gates_test.py +++ b/cirq/ops/common_gates_test.py @@ -530,6 +530,78 @@ def test_rz_unitary(): np.array([[1j, 0], [0, -1j]])) +def test_x_stabilizer(): + gate = cirq.X + assert cirq.has_stabilizer_effect(gate) + assert not cirq.has_stabilizer_effect(gate**0.5) + assert cirq.has_stabilizer_effect(gate**0) + assert not cirq.has_stabilizer_effect(gate**-0.5) + assert cirq.has_stabilizer_effect(gate**4) + assert not cirq.has_stabilizer_effect(gate**1.2) + foo = sympy.Symbol('foo') + assert not cirq.has_stabilizer_effect(gate**foo) + + +def test_y_stabilizer(): + gate = cirq.Y + assert cirq.has_stabilizer_effect(gate) + assert not cirq.has_stabilizer_effect(gate**0.5) + assert cirq.has_stabilizer_effect(gate**0) + assert not cirq.has_stabilizer_effect(gate**-0.5) + assert cirq.has_stabilizer_effect(gate**4) + assert not cirq.has_stabilizer_effect(gate**1.2) + foo = sympy.Symbol('foo') + assert not cirq.has_stabilizer_effect(gate**foo) + + +def test_z_stabilizer(): + gate = cirq.Z + assert cirq.has_stabilizer_effect(gate) + assert cirq.has_stabilizer_effect(gate**0.5) + assert cirq.has_stabilizer_effect(gate**0) + assert cirq.has_stabilizer_effect(gate**-0.5) + assert cirq.has_stabilizer_effect(gate**4) + assert not cirq.has_stabilizer_effect(gate**1.2) + foo = sympy.Symbol('foo') + assert not cirq.has_stabilizer_effect(gate**foo) + + +def test_h_stabilizer(): + gate = cirq.H + assert cirq.has_stabilizer_effect(gate) + assert not cirq.has_stabilizer_effect(gate**0.5) + assert cirq.has_stabilizer_effect(gate**0) + assert not cirq.has_stabilizer_effect(gate**-0.5) + assert cirq.has_stabilizer_effect(gate**4) + assert not cirq.has_stabilizer_effect(gate**1.2) + foo = sympy.Symbol('foo') + assert not cirq.has_stabilizer_effect(gate**foo) + + +def test_cz_stabilizer(): + gate = cirq.CZ + assert cirq.has_stabilizer_effect(gate) + assert not cirq.has_stabilizer_effect(gate**0.5) + assert cirq.has_stabilizer_effect(gate**0) + assert not cirq.has_stabilizer_effect(gate**-0.5) + assert cirq.has_stabilizer_effect(gate**4) + assert not cirq.has_stabilizer_effect(gate**1.2) + foo = sympy.Symbol('foo') + assert not cirq.has_stabilizer_effect(gate**foo) + + +def test_cnot_stabilizer(): + gate = cirq.CNOT + assert cirq.has_stabilizer_effect(gate) + assert not cirq.has_stabilizer_effect(gate**0.5) + assert cirq.has_stabilizer_effect(gate**0) + assert not cirq.has_stabilizer_effect(gate**-0.5) + assert cirq.has_stabilizer_effect(gate**4) + assert not cirq.has_stabilizer_effect(gate**1.2) + foo = sympy.Symbol('foo') + assert not cirq.has_stabilizer_effect(gate**foo) + + @pytest.mark.parametrize('rads', (-1, -0.3, 0.1, 1)) def test_deprecated_rxyz_rotations(rads): with capture_logging(): diff --git a/cirq/ops/raw_types.py b/cirq/ops/raw_types.py index b791bc43114..6f81a3ee772 100644 --- a/cirq/ops/raw_types.py +++ b/cirq/ops/raw_types.py @@ -294,6 +294,9 @@ def controlled(self, def _backwards_compatibility_num_qubits(self) -> int: return protocols.num_qubits(self) + def _has_stabilizer_effect_(self) -> Optional[bool]: + return NotImplemented + @value.alternative(requires='_num_qubits_', implementation=_backwards_compatibility_num_qubits) def num_qubits(self) -> int: diff --git a/cirq/protocols/__init__.py b/cirq/protocols/__init__.py index ceb41fb32d1..cfd0932f662 100644 --- a/cirq/protocols/__init__.py +++ b/cirq/protocols/__init__.py @@ -60,6 +60,8 @@ equal_up_to_global_phase, SupportsEqualUpToGlobalPhase, ) +from cirq.protocols.has_stabilizer_effect_protocol import ( + has_stabilizer_effect,) from cirq.protocols.has_unitary_protocol import ( has_unitary, SupportsExplicitHasUnitary, diff --git a/cirq/protocols/has_stabilizer_effect_protocol.py b/cirq/protocols/has_stabilizer_effect_protocol.py new file mode 100644 index 00000000000..606f286b167 --- /dev/null +++ b/cirq/protocols/has_stabilizer_effect_protocol.py @@ -0,0 +1,61 @@ +# Copyright 2018 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import ( + Any, + Optional, +) + + +def has_stabilizer_effect(val: Any) -> bool: + """ + Returns whether the input has a stabilizer effect. Currently only limits to + Pauli, H, S, CNOT and CZ gates and their Operations. Does not attempt to + decompose a gate into supported gates. For e.g. iSWAP or X**0.5 gate will + return False. + """ + strats = [ + _strat_has_stabilizer_effect_from_has_stabilizer_effect, + _strat_has_stabilizer_effect_from_gate + ] + for strat in strats: + result = strat(val) + if result is not None: + return result + + # If you can't determine if it has stabilizer effect, it does not. + return False + + +def _strat_has_stabilizer_effect_from_has_stabilizer_effect(val: Any + ) -> Optional[bool]: + """ + Attempts to infer whether val has stabilzer effect via its + _has_stabilizer_effect_ method. + """ + if hasattr(val, '_has_stabilizer_effect_'): + result = val._has_stabilizer_effect_() + if result is not NotImplemented and result is not None: + return result + return None + + +def _strat_has_stabilizer_effect_from_gate(val: Any) -> Optional[bool]: + """ + Attempts to infer whether val has stabilzer effect via the value of + _has_stabilizer_effect_ method of its constituent gate. + """ + if hasattr(val, 'gate'): + return _strat_has_stabilizer_effect_from_has_stabilizer_effect(val.gate) + return None diff --git a/cirq/protocols/has_stabilizer_effect_protocol_test.py b/cirq/protocols/has_stabilizer_effect_protocol_test.py new file mode 100644 index 00000000000..c6e2f0025be --- /dev/null +++ b/cirq/protocols/has_stabilizer_effect_protocol_test.py @@ -0,0 +1,113 @@ +# Copyright 2018 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cirq + + +class No: + pass + + +class No1: + + def _has_stabilizer_effect_(self): + return NotImplemented + + +class No2: + + def _has_stabilizer_effect_(self): + return None + + +class No3: + + def _has_stabilizer_effect_(self): + return False + + +class Yes: + + def _has_stabilizer_effect_(self): + return True + + +class EmptyOp(cirq.Operation): + """A trivial operation.""" + + @property + def qubits(self): + # coverage: ignore + return () + + def with_qubits(self, *new_qubits): + # coverage: ignore + return self + + +class NoOp(EmptyOp): + + @property + def gate(self): + return No() + + +class NoOp1(EmptyOp): + + @property + def gate(self): + return No1() + + +class NoOp2(EmptyOp): + + @property + def gate(self): + return No2() + + +class NoOp3(EmptyOp): + + @property + def gate(self): + return No3() + + +class YesOp(EmptyOp): + + @property + def gate(self): + return Yes() + + +def test_inconclusive(): + assert not cirq.has_stabilizer_effect(object()) + assert not cirq.has_stabilizer_effect('boo') + assert not cirq.has_stabilizer_effect(cirq.SingleQubitGate()) + assert not cirq.has_stabilizer_effect(No()) + assert not cirq.has_stabilizer_effect(NoOp()) + + +def test_via_has_stabilizer_effect_method(): + assert not cirq.has_stabilizer_effect(No1()) + assert not cirq.has_stabilizer_effect(No2()) + assert not cirq.has_stabilizer_effect(No3()) + assert cirq.has_stabilizer_effect(Yes()) + + +def test_via_gate_of_op(): + assert cirq.has_stabilizer_effect(YesOp()) + assert not cirq.has_stabilizer_effect(NoOp1()) + assert not cirq.has_stabilizer_effect(NoOp2()) + assert not cirq.has_stabilizer_effect(NoOp3()) diff --git a/cirq/sim/clifford/clifford_simulator.py b/cirq/sim/clifford/clifford_simulator.py index b3da45da4d9..7ea192380c4 100644 --- a/cirq/sim/clifford/clifford_simulator.py +++ b/cirq/sim/clifford/clifford_simulator.py @@ -52,10 +52,6 @@ def __init__(self, seed: value.RANDOM_STATE_LIKE = None): self.init = True self._prng = value.parse_random_state(seed) - @staticmethod - def get_supported_gates() -> List['cirq.Gate']: - return [cirq.X, cirq.Y, cirq.Z, cirq.H, cirq.S, cirq.CNOT, cirq.CZ] - def _base_iterator(self, circuit: circuits.Circuit, qubit_order: ops.QubitOrderOrList, initial_state: int ) -> Iterator['cirq.CliffordSimulatorStepResult']: diff --git a/cirq/sim/mux.py b/cirq/sim/mux.py index 0d74988bb9b..610e8adf528 100644 --- a/cirq/sim/mux.py +++ b/cirq/sim/mux.py @@ -39,11 +39,9 @@ def _is_clifford_circuit(program: 'cirq.Circuit') -> bool: - supported_ops = clifford_simulator.CliffordSimulator.get_supported_gates() - # TODO: Have this method check the decomposition of the circuit into - # clifford operations. - return all(op.gate in supported_ops or protocols.is_measurement(op) - for op in program.all_operations()) + return all( + protocols.has_stabilizer_effect(op) or protocols.is_measurement(op) + for op in program.all_operations()) def sample(program: 'cirq.Circuit', diff --git a/docs/api.rst b/docs/api.rst index 246ee59a31a..bb23cf4745b 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -327,6 +327,7 @@ the magic methods that can be implemented. cirq.has_channel cirq.has_mixture cirq.has_mixture_channel + cirq.has_stabilizer_effect cirq.has_unitary cirq.inverse cirq.is_measurement