Skip to content

Commit 3b315cd

Browse files
authored
Optimize is_measurement protocol (#4140)
* Optimize is_measurement protocol * Optimize is_measurement protocol * No shortcut in circuit * Change from recursive to iterative * add test and remove vals_to_decompose arg * FrozenCircuit uses evaluates fully for memoization * Coverage for tagged non-GateOperation
1 parent 8e7e302 commit 3b315cd

10 files changed

+154
-15
lines changed

cirq-core/cirq/circuits/circuit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,7 @@ def findall_operations_with_gate_type(
757757
yield index, gate_op, cast(T_DESIRED_GATE_TYPE, gate_op.gate)
758758

759759
def has_measurements(self):
760-
return any(self.findall_operations(protocols.is_measurement))
760+
return protocols.is_measurement(self)
761761

762762
def are_all_measurements_terminal(self) -> bool:
763763
"""Whether all measurement gates are at the end of the circuit."""

cirq-core/cirq/circuits/circuit_operation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,9 @@ def _default_repetition_ids(self) -> Optional[List[str]]:
160160
def _qid_shape_(self) -> Tuple[int, ...]:
161161
return tuple(q.dimension for q in self.qubits)
162162

163+
def _is_measurement_(self) -> bool:
164+
return self.circuit._is_measurement_()
165+
163166
def _measurement_keys_(self) -> AbstractSet[str]:
164167
circuit_keys = [
165168
value.MeasurementKey.parse_serialized(key_str)

cirq-core/cirq/circuits/circuit_operation_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,16 @@ def test_repetitions_and_ids_length_mismatch():
8181
_ = cirq.CircuitOperation(circuit, repetitions=2, repetition_ids=['a', 'b', 'c'])
8282

8383

84+
def test_is_measurement_memoization():
85+
a = cirq.LineQubit(0)
86+
circuit = cirq.FrozenCircuit(cirq.measure(a, key='m'))
87+
c_op = cirq.CircuitOperation(circuit)
88+
assert circuit._has_measurements is None
89+
# Memoize `_has_measurements` in the circuit.
90+
assert cirq.is_measurement(c_op)
91+
assert circuit._has_measurements is True
92+
93+
8494
def test_invalid_measurement_keys():
8595
a = cirq.LineQubit(0)
8696
circuit = cirq.FrozenCircuit(cirq.measure(a, key='m'))

cirq-core/cirq/circuits/frozen_circuit.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,11 @@ def _unitary_(self) -> Union[np.ndarray, NotImplementedType]:
109109
self._unitary = super()._unitary_()
110110
return self._unitary
111111

112+
def _is_measurement_(self) -> bool:
113+
if self._has_measurements is None:
114+
self._has_measurements = protocols.is_measurement(self.unfreeze())
115+
return self._has_measurements
116+
112117
def all_qubits(self) -> FrozenSet['cirq.Qid']:
113118
if self._all_qubits is None:
114119
self._all_qubits = super().all_qubits()

cirq-core/cirq/ops/gate_operation.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,13 @@ def _kraus_(self) -> Union[Tuple[np.ndarray], NotImplementedType]:
209209
return getter()
210210
return NotImplemented
211211

212+
def _is_measurement_(self) -> Optional[bool]:
213+
getter = getattr(self.gate, '_is_measurement_', None)
214+
if getter is not None:
215+
return getter()
216+
# Let the protocol handle the fallback.
217+
return NotImplemented
218+
212219
def _measurement_key_(self) -> Optional[str]:
213220
getter = getattr(self.gate, '_measurement_key_', None)
214221
if getter is not None:

cirq-core/cirq/ops/measurement_gate.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ def full_invert_mask(self):
122122
mask += (False,) * deficit
123123
return mask
124124

125+
def _is_measurement_(self) -> bool:
126+
return True
127+
125128
def _measurement_key_(self):
126129
return self.key
127130

cirq-core/cirq/ops/raw_types.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,12 @@ def _kraus_(self) -> Union[Tuple[np.ndarray], NotImplementedType]:
664664
def _measurement_key_(self) -> str:
665665
return protocols.measurement_key(self.sub_operation, NotImplemented)
666666

667+
def _is_measurement_(self) -> bool:
668+
sub = getattr(self.sub_operation, "_is_measurement_", None)
669+
if sub is not None:
670+
return sub()
671+
return NotImplemented
672+
667673
def _is_parameterized_(self) -> bool:
668674
return protocols.is_parameterized(self.sub_operation) or any(
669675
protocols.is_parameterized(tag) for tag in self.tags

cirq-core/cirq/ops/raw_types_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,14 +431,19 @@ def test_tagged_operation():
431431
assert op.gate == cirq.X
432432
assert op.with_qubits(q2) == cirq.X(q2).with_tags('tag1')
433433
assert op.with_qubits(q2).qubits == (q2,)
434+
assert not cirq.is_measurement(op)
434435

435436

436437
def test_tagged_measurement():
438+
assert not cirq.is_measurement(cirq.GlobalPhaseOperation(coefficient=-1.0).with_tags('tag0'))
439+
437440
a = cirq.LineQubit(0)
438441
op = cirq.measure(a, key='m').with_tags('tag')
442+
assert cirq.is_measurement(op)
439443

440444
remap_op = cirq.with_measurement_key_mapping(op, {'m': 'k'})
441445
assert remap_op.tags == ('tag',)
446+
assert cirq.is_measurement(remap_op)
442447
assert cirq.measurement_keys(remap_op) == {'k'}
443448
assert cirq.with_measurement_key_mapping(op, {'x': 'k'}) == op
444449

cirq-core/cirq/protocols/measurement_key_protocol.py

Lines changed: 81 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
"""Protocol for object that have measurement keys."""
1515

16-
from typing import AbstractSet, Any, Dict, Iterable, Tuple
16+
from typing import AbstractSet, Any, Dict, List, Iterable, Optional, Tuple
1717

1818
from typing_extensions import Protocol
1919

@@ -46,6 +46,10 @@ class SupportsMeasurementKey(Protocol):
4646
conditional on the measurement outcome being $k$.
4747
"""
4848

49+
@doc_private
50+
def _is_measurement_(self) -> str:
51+
"""Return if this object is (or contains) a measurement."""
52+
4953
@doc_private
5054
def _measurement_key_(self) -> str:
5155
"""Return the key that will be used to identify this measurement.
@@ -106,6 +110,22 @@ def measurement_key(val: Any, default: Any = RaiseTypeErrorIfNotProvided):
106110
raise TypeError(f"Object of type '{type(val)}' had no measurement keys.")
107111

108112

113+
def _measurement_keys_from_magic_methods(val: Any) -> Optional[AbstractSet[str]]:
114+
"""Uses the measurement key related magic methods to get the keys for this object."""
115+
116+
getter = getattr(val, '_measurement_keys_', None)
117+
result = NotImplemented if getter is None else getter()
118+
if result is not NotImplemented and result is not None:
119+
return set(result)
120+
121+
getter = getattr(val, '_measurement_key_', None)
122+
result = NotImplemented if getter is None else getter()
123+
if result is not NotImplemented and result is not None:
124+
return {result}
125+
126+
return result
127+
128+
109129
def measurement_keys(val: Any, *, allow_decompose: bool = True) -> AbstractSet[str]:
110130
"""Gets the measurement keys of measurements within the given value.
111131
@@ -122,15 +142,9 @@ def measurement_keys(val: Any, *, allow_decompose: bool = True) -> AbstractSet[s
122142
The measurement keys of the value. If the value has no measurement,
123143
the result is the empty tuple.
124144
"""
125-
getter = getattr(val, '_measurement_keys_', None)
126-
result = NotImplemented if getter is None else getter()
127-
if result is not NotImplemented and result is not None:
128-
return set(result)
129-
130-
getter = getattr(val, '_measurement_key_', None)
131-
result = NotImplemented if getter is None else getter()
145+
result = _measurement_keys_from_magic_methods(val)
132146
if result is not NotImplemented and result is not None:
133-
return {result}
147+
return result
134148

135149
if allow_decompose:
136150
operations, _, _ = _try_decompose_into_operations_and_qubits(val)
@@ -140,13 +154,66 @@ def measurement_keys(val: Any, *, allow_decompose: bool = True) -> AbstractSet[s
140154
return set()
141155

142156

143-
def is_measurement(val: Any) -> bool:
144-
"""Determines whether or not the given value is a measurement.
157+
def _is_measurement_from_magic_method(val: Any) -> Optional[bool]:
158+
"""Uses `is_measurement` magic method to determine if this object is a measurement."""
159+
getter = getattr(val, '_is_measurement_', None)
160+
return NotImplemented if getter is None else getter()
161+
145162

146-
Measurements are identified by the fact that `cirq.measurement_keys` returns
147-
a non-empty result for them.
163+
def _is_any_measurement(vals: List[Any], allow_decompose: bool) -> bool:
164+
"""Given a list of objects, returns True if any of them is a measurement.
165+
166+
If `allow_decompose` is True, decomposes the objects and runs the measurement checks on the
167+
constituent decomposed operations. But a decompose operation is only called if all cheaper
168+
checks are done. A BFS for searching measurements, where "depth" is each level of decompose.
169+
"""
170+
vals_to_decompose = [] # type: List[Any]
171+
while vals:
172+
val = vals.pop(0)
173+
result = _is_measurement_from_magic_method(val)
174+
if result is not NotImplemented:
175+
if result is True:
176+
return True
177+
if result is False:
178+
# Do not try any other strategies if `val` was explicitly marked as
179+
# "not measurement".
180+
continue
181+
182+
keys = _measurement_keys_from_magic_methods(val)
183+
if keys is not NotImplemented and bool(keys) is True:
184+
return True
185+
186+
if allow_decompose:
187+
vals_to_decompose.append(val)
188+
189+
# If vals has finished iterating over, keep decomposing from vals_to_decompose until vals
190+
# is populated with something.
191+
while not vals:
192+
if not vals_to_decompose:
193+
# Nothing left to process, this is not a measurement.
194+
return False
195+
operations, _, _ = _try_decompose_into_operations_and_qubits(vals_to_decompose.pop(0))
196+
if operations:
197+
# Reverse the decomposed operations because measurements are typically at later
198+
# moments.
199+
vals = operations[::-1]
200+
201+
return False
202+
203+
204+
def is_measurement(val: Any, allow_decompose: bool = True) -> bool:
205+
"""Determines whether or not the given value is a measurement (or contains one).
206+
207+
Measurements are identified by the fact that any of them may have an `_is_measurement_` method
208+
or `cirq.measurement_keys` returns a non-empty result for them.
209+
210+
Args:
211+
val: The value which to evaluate.
212+
allow_decompose: Defaults to True. When true, composite operations that
213+
don't directly specify their `_is_measurement_` property will be decomposed in
214+
order to find any measurements keys within the decomposed operations.
148215
"""
149-
return bool(measurement_keys(val))
216+
return _is_any_measurement([val], allow_decompose)
150217

151218

152219
def with_measurement_key_mapping(val: Any, key_map: Dict[str, str]):

cirq-core/cirq/protocols/measurement_key_protocol_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class ReturnsStr:
2222
def _measurement_key_(self):
2323
return 'door locker'
2424

25+
assert cirq.is_measurement(ReturnsStr())
2526
assert cirq.measurement_key(ReturnsStr()) == 'door locker'
2627

2728
assert cirq.measurement_key(ReturnsStr(), None) == 'door locker'
@@ -84,6 +85,36 @@ def qubits(self):
8485
assert not cirq.is_measurement(NotImplementedOperation())
8586

8687

88+
def test_measurement_without_key():
89+
class MeasurementWithoutKey:
90+
def _is_measurement_(self):
91+
return True
92+
93+
with pytest.raises(TypeError, match='no measurement keys'):
94+
_ = cirq.measurement_key(MeasurementWithoutKey())
95+
96+
assert cirq.is_measurement(MeasurementWithoutKey())
97+
98+
99+
def test_non_measurement_with_key():
100+
class NonMeasurementGate(cirq.Gate):
101+
def _is_measurement_(self):
102+
return False
103+
104+
def _decompose_(self, qubits):
105+
# Decompose should not be called by `is_measurement`
106+
assert False
107+
108+
def _measurement_key_(self):
109+
# `measurement_key`` should not be called by `is_measurement`
110+
assert False
111+
112+
def num_qubits(self) -> int:
113+
return 2 # coverage: ignore
114+
115+
assert not cirq.is_measurement(NonMeasurementGate())
116+
117+
87118
def test_measurement_keys():
88119
class Composite(cirq.Gate):
89120
def _decompose_(self, qubits):
@@ -102,8 +133,10 @@ def num_qubits(self) -> int:
102133
return 1
103134

104135
a, b = cirq.LineQubit.range(2)
136+
assert cirq.is_measurement(Composite())
105137
assert cirq.measurement_keys(Composite()) == {'inner1', 'inner2'}
106138
assert cirq.measurement_keys(Composite().on(a, b)) == {'inner1', 'inner2'}
139+
assert not cirq.is_measurement(Composite(), allow_decompose=False)
107140
assert cirq.measurement_keys(Composite(), allow_decompose=False) == set()
108141
assert cirq.measurement_keys(Composite().on(a, b), allow_decompose=False) == set()
109142

0 commit comments

Comments
 (0)