Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache Circuit properties between mutations #6322

Merged
merged 6 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 73 additions & 15 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,28 +188,20 @@ def _from_moments(cls: Type[CIRCUIT_TYPE], moments: Iterable['cirq.Moment']) ->
def moments(self) -> Sequence['cirq.Moment']:
pass

@abc.abstractmethod
def freeze(self) -> 'cirq.FrozenCircuit':
"""Creates a FrozenCircuit from this circuit.

If 'self' is a FrozenCircuit, the original object is returned.
"""
from cirq.circuits import FrozenCircuit

if isinstance(self, FrozenCircuit):
return self

return FrozenCircuit(self, strategy=InsertStrategy.EARLIEST)

@abc.abstractmethod
def unfreeze(self, copy: bool = True) -> 'cirq.Circuit':
"""Creates a Circuit from this circuit.

Args:
copy: If True and 'self' is a Circuit, returns a copy that circuit.
"""
if isinstance(self, Circuit):
return Circuit.copy(self) if copy else self

return Circuit(self, strategy=InsertStrategy.EARLIEST)

def __bool__(self):
return bool(self.moments)
Expand Down Expand Up @@ -822,6 +814,9 @@ def has_measurements(self):
"""
return protocols.is_measurement(self)

def _is_measurement_(self) -> bool:
return any(protocols.is_measurement(op) for op in self.all_operations())

def are_all_measurements_terminal(self) -> bool:
"""Whether all measurement gates are at the end of the circuit.

Expand Down Expand Up @@ -1383,8 +1378,7 @@ def save_qasm(
self._to_qasm_output(header, precision, qubit_order).save(file_path)

def _json_dict_(self):
ret = protocols.obj_to_dict_helper(self, ['moments'])
return ret
return protocols.obj_to_dict_helper(self, ['moments'])

@classmethod
def _from_json_dict_(cls, moments, **kwargs):
Expand Down Expand Up @@ -1759,6 +1753,16 @@ def __init__(
circuit.
"""
self._moments: List['cirq.Moment'] = []

# Implementation note: the following cached properties are set lazily and then
# invalidated and reset to None in `self._mutated()`, which is called any time
# `self._moments` is changed.
self._all_qubits: Optional[FrozenSet['cirq.Qid']] = None
self._frozen: Optional['cirq.FrozenCircuit'] = None
self._is_measurement: Optional[bool] = None
self._is_parameterized: Optional[bool] = None
self._parameter_names: Optional[AbstractSet[str]] = None

flattened_contents = tuple(ops.flatten_to_ops_or_moments(contents))
if all(isinstance(c, Moment) for c in flattened_contents):
self._moments[:] = cast(Iterable[Moment], flattened_contents)
Expand All @@ -1769,6 +1773,14 @@ def __init__(
else:
self.append(flattened_contents, strategy=strategy)

def _mutated(self) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this also include _is_measurement and _parameter_names?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Fixed.

"""Clear cached properties in response to this circuit being mutated."""
self._all_qubits = None
self._frozen = None
self._is_measurement = None
self._is_parameterized = None
self._parameter_names = None

@classmethod
def _from_moments(cls, moments: Iterable['cirq.Moment']) -> 'Circuit':
new_circuit = Circuit()
Expand Down Expand Up @@ -1831,6 +1843,41 @@ def _load_contents_with_earliest_strategy(self, contents: 'cirq.OP_TREE'):
def __copy__(self) -> 'cirq.Circuit':
return self.copy()

def freeze(self) -> 'cirq.FrozenCircuit':
"""Gets a frozen version of this circuit.

Repeated calls to `.freeze()` will return the same FrozenCircuit
instance as long as this circuit is not mutated.
"""
from cirq.circuits.frozen_circuit import FrozenCircuit

if self._frozen is None:
self._frozen = FrozenCircuit.from_moments(*self._moments)
return self._frozen

def unfreeze(self, copy: bool = True) -> 'cirq.Circuit':
return self.copy() if copy else self

def all_qubits(self) -> FrozenSet['cirq.Qid']:
if self._all_qubits is None:
self._all_qubits = super().all_qubits()
return self._all_qubits

def _is_measurement_(self) -> bool:
if self._is_measurement is None:
self._is_measurement = super()._is_measurement_()
return self._is_measurement

def _is_parameterized_(self) -> bool:
if self._is_parameterized is None:
self._is_parameterized = super()._is_parameterized_()
return self._is_parameterized

def _parameter_names_(self) -> AbstractSet[str]:
if self._parameter_names is None:
self._parameter_names = super()._parameter_names_()
return self._parameter_names

def copy(self) -> 'Circuit':
"""Return a copy of this circuit."""
copied_circuit = Circuit()
Expand All @@ -1856,11 +1903,13 @@ def __setitem__(self, key, value):
raise TypeError('Can only assign Moments into Circuits.')

self._moments[key] = value
self._mutated()

# pylint: enable=function-redefined

def __delitem__(self, key: Union[int, slice]):
del self._moments[key]
self._mutated()

def __iadd__(self, other):
self.append(other)
Expand Down Expand Up @@ -1889,6 +1938,7 @@ def __imul__(self, repetitions: _INT_TYPE):
if not isinstance(repetitions, (int, np.integer)):
return NotImplemented
self._moments *= int(repetitions)
self._mutated()
return self

def __mul__(self, repetitions: _INT_TYPE):
Expand Down Expand Up @@ -2032,6 +2082,7 @@ def _pick_or_create_inserted_op_moment_index(

if strategy is InsertStrategy.NEW or strategy is InsertStrategy.NEW_THEN_INLINE:
self._moments.insert(splitter_index, Moment())
self._mutated()
return splitter_index

if strategy is InsertStrategy.INLINE:
Expand Down Expand Up @@ -2099,6 +2150,7 @@ def insert(
k = max(k, p + 1)
if strategy is InsertStrategy.NEW_THEN_INLINE:
strategy = InsertStrategy.INLINE
self._mutated()
return k

def insert_into_range(self, operations: 'cirq.OP_TREE', start: int, end: int) -> int:
Expand Down Expand Up @@ -2135,6 +2187,7 @@ def insert_into_range(self, operations: 'cirq.OP_TREE', start: int, end: int) ->

self._moments[i] = self._moments[i].with_operation(op)
op_index += 1
self._mutated()

if op_index >= len(flat_ops):
return end
Expand Down Expand Up @@ -2180,6 +2233,7 @@ def _push_frontier(
if n_new_moments > 0:
insert_index = min(late_frontier.values())
self._moments[insert_index:insert_index] = [Moment()] * n_new_moments
self._mutated()
for q in update_qubits:
if early_frontier.get(q, 0) > insert_index:
early_frontier[q] += n_new_moments
Expand All @@ -2206,13 +2260,12 @@ def _insert_operations(
if len(operations) != len(insertion_indices):
raise ValueError('operations and insertion_indices must have the same length.')
self._moments += [Moment() for _ in range(1 + max(insertion_indices) - len(self))]
self._mutated()
moment_to_ops: Dict[int, List['cirq.Operation']] = defaultdict(list)
for op_index, moment_index in enumerate(insertion_indices):
moment_to_ops[moment_index].append(operations[op_index])
for moment_index, new_ops in moment_to_ops.items():
self._moments[moment_index] = Moment(
self._moments[moment_index].operations + tuple(new_ops)
)
self._moments[moment_index] = self._moments[moment_index].with_operations(*new_ops)

def insert_at_frontier(
self,
Expand Down Expand Up @@ -2274,6 +2327,7 @@ def batch_remove(self, removals: Iterable[Tuple[int, 'cirq.Operation']]) -> None
old_op for old_op in copy._moments[i].operations if op != old_op
)
self._moments = copy._moments
self._mutated()

def batch_replace(
self, replacements: Iterable[Tuple[int, 'cirq.Operation', 'cirq.Operation']]
Expand All @@ -2298,6 +2352,7 @@ def batch_replace(
old_op if old_op != op else new_op for old_op in copy._moments[i].operations
)
self._moments = copy._moments
self._mutated()

def batch_insert_into(self, insert_intos: Iterable[Tuple[int, 'cirq.OP_TREE']]) -> None:
"""Inserts operations into empty spaces in existing moments.
Expand All @@ -2318,6 +2373,7 @@ def batch_insert_into(self, insert_intos: Iterable[Tuple[int, 'cirq.OP_TREE']])
for i, insertions in insert_intos:
copy._moments[i] = copy._moments[i].with_operations(insertions)
self._moments = copy._moments
self._mutated()

def batch_insert(self, insertions: Iterable[Tuple[int, 'cirq.OP_TREE']]) -> None:
"""Applies a batched insert operation to the circuit.
Expand Down Expand Up @@ -2352,6 +2408,7 @@ def batch_insert(self, insertions: Iterable[Tuple[int, 'cirq.OP_TREE']]) -> None
if next_index > insert_index:
shift += next_index - insert_index
self._moments = copy._moments
self._mutated()

def append(
self,
Expand Down Expand Up @@ -2382,6 +2439,7 @@ def clear_operations_touching(
for k in moment_indices:
if 0 <= k < len(self._moments):
self._moments[k] = self._moments[k].without_operations_touching(qubits)
self._mutated()

@property
def moments(self) -> Sequence['cirq.Moment']:
Expand Down
94 changes: 94 additions & 0 deletions cirq-core/cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4533,6 +4533,100 @@ def test_freeze_not_relocate_moments():
assert [mc is fc for mc, fc in zip(c, f)] == [True, True]


def test_freeze_is_cached():
q = cirq.q(0)
c = cirq.Circuit(cirq.X(q), cirq.measure(q))
f0 = c.freeze()
f1 = c.freeze()
assert f1 is f0

c.append(cirq.Y(q))
f2 = c.freeze()
f3 = c.freeze()
assert f2 is not f1
assert f3 is f2

c[-1] = cirq.Moment(cirq.Y(q))
f4 = c.freeze()
f5 = c.freeze()
assert f4 is not f3
assert f5 is f4


@pytest.mark.parametrize(
"circuit, mutate",
[
(
cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))),
lambda c: c.__setitem__(0, cirq.Moment(cirq.Y(cirq.q(0)))),
),
(cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.__delitem__(0)),
(cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.__imul__(2)),
(
cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))),
lambda c: c.insert(1, cirq.Y(cirq.q(0))),
),
(
cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))),
lambda c: c.insert_into_range([cirq.Y(cirq.q(1)), cirq.M(cirq.q(1))], 0, 2),
),
(
cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))),
lambda c: c.insert_at_frontier([cirq.Y(cirq.q(0)), cirq.Y(cirq.q(1))], 1),
),
(
cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))),
lambda c: c.batch_replace([(0, cirq.X(cirq.q(0)), cirq.Y(cirq.q(0)))]),
),
(
cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0), cirq.q(1))),
lambda c: c.batch_insert_into([(0, cirq.X(cirq.q(1)))]),
),
(
cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))),
lambda c: c.batch_insert([(1, cirq.Y(cirq.q(0)))]),
),
(
cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))),
lambda c: c.clear_operations_touching([cirq.q(0)], [0]),
),
],
)
def test_mutation_clears_cached_attributes(circuit, mutate):
cached_attributes = [
"_all_qubits",
"_frozen",
"_is_measurement",
"_is_parameterized",
"_parameter_names",
]

for attr in cached_attributes:
assert getattr(circuit, attr) is None, f"{attr=} is not None"

# Check that attributes are cached after getting them.
qubits = circuit.all_qubits()
frozen = circuit.freeze()
is_measurement = cirq.is_measurement(circuit)
is_parameterized = cirq.is_parameterized(circuit)
parameter_names = cirq.parameter_names(circuit)

for attr in cached_attributes:
assert getattr(circuit, attr) is not None, f"{attr=} is None"

# Check that getting again returns same object.
assert circuit.all_qubits() is qubits
assert circuit.freeze() is frozen
assert cirq.is_measurement(circuit) is is_measurement
assert cirq.is_parameterized(circuit) is is_parameterized
assert cirq.parameter_names(circuit) is parameter_names

# Check that attributes are cleared after mutation.
mutate(circuit)
for attr in cached_attributes:
assert getattr(circuit, attr) is None, f"{attr=} is not None"


def test_factorize_one_factor():
circuit = cirq.Circuit()
q0, q1, q2 = cirq.LineQubit.range(3)
Expand Down
6 changes: 6 additions & 0 deletions cirq-core/cirq/circuits/frozen_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ def _from_moments(cls, moments: Iterable['cirq.Moment']) -> 'FrozenCircuit':
def moments(self) -> Sequence['cirq.Moment']:
return self._moments

def freeze(self) -> 'cirq.FrozenCircuit':
return self

def unfreeze(self, copy: bool = True) -> 'cirq.Circuit':
return Circuit.from_moments(*self)

@property
def tags(self) -> Tuple[Hashable, ...]:
"""Returns a tuple of the Circuit's tags."""
Expand Down