|
15 | 15 | """Transformer pass to repack circuits avoiding simultaneous operations with different classes."""
|
16 | 16 |
|
17 | 17 | import itertools
|
18 |
| -from typing import TYPE_CHECKING, Type, Callable, Dict, Optional, Union, Iterable, Sequence, List |
| 18 | +from typing import ( |
| 19 | + TYPE_CHECKING, |
| 20 | + Type, |
| 21 | + Callable, |
| 22 | + Dict, |
| 23 | + Iterator, |
| 24 | + Optional, |
| 25 | + Set, |
| 26 | + Union, |
| 27 | + Iterable, |
| 28 | + Sequence, |
| 29 | + List, |
| 30 | +) |
19 | 31 |
|
20 | 32 | from cirq import ops, circuits, protocols, _import
|
21 | 33 | from cirq.transformers import transformer_api
|
@@ -174,6 +186,36 @@ def _stratify_circuit(
|
174 | 186 | return circuits.Circuit(circuits.Moment(moment) for moment in new_moments if moment)
|
175 | 187 |
|
176 | 188 |
|
| 189 | +# TODO: |
| 190 | +# - properly deal with tags_to_ignore |
| 191 | +# - properly deal with measurement/control keys |
| 192 | +# - optimize over stratifying circuit vs. circuit[::-1] |
| 193 | +# - decide: replace the old stratify_circuit method, or add an option for which method to use? |
| 194 | +@transformer_api.transformer(add_deep_support=True) |
| 195 | +def dynamically_stratified_circuit( |
| 196 | + circuit: 'cirq.AbstractCircuit', |
| 197 | + *, |
| 198 | + context: Optional['cirq.TransformerContext'] = None, |
| 199 | + categories: Iterable[Category] = (), |
| 200 | +) -> 'cirq.Circuit': |
| 201 | + """A "dynamic" stratifying method that: |
| 202 | + - Iterates over all operations in topological order. |
| 203 | + - Creates new moments on an as-needed basis. |
| 204 | + - Advances moments up/forward if and when possible to absorb a new operation. |
| 205 | +
|
| 206 | + All of the complexity of this stratifying method is offloaded to the _Strata class. |
| 207 | + """ |
| 208 | + # Normalize categories into classifier functions. |
| 209 | + classifiers = _get_classifiers(circuit, categories) |
| 210 | + |
| 211 | + # Initialize a _Strata object, and add operations to it incrementally. |
| 212 | + strata = _Strata(classifiers) |
| 213 | + for op in circuit.all_operations(): |
| 214 | + strata.add(op) |
| 215 | + |
| 216 | + return circuits.Circuit(stratum.as_moment() for stratum in strata) |
| 217 | + |
| 218 | + |
177 | 219 | def _get_classifiers(
|
178 | 220 | circuit: circuits.AbstractCircuit, categories: Iterable[Category]
|
179 | 221 | ) -> List[Classifier]:
|
@@ -240,3 +282,215 @@ def _get_op_class(op: 'cirq.Operation', classifiers: Sequence[Classifier]) -> in
|
240 | 282 | return dummy_classifier_index
|
241 | 283 | except NameError:
|
242 | 284 | raise ValueError(f"Operation {op} not identified by any classifier")
|
| 285 | + |
| 286 | + |
| 287 | +#################################################################################################### |
| 288 | +# stratifying data structures |
| 289 | + |
| 290 | + |
| 291 | +class _Stratum: |
| 292 | + """A custom cirq.Moment that additionally keeps track of: |
| 293 | + - the time_index that it should occupy in a circuit |
| 294 | + - a set of qubits that are "blocked" by operations "ahead of" this _Stratum |
| 295 | + - an integer "class_index" that identifies the "type" of operations in this _Stratum |
| 296 | + """ |
| 297 | + |
| 298 | + def __init__(self, time_index: int, op: ops.Operation, class_index: int) -> None: |
| 299 | + """Initialize an empty _Stratum with a fixed class_index.""" |
| 300 | + self.time_index = time_index |
| 301 | + self._ops = [op] |
| 302 | + self._class_index = class_index |
| 303 | + |
| 304 | + self._qubits = set(op.qubits) |
| 305 | + self._blocked_qubits: Set['cirq.Qid'] = set() |
| 306 | + |
| 307 | + @property |
| 308 | + def qubits(self) -> Set['cirq.Qid']: |
| 309 | + return self._qubits |
| 310 | + |
| 311 | + @property |
| 312 | + def class_index(self) -> int: |
| 313 | + return self._class_index |
| 314 | + |
| 315 | + def add(self, op: ops.Operation) -> None: |
| 316 | + """Add an operation to this stratum. |
| 317 | +
|
| 318 | + WARNING: For performance reasons, this method does not check whether this stratum can |
| 319 | + accomodate the given op. Add operations at your own peril! |
| 320 | + """ |
| 321 | + self._ops.append(op) |
| 322 | + self._qubits |= set(op.qubits) |
| 323 | + |
| 324 | + def as_moment(self) -> circuits.Moment: |
| 325 | + """Convert this _Stratum into a Moment.""" |
| 326 | + return circuits.Moment(self._ops) |
| 327 | + |
| 328 | + |
| 329 | +class _Strata: |
| 330 | + """A data structure to organize a collection of strata ('_Stratum's). |
| 331 | +
|
| 332 | + The naming and language in this class imagine that strata are organized into a vertical stack, |
| 333 | + with time "increasing" as you go "up". That is, if stratum A precedes stratum B (i.e., |
| 334 | + A.time_index < B.time_index), then stratum A is said to be "below" stratum B, and stratum B is |
| 335 | + said to be "above" stratum A. |
| 336 | +
|
| 337 | + In accordance with this metaphor, we build a '_Strata_ object by adding operations to the stack |
| 338 | + of strata "from above". |
| 339 | + """ |
| 340 | + |
| 341 | + def __init__(self, classifiers: Sequence[Classifier]) -> None: |
| 342 | + self._classifiers = classifiers |
| 343 | + self._strata: List[_Stratum] = [] |
| 344 | + |
| 345 | + # map from qubit --> the last stratum that adresses that qubit |
| 346 | + self._qubit_floor: Dict['cirq.Qid', _Stratum] = {} |
| 347 | + |
| 348 | + # map from a stratum to its index in self._strata |
| 349 | + self._stratum_index: Dict[_Stratum, int] = {} |
| 350 | + |
| 351 | + def __iter__(self) -> Iterator[_Stratum]: |
| 352 | + yield from self._strata |
| 353 | + |
| 354 | + def add(self, op: ops.Operation) -> None: |
| 355 | + """Add an operation to the lowest stratum possible. |
| 356 | +
|
| 357 | + Strategy: |
| 358 | + (1) Find the "op_floor" stratum, i.e., the highest stratum that collides with the op. |
| 359 | + (2) Try to find the lowest stratum that |
| 360 | + (a) is below the op_floor, |
| 361 | + (b) can accomodate the op, and |
| 362 | + (c) can be moved up above the op_floor (without violating causality). |
| 363 | + If such a "below_stratum" exists, move it above the op_floor add the op to it. |
| 364 | + (3) If no below_stratum exists, find the lowest stratum above the op_floor that can |
| 365 | + accomodate the op, and add the op to this "above_stratum". |
| 366 | + (4) If no above_stratum exists either, add the op to a new stratum above everything. |
| 367 | + """ |
| 368 | + op_class = _get_op_class(op, self._classifiers) |
| 369 | + op_floor = self._get_op_floor(op) |
| 370 | + |
| 371 | + if (op_stratum := self._get_below_stratum(op, op_class, op_floor)) is not None: |
| 372 | + if op_floor is not None: |
| 373 | + self._move_stratum_above_floor(op, op_class, op_floor, op_stratum) |
| 374 | + op_stratum.add(op) |
| 375 | + |
| 376 | + elif (op_stratum := self._get_above_stratum(op, op_class, op_floor)) is not None: |
| 377 | + op_stratum.add(op) |
| 378 | + |
| 379 | + else: |
| 380 | + op_stratum = self._get_new_stratum(op, op_class) |
| 381 | + |
| 382 | + self._qubit_floor.update({qubit: op_stratum for qubit in op.qubits}) |
| 383 | + |
| 384 | + def _get_op_floor(self, op: ops.Operation) -> Optional[_Stratum]: |
| 385 | + """Get the highest stratum that collides with this op, if there is any.""" |
| 386 | + candidates = [stratum for qubit in op.qubits if (stratum := self._qubit_floor.get(qubit))] |
| 387 | + return max(candidates, key=lambda stratum: stratum.time_index) if candidates else None |
| 388 | + |
| 389 | + def _get_below_stratum( |
| 390 | + self, op: ops.Operation, op_class: int, op_floor: Optional[_Stratum] |
| 391 | + ) -> Optional[_Stratum]: |
| 392 | + """Get the lowest stratum that: |
| 393 | + (a) is below the op_floor, |
| 394 | + (b) can accomodate the op, and |
| 395 | + (c) can be moved up above the op_floor (without violating causality). |
| 396 | + If no such stratum exists, return None. |
| 397 | + """ |
| 398 | + if op_floor is None: |
| 399 | + return None |
| 400 | + below_stratum = None # initialize the null hypothesis that no below_stratum exists |
| 401 | + |
| 402 | + # Keep track of qubits in the past light cone of the op, which block a candidate |
| 403 | + # below_stratum from being able to move up above the op_floor. |
| 404 | + past_light_cone_qubits = set(op.qubits) |
| 405 | + op_floor_index = self._stratum_index[op_floor] |
| 406 | + |
| 407 | + # Starting from the op_floor, look down/backwards for a candidate below_stratum. |
| 408 | + for stratum in self._strata[op_floor_index::-1]: |
| 409 | + if stratum.class_index != op_class: |
| 410 | + # This stratum cannot accomodate the op, but it might be in op's past light cone. |
| 411 | + if not stratum.qubits.isdisjoint(past_light_cone_qubits): |
| 412 | + past_light_cone_qubits |= stratum.qubits |
| 413 | + else: |
| 414 | + if stratum.qubits.isdisjoint(past_light_cone_qubits): |
| 415 | + # This stratum can accomodate the op, so it is a candidate below_stratum. |
| 416 | + below_stratum = stratum |
| 417 | + else: |
| 418 | + # This stratum collides with the op's past light cone. Corrolaries: |
| 419 | + # (1) This stratum cannot accomodate this op (obvious). |
| 420 | + # (2) No lower stratum can be a candiate below_stratum (less obvious). |
| 421 | + # Hand-wavy proof by contradiction for claim 2: |
| 422 | + # (a) Assume there exists a lower stratum is a candidate for the below_stratum, |
| 423 | + # which means that it does not collide with this op's past light cone. |
| 424 | + # (b) In particular, the lower stratum does not collide with *this* stratum's |
| 425 | + # past light cone, so it can be moved up and merged into this stratum. |
| 426 | + # (c) That contradicts the incremental construction of _Strata, which would |
| 427 | + # have moved the lower stratum up to absorb ops in this stratum when those |
| 428 | + # ops were added to this _Strata object (self). |
| 429 | + # Altogether, our search for a below_stratum is done, so we can stop our |
| 430 | + # backwards search through self._strata. |
| 431 | + break |
| 432 | + |
| 433 | + return below_stratum |
| 434 | + |
| 435 | + def _move_stratum_above_floor( |
| 436 | + self, op: ops.Operation, op_class: int, op_floor: _Stratum, below_stratum: _Stratum |
| 437 | + ) -> None: |
| 438 | + """Move a below_stratum up above the op_floor.""" |
| 439 | + op_floor_index = self._stratum_index[op_floor] |
| 440 | + above_floor_index = op_floor_index + 1 # hack around flake8 false positive (E203) |
| 441 | + below_stratum_index = self._stratum_index[below_stratum] |
| 442 | + |
| 443 | + # Identify all strata in the future light cone of the below_stratum. When we move the |
| 444 | + # below_stratum up above the op_floor, we need to likewise shift all of these strata up in |
| 445 | + # order to preserve causal structure. |
| 446 | + light_cone_strata = [below_stratum] |
| 447 | + light_cone_qubits = below_stratum.qubits |
| 448 | + |
| 449 | + # Keep track of "spectator" strata that are currently above the below_stratum, but are not |
| 450 | + # in its future light cone. |
| 451 | + spectator_strata = [] |
| 452 | + |
| 453 | + start = below_stratum_index + 1 # hack around flake8 false positive (E203) |
| 454 | + for stratum in self._strata[start:above_floor_index]: |
| 455 | + if not stratum.qubits.isdisjoint(light_cone_qubits): |
| 456 | + # This stratum is in the future light cone of the below_stratum. |
| 457 | + light_cone_strata.append(stratum) |
| 458 | + light_cone_qubits |= stratum.qubits |
| 459 | + |
| 460 | + else: |
| 461 | + spectator_strata.append(stratum) |
| 462 | + |
| 463 | + # The light cone strata are going to be moved above this spectator stratum. |
| 464 | + # Shift the indices of strata accordingly. |
| 465 | + self._stratum_index[stratum] -= len(light_cone_strata) |
| 466 | + for stratum in light_cone_strata: |
| 467 | + self._stratum_index[stratum] += 1 |
| 468 | + |
| 469 | + # Shift the entire light cone forward, so that the below_stratum lies above the op_floor. |
| 470 | + # Also shift everything above the op_floor forward by the same amount to ensure that it |
| 471 | + # still lies above the below_stratum. |
| 472 | + strata_to_shift = light_cone_strata + self._strata[above_floor_index:] |
| 473 | + time_index_shift = self._strata[op_floor_index].time_index - below_stratum.time_index + 1 |
| 474 | + for stratum in strata_to_shift: |
| 475 | + stratum.time_index += time_index_shift |
| 476 | + |
| 477 | + # Sort all strata by their time_index. |
| 478 | + self._strata[below_stratum_index:] = spectator_strata + strata_to_shift |
| 479 | + |
| 480 | + def _get_above_stratum( |
| 481 | + self, op: ops.Operation, op_class: int, op_floor: Optional[_Stratum] |
| 482 | + ) -> Optional[_Stratum]: |
| 483 | + """Get the lowest accomodating stratum above the op_floor, if there is any.""" |
| 484 | + start = self._stratum_index[op_floor] + 1 if op_floor is not None else 0 |
| 485 | + for stratum in self._strata[start:]: |
| 486 | + if stratum.class_index == op_class and stratum.qubits.isdisjoint(op.qubits): |
| 487 | + return stratum |
| 488 | + return None |
| 489 | + |
| 490 | + def _get_new_stratum(self, op: ops.Operation, op_class: int) -> _Stratum: |
| 491 | + """Add the given operation to a new stratum above all other strata. Return that stratum.""" |
| 492 | + op_time_index = self._strata[-1].time_index + 1 if self._strata else 0 |
| 493 | + op_stratum = _Stratum(op_time_index, op, op_class) |
| 494 | + self._strata.append(op_stratum) |
| 495 | + self._stratum_index[op_stratum] = len(self._strata) - 1 |
| 496 | + return op_stratum |
0 commit comments