Skip to content

Commit b1e09a9

Browse files
authored
Update cirq.decompose protocol to perform a DFS instead of a BFS on the decomposed OP-TREE (#6116)
* update cirq.decompose protocol to perform a DFS instead of a BFS on the decomposed OP-TREE * Fix mypy error * Use a dataclass instead of kwargs
1 parent ebec38b commit b1e09a9

File tree

1 file changed

+64
-111
lines changed

1 file changed

+64
-111
lines changed

cirq-core/cirq/protocols/decompose_protocol.py

Lines changed: 64 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import dataclasses
1515
from typing import (
1616
TYPE_CHECKING,
1717
Any,
1818
Callable,
1919
Dict,
2020
Iterable,
21+
Iterator,
2122
List,
2223
Optional,
2324
overload,
@@ -128,6 +129,60 @@ def _decompose_(self, qubits: Tuple['cirq.Qid', ...]) -> DecomposeResult:
128129
pass
129130

130131

132+
def _try_op_decomposer(val: Any, decomposer: Optional[OpDecomposer]) -> DecomposeResult:
133+
if decomposer is None or not isinstance(val, ops.Operation):
134+
return None
135+
return decomposer(val)
136+
137+
138+
@dataclasses.dataclass(frozen=True)
139+
class _DecomposeArgs:
140+
intercepting_decomposer: Optional[OpDecomposer]
141+
fallback_decomposer: Optional[OpDecomposer]
142+
keep: Optional[Callable[['cirq.Operation'], bool]]
143+
on_stuck_raise: Union[None, Exception, Callable[['cirq.Operation'], Optional[Exception]]]
144+
preserve_structure: bool
145+
146+
147+
def _decompose_dfs(item: Any, args: _DecomposeArgs) -> Iterator['cirq.Operation']:
148+
from cirq.circuits import CircuitOperation, FrozenCircuit
149+
150+
if isinstance(item, ops.Operation):
151+
item_untagged = item.untagged
152+
if args.preserve_structure and isinstance(item_untagged, CircuitOperation):
153+
new_fc = FrozenCircuit(_decompose_dfs(item_untagged.circuit, args))
154+
yield item_untagged.replace(circuit=new_fc).with_tags(*item.tags)
155+
return
156+
if args.keep is not None and args.keep(item):
157+
yield item
158+
return
159+
160+
decomposed = _try_op_decomposer(item, args.intercepting_decomposer)
161+
162+
if decomposed is NotImplemented or decomposed is None:
163+
decomposed = decompose_once(item, default=None)
164+
165+
if decomposed is NotImplemented or decomposed is None:
166+
decomposed = _try_op_decomposer(item, args.fallback_decomposer)
167+
168+
if decomposed is NotImplemented or decomposed is None:
169+
if not isinstance(item, ops.Operation) and isinstance(item, Iterable):
170+
decomposed = item
171+
172+
if decomposed is NotImplemented or decomposed is None:
173+
if args.keep is not None and args.on_stuck_raise is not None:
174+
if isinstance(args.on_stuck_raise, Exception):
175+
raise args.on_stuck_raise
176+
elif callable(args.on_stuck_raise):
177+
error = args.on_stuck_raise(item)
178+
if error is not None:
179+
raise error
180+
yield item
181+
else:
182+
for val in ops.flatten_to_ops(decomposed):
183+
yield from _decompose_dfs(val, args)
184+
185+
131186
def decompose(
132187
val: Any,
133188
*,
@@ -200,55 +255,14 @@ def decompose(
200255
"acceptable to keep."
201256
)
202257

203-
if preserve_structure:
204-
return _decompose_preserving_structure(
205-
val,
206-
intercepting_decomposer=intercepting_decomposer,
207-
fallback_decomposer=fallback_decomposer,
208-
keep=keep,
209-
on_stuck_raise=on_stuck_raise,
210-
)
211-
212-
def try_op_decomposer(val: Any, decomposer: Optional[OpDecomposer]) -> DecomposeResult:
213-
if decomposer is None or not isinstance(val, ops.Operation):
214-
return None
215-
return decomposer(val)
216-
217-
output = []
218-
queue: List[Any] = [val]
219-
while queue:
220-
item = queue.pop(0)
221-
if isinstance(item, ops.Operation) and keep is not None and keep(item):
222-
output.append(item)
223-
continue
224-
225-
decomposed = try_op_decomposer(item, intercepting_decomposer)
226-
227-
if decomposed is NotImplemented or decomposed is None:
228-
decomposed = decompose_once(item, default=None)
229-
230-
if decomposed is NotImplemented or decomposed is None:
231-
decomposed = try_op_decomposer(item, fallback_decomposer)
232-
233-
if decomposed is not NotImplemented and decomposed is not None:
234-
queue[:0] = ops.flatten_to_ops(decomposed)
235-
continue
236-
237-
if not isinstance(item, ops.Operation) and isinstance(item, Iterable):
238-
queue[:0] = ops.flatten_to_ops(item)
239-
continue
240-
241-
if keep is not None and on_stuck_raise is not None:
242-
if isinstance(on_stuck_raise, Exception):
243-
raise on_stuck_raise
244-
elif callable(on_stuck_raise):
245-
error = on_stuck_raise(item)
246-
if error is not None:
247-
raise error
248-
249-
output.append(item)
250-
251-
return output
258+
args = _DecomposeArgs(
259+
intercepting_decomposer=intercepting_decomposer,
260+
fallback_decomposer=fallback_decomposer,
261+
keep=keep,
262+
on_stuck_raise=on_stuck_raise,
263+
preserve_structure=preserve_structure,
264+
)
265+
return [*_decompose_dfs(val, args)]
252266

253267

254268
# pylint: disable=function-redefined
@@ -383,65 +397,4 @@ def _try_decompose_into_operations_and_qubits(
383397
qid_shape_dict[q] = max(qid_shape_dict[q], level)
384398
qubits = sorted(qubit_set)
385399
return result, qubits, tuple(qid_shape_dict[q] for q in qubits)
386-
387400
return None, (), ()
388-
389-
390-
def _decompose_preserving_structure(
391-
val: Any,
392-
*,
393-
intercepting_decomposer: Optional[OpDecomposer] = None,
394-
fallback_decomposer: Optional[OpDecomposer] = None,
395-
keep: Optional[Callable[['cirq.Operation'], bool]] = None,
396-
on_stuck_raise: Union[
397-
None, Exception, Callable[['cirq.Operation'], Optional[Exception]]
398-
] = _value_error_describing_bad_operation,
399-
) -> List['cirq.Operation']:
400-
"""Preserves structure (e.g. subcircuits) while decomposing ops.
401-
402-
This can be used to reduce a circuit to a particular gateset without
403-
increasing its serialization size. See tests for examples.
404-
"""
405-
406-
# This method provides a generated 'keep' to its decompose() calls.
407-
# If the user-provided keep is not set, on_stuck_raise must be unset to
408-
# ensure that failure to decompose does not generate errors.
409-
on_stuck_raise = on_stuck_raise if keep is not None else None
410-
411-
from cirq.circuits import CircuitOperation, FrozenCircuit
412-
413-
visited_fcs = set()
414-
415-
def keep_structure(op: 'cirq.Operation'):
416-
circuit = getattr(op.untagged, 'circuit', None)
417-
if circuit is not None:
418-
return circuit in visited_fcs
419-
if keep is not None and keep(op):
420-
return True
421-
422-
def dps_interceptor(op: 'cirq.Operation'):
423-
if not isinstance(op.untagged, CircuitOperation):
424-
if intercepting_decomposer is None:
425-
return NotImplemented
426-
return intercepting_decomposer(op)
427-
428-
new_fc = FrozenCircuit(
429-
decompose(
430-
op.untagged.circuit,
431-
intercepting_decomposer=dps_interceptor,
432-
fallback_decomposer=fallback_decomposer,
433-
keep=keep_structure,
434-
on_stuck_raise=on_stuck_raise,
435-
)
436-
)
437-
visited_fcs.add(new_fc)
438-
new_co = op.untagged.replace(circuit=new_fc)
439-
return new_co if not op.tags else new_co.with_tags(*op.tags)
440-
441-
return decompose(
442-
val,
443-
intercepting_decomposer=dps_interceptor,
444-
fallback_decomposer=fallback_decomposer,
445-
keep=keep_structure,
446-
on_stuck_raise=on_stuck_raise,
447-
)

0 commit comments

Comments
 (0)