Skip to content

Commit 6a0c381

Browse files
authored
Dedicated method for creating circuit from op tree with EARLIEST strategy (quantumlib#5332)
We keep track of the latest moment that contains each qubit and measurement key. Then we know where to put each operation in constant time. We don't create the actual Moments until the very end, when we know where everything goes. Also added explicit key protocol impls for EigenGate, preempting the protocol from attempting a bunch of fallback options. On my laptop this speeds up creating circuits with EARLIEST strategy by almost infinity percent. (On my laptop, the circuit in the new test goes from 29.00s on master to 0.13s here).
1 parent 09ae9f3 commit 6a0c381

File tree

4 files changed

+112
-10
lines changed

4 files changed

+112
-10
lines changed

cirq/circuits/circuit.py

+92-9
Original file line numberDiff line numberDiff line change
@@ -1709,7 +1709,95 @@ def __init__(
17091709
"""
17101710
self._moments: List['cirq.Moment'] = []
17111711
with _compat.block_overlapping_deprecation('.*'):
1712-
self.append(contents, strategy=strategy)
1712+
if strategy == InsertStrategy.EARLIEST:
1713+
self._load_contents_with_earliest_strategy(contents)
1714+
else:
1715+
self.append(contents, strategy=strategy)
1716+
1717+
def _load_contents_with_earliest_strategy(self, contents: 'cirq.OP_TREE'):
1718+
"""Optimized algorithm to load contents quickly.
1719+
1720+
The default algorithm appends operations one-at-a-time, letting them
1721+
fall back until they encounter a moment they cannot commute with. This
1722+
is slow because it requires re-checking for conflicts at each moment.
1723+
1724+
Here, we instead keep track of the greatest moment that contains each
1725+
qubit, measurement key, and control key, and append the operation to
1726+
the moment after the maximum of these. This avoids having to check each
1727+
moment.
1728+
1729+
Args:
1730+
contents: The initial list of moments and operations defining the
1731+
circuit. You can also pass in operations, lists of operations,
1732+
or generally anything meeting the `cirq.OP_TREE` contract.
1733+
Non-moment entries will be inserted according to the EARLIEST
1734+
insertion strategy.
1735+
"""
1736+
# These are dicts from the qubit/key to the greatest moment index that has it. It is safe
1737+
# to default to `-1`, as that is interpreted as meaning the zeroth index onward does not
1738+
# have this value.
1739+
qubit_indexes: Dict['cirq.Qid', int] = defaultdict(lambda: -1)
1740+
mkey_indexes: Dict['cirq.MeasurementKey', int] = defaultdict(lambda: -1)
1741+
ckey_indexes: Dict['cirq.MeasurementKey', int] = defaultdict(lambda: -1)
1742+
1743+
# We also maintain the dict from moment index to moments/ops that go into it, for use when
1744+
# building the actual moments at the end.
1745+
op_lists_by_index: Dict[int, List['cirq.Operation']] = defaultdict(list)
1746+
moments_by_index: Dict[int, 'cirq.Moment'] = {}
1747+
1748+
# For keeping track of length of the circuit thus far.
1749+
length = 0
1750+
1751+
# "mop" means current moment-or-operation
1752+
for mop in ops.flatten_to_ops_or_moments(contents):
1753+
mop_qubits = mop.qubits
1754+
mop_mkeys = protocols.measurement_key_objs(mop)
1755+
mop_ckeys = protocols.control_keys(mop)
1756+
1757+
# Both branches define `i`, the moment index at which to place the mop.
1758+
if isinstance(mop, Moment):
1759+
# We always append moment to the end, to be consistent with `self.append`
1760+
i = length
1761+
moments_by_index[i] = mop
1762+
else:
1763+
# Initially we define `i` as the greatest moment index that has a conflict. `-1` is
1764+
# the initial conflict, and we search for larger ones. Once we get the largest one,
1765+
# we increment i by 1 to set the placement index.
1766+
i = -1
1767+
1768+
# Look for the maximum conflict; i.e. a moment that has a qubit the same as one of
1769+
# this op's qubits, that has a measurement or control key the same as one of this
1770+
# op's measurement keys, or that has a measurement key the same as one of this op's
1771+
# control keys. (Control keys alone can commute past each other). The `ifs` are
1772+
# logically unnecessary but seem to make this slightly faster.
1773+
if mop_qubits:
1774+
i = max(i, *[qubit_indexes[q] for q in mop_qubits])
1775+
if mop_mkeys:
1776+
i = max(i, *[mkey_indexes[k] for k in mop_mkeys])
1777+
i = max(i, *[ckey_indexes[k] for k in mop_mkeys])
1778+
if mop_ckeys:
1779+
i = max(i, *[mkey_indexes[k] for k in mop_ckeys])
1780+
i += 1
1781+
op_lists_by_index[i].append(mop)
1782+
1783+
# Update our dicts with data from the latest mop placement. Note `i` will always be
1784+
# greater than the existing value for all of these, by construction, so there is no
1785+
# need to do a `max(i, existing)`.
1786+
for q in mop_qubits:
1787+
qubit_indexes[q] = i
1788+
for k in mop_mkeys:
1789+
mkey_indexes[k] = i
1790+
for k in mop_ckeys:
1791+
ckey_indexes[k] = i
1792+
length = max(length, i + 1)
1793+
1794+
# Finally, once everything is placed, we can construct and append the actual moments for
1795+
# each index.
1796+
for i in range(length):
1797+
if i in moments_by_index:
1798+
self._moments.append(moments_by_index[i].with_operations(op_lists_by_index[i]))
1799+
else:
1800+
self._moments.append(Moment(op_lists_by_index[i]))
17131801

17141802
def __copy__(self) -> 'cirq.Circuit':
17151803
return self.copy()
@@ -1888,11 +1976,11 @@ def earliest_available_moment(
18881976
moment = self._moments[k]
18891977
if moment.operates_on(op_qubits):
18901978
return last_available
1891-
moment_measurement_keys = protocols.measurement_key_objs(moment)
1979+
moment_measurement_keys = moment._measurement_key_objs_()
18921980
if (
18931981
not op_measurement_keys.isdisjoint(moment_measurement_keys)
18941982
or not op_control_keys.isdisjoint(moment_measurement_keys)
1895-
or not protocols.control_keys(moment).isdisjoint(op_measurement_keys)
1983+
or not moment._control_keys_().isdisjoint(op_measurement_keys)
18961984
):
18971985
return last_available
18981986
if self._can_add_op_at(k, op):
@@ -1973,14 +2061,9 @@ def insert(
19732061
Raises:
19742062
ValueError: Bad insertion strategy.
19752063
"""
1976-
moments_and_operations = list(
1977-
ops.flatten_to_ops_or_moments(
1978-
ops.transform_op_tree(moment_or_operation_tree, preserve_moments=True)
1979-
)
1980-
)
19812064
# limit index to 0..len(self._moments), also deal with indices smaller 0
19822065
k = max(min(index if index >= 0 else len(self._moments) + index, len(self._moments)), 0)
1983-
for moment_or_op in moments_and_operations:
2066+
for moment_or_op in ops.flatten_to_ops_or_moments(moment_or_operation_tree):
19842067
if isinstance(moment_or_op, Moment):
19852068
self._moments.insert(k, moment_or_op)
19862069
k += 1

cirq/circuits/circuit_test.py

+16
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import itertools
1515
import os
16+
import time
1617
from collections import defaultdict
1718
from random import randint, random, sample, randrange
1819
from typing import Iterator, Optional, Tuple, TYPE_CHECKING
@@ -4644,3 +4645,18 @@ def _circuit_diagram_info_(self, args) -> str:
46444645
└────────┘
46454646
""",
46464647
)
4648+
4649+
4650+
def test_create_speed():
4651+
# Added in https://github.com/quantumlib/Cirq/pull/5332
4652+
# Previously this took ~30s to run. Now it should take ~150ms. However the coverage test can
4653+
# run this slowly, so allowing 2 sec to account for things like that. Feel free to increase the
4654+
# buffer time or delete the test entirely if it ends up causing flakes.
4655+
qs = 100
4656+
moments = 500
4657+
xs = [cirq.X(cirq.LineQubit(i)) for i in range(qs)]
4658+
opa = [xs[i] for i in range(qs) for _ in range(moments)]
4659+
t = time.perf_counter()
4660+
c = cirq.Circuit(opa)
4661+
assert len(c) == moments
4662+
assert time.perf_counter() - t < 2

cirq/circuits/moment.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def operates_on(self, qubits: Iterable['cirq.Qid']) -> bool:
132132
Returns:
133133
Whether this moment has operations involving the qubits.
134134
"""
135-
return bool(set(qubits) & self.qubits)
135+
return not self._qubits.isdisjoint(qubits)
136136

137137
def operation_at(self, qubit: raw_types.Qid) -> Optional['cirq.Operation']:
138138
"""Returns the operation on a certain qubit for the moment.

cirq/ops/eigen_gate.py

+3
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,9 @@ def _equal_up_to_global_phase_(self, other, atol):
393393
def _json_dict_(self) -> Dict[str, Any]:
394394
return protocols.obj_to_dict_helper(self, ['exponent', 'global_shift'])
395395

396+
def _measurement_key_objs_(self):
397+
return frozenset()
398+
396399

397400
def _lcm(vals: Iterable[int]) -> int:
398401
t = 1

0 commit comments

Comments
 (0)