Skip to content

Commit 45624ff

Browse files
authored
Add support for deep=True flag in cg.optimized_for_sycamore and cg.SycamoreTargetGateset transformers (#5126)
- Adds support for deep=True flag in `sycamore_gateset.merge_swap_rzz_and_2q_unitaries` transformer - Updates `cg.optimized_for_sycamore` to call `cirq.optimize_for_target_gateset` with `deep=True` by default, such that the method preserves circuit structure by default (which corresponds to its old behavior). - Fixes #5039
1 parent cdd3f8c commit 45624ff

File tree

4 files changed

+103
-5
lines changed

4 files changed

+103
-5
lines changed

cirq-google/cirq_google/optimizers/optimize_for_sycamore.py

+1
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def optimized_for_sycamore(
105105
copy = cirq.optimize_for_target_gateset(
106106
circuit,
107107
gateset=_TARGET_GATESETS[optimizer_type](tolerance, tabulation),
108+
context=cirq.TransformerContext(deep=True),
108109
)
109110
copy = cirq.merge_single_qubit_gates_to_phxz(copy, atol=tolerance)
110111
copy = cirq.eject_phased_paulis(copy, atol=tolerance)

cirq-google/cirq_google/optimizers/optimize_for_sycamore_test.py

+27
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,30 @@ def test_assert_new_device_deprecated():
134134
_ = cg.optimized_for_sycamore(
135135
circuit0, optimizer_type='sqrt_iswap', new_device=TestDevice()
136136
)
137+
138+
139+
@pytest.mark.parametrize(
140+
'optimizer_type, two_qubit_gate_type',
141+
[('sycamore', cg.SycamoreGate), ('sqrt_iswap', cirq.ISwapPowGate), ('xmon', cirq.CZPowGate)],
142+
)
143+
def test_circuit_operation_conversion(optimizer_type, two_qubit_gate_type):
144+
q0, q1 = cirq.LineQubit.range(2)
145+
subcircuit = cirq.FrozenCircuit(cirq.X(q0), cirq.SWAP(q0, q1))
146+
circuit = cirq.Circuit(cirq.CircuitOperation(subcircuit))
147+
converted_circuit = cg.optimized_for_sycamore(circuit, optimizer_type=optimizer_type)
148+
# Verify that the CircuitOperation was preserved.
149+
ops = list(converted_circuit.all_operations())
150+
assert isinstance(ops[0], cirq.CircuitOperation)
151+
# Verify that the contents of the CircuitOperation were optimized.
152+
converted_subcircuit = cg.optimized_for_sycamore(
153+
subcircuit.unfreeze(), optimizer_type=optimizer_type
154+
)
155+
assert len(
156+
[*converted_subcircuit.findall_operations_with_gate_type(two_qubit_gate_type)]
157+
) == len([*ops[0].circuit.findall_operations_with_gate_type(two_qubit_gate_type)])
158+
cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
159+
ops[0].circuit, converted_subcircuit, atol=1e-8
160+
)
161+
cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
162+
circuit, converted_circuit, atol=1e-8
163+
)

cirq-google/cirq_google/transformers/target_gatesets/sycamore_gateset.py

+26-5
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""Target gateset used for compiling circuits to Sycamore + 1-q rotations + measurement gates."""
1616

1717
import itertools
18-
from typing import List, Optional, Sequence
18+
from typing import cast, List, Optional, Sequence
1919

2020
import cirq
2121
from cirq.protocols.decompose_protocol import DecomposeResult
@@ -33,6 +33,7 @@ def merge_swap_rzz_and_2q_unitaries(
3333
context: Optional['cirq.TransformerContext'] = None,
3434
merged_swap_rzz_tag: str = "_merged_swap_rzz",
3535
merged_2q_component_tag: str = "_merged_2q_unitaries",
36+
intermediate_result_tag: Optional[str] = None,
3637
) -> 'cirq.Circuit':
3738
"""Merges 2-qubit connected components and adjacent `cirq.SWAP` and `cirq.ZZPowGate` gates.
3839
@@ -50,6 +51,8 @@ def merge_swap_rzz_and_2q_unitaries(
5051
`cirq.SWAP` and `cirq.ZZPowGate`s.
5152
merged_2q_component_tag: Tag to apply on newly introduced circuit operations wrapping
5253
connected components of 1 and 2 qubit unitaries.
54+
intermediate_result_tag: If specified, the tag is added to newly introduced both the newly
55+
introduced circuit operations encapsulating swap_rzz or 2q connected component.
5356
5457
Returns:
5558
Copy of the transformed input circuit.
@@ -71,19 +74,34 @@ def merge_func_swap_rzz(
7174
return False
7275

7376
tags_to_ignore = context.tags_to_ignore if context else ()
77+
deep = context.deep if context else False
7478
circuit = cirq.merge_operations_to_circuit_op(
7579
circuit,
7680
merge_func_swap_rzz,
7781
tags_to_ignore=tags_to_ignore,
7882
merged_circuit_op_tag=merged_swap_rzz_tag,
83+
deep=deep,
7984
)
8085

81-
return cirq.merge_k_qubit_unitaries_to_circuit_op(
86+
circuit = cirq.merge_k_qubit_unitaries_to_circuit_op(
8287
circuit,
8388
k=2,
84-
tags_to_ignore=tags_to_ignore + (merged_swap_rzz_tag,),
89+
tags_to_ignore=tuple(tags_to_ignore) + (merged_swap_rzz_tag,),
8590
merged_circuit_op_tag=merged_2q_component_tag,
86-
).unfreeze(copy=False)
91+
deep=deep,
92+
)
93+
94+
if intermediate_result_tag is not None:
95+
merged_cop_tags = {merged_swap_rzz_tag, merged_2q_component_tag}
96+
circuit = cirq.map_operations(
97+
circuit,
98+
map_func=lambda op, _: op
99+
if merged_cop_tags.isdisjoint(op.tags)
100+
else op.with_tags(cast(str, intermediate_result_tag)),
101+
tags_to_ignore=tags_to_ignore,
102+
deep=True,
103+
)
104+
return circuit.unfreeze(copy=False)
87105

88106

89107
class SycamoreTargetGateset(cirq.TwoQubitCompilationTargetGateset):
@@ -122,7 +140,10 @@ def preprocess_transformers(self) -> List[cirq.TRANSFORMER]:
122140
cirq.expand_composite,
123141
no_decomp=lambda op: cirq.num_qubits(op) <= self.num_qubits,
124142
),
125-
merge_swap_rzz_and_2q_unitaries,
143+
_create_transformer_with_kwargs(
144+
merge_swap_rzz_and_2q_unitaries,
145+
intermediate_result_tag=self._intermediate_result_tag,
146+
),
126147
]
127148

128149
def _decompose_two_qubit_operation(self, op: cirq.Operation, _) -> DecomposeResult:

cirq-google/cirq_google/transformers/target_gatesets/sycamore_gateset_test.py

+49
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,55 @@ def test_merge_swap_rzz_and_2q_unitaries_raises_if_tags_sames():
9797
)
9898

9999

100+
def test_merge_swap_rzz_and_2q_unitaries_deep():
101+
q = cirq.LineQubit.range(3)
102+
swap_rzz = cirq.FrozenCircuit(cirq.SWAP(*q[:2]), cirq.ZZ(*q[:2]) ** 0.5)
103+
rzz_swap = cirq.FrozenCircuit(cirq.ZZ(*q[1:]) ** 0.25, cirq.SWAP(*q[1:]))
104+
x_cnot_x = cirq.FrozenCircuit(cirq.X(q[0]), cirq.CNOT(*q[:2]), cirq.X(q[0]))
105+
x_cz_x = cirq.FrozenCircuit(cirq.X(q[2]), cirq.CZ(*q[1:]), cirq.X(q[2]))
106+
c_orig = cirq.Circuit(
107+
cirq.CircuitOperation(swap_rzz).repeat(3).with_tags("ignore"),
108+
cirq.CircuitOperation(rzz_swap).repeat(5).with_tags("preserve_tag"),
109+
cirq.CircuitOperation(x_cnot_x).repeat(7).with_tags("ignore"),
110+
cirq.CircuitOperation(x_cz_x).repeat(9).with_tags("preserve_tag"),
111+
cirq.CircuitOperation(
112+
cirq.FrozenCircuit(
113+
[swap_rzz, rzz_swap, x_cnot_x, x_cz_x],
114+
cirq.Moment(cirq.Y(qq).with_tags("ignore") for qq in q),
115+
)
116+
),
117+
)
118+
t_swap_rzz = "_merged_swap_rzz_tag"
119+
t_2q_cmp = "_merged_2q_unitaries_component"
120+
t_all = "_intermediate_result_tag_apply_to_all"
121+
122+
def _wrap_cop(c: cirq.FrozenCircuit, *tags) -> cirq.FrozenCircuit:
123+
return cirq.FrozenCircuit(cirq.CircuitOperation(c).with_tags(*tags, t_all))
124+
125+
c_expected = cirq.Circuit(
126+
cirq.CircuitOperation(swap_rzz).repeat(3).with_tags("ignore"),
127+
cirq.CircuitOperation(_wrap_cop(rzz_swap, t_swap_rzz)).repeat(5).with_tags("preserve_tag"),
128+
cirq.CircuitOperation(x_cnot_x).repeat(7).with_tags("ignore"),
129+
cirq.CircuitOperation(_wrap_cop(x_cz_x, t_2q_cmp)).repeat(9).with_tags("preserve_tag"),
130+
cirq.CircuitOperation(
131+
cirq.FrozenCircuit(
132+
[_wrap_cop(swap_rzz, t_swap_rzz), _wrap_cop(rzz_swap, t_swap_rzz)],
133+
[_wrap_cop(x_cnot_x, t_2q_cmp), _wrap_cop(x_cz_x, t_2q_cmp)],
134+
cirq.Moment(cirq.Y(qq).with_tags("ignore") for qq in q),
135+
)
136+
),
137+
)
138+
context = cirq.TransformerContext(tags_to_ignore=["ignore"], deep=True)
139+
c_new = sycamore_gateset.merge_swap_rzz_and_2q_unitaries(
140+
c_orig,
141+
context=context,
142+
merged_swap_rzz_tag=t_swap_rzz,
143+
merged_2q_component_tag=t_2q_cmp,
144+
intermediate_result_tag=t_all,
145+
)
146+
cirq.testing.assert_same_circuits(cirq.drop_empty_moments(c_new, context=context), c_expected)
147+
148+
100149
def test_sycamore_gateset_compiles_swap_zz():
101150
qubits = cirq.LineQubit.range(3)
102151

0 commit comments

Comments
 (0)