15
15
"""Target gateset used for compiling circuits to Sycamore + 1-q rotations + measurement gates."""
16
16
17
17
import itertools
18
- from typing import List , Optional , Sequence
18
+ from typing import cast , List , Optional , Sequence
19
19
20
20
import cirq
21
21
from cirq .protocols .decompose_protocol import DecomposeResult
@@ -33,6 +33,7 @@ def merge_swap_rzz_and_2q_unitaries(
33
33
context : Optional ['cirq.TransformerContext' ] = None ,
34
34
merged_swap_rzz_tag : str = "_merged_swap_rzz" ,
35
35
merged_2q_component_tag : str = "_merged_2q_unitaries" ,
36
+ intermediate_result_tag : Optional [str ] = None ,
36
37
) -> 'cirq.Circuit' :
37
38
"""Merges 2-qubit connected components and adjacent `cirq.SWAP` and `cirq.ZZPowGate` gates.
38
39
@@ -50,6 +51,8 @@ def merge_swap_rzz_and_2q_unitaries(
50
51
`cirq.SWAP` and `cirq.ZZPowGate`s.
51
52
merged_2q_component_tag: Tag to apply on newly introduced circuit operations wrapping
52
53
connected components of 1 and 2 qubit unitaries.
54
+ intermediate_result_tag: If specified, the tag is added to newly introduced both the newly
55
+ introduced circuit operations encapsulating swap_rzz or 2q connected component.
53
56
54
57
Returns:
55
58
Copy of the transformed input circuit.
@@ -71,19 +74,34 @@ def merge_func_swap_rzz(
71
74
return False
72
75
73
76
tags_to_ignore = context .tags_to_ignore if context else ()
77
+ deep = context .deep if context else False
74
78
circuit = cirq .merge_operations_to_circuit_op (
75
79
circuit ,
76
80
merge_func_swap_rzz ,
77
81
tags_to_ignore = tags_to_ignore ,
78
82
merged_circuit_op_tag = merged_swap_rzz_tag ,
83
+ deep = deep ,
79
84
)
80
85
81
- return cirq .merge_k_qubit_unitaries_to_circuit_op (
86
+ circuit = cirq .merge_k_qubit_unitaries_to_circuit_op (
82
87
circuit ,
83
88
k = 2 ,
84
- tags_to_ignore = tags_to_ignore + (merged_swap_rzz_tag ,),
89
+ tags_to_ignore = tuple ( tags_to_ignore ) + (merged_swap_rzz_tag ,),
85
90
merged_circuit_op_tag = merged_2q_component_tag ,
86
- ).unfreeze (copy = False )
91
+ deep = deep ,
92
+ )
93
+
94
+ if intermediate_result_tag is not None :
95
+ merged_cop_tags = {merged_swap_rzz_tag , merged_2q_component_tag }
96
+ circuit = cirq .map_operations (
97
+ circuit ,
98
+ map_func = lambda op , _ : op
99
+ if merged_cop_tags .isdisjoint (op .tags )
100
+ else op .with_tags (cast (str , intermediate_result_tag )),
101
+ tags_to_ignore = tags_to_ignore ,
102
+ deep = True ,
103
+ )
104
+ return circuit .unfreeze (copy = False )
87
105
88
106
89
107
class SycamoreTargetGateset (cirq .TwoQubitCompilationTargetGateset ):
@@ -122,7 +140,10 @@ def preprocess_transformers(self) -> List[cirq.TRANSFORMER]:
122
140
cirq .expand_composite ,
123
141
no_decomp = lambda op : cirq .num_qubits (op ) <= self .num_qubits ,
124
142
),
125
- merge_swap_rzz_and_2q_unitaries ,
143
+ _create_transformer_with_kwargs (
144
+ merge_swap_rzz_and_2q_unitaries ,
145
+ intermediate_result_tag = self ._intermediate_result_tag ,
146
+ ),
126
147
]
127
148
128
149
def _decompose_two_qubit_operation (self , op : cirq .Operation , _ ) -> DecomposeResult :
0 commit comments