Skip to content

Commit d3526df

Browse files
committed
add option to choose stratification method
1 parent 6bf8686 commit d3526df

File tree

1 file changed

+92
-62
lines changed

1 file changed

+92
-62
lines changed

cirq-core/cirq/transformers/stratify.py

Lines changed: 92 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
Callable,
2222
Dict,
2323
Iterator,
24+
Literal,
2425
Optional,
2526
Set,
2627
Union,
@@ -51,6 +52,7 @@
5152
def stratified_circuit(
5253
circuit: 'cirq.AbstractCircuit',
5354
*,
55+
method: Literal["dynamic", "static"] = "dynamic",
5456
context: Optional['cirq.TransformerContext'] = None,
5557
categories: Iterable[Category] = (),
5658
) -> 'cirq.Circuit':
@@ -76,55 +78,113 @@ def stratified_circuit(
7678
# Normalize categories into classifier functions.
7779
classifiers = _get_classifiers(circuit, categories)
7880

79-
# Try the algorithm with each permutation of the classifiers.
81+
if method == "static":
82+
return _statically_stratify_circuit(circuit, classifiers, context)
83+
return _dynamically_stratify_circuit(circuit, classifiers, context)
84+
85+
86+
StratifyingMethod = Callable[
87+
[circuits.AbstractCircuit, Sequence[Classifier], 'cirq.TransformerContext'],
88+
circuits.AbstractCircuit,
89+
]
90+
91+
92+
def _optimize_statifying_direction(stratifying_func: StratifyingMethod) -> StratifyingMethod:
93+
"""Decorator to optimize over stratifying a circuit left-to-right vs. right-to-left."""
94+
95+
def optimized_stratifying_func(
96+
circuit: circuits.AbstractCircuit,
97+
classifiers: Sequence[Classifier],
98+
context: 'cirq.TransformerContext',
99+
) -> 'cirq.Circuit':
100+
forward_circuit = stratifying_func(circuit, classifiers, context)
101+
backward_circuit = stratifying_func(circuit[::-1], classifiers, context)
102+
if len(forward_circuit) <= len(backward_circuit):
103+
return forward_circuit
104+
return backward_circuit[::-1]
105+
106+
return optimized_stratifying_func
107+
108+
109+
# TODO:
110+
# - properly deal with tags_to_ignore
111+
# - properly deal with measurement/control keys
112+
@_optimize_statifying_direction
113+
def _dynamically_stratify_circuit(
114+
circuit: 'cirq.AbstractCircuit',
115+
*,
116+
context: Optional['cirq.TransformerContext'] = None,
117+
categories: Iterable[Category] = (),
118+
) -> 'cirq.Circuit':
119+
"""A "dynamic" stratifying method that:
120+
- Iterates over all operations in topological order.
121+
- Creates new moments on an as-needed basis.
122+
- Advances moments up/forward if and when possible to absorb a new operation.
123+
124+
All of the complexity of this stratifying method is offloaded to the _Strata class.
125+
126+
Args:
127+
circuit: The circuit to break out into homogeneous moments. Will not be edited.
128+
classifiers: A list of rules to align the circuit. Must be exhaustive, i.e. all operations
129+
will be caught by one of the processors.
130+
context: `cirq.TransformerContext` storing common configurable options for transformers.
131+
132+
Returns:
133+
The stratified circuit.
134+
"""
135+
# Normalize categories into classifier functions.
136+
classifiers = _get_classifiers(circuit, categories)
137+
138+
# Initialize a _Strata object, and add operations to it incrementally.
139+
strata = _Strata(classifiers)
140+
for op in circuit.all_operations():
141+
strata.add(op)
142+
143+
return circuits.Circuit(stratum.as_moment() for stratum in strata)
144+
145+
146+
@_optimize_statifying_direction
147+
def _statically_stratify_circuit(
148+
circuit: circuits.AbstractCircuit,
149+
classifiers: Sequence[Classifier],
150+
context: 'cirq.TransformerContext',
151+
) -> 'cirq.Circuit':
152+
"""A "static" stratifying method that:
153+
- Enforces that a fixed stratification structure, e.g. moments of type [A, B, C, A, B, C, ...].
154+
- Places each operation into the earliest moment that can accomodate it.
155+
- Optimizes over the order of the classifiers, returning the shortest circuit found.
156+
157+
Args:
158+
circuit: The circuit to break out into homogeneous moments. Will not be edited.
159+
classifiers: A list of rules to align the circuit. Must be exhaustive, i.e. all operations
160+
will be caught by one of the processors.
161+
context: `cirq.TransformerContext` storing common configurable options for transformers.
162+
163+
Returns:
164+
The stratified circuit.
165+
"""
80166
smallest_depth = protocols.num_qubits(circuit) * len(circuit) + 1
81167
shortest_stratified_circuit = circuits.Circuit()
82-
reversed_circuit = circuit[::-1]
83168
for ordered_classifiers in itertools.permutations(classifiers):
84-
solution = _stratify_circuit(
169+
solution = _statically_stratify_fixed_circuit(
85170
circuit,
86171
classifiers=ordered_classifiers,
87172
context=context or transformer_api.TransformerContext(),
88173
)
89174
if len(solution) < smallest_depth:
90175
shortest_stratified_circuit = solution
91176
smallest_depth = len(solution)
92-
93-
# Do the same thing, except this time in reverse. This helps for some
94-
# circuits because it inserts operations at the end instead of at the
95-
# beginning.
96-
solution = _stratify_circuit(
97-
reversed_circuit,
98-
classifiers=ordered_classifiers,
99-
context=context or transformer_api.TransformerContext(),
100-
)[::-1]
101-
if len(solution) < smallest_depth:
102-
shortest_stratified_circuit = solution
103-
smallest_depth = len(solution)
104-
105177
return shortest_stratified_circuit
106178

107179

108-
def _stratify_circuit(
180+
def _statically_stratify_fixed_circuit(
109181
circuit: circuits.AbstractCircuit,
110-
*,
111-
context: 'cirq.TransformerContext',
112182
classifiers: Sequence[Classifier],
183+
context: 'cirq.TransformerContext',
113184
) -> 'cirq.Circuit':
114-
"""Performs the stratification by iterating through the operations in the
115-
circuit and using the given classifiers to align them.
116-
117-
Tagged Operations marked with any of `context.tags_to_ignore` are treated as separate
118-
categories and left in their original moments without stratification.
119-
120-
Args:
121-
circuit: The circuit to break out into homogeneous moments. Will not be edited.
122-
context: `cirq.TransformerContext` storing common configurable options for transformers.
123-
classifiers: A list of rules to align the circuit. Must be exhaustive, i.e. all operations
124-
will be caught by one of the processors.
185+
"""Helper function for '_statically_stratify_circuit'.
125186
126-
Returns:
127-
The stratified circuit.
187+
Stratifies a circuit without optimizing over the order of classifiers.
128188
"""
129189
num_classes = len(classifiers) + 1 # include one "extra" category for ignored operations
130190
new_moments: List[List['cirq.Operation']] = []
@@ -186,36 +246,6 @@ def _stratify_circuit(
186246
return circuits.Circuit(circuits.Moment(moment) for moment in new_moments if moment)
187247

188248

189-
# TODO:
190-
# - properly deal with tags_to_ignore
191-
# - properly deal with measurement/control keys
192-
# - optimize over stratifying circuit vs. circuit[::-1]
193-
# - decide: replace the old stratify_circuit method, or add an option for which method to use?
194-
@transformer_api.transformer(add_deep_support=True)
195-
def dynamically_stratified_circuit(
196-
circuit: 'cirq.AbstractCircuit',
197-
*,
198-
context: Optional['cirq.TransformerContext'] = None,
199-
categories: Iterable[Category] = (),
200-
) -> 'cirq.Circuit':
201-
"""A "dynamic" stratifying method that:
202-
- Iterates over all operations in topological order.
203-
- Creates new moments on an as-needed basis.
204-
- Advances moments up/forward if and when possible to absorb a new operation.
205-
206-
All of the complexity of this stratifying method is offloaded to the _Strata class.
207-
"""
208-
# Normalize categories into classifier functions.
209-
classifiers = _get_classifiers(circuit, categories)
210-
211-
# Initialize a _Strata object, and add operations to it incrementally.
212-
strata = _Strata(classifiers)
213-
for op in circuit.all_operations():
214-
strata.add(op)
215-
216-
return circuits.Circuit(stratum.as_moment() for stratum in strata)
217-
218-
219249
def _get_classifiers(
220250
circuit: circuits.AbstractCircuit, categories: Iterable[Category]
221251
) -> List[Classifier]:

0 commit comments

Comments
 (0)