Skip to content

Commit 6bf8686

Browse files
committed
add draft of new stratifying method
1 parent 79286a1 commit 6bf8686

File tree

1 file changed

+255
-1
lines changed

1 file changed

+255
-1
lines changed

cirq-core/cirq/transformers/stratify.py

Lines changed: 255 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,19 @@
1515
"""Transformer pass to repack circuits avoiding simultaneous operations with different classes."""
1616

1717
import itertools
18-
from typing import TYPE_CHECKING, Type, Callable, Dict, Optional, Union, Iterable, Sequence, List
18+
from typing import (
19+
TYPE_CHECKING,
20+
Type,
21+
Callable,
22+
Dict,
23+
Iterator,
24+
Optional,
25+
Set,
26+
Union,
27+
Iterable,
28+
Sequence,
29+
List,
30+
)
1931

2032
from cirq import ops, circuits, protocols, _import
2133
from cirq.transformers import transformer_api
@@ -174,6 +186,36 @@ def _stratify_circuit(
174186
return circuits.Circuit(circuits.Moment(moment) for moment in new_moments if moment)
175187

176188

189+
# TODO:
190+
# - properly deal with tags_to_ignore
191+
# - properly deal with measurement/control keys
192+
# - optimize over stratifying circuit vs. circuit[::-1]
193+
# - decide: replace the old stratify_circuit method, or add an option for which method to use?
194+
@transformer_api.transformer(add_deep_support=True)
195+
def dynamically_stratified_circuit(
196+
circuit: 'cirq.AbstractCircuit',
197+
*,
198+
context: Optional['cirq.TransformerContext'] = None,
199+
categories: Iterable[Category] = (),
200+
) -> 'cirq.Circuit':
201+
"""A "dynamic" stratifying method that:
202+
- Iterates over all operations in topological order.
203+
- Creates new moments on an as-needed basis.
204+
- Advances moments up/forward if and when possible to absorb a new operation.
205+
206+
All of the complexity of this stratifying method is offloaded to the _Strata class.
207+
"""
208+
# Normalize categories into classifier functions.
209+
classifiers = _get_classifiers(circuit, categories)
210+
211+
# Initialize a _Strata object, and add operations to it incrementally.
212+
strata = _Strata(classifiers)
213+
for op in circuit.all_operations():
214+
strata.add(op)
215+
216+
return circuits.Circuit(stratum.as_moment() for stratum in strata)
217+
218+
177219
def _get_classifiers(
178220
circuit: circuits.AbstractCircuit, categories: Iterable[Category]
179221
) -> List[Classifier]:
@@ -240,3 +282,215 @@ def _get_op_class(op: 'cirq.Operation', classifiers: Sequence[Classifier]) -> in
240282
return dummy_classifier_index
241283
except NameError:
242284
raise ValueError(f"Operation {op} not identified by any classifier")
285+
286+
287+
####################################################################################################
288+
# stratifying data structures
289+
290+
291+
class _Stratum:
292+
"""A custom cirq.Moment that additionally keeps track of:
293+
- the time_index that it should occupy in a circuit
294+
- a set of qubits that are "blocked" by operations "ahead of" this _Stratum
295+
- an integer "class_index" that identifies the "type" of operations in this _Stratum
296+
"""
297+
298+
def __init__(self, time_index: int, op: ops.Operation, class_index: int) -> None:
299+
"""Initialize an empty _Stratum with a fixed class_index."""
300+
self.time_index = time_index
301+
self._ops = [op]
302+
self._class_index = class_index
303+
304+
self._qubits = set(op.qubits)
305+
self._blocked_qubits: Set['cirq.Qid'] = set()
306+
307+
@property
308+
def qubits(self) -> Set['cirq.Qid']:
309+
return self._qubits
310+
311+
@property
312+
def class_index(self) -> int:
313+
return self._class_index
314+
315+
def add(self, op: ops.Operation) -> None:
316+
"""Add an operation to this stratum.
317+
318+
WARNING: For performance reasons, this method does not check whether this stratum can
319+
accomodate the given op. Add operations at your own peril!
320+
"""
321+
self._ops.append(op)
322+
self._qubits |= set(op.qubits)
323+
324+
def as_moment(self) -> circuits.Moment:
325+
"""Convert this _Stratum into a Moment."""
326+
return circuits.Moment(self._ops)
327+
328+
329+
class _Strata:
330+
"""A data structure to organize a collection of strata ('_Stratum's).
331+
332+
The naming and language in this class imagine that strata are organized into a vertical stack,
333+
with time "increasing" as you go "up". That is, if stratum A precedes stratum B (i.e.,
334+
A.time_index < B.time_index), then stratum A is said to be "below" stratum B, and stratum B is
335+
said to be "above" stratum A.
336+
337+
In accordance with this metaphor, we build a '_Strata_ object by adding operations to the stack
338+
of strata "from above".
339+
"""
340+
341+
def __init__(self, classifiers: Sequence[Classifier]) -> None:
342+
self._classifiers = classifiers
343+
self._strata: List[_Stratum] = []
344+
345+
# map from qubit --> the last stratum that adresses that qubit
346+
self._qubit_floor: Dict['cirq.Qid', _Stratum] = {}
347+
348+
# map from a stratum to its index in self._strata
349+
self._stratum_index: Dict[_Stratum, int] = {}
350+
351+
def __iter__(self) -> Iterator[_Stratum]:
352+
yield from self._strata
353+
354+
def add(self, op: ops.Operation) -> None:
355+
"""Add an operation to the lowest stratum possible.
356+
357+
Strategy:
358+
(1) Find the "op_floor" stratum, i.e., the highest stratum that collides with the op.
359+
(2) Try to find the lowest stratum that
360+
(a) is below the op_floor,
361+
(b) can accomodate the op, and
362+
(c) can be moved up above the op_floor (without violating causality).
363+
If such a "below_stratum" exists, move it above the op_floor add the op to it.
364+
(3) If no below_stratum exists, find the lowest stratum above the op_floor that can
365+
accomodate the op, and add the op to this "above_stratum".
366+
(4) If no above_stratum exists either, add the op to a new stratum above everything.
367+
"""
368+
op_class = _get_op_class(op, self._classifiers)
369+
op_floor = self._get_op_floor(op)
370+
371+
if (op_stratum := self._get_below_stratum(op, op_class, op_floor)) is not None:
372+
if op_floor is not None:
373+
self._move_stratum_above_floor(op, op_class, op_floor, op_stratum)
374+
op_stratum.add(op)
375+
376+
elif (op_stratum := self._get_above_stratum(op, op_class, op_floor)) is not None:
377+
op_stratum.add(op)
378+
379+
else:
380+
op_stratum = self._get_new_stratum(op, op_class)
381+
382+
self._qubit_floor.update({qubit: op_stratum for qubit in op.qubits})
383+
384+
def _get_op_floor(self, op: ops.Operation) -> Optional[_Stratum]:
385+
"""Get the highest stratum that collides with this op, if there is any."""
386+
candidates = [stratum for qubit in op.qubits if (stratum := self._qubit_floor.get(qubit))]
387+
return max(candidates, key=lambda stratum: stratum.time_index) if candidates else None
388+
389+
def _get_below_stratum(
390+
self, op: ops.Operation, op_class: int, op_floor: Optional[_Stratum]
391+
) -> Optional[_Stratum]:
392+
"""Get the lowest stratum that:
393+
(a) is below the op_floor,
394+
(b) can accomodate the op, and
395+
(c) can be moved up above the op_floor (without violating causality).
396+
If no such stratum exists, return None.
397+
"""
398+
if op_floor is None:
399+
return None
400+
below_stratum = None # initialize the null hypothesis that no below_stratum exists
401+
402+
# Keep track of qubits in the past light cone of the op, which block a candidate
403+
# below_stratum from being able to move up above the op_floor.
404+
past_light_cone_qubits = set(op.qubits)
405+
op_floor_index = self._stratum_index[op_floor]
406+
407+
# Starting from the op_floor, look down/backwards for a candidate below_stratum.
408+
for stratum in self._strata[op_floor_index::-1]:
409+
if stratum.class_index != op_class:
410+
# This stratum cannot accomodate the op, but it might be in op's past light cone.
411+
if not stratum.qubits.isdisjoint(past_light_cone_qubits):
412+
past_light_cone_qubits |= stratum.qubits
413+
else:
414+
if stratum.qubits.isdisjoint(past_light_cone_qubits):
415+
# This stratum can accomodate the op, so it is a candidate below_stratum.
416+
below_stratum = stratum
417+
else:
418+
# This stratum collides with the op's past light cone. Corrolaries:
419+
# (1) This stratum cannot accomodate this op (obvious).
420+
# (2) No lower stratum can be a candiate below_stratum (less obvious).
421+
# Hand-wavy proof by contradiction for claim 2:
422+
# (a) Assume there exists a lower stratum is a candidate for the below_stratum,
423+
# which means that it does not collide with this op's past light cone.
424+
# (b) In particular, the lower stratum does not collide with *this* stratum's
425+
# past light cone, so it can be moved up and merged into this stratum.
426+
# (c) That contradicts the incremental construction of _Strata, which would
427+
# have moved the lower stratum up to absorb ops in this stratum when those
428+
# ops were added to this _Strata object (self).
429+
# Altogether, our search for a below_stratum is done, so we can stop our
430+
# backwards search through self._strata.
431+
break
432+
433+
return below_stratum
434+
435+
def _move_stratum_above_floor(
436+
self, op: ops.Operation, op_class: int, op_floor: _Stratum, below_stratum: _Stratum
437+
) -> None:
438+
"""Move a below_stratum up above the op_floor."""
439+
op_floor_index = self._stratum_index[op_floor]
440+
above_floor_index = op_floor_index + 1 # hack around flake8 false positive (E203)
441+
below_stratum_index = self._stratum_index[below_stratum]
442+
443+
# Identify all strata in the future light cone of the below_stratum. When we move the
444+
# below_stratum up above the op_floor, we need to likewise shift all of these strata up in
445+
# order to preserve causal structure.
446+
light_cone_strata = [below_stratum]
447+
light_cone_qubits = below_stratum.qubits
448+
449+
# Keep track of "spectator" strata that are currently above the below_stratum, but are not
450+
# in its future light cone.
451+
spectator_strata = []
452+
453+
start = below_stratum_index + 1 # hack around flake8 false positive (E203)
454+
for stratum in self._strata[start:above_floor_index]:
455+
if not stratum.qubits.isdisjoint(light_cone_qubits):
456+
# This stratum is in the future light cone of the below_stratum.
457+
light_cone_strata.append(stratum)
458+
light_cone_qubits |= stratum.qubits
459+
460+
else:
461+
spectator_strata.append(stratum)
462+
463+
# The light cone strata are going to be moved above this spectator stratum.
464+
# Shift the indices of strata accordingly.
465+
self._stratum_index[stratum] -= len(light_cone_strata)
466+
for stratum in light_cone_strata:
467+
self._stratum_index[stratum] += 1
468+
469+
# Shift the entire light cone forward, so that the below_stratum lies above the op_floor.
470+
# Also shift everything above the op_floor forward by the same amount to ensure that it
471+
# still lies above the below_stratum.
472+
strata_to_shift = light_cone_strata + self._strata[above_floor_index:]
473+
time_index_shift = self._strata[op_floor_index].time_index - below_stratum.time_index + 1
474+
for stratum in strata_to_shift:
475+
stratum.time_index += time_index_shift
476+
477+
# Sort all strata by their time_index.
478+
self._strata[below_stratum_index:] = spectator_strata + strata_to_shift
479+
480+
def _get_above_stratum(
481+
self, op: ops.Operation, op_class: int, op_floor: Optional[_Stratum]
482+
) -> Optional[_Stratum]:
483+
"""Get the lowest accomodating stratum above the op_floor, if there is any."""
484+
start = self._stratum_index[op_floor] + 1 if op_floor is not None else 0
485+
for stratum in self._strata[start:]:
486+
if stratum.class_index == op_class and stratum.qubits.isdisjoint(op.qubits):
487+
return stratum
488+
return None
489+
490+
def _get_new_stratum(self, op: ops.Operation, op_class: int) -> _Stratum:
491+
"""Add the given operation to a new stratum above all other strata. Return that stratum."""
492+
op_time_index = self._strata[-1].time_index + 1 if self._strata else 0
493+
op_stratum = _Stratum(op_time_index, op, op_class)
494+
self._strata.append(op_stratum)
495+
self._stratum_index[op_stratum] = len(self._strata) - 1
496+
return op_stratum

0 commit comments

Comments
 (0)