Skip to content

Commit 9e94b9e

Browse files
tanujkhattarMichaelBroughton
authored andcommitted
Add merge_operations and merge_moments transformer primitives (quantumlib#4708)
* merge_operations and merge_moments transformer primitives * Refactor to use features compatible with python3.6 * Add complexity tests for merge_operations * Add iteration info to the docstring
1 parent 45fd829 commit 9e94b9e

File tree

4 files changed

+267
-0
lines changed

4 files changed

+267
-0
lines changed

cirq-core/cirq/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,8 @@
339339
map_moments,
340340
map_operations,
341341
map_operations_and_unroll,
342+
merge_moments,
343+
merge_operations,
342344
merge_single_qubit_gates_into_phased_x_z,
343345
merge_single_qubit_gates_into_phxz,
344346
MergeInteractions,

cirq-core/cirq/optimizers/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@
101101
map_moments,
102102
map_operations,
103103
map_operations_and_unroll,
104+
merge_moments,
105+
merge_operations,
104106
unroll_circuit_op,
105107
unroll_circuit_op_greedy_earliest,
106108
unroll_circuit_op_greedy_frontier,

cirq-core/cirq/optimizers/transformer_primitives.py

+117
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Callable,
2121
Dict,
2222
Hashable,
23+
List,
2324
Optional,
2425
Sequence,
2526
TYPE_CHECKING,
@@ -129,6 +130,122 @@ def map_operations_and_unroll(
129130
return unroll_circuit_op(map_operations(circuit, map_func))
130131

131132

133+
def merge_operations(
134+
circuit: CIRCUIT_TYPE,
135+
merge_func: Callable[[ops.Operation, ops.Operation], Optional[ops.Operation]],
136+
) -> CIRCUIT_TYPE:
137+
"""Merges operations in a circuit by calling `merge_func` iteratively on operations.
138+
139+
Two operations op1 and op2 are merge-able if
140+
- There is no other operations between op1 and op2 in the circuit
141+
- is_subset(op1.qubits, op2.qubits) or is_subset(op2.qubits, op1.qubits)
142+
143+
The `merge_func` is a callable which, given two merge-able operations
144+
op1 and op2, decides whether they should be merged into a single operation
145+
or not. If not, it should return None, else it should return the single merged
146+
operations `op`.
147+
148+
The method iterates on the input circuit moment-by-moment from left to right and attempts
149+
to repeatedly merge each operation in the latest moment with all the corresponding merge-able
150+
operations to it's left.
151+
152+
If op1 and op2 are merged, both op1 and op2 are deleted from the circuit and
153+
the resulting `merged_op` is inserted at the index corresponding to the larger
154+
of op1/op2. If both op1 and op2 act on the same number of qubits, `merged_op` is
155+
inserted in the smaller moment index to minimize circuit depth.
156+
157+
The number of calls to `merge_func` is O(N), where N = Total no. of operations, because:
158+
- Every time the `merge_func` returns a new operation, the number of operations in the
159+
circuit reduce by 1 and hence this can happen at most O(N) times
160+
- Every time the `merge_func` returns None, the current operation is inserted into the
161+
frontier and we go on to process the next operation, which can also happen at-most
162+
O(N) times.
163+
164+
Args:
165+
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
166+
merge_func: Callable to determine whether two merge-able operations in the circuit should
167+
be merged. If the operations can be merged, the callable should return the merged
168+
operation, else None.
169+
170+
Returns:
171+
Copy of input circuit with merged operations.
172+
173+
Raises:
174+
ValueError if the merged operation acts on new qubits outside the set of qubits
175+
corresponding to the original operations to be merged.
176+
"""
177+
178+
def apply_merge_func(op1: ops.Operation, op2: ops.Operation) -> Optional[ops.Operation]:
179+
new_op = merge_func(op1, op2)
180+
qubit_set = frozenset(op1.qubits + op2.qubits)
181+
if new_op is not None and not qubit_set.issuperset(new_op.qubits):
182+
raise ValueError(
183+
f"Merged operation {new_op} must act on a subset of qubits of "
184+
f"original operations {op1} and {op2}"
185+
)
186+
return new_op
187+
188+
ret_circuit = circuits.Circuit()
189+
for current_moment in circuit:
190+
new_moment = ops.Moment()
191+
for op in current_moment:
192+
op_qs = set(op.qubits)
193+
idx = ret_circuit.prev_moment_operating_on(tuple(op_qs))
194+
if idx is not None and op_qs.issubset(ret_circuit[idx][op_qs].operations[0].qubits):
195+
# Case-1: Try to merge op with the larger operation on the left.
196+
left_op = ret_circuit[idx][op_qs].operations[0]
197+
new_op = apply_merge_func(left_op, op)
198+
if new_op is not None:
199+
ret_circuit.batch_replace([(idx, left_op, new_op)])
200+
else:
201+
new_moment = new_moment.with_operation(op)
202+
continue
203+
204+
while idx is not None and len(op_qs) > 0:
205+
# Case-2: left_ops will merge right into `op` whenever possible.
206+
for left_op in ret_circuit[idx][op_qs].operations:
207+
is_merged = False
208+
if op_qs.issuperset(left_op.qubits):
209+
# Try to merge left_op into op
210+
new_op = apply_merge_func(left_op, op)
211+
if new_op is not None:
212+
ret_circuit.batch_remove([(idx, left_op)])
213+
op, is_merged = new_op, True
214+
if not is_merged:
215+
op_qs -= frozenset(left_op.qubits)
216+
idx = ret_circuit.prev_moment_operating_on(tuple(op_qs))
217+
new_moment = new_moment.with_operation(op)
218+
ret_circuit += new_moment
219+
return _to_target_circuit_type(ret_circuit, circuit)
220+
221+
222+
def merge_moments(
223+
circuit: CIRCUIT_TYPE,
224+
merge_func: Callable[[ops.Moment, ops.Moment], Optional[ops.Moment]],
225+
) -> CIRCUIT_TYPE:
226+
"""Merges adjacent moments, one by one from left to right, by calling `merge_func(m1, m2)`.
227+
228+
Args:
229+
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
230+
merge_func: Callable to determine whether two adjacent moments in the circuit should be
231+
merged. If the moments can be merged, the callable should return the merged moment,
232+
else None.
233+
234+
Returns:
235+
Copy of input circuit with merged moments.
236+
"""
237+
if not circuit:
238+
return circuit
239+
merged_moments: List[ops.Moment] = [circuit[0]]
240+
for current_moment in circuit[1:]:
241+
merged_moment = merge_func(merged_moments[-1], current_moment)
242+
if not merged_moment:
243+
merged_moments.append(current_moment)
244+
else:
245+
merged_moments[-1] = merged_moment
246+
return _create_target_circuit_type(merged_moments, circuit)
247+
248+
132249
def _check_circuit_op(op, tags_to_check: Optional[Sequence[Hashable]]):
133250
return isinstance(op.untagged, circuits.CircuitOperation) and (
134251
tags_to_check is None or any(tag in op.tags for tag in tags_to_check)

cirq-core/cirq/optimizers/transformer_primitives_test.py

+146
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import Optional
1516
import pytest
17+
1618
import cirq
1719
from cirq.optimizers.transformer_primitives import MAPPED_CIRCUIT_OP_TAG
1820

@@ -198,3 +200,147 @@ def test_map_moments_drop_empty_moments():
198200
c = cirq.Circuit(cirq.Moment(op), cirq.Moment(), cirq.Moment(op))
199201
c_mapped = cirq.map_moments(c, lambda m, i: [] if len(m) == 0 else [m])
200202
cirq.testing.assert_same_circuits(c_mapped, cirq.Circuit(c[0], c[0]))
203+
204+
205+
def test_merge_moments():
206+
q = cirq.LineQubit.range(3)
207+
c_orig = cirq.Circuit(
208+
cirq.Z.on_each(q[0], q[1]),
209+
cirq.Z.on_each(q[1], q[2]),
210+
cirq.Z.on_each(q[1], q[0]),
211+
strategy=cirq.InsertStrategy.NEW_THEN_INLINE,
212+
)
213+
c_orig = cirq.Circuit(c_orig, cirq.CCX(*q), c_orig)
214+
cirq.testing.assert_has_diagram(
215+
c_orig,
216+
'''
217+
0: ───Z───────Z───@───Z───────Z───
218+
219+
1: ───Z───Z───Z───@───Z───Z───Z───
220+
221+
2: ───────Z───────X───────Z───────
222+
''',
223+
)
224+
225+
def merge_func(m1: cirq.Moment, m2: cirq.Moment) -> Optional[cirq.Moment]:
226+
def is_z_moment(m):
227+
return all(op.gate == cirq.Z for op in m)
228+
229+
if not (is_z_moment(m1) and is_z_moment(m2)):
230+
return None
231+
qubits = m1.qubits | m2.qubits
232+
233+
def mul(op1, op2):
234+
return (op1 or op2) if not (op1 and op2) else cirq.decompose_once(op1 * op2)
235+
236+
return cirq.Moment(mul(m1.operation_at(q), m2.operation_at(q)) for q in qubits)
237+
238+
cirq.testing.assert_has_diagram(
239+
cirq.merge_moments(c_orig, merge_func),
240+
'''
241+
0: ───────@───────
242+
243+
1: ───Z───@───Z───
244+
245+
2: ───Z───X───Z───
246+
''',
247+
)
248+
249+
250+
def test_merge_moments_empty_circuit():
251+
def fail_if_called_func(*_):
252+
assert False
253+
254+
c = cirq.Circuit()
255+
assert cirq.merge_moments(c, fail_if_called_func) is c
256+
257+
258+
def test_merge_operations_raises():
259+
q = cirq.LineQubit.range(3)
260+
c = cirq.Circuit(cirq.CZ(*q[:2]), cirq.X(q[0]))
261+
with pytest.raises(ValueError, match='must act on a subset of qubits'):
262+
cirq.merge_operations(c, lambda *_: cirq.X(q[2]))
263+
264+
265+
def test_merge_operations_nothing_to_merge():
266+
def fail_if_called_func(*_):
267+
assert False
268+
269+
# Empty Circuit.
270+
c = cirq.Circuit()
271+
assert cirq.merge_operations(c, fail_if_called_func) == c
272+
# Single moment
273+
q = cirq.LineQubit.range(3)
274+
c += cirq.Moment(cirq.CZ(*q[:2]))
275+
assert cirq.merge_operations(c, fail_if_called_func) == c
276+
# Multi moment with disjoint operations + global phase operation.
277+
c += cirq.Moment(cirq.X(q[2]), cirq.GlobalPhaseOperation(1j))
278+
assert cirq.merge_operations(c, fail_if_called_func) == c
279+
280+
281+
def test_merge_operations_merges_connected_component():
282+
q = cirq.LineQubit.range(3)
283+
c_orig = cirq.Circuit(
284+
cirq.Moment(cirq.H.on_each(*q)),
285+
cirq.CNOT(q[0], q[2]),
286+
cirq.CNOT(*q[0:2]),
287+
cirq.H(q[0]),
288+
cirq.CZ(*q[:2]),
289+
cirq.X(q[0]),
290+
cirq.Y(q[1]),
291+
cirq.CNOT(*q[0:2]),
292+
cirq.CNOT(*q[1:3]),
293+
cirq.X(q[0]),
294+
cirq.Y(q[1]),
295+
cirq.CNOT(*q[:2]),
296+
strategy=cirq.InsertStrategy.NEW,
297+
)
298+
cirq.testing.assert_has_diagram(
299+
c_orig,
300+
'''
301+
0: ───H───@───@───H───@───X───────@───────X───────@───
302+
│ │ │ │ │
303+
1: ───H───┼───X───────@───────Y───X───@───────Y───X───
304+
│ │
305+
2: ───H───X───────────────────────────X───────────────
306+
''',
307+
)
308+
309+
def merge_func(op1, op2):
310+
"""Artificial example where a CZ will absorb any merge-able operation."""
311+
for op in [op1, op2]:
312+
if op.gate == cirq.CZ:
313+
return op
314+
return None
315+
316+
c_new = cirq.merge_operations(c_orig, merge_func)
317+
cirq.testing.assert_has_diagram(
318+
c_new,
319+
'''
320+
0: ───H───@───────────@───────────────────────────@───
321+
│ │ │
322+
1: ───────┼───────────@───────────────@───────Y───X───
323+
│ │
324+
2: ───H───X───────────────────────────X───────────────''',
325+
)
326+
327+
328+
@pytest.mark.parametrize("op_density", [0.1, 0.5, 0.9])
329+
def test_merge_operations_complexity(op_density):
330+
prng = cirq.value.parse_random_state(11011)
331+
circuit = cirq.testing.random_circuit(20, 500, op_density, random_state=prng)
332+
for merge_func in [
333+
lambda _, __: None,
334+
lambda op1, _: op1,
335+
lambda _, op2: op2,
336+
lambda op1, op2: prng.choice([op1, op2, None]),
337+
]:
338+
339+
def wrapped_merge_func(op1, op2):
340+
wrapped_merge_func.num_function_calls += 1
341+
return merge_func(op1, op2)
342+
343+
wrapped_merge_func.num_function_calls = 0
344+
_ = cirq.merge_operations(circuit, wrapped_merge_func)
345+
total_operations = len([*circuit.all_operations()])
346+
assert wrapped_merge_func.num_function_calls <= 2 * total_operations

0 commit comments

Comments
 (0)