Skip to content

Commit fc35a4f

Browse files
authored
Add assert_decompose_ends_at_default_gateset consistency test (quantumlib#5079)
* Add assert_decompose_ends_at_default_gateset consistency test * Refactor assert_decompose_ends_at_default_gateset to provide hook to ignore gates without known decompositions
1 parent 4a648cc commit fc35a4f

File tree

4 files changed

+88
-0
lines changed

4 files changed

+88
-0
lines changed

cirq/protocols/decompose_protocol.py

+9
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,15 @@
4747
DecomposeResult = Union[None, NotImplementedType, 'cirq.OP_TREE']
4848
OpDecomposer = Callable[['cirq.Operation'], DecomposeResult]
4949

50+
DECOMPOSE_TARGET_GATESET = ops.Gateset(
51+
ops.XPowGate,
52+
ops.YPowGate,
53+
ops.ZPowGate,
54+
ops.CZPowGate,
55+
ops.MeasurementGate,
56+
ops.GlobalPhaseGate,
57+
)
58+
5059

5160
def _value_error_describing_bad_operation(op: 'cirq.Operation') -> ValueError:
5261
return ValueError(f"Operation doesn't satisfy the given `keep` but can't be decomposed: {op!r}")

cirq/testing/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
)
3434

3535
from cirq.testing.consistent_decomposition import (
36+
assert_decompose_ends_at_default_gateset,
3637
assert_decompose_is_consistent_with_unitary,
3738
)
3839

cirq/testing/consistent_decomposition.py

+17
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,20 @@ def assert_decompose_is_consistent_with_unitary(val: Any, ignoring_global_phase:
4747
else:
4848
# coverage: ignore
4949
np.testing.assert_allclose(actual, expected, atol=1e-8)
50+
51+
52+
def _known_gate_with_no_decomposition(val: Any):
53+
"""Checks whether `val` is a known gate with no default decomposition to default gateset."""
54+
return False
55+
56+
57+
def assert_decompose_ends_at_default_gateset(val: Any):
58+
"""Asserts that cirq.decompose(val) ends at default cirq gateset or a known gate."""
59+
if _known_gate_with_no_decomposition(val):
60+
return # coverage: ignore
61+
args = () if isinstance(val, ops.Operation) else (tuple(devices.LineQid.for_gate(val)),)
62+
dec_once = protocols.decompose_once(val, [val(*args[0]) if args else val], *args)
63+
for op in [*ops.flatten_to_ops(protocols.decompose(d) for d in dec_once)]:
64+
assert _known_gate_with_no_decomposition(op.gate) or (
65+
op in protocols.decompose_protocol.DECOMPOSE_TARGET_GATESET
66+
), f'{val} decomposed to {op}, which is not part of default cirq target gateset.'

cirq/testing/consistent_decomposition_test.py

+61
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import pytest
1616

1717
import numpy as np
18+
import sympy
1819

1920
import cirq
2021

@@ -49,3 +50,63 @@ def test_assert_decompose_is_consistent_with_unitary():
4950
cirq.testing.assert_decompose_is_consistent_with_unitary(
5051
BadGateDecompose().on(cirq.NamedQubit('q'))
5152
)
53+
54+
55+
class GateDecomposesToDefaultGateset(cirq.Gate):
56+
def _num_qubits_(self):
57+
return 2
58+
59+
def _decompose_(self, qubits):
60+
return [GoodGateDecompose().on(qubits[0]), BadGateDecompose().on(qubits[1])]
61+
62+
63+
class GateDecomposeDoesNotEndInDefaultGateset(cirq.Gate):
64+
def _num_qubits_(self):
65+
return 4
66+
67+
def _decompose_(self, qubits):
68+
yield GateDecomposeNotImplemented().on_each(*qubits)
69+
70+
71+
class GateDecomposeNotImplemented(cirq.SingleQubitGate):
72+
def _decompose_(self, qubits):
73+
return NotImplemented
74+
75+
76+
class ParameterizedGate(cirq.SingleQubitGate):
77+
def _num_qubits_(self):
78+
return 2
79+
80+
def _decompose_(self, qubits):
81+
yield cirq.X(qubits[0]) ** sympy.Symbol("x")
82+
yield cirq.Y(qubits[1]) ** sympy.Symbol("y")
83+
84+
85+
def test_assert_decompose_ends_at_default_gateset():
86+
87+
cirq.testing.assert_decompose_ends_at_default_gateset(GateDecomposesToDefaultGateset())
88+
cirq.testing.assert_decompose_ends_at_default_gateset(
89+
GateDecomposesToDefaultGateset().on(*cirq.LineQubit.range(2))
90+
)
91+
92+
cirq.testing.assert_decompose_ends_at_default_gateset(ParameterizedGate())
93+
cirq.testing.assert_decompose_ends_at_default_gateset(
94+
ParameterizedGate().on(*cirq.LineQubit.range(2))
95+
)
96+
97+
with pytest.raises(AssertionError):
98+
cirq.testing.assert_decompose_ends_at_default_gateset(GateDecomposeNotImplemented())
99+
100+
with pytest.raises(AssertionError):
101+
cirq.testing.assert_decompose_ends_at_default_gateset(
102+
GateDecomposeNotImplemented().on(cirq.NamedQubit('q'))
103+
)
104+
with pytest.raises(AssertionError):
105+
cirq.testing.assert_decompose_ends_at_default_gateset(
106+
GateDecomposeDoesNotEndInDefaultGateset()
107+
)
108+
109+
with pytest.raises(AssertionError):
110+
cirq.testing.assert_decompose_ends_at_default_gateset(
111+
GateDecomposeDoesNotEndInDefaultGateset().on(*cirq.LineQubit.range(4))
112+
)

0 commit comments

Comments
 (0)