14
14
15
15
"""Transformers to rewrite a circuit using gates from a given target gateset."""
16
16
17
- from typing import Optional , Callable , TYPE_CHECKING
17
+ from typing import Optional , Callable , Hashable , Sequence , TYPE_CHECKING
18
18
19
+ from cirq import circuits
19
20
from cirq .protocols import decompose_protocol as dp
20
21
from cirq .transformers import transformer_api , transformer_primitives
21
22
@@ -38,6 +39,7 @@ def _decompose_operations_to_target_gateset(
38
39
gateset : Optional ['cirq.Gateset' ] = None ,
39
40
decomposer : Callable [['cirq.Operation' , int ], dp .DecomposeResult ] = lambda * _ : NotImplemented ,
40
41
ignore_failures : bool = True ,
42
+ tags_to_decompose : Sequence [Hashable ] = (),
41
43
) -> 'cirq.Circuit' :
42
44
"""Decomposes every operation to `gateset` using `cirq.decompose` and `decomposer`.
43
45
@@ -56,6 +58,8 @@ def _decompose_operations_to_target_gateset(
56
58
- `None` or `NotImplemented` if does not know how to decompose a given `op`.
57
59
ignore_failures: If set, operations that fail to convert are left unchanged. If not set,
58
60
conversion failures raise a ValueError.
61
+ tags_to_decompose: `cirq.CircuitOperation`s tagged with any of `tags_to_decompose` will
62
+ be decomposed even if context.deep is True.
59
63
60
64
Returns:
61
65
An equivalent circuit containing gates accepted by `gateset`.
@@ -65,6 +69,13 @@ def _decompose_operations_to_target_gateset(
65
69
"""
66
70
67
71
def map_func (op : 'cirq.Operation' , moment_index : int ):
72
+ if (
73
+ context
74
+ and context .deep
75
+ and isinstance (op .untagged , circuits .CircuitOperation )
76
+ and set (op .tags ).isdisjoint (tags_to_decompose )
77
+ ):
78
+ return op
68
79
return dp .decompose (
69
80
op ,
70
81
intercepting_decomposer = lambda o : decomposer (o , moment_index ),
@@ -77,7 +88,10 @@ def map_func(op: 'cirq.Operation', moment_index: int):
77
88
)
78
89
79
90
return transformer_primitives .map_operations_and_unroll (
80
- circuit , map_func , tags_to_ignore = context .tags_to_ignore if context else ()
91
+ circuit ,
92
+ map_func ,
93
+ tags_to_ignore = context .tags_to_ignore if context else (),
94
+ deep = context .deep if context else False ,
81
95
).unfreeze (copy = False )
82
96
83
97
@@ -122,6 +136,7 @@ def optimize_for_target_gateset(
122
136
gateset = gateset ,
123
137
decomposer = gateset .decompose_to_target_gateset ,
124
138
ignore_failures = ignore_failures ,
139
+ tags_to_decompose = (gateset ._intermediate_result_tag ,),
125
140
)
126
141
127
142
for transformer in gateset .postprocess_transformers :
0 commit comments