diff --git a/cirq-core/cirq/ops/controlled_gate.py b/cirq-core/cirq/ops/controlled_gate.py index 776f472b09b..a43c180c976 100644 --- a/cirq-core/cirq/ops/controlled_gate.py +++ b/cirq-core/cirq/ops/controlled_gate.py @@ -12,15 +12,35 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import AbstractSet, Any, cast, Collection, Dict, Optional, Sequence, Tuple, Union +from typing import ( + AbstractSet, + Any, + cast, + Collection, + Dict, + List, + Optional, + Sequence, + Tuple, + Union, + TYPE_CHECKING, +) import numpy as np -import cirq -from cirq import protocols, value -from cirq.ops import raw_types, controlled_operation as cop +from cirq import protocols, value, _import +from cirq.ops import raw_types, controlled_operation as cop, matrix_gates from cirq.type_workarounds import NotImplementedType +if TYPE_CHECKING: + import cirq + +controlled_gate_decomposition = _import.LazyLoader( + 'controlled_gate_decomposition', globals(), 'cirq.transformers.analytical_decompositions' +) +common_gates = _import.LazyLoader('common_gates', globals(), 'cirq.ops') +line_qubit = _import.LazyLoader('line_qubit', globals(), 'cirq.devices') + @value.value_equality class ControlledGate(raw_types.Gate): @@ -100,17 +120,50 @@ def num_controls(self) -> int: return len(self.control_qid_shape) def _qid_shape_(self) -> Tuple[int, ...]: - return self.control_qid_shape + cirq.qid_shape(self.sub_gate) + return self.control_qid_shape + protocols.qid_shape(self.sub_gate) def _decompose_(self, qubits): + if ( + protocols.has_unitary(self.sub_gate) + and protocols.num_qubits(self.sub_gate) == 1 + and self._qid_shape_() == (2,) * len(self._qid_shape_()) + ): + control_qubits = list(qubits[: self.num_controls()]) + invert_ops: List['cirq.Operation'] = [] + for cvals, cqbit in zip(self.control_values, qubits[: self.num_controls()]): + if set(cvals) == {0}: + invert_ops.append(common_gates.X(cqbit)) + elif set(cvals) == {0, 1}: + control_qubits.remove(cqbit) + decomposed_ops = controlled_gate_decomposition.decompose_multi_controlled_rotation( + protocols.unitary(self.sub_gate), control_qubits, qubits[-1] + ) + return invert_ops + decomposed_ops + invert_ops + + if isinstance(self.sub_gate, common_gates.CZPowGate): + z_sub_gate = common_gates.ZPowGate( + exponent=self.sub_gate.exponent, global_shift=self.sub_gate.global_shift + ) + controlled_z = ControlledGate( + sub_gate=z_sub_gate, + num_controls=self.num_controls() + 1, + control_values=self.control_values + (1,), + control_qid_shape=self.control_qid_shape + (2,), + ) + return protocols.decompose_once_with_qubits(controlled_z, qubits, NotImplemented) + + if isinstance(self.sub_gate, matrix_gates.MatrixGate): + # Default decompositions of 2/3 qubit `cirq.MatrixGate` ignores global phase, which is + # local phase in the controlled variant and hence cannot be ignored. + return NotImplemented + result = protocols.decompose_once_with_qubits( self.sub_gate, qubits[self.num_controls() :], NotImplemented ) - if result is NotImplemented: return NotImplemented - decomposed = [] + decomposed: List['cirq.Operation'] = [] for op in result: decomposed.append( cop.ControlledOperation(qubits[: self.num_controls()], op, self.control_values) @@ -135,7 +188,7 @@ def _value_equality_values_(self): ) def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs') -> np.ndarray: - qubits = cirq.LineQid.for_gate(self) + qubits = line_qubit.LineQid.for_gate(self) op = self.sub_gate.on(*qubits[self.num_controls() :]) c_op = cop.ControlledOperation(qubits[: self.num_controls()], op, self.control_values) return protocols.apply_unitary(c_op, args, default=NotImplemented) @@ -144,7 +197,7 @@ def _has_unitary_(self) -> bool: return protocols.has_unitary(self.sub_gate) def _unitary_(self) -> Union[np.ndarray, NotImplementedType]: - qubits = cirq.LineQid.for_gate(self) + qubits = line_qubit.LineQid.for_gate(self) op = self.sub_gate.on(*qubits[self.num_controls() :]) c_op = cop.ControlledOperation(qubits[: self.num_controls()], op, self.control_values) @@ -154,7 +207,7 @@ def _has_mixture_(self) -> bool: return protocols.has_mixture(self.sub_gate) def _mixture_(self) -> Union[np.ndarray, NotImplementedType]: - qubits = cirq.LineQid.for_gate(self) + qubits = line_qubit.LineQid.for_gate(self) op = self.sub_gate.on(*qubits[self.num_controls() :]) c_op = cop.ControlledOperation(qubits[: self.num_controls()], op, self.control_values) return protocols.mixture(c_op, default=NotImplemented) diff --git a/cirq-core/cirq/ops/controlled_gate_test.py b/cirq-core/cirq/ops/controlled_gate_test.py index 596ebb8147d..adb8a1f4882 100644 --- a/cirq-core/cirq/ops/controlled_gate_test.py +++ b/cirq-core/cirq/ops/controlled_gate_test.py @@ -46,12 +46,12 @@ def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> Union[np.ndarray, NotI zero = seed * a + (0, Ellipsis) one = seed * a + (1, Ellipsis) result = np.zeros(args.target_tensor.shape, args.target_tensor.dtype) - result[zero] = args.target_tensor[zero] * 2 + args.target_tensor[one] * 3 - result[one] = args.target_tensor[zero] * 5 + args.target_tensor[one] * 7 + result[zero] = (args.target_tensor[zero] + args.target_tensor[one]) * np.sqrt(0.5) + result[one] = (args.target_tensor[zero] - args.target_tensor[one]) * np.sqrt(0.5) return result def _unitary_(self): - return np.array([[2, 3], [5, 7]]) + return np.array([[1, 1], [1, -1]]) * np.sqrt(0.5) def __eq__(self, other): return isinstance(other, type(self)) @@ -331,8 +331,14 @@ def test_unitary(): GateUsingWorkspaceForApplyUnitary(), GateAllocatingNewSpaceForResult(), cirq.IdentityGate(qid_shape=(3, 4)), + cirq.ControlledGate( + cirq.XXPowGate(exponent=0.25, global_shift=-0.5), + num_controls=2, + control_values=(1, (1, 0)), + ), # Single qudit gate with dimension 4. - cirq.MatrixGate(np.kron(*(cirq.unitary(cirq.H),) * 2)), + cirq.MatrixGate(np.kron(*(cirq.unitary(cirq.H),) * 2), qid_shape=(4,)), + cirq.MatrixGate(cirq.testing.random_unitary(4, random_state=1234)), ], ) def test_controlled_gate_is_consistent(gate: cirq.Gate): diff --git a/cirq-core/cirq/ops/controlled_operation.py b/cirq-core/cirq/ops/controlled_operation.py index 151eff60874..d0035940341 100644 --- a/cirq-core/cirq/ops/controlled_operation.py +++ b/cirq-core/cirq/ops/controlled_operation.py @@ -93,11 +93,17 @@ def with_qubits(self, *new_qubits): ) def _decompose_(self): + result = protocols.decompose_once_with_qubits(self.gate, self.qubits, NotImplemented) + if result is not NotImplemented: + return result + result = protocols.decompose_once(self.sub_operation, NotImplemented) if result is NotImplemented: return NotImplemented - return [ControlledOperation(self.controls, op, self.control_values) for op in result] + return [ + op.controlled_by(*self.controls, control_values=self.control_values) for op in result + ] def _value_equality_values_(self): return (frozenset(zip(self.controls, self.control_values)), self.sub_operation) diff --git a/cirq-core/cirq/ops/controlled_operation_test.py b/cirq-core/cirq/ops/controlled_operation_test.py index 8ce6be78dc5..b6ed22d0a8a 100644 --- a/cirq-core/cirq/ops/controlled_operation_test.py +++ b/cirq-core/cirq/ops/controlled_operation_test.py @@ -47,12 +47,12 @@ def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> Union[np.ndarray, NotI zero = seed * a + (0, Ellipsis) one = seed * a + (1, Ellipsis) result = np.zeros(args.target_tensor.shape, args.target_tensor.dtype) - result[zero] = args.target_tensor[zero] * 2 + args.target_tensor[one] * 3 - result[one] = args.target_tensor[zero] * 5 + args.target_tensor[one] * 7 + result[zero] = (args.target_tensor[zero] + args.target_tensor[one]) * np.sqrt(0.5) + result[one] = (args.target_tensor[zero] - args.target_tensor[one]) * np.sqrt(0.5) return result def _unitary_(self): - return np.array([[2, 3], [5, 7]]) + return np.array([[1, 1], [1, -1]]) * np.sqrt(0.5) def __eq__(self, other): return isinstance(other, type(self)) @@ -323,7 +323,26 @@ def test_controlled_operation_is_consistent(gate: cirq.GateOperation): cb3 = cb.with_dimension(3) cgate = cirq.ControlledOperation([cb3], gate, control_values=[(0, 2)]) - cirq.testing.assert_implements_consistent_protocols(cgate) + cirq.testing.assert_implements_consistent_protocols( + cgate, ignore_decompose_to_default_gateset=True + ) + + +def test_controlled_circuit_operation_has_consistent_decomposition(): + op = cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.XXPowGate(exponent=0.25, global_shift=-0.5).on(*cirq.LineQubit.range(2)) + ) + ) + cb = cirq.NamedQubit('ctr') + cop = cirq.ControlledOperation([cb], op) + cirq.testing.assert_implements_consistent_protocols(cop, exponents=(-1, 1, 2)) + + cop = cirq.ControlledOperation([cb], op, control_values=[0]) + cirq.testing.assert_implements_consistent_protocols(cop, exponents=(-1, 1, 2)) + + cop = cirq.ControlledOperation([cb], op, control_values=[(0, 1)]) + cirq.testing.assert_implements_consistent_protocols(cop, exponents=(-1, 1, 2)) @pytest.mark.parametrize('resolve_fn', [cirq.resolve_parameters, cirq.resolve_parameters_once]) diff --git a/cirq-core/cirq/ops/matrix_gates.py b/cirq-core/cirq/ops/matrix_gates.py index 4513d644566..ff97607049a 100644 --- a/cirq-core/cirq/ops/matrix_gates.py +++ b/cirq-core/cirq/ops/matrix_gates.py @@ -18,13 +18,23 @@ import numpy as np -from cirq import linalg, protocols +from cirq import linalg, protocols, _import from cirq._compat import proper_repr from cirq.ops import raw_types if TYPE_CHECKING: import cirq +single_qubit_decompositions = _import.LazyLoader( + 'single_qubit_decompositions', globals(), 'cirq.transformers.analytical_decompositions' +) +two_qubit_to_cz = _import.LazyLoader( + 'two_qubit_to_cz', globals(), 'cirq.transformers.analytical_decompositions' +) +three_qubit_decomposition = _import.LazyLoader( + 'three_qubit_decomposition', globals(), 'cirq.transformers.analytical_decompositions' +) + class MatrixGate(raw_types.Gate): """A unitary qubit or qudit gate defined entirely by its matrix.""" @@ -116,6 +126,20 @@ def _phase_by_(self, phase_turns: float, qubit_index: int) -> 'MatrixGate': result[linalg.slice_for_qubits_equal_to([j], 1)] *= np.conj(p) return MatrixGate(matrix=result.reshape(self._matrix.shape), qid_shape=self._qid_shape) + def _decompose_(self, qubits: Tuple['cirq.Qid', ...]): + if self._qid_shape == (2,): + return [ + g.on(qubits[0]) + for g in single_qubit_decompositions.single_qubit_matrix_to_gates(self._matrix) + ] + if self._qid_shape == (2,) * 2: + return two_qubit_to_cz.two_qubit_matrix_to_operations( + *qubits, self._matrix, allow_partial_czs=True + ) + if self._qid_shape == (2,) * 3: + return three_qubit_decomposition.three_qubit_matrix_to_operations(*qubits, self._matrix) + return NotImplemented + def _has_unitary_(self) -> bool: return True diff --git a/cirq-core/cirq/ops/matrix_gates_test.py b/cirq-core/cirq/ops/matrix_gates_test.py index 16d21a45d18..53748830c46 100644 --- a/cirq-core/cirq/ops/matrix_gates_test.py +++ b/cirq-core/cirq/ops/matrix_gates_test.py @@ -279,13 +279,19 @@ def test_str_executes(): def test_one_qubit_consistent(): u = cirq.testing.random_unitary(2) g = cirq.MatrixGate(u) - cirq.testing.assert_implements_consistent_protocols(g) + cirq.testing.assert_implements_consistent_protocols(g, ignoring_global_phase=True) def test_two_qubit_consistent(): u = cirq.testing.random_unitary(4) g = cirq.MatrixGate(u) - cirq.testing.assert_implements_consistent_protocols(g) + cirq.testing.assert_implements_consistent_protocols(g, ignoring_global_phase=True) + + +def test_three_qubit_consistent(): + u = cirq.testing.random_unitary(8) + g = cirq.MatrixGate(u) + cirq.testing.assert_implements_consistent_protocols(g, ignoring_global_phase=True) def test_repr(): diff --git a/cirq-core/cirq/protocols/decompose_protocol.py b/cirq-core/cirq/protocols/decompose_protocol.py index dda931d1772..47aaa839414 100644 --- a/cirq-core/cirq/protocols/decompose_protocol.py +++ b/cirq-core/cirq/protocols/decompose_protocol.py @@ -47,6 +47,15 @@ DecomposeResult = Union[None, NotImplementedType, 'cirq.OP_TREE'] OpDecomposer = Callable[['cirq.Operation'], DecomposeResult] +DECOMPOSE_TARGET_GATESET = ops.Gateset( + ops.XPowGate, + ops.YPowGate, + ops.ZPowGate, + ops.CZPowGate, + ops.MeasurementGate, + ops.GlobalPhaseGate, +) + def _value_error_describing_bad_operation(op: 'cirq.Operation') -> ValueError: return ValueError(f"Operation doesn't satisfy the given `keep` but can't be decomposed: {op!r}") diff --git a/cirq-core/cirq/testing/__init__.py b/cirq-core/cirq/testing/__init__.py index f08cffc00d5..7124f91173e 100644 --- a/cirq-core/cirq/testing/__init__.py +++ b/cirq-core/cirq/testing/__init__.py @@ -33,6 +33,7 @@ ) from cirq.testing.consistent_decomposition import ( + assert_decompose_ends_at_default_gateset, assert_decompose_is_consistent_with_unitary, ) diff --git a/cirq-core/cirq/testing/consistent_decomposition.py b/cirq-core/cirq/testing/consistent_decomposition.py index 1050a48ffdd..81cca2842c5 100644 --- a/cirq-core/cirq/testing/consistent_decomposition.py +++ b/cirq-core/cirq/testing/consistent_decomposition.py @@ -47,3 +47,20 @@ def assert_decompose_is_consistent_with_unitary(val: Any, ignoring_global_phase: else: # coverage: ignore np.testing.assert_allclose(actual, expected, atol=1e-8) + + +def assert_decompose_ends_at_default_gateset(val: Any): + """Ensures that all cirq gate decompositions end at the default cirq gateset.""" + + # pylint: disable=unused-variable + __tracebackhide__ = True + # pylint: enable=unused-variable + if protocols.is_parameterized(val): + return + args = () if isinstance(val, ops.Operation) else (tuple(devices.LineQid.for_gate(val)),) + dec_once = protocols.decompose_once(val, None, *args) + if dec_once is None: + # _decompose_ is NotImplemented, so silently return. + return + dec = [*ops.flatten_to_ops(protocols.decompose(d) for d in dec_once)] + assert all(op in protocols.decompose_protocol.DECOMPOSE_TARGET_GATESET for op in dec) diff --git a/cirq-core/cirq/testing/consistent_decomposition_test.py b/cirq-core/cirq/testing/consistent_decomposition_test.py index e573084254a..781d6fb8cfd 100644 --- a/cirq-core/cirq/testing/consistent_decomposition_test.py +++ b/cirq-core/cirq/testing/consistent_decomposition_test.py @@ -49,3 +49,63 @@ def test_assert_decompose_is_consistent_with_unitary(): cirq.testing.assert_decompose_is_consistent_with_unitary( BadGateDecompose().on(cirq.NamedQubit('q')) ) + + +class GateDecomposesToDefaultGateset(cirq.Gate): + def _num_qubits_(self): + return 2 + + def _decompose_(self, qubits): + return [GoodGateDecompose().on(qubits[0]), BadGateDecompose().on(qubits[1])] + + +class GateDecomposeDoesNotEndInDefaultGateset(cirq.Gate): + def _num_qubits_(self): + return 4 + + def _decompose_(self, qubits): + return cirq.MatrixGate(cirq.testing.random_unitary(16)).on(*qubits) + + +class GateDecomposeNotImplemented(cirq.SingleQubitGate): + def _decompose_(self, qubits): + return NotImplemented + + +class ParameterizedGate(cirq.SingleQubitGate): + def _is_parameterized_(self): + return True + + def _num_qubits_(self): + return 4 + + def _decompose_(self, qubits): + assert False, "Decompose should not be called for parameterized gates." + + +def test_assert_decompose_ends_at_default_gateset(): + + cirq.testing.assert_decompose_ends_at_default_gateset(GateDecomposesToDefaultGateset()) + cirq.testing.assert_decompose_ends_at_default_gateset( + GateDecomposesToDefaultGateset().on(*cirq.LineQubit.range(2)) + ) + + cirq.testing.assert_decompose_ends_at_default_gateset(GateDecomposeNotImplemented()) + cirq.testing.assert_decompose_ends_at_default_gateset( + GateDecomposeNotImplemented().on(cirq.NamedQubit('q')) + ) + + cirq.testing.assert_decompose_ends_at_default_gateset(ParameterizedGate()) + cirq.testing.assert_decompose_ends_at_default_gateset( + ParameterizedGate().on(*cirq.LineQubit.range(4)) + ) + + with pytest.raises(AssertionError): + cirq.testing.assert_decompose_ends_at_default_gateset( + GateDecomposeDoesNotEndInDefaultGateset() + ) + + with pytest.raises(AssertionError): + cirq.testing.assert_decompose_ends_at_default_gateset( + GateDecomposeDoesNotEndInDefaultGateset().on(*cirq.LineQubit.range(4)) + ) diff --git a/cirq-core/cirq/testing/consistent_protocols.py b/cirq-core/cirq/testing/consistent_protocols.py index 51d44a77416..71e2047509b 100644 --- a/cirq-core/cirq/testing/consistent_protocols.py +++ b/cirq-core/cirq/testing/consistent_protocols.py @@ -26,6 +26,7 @@ ) from cirq.testing.consistent_decomposition import ( assert_decompose_is_consistent_with_unitary, + assert_decompose_ends_at_default_gateset, ) from cirq.testing.consistent_phase_by import ( assert_phase_by_is_consistent_with_unitary, @@ -55,6 +56,7 @@ def assert_implements_consistent_protocols( setup_code: str = 'import cirq\nimport numpy as np\nimport sympy', global_vals: Optional[Dict[str, Any]] = None, local_vals: Optional[Dict[str, Any]] = None, + ignore_decompose_to_default_gateset: bool = False, ) -> None: """Checks that a value is internally consistent and has a good __repr__.""" global_vals = global_vals or {} @@ -66,6 +68,7 @@ def assert_implements_consistent_protocols( setup_code=setup_code, global_vals=global_vals, local_vals=local_vals, + ignore_decompose_to_default_gateset=ignore_decompose_to_default_gateset, ) for exponent in exponents: @@ -77,6 +80,7 @@ def assert_implements_consistent_protocols( setup_code=setup_code, global_vals=global_vals, local_vals=local_vals, + ignore_decompose_to_default_gateset=ignore_decompose_to_default_gateset, ) @@ -90,11 +94,12 @@ def assert_eigengate_implements_consistent_protocols( setup_code: str = 'import cirq\nimport numpy as np\nimport sympy', global_vals: Optional[Dict[str, Any]] = None, local_vals: Optional[Dict[str, Any]] = None, + ignore_decompose_to_default_gateset: bool = False, ) -> None: """Checks that an EigenGate subclass is internally consistent and has a good __repr__.""" # pylint: disable=unused-variable - __tracebackhide__ = True + # __tracebackhide__ = True # pylint: enable=unused-variable for exponent in exponents: @@ -105,6 +110,7 @@ def assert_eigengate_implements_consistent_protocols( setup_code=setup_code, global_vals=global_vals, local_vals=local_vals, + ignore_decompose_to_default_gateset=ignore_decompose_to_default_gateset, ) @@ -143,8 +149,9 @@ def _assert_meets_standards_helper( setup_code: str, global_vals: Optional[Dict[str, Any]], local_vals: Optional[Dict[str, Any]], + ignore_decompose_to_default_gateset: bool, ) -> None: - __tracebackhide__ = True # pylint: disable=unused-variable + # __tracebackhide__ = True # pylint: disable=unused-variable assert_consistent_resolve_parameters(val) assert_specifies_has_unitary_if_unitary(val) @@ -154,6 +161,8 @@ def _assert_meets_standards_helper( assert_qasm_is_consistent_with_unitary(val) assert_has_consistent_trace_distance_bound(val) assert_decompose_is_consistent_with_unitary(val, ignoring_global_phase=ignoring_global_phase) + if not ignore_decompose_to_default_gateset: + assert_decompose_ends_at_default_gateset(val) assert_phase_by_is_consistent_with_unitary(val) assert_pauli_expansion_is_consistent_with_unitary(val) assert_equivalent_repr(