Skip to content

Commit 9f3034c

Browse files
authored
Replace isinstance(op, GateOperation) checks in cirq_google optimizers to support other operation types. (#4459)
Proliferation of `isinstance(op, GateOperation)` checks results in many inconsistencies due to different available operation types like `ControlledOperations` and `TaggedOperations`. This PR fixes #4152 and is a first step towards fixing #3556 Note that `TaggedOperations` which were earlier ignored by the optimizers would now be considered, and hence this is potentially a breaking change if people were implicitly relying on TaggedOperations not getting compiled by the optimizers. Since the optimizer doesn't document / test this behavior, I consider it to be a bug rather than a feature and an explicit `NoCompile` tag should be implemented as part of #4253 This PR is blocked on submitting #4167 (tests will stop failing once the PR is submitted and this rebased). Update: This is now ready for review.
1 parent bd2e63c commit 9f3034c

6 files changed

+81
-20
lines changed

cirq-google/cirq_google/devices/xmon_device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def validate_gate(self, gate: cirq.Gate):
9898
raise ValueError(f'Unsupported gate type: {gate!r}')
9999

100100
def validate_operation(self, operation: cirq.Operation):
101-
if not isinstance(operation, cirq.GateOperation):
101+
if operation.gate is None:
102102
raise ValueError(f'Unsupported operation: {operation!r}')
103103

104104
self.validate_gate(operation.gate)

cirq-google/cirq_google/devices/xmon_device_test.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -153,20 +153,39 @@ def test_validate_operation_existing_qubits():
153153
d.validate_operation(cirq.CZ(cirq.GridQubit(1, 0), cirq.GridQubit(1, 1)))
154154

155155

156-
def test_validate_operation_supported_gate():
156+
class MyGate(cirq.Gate):
157+
def num_qubits(self):
158+
return 1
159+
160+
161+
q = cirq.GridQubit.rect(1, 3)
162+
matrix_gate = cirq.MatrixGate(cirq.testing.random_unitary(2))
163+
164+
165+
@pytest.mark.parametrize(
166+
'op,is_valid',
167+
[
168+
(cirq.Z(cirq.GridQubit(0, 0)), True),
169+
(cirq.Z(cirq.GridQubit(0, 0)).with_tags('test_tag'), True),
170+
(
171+
cirq.Z(cirq.GridQubit(0, 0)).with_tags('test_tag').controlled_by(cirq.GridQubit(0, 1)),
172+
True,
173+
),
174+
(
175+
cirq.Z(cirq.GridQubit(0, 0)).controlled_by(cirq.GridQubit(0, 1)).with_tags('test_tag'),
176+
True,
177+
),
178+
(NotImplementedOperation(), False),
179+
(MyGate()(cirq.GridQubit(0, 0)), False),
180+
],
181+
)
182+
def test_validate_operation_supported_gate(op, is_valid):
157183
d = square_device(3, 3)
158-
159-
class MyGate(cirq.Gate):
160-
def num_qubits(self):
161-
return 1
162-
163-
d.validate_operation(cirq.GateOperation(cirq.Z, [cirq.GridQubit(0, 0)]))
164-
165-
assert MyGate().num_qubits() == 1
166-
with pytest.raises(ValueError):
167-
d.validate_operation(cirq.GateOperation(MyGate(), [cirq.GridQubit(0, 0)]))
168-
with pytest.raises(ValueError):
169-
d.validate_operation(NotImplementedOperation())
184+
if is_valid:
185+
d.validate_operation(op)
186+
else:
187+
with pytest.raises(ValueError):
188+
d.validate_operation(op)
170189

171190

172191
def test_validate_circuit_repeat_measurement_keys():

cirq-google/cirq_google/optimizers/convert_to_sycamore_gates.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def _convert_one(self, op: cirq.Operation) -> cirq.OP_TREE:
113113
"""
114114
if len(op.qubits) == 1:
115115
return _phased_x_z_ops(cirq.unitary(op, None), op.qubits[0])
116-
elif len(op.qubits) == 2 and isinstance(op, cirq.GateOperation):
116+
elif len(op.qubits) == 2:
117117
return known_two_q_operations_to_sycamore_operations(
118118
op.qubits[0], op.qubits[1], op, self.tabulation
119119
)
@@ -139,7 +139,7 @@ def on_stuck_raise(bad):
139139
def optimization_at(
140140
self, circuit: cirq.Circuit, index: int, op: cirq.Operation
141141
) -> Optional[cirq.PointOptimizationSummary]:
142-
if not isinstance(op, cirq.GateOperation):
142+
if op.gate is None:
143143
return None
144144

145145
gate = op.gate
@@ -151,7 +151,7 @@ def optimization_at(
151151
next_index = circuit.next_moment_operating_on(op.qubits, index + 1)
152152
if next_index is not None:
153153
ops_in_front = list({circuit.operation_at(q, next_index) for q in op.qubits})
154-
if len(ops_in_front) == 1 and isinstance(ops_in_front[0], cirq.GateOperation):
154+
if len(ops_in_front) == 1 and ops_in_front[0] is not None:
155155
gate2 = ops_in_front[0].gate
156156
else:
157157
next_index = 0

cirq-google/cirq_google/optimizers/convert_to_sycamore_gates_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,3 +278,23 @@ def test_sycamore_invalid_tabulation():
278278
sycamore_tabulation = {}
279279
with pytest.raises(ValueError):
280280
cgoc.ConvertToSycamoreGates(sycamore_tabulation)
281+
282+
283+
q = cirq.GridQubit.rect(1, 3)
284+
matrix_gate = cirq.MatrixGate(cirq.testing.random_unitary(2))
285+
286+
287+
@pytest.mark.parametrize(
288+
'op, is_valid',
289+
[
290+
(cirq.CircuitOperation(cirq.FrozenCircuit(matrix_gate(q[0]))), False),
291+
(matrix_gate(q[0]), True),
292+
(matrix_gate(q[0]).with_tags('test_tags'), True),
293+
(matrix_gate(q[0]).controlled_by(q[1]), True),
294+
(matrix_gate(q[0]).controlled_by(q[1]).with_tags('test_tags'), True),
295+
(matrix_gate(q[0]).with_tags('test_tags').controlled_by(q[1]), True),
296+
],
297+
)
298+
def test_supported_operation(op, is_valid):
299+
c = cirq.Circuit(op)
300+
assert (cirq_google.ConvertToSycamoreGates().optimization_at(c, 0, op) is not None) == is_valid

cirq-google/cirq_google/optimizers/convert_to_xmon_gates.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def _is_native_xmon_op(self, op: cirq.Operation) -> bool:
6565
"""
6666
from cirq_google.devices import XmonDevice
6767

68-
return isinstance(op, cirq.GateOperation) and XmonDevice.is_supported_gate(op.gate)
68+
return op.gate is not None and XmonDevice.is_supported_gate(op.gate)
6969

7070
def convert(self, op: cirq.Operation) -> List[cirq.Operation]:
7171
def on_stuck_raise(bad):
@@ -86,6 +86,9 @@ def on_stuck_raise(bad):
8686
def optimization_at(
8787
self, circuit: cirq.Circuit, index: int, op: cirq.Operation
8888
) -> Optional[cirq.PointOptimizationSummary]:
89+
if op.gate is None:
90+
return None
91+
8992
converted = self.convert(op)
9093
if len(converted) == 1 and converted[0] is op:
9194
return None

cirq-google/cirq_google/optimizers/convert_to_xmon_gates_test.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,27 @@ def test_avoids_infinite_cycle_when_matrix_available():
4545
cirq.protocols.decompose(c)
4646

4747

48+
q = cirq.GridQubit.rect(1, 3)
49+
matrix_gate = cirq.MatrixGate(cirq.testing.random_unitary(2))
50+
51+
4852
def test_bad_operation():
49-
qubits = cirq.GridQubit.rect(1, 3)
50-
c = cirq.Circuit(NonNativeGate().on(qubits[0]))
53+
c = cirq.Circuit(NonNativeGate().on(q[0]))
5154
with pytest.raises(TypeError):
5255
cirq_google.ConvertToXmonGates().optimize_circuit(c)
56+
57+
58+
@pytest.mark.parametrize(
59+
'op, is_valid',
60+
[
61+
(cirq.CircuitOperation(cirq.FrozenCircuit(matrix_gate(q[0]))), False),
62+
(matrix_gate(q[0]), True),
63+
(matrix_gate(q[0]).with_tags('test_tags'), True),
64+
(matrix_gate(q[0]).controlled_by(q[1]), True),
65+
(matrix_gate(q[0]).controlled_by(q[1]).with_tags('test_tags'), True),
66+
(matrix_gate(q[0]).with_tags('test_tags').controlled_by(q[1]), True),
67+
],
68+
)
69+
def test_supported_operation(op, is_valid):
70+
c = cirq.Circuit(op)
71+
assert (cirq_google.ConvertToXmonGates().optimization_at(c, 0, op) is not None) == is_valid

0 commit comments

Comments
 (0)