Skip to content

Commit 4a0956b

Browse files
authored
Fail fast on measurements in has_unitary (quantumlib#5020)
* Fail fast on measurements in has_unitary * Unit test now also uses circuit * Implement has_unitary instead * Remove unused imports * Unit test that would blow up the memory
1 parent 627e12b commit 4a0956b

File tree

5 files changed

+40
-0
lines changed

5 files changed

+40
-0
lines changed

cirq/ops/measurement_gate.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ def key(self, key: Union[str, 'cirq.MeasurementKey']):
8484
def _qid_shape_(self) -> Tuple[int, ...]:
8585
return self._qid_shape
8686

87+
def _has_unitary_(self) -> bool:
88+
return False
89+
8790
def with_key(self, key: Union[str, 'cirq.MeasurementKey']) -> 'MeasurementGate':
8891
"""Creates a measurement gate with a new key but otherwise identical."""
8992
if key == self.key:

cirq/ops/measurement_gate_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ def test_measure_init(num_qubits):
5959
cirq.MeasurementGate()
6060

6161

62+
def test_measurement_has_unitary_returns_false():
63+
gate = cirq.MeasurementGate(1, 'a')
64+
assert not cirq.has_unitary(gate)
65+
66+
6267
@pytest.mark.parametrize('num_qubits', [1, 2, 4])
6368
def test_has_stabilizer_effect(num_qubits):
6469
assert cirq.has_stabilizer_effect(cirq.MeasurementGate(num_qubits, 'a'))

cirq/ops/pauli_measurement_gate.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ def key(self, key: Union[str, 'cirq.MeasurementKey']) -> None:
8383
def _qid_shape_(self) -> Tuple[int, ...]:
8484
return (2,) * len(self._observable)
8585

86+
def _has_unitary_(self) -> bool:
87+
return False
88+
8689
def with_key(self, key: Union[str, 'cirq.MeasurementKey']) -> 'PauliMeasurementGate':
8790
"""Creates a pauli measurement gate with a new key but otherwise identical."""
8891
if key == self.key:

cirq/ops/pauli_measurement_gate_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ def test_init(observable, key):
4747
assert cirq.qid_shape(g) == (2,) * len(observable)
4848

4949

50+
def test_measurement_has_unitary_returns_false():
51+
gate = cirq.PauliMeasurementGate([cirq.X], 'a')
52+
assert not cirq.has_unitary(gate)
53+
54+
5055
def test_measurement_eq():
5156
eq = cirq.testing.EqualsTester()
5257
eq.make_equality_group(

cirq/protocols/has_unitary_protocol_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import numpy as np
16+
import pytest
1617

1718
import cirq
1819

@@ -26,6 +27,29 @@ class No:
2627
assert not cirq.has_unitary(No())
2728

2829

30+
@pytest.mark.parametrize(
31+
'measurement_gate',
32+
(
33+
cirq.MeasurementGate(1, 'a'),
34+
cirq.PauliMeasurementGate([cirq.X], 'a'),
35+
),
36+
)
37+
def test_fail_fast_measure(measurement_gate):
38+
assert not cirq.has_unitary(measurement_gate)
39+
40+
qubit = cirq.NamedQubit('q0')
41+
circuit = cirq.Circuit()
42+
circuit += measurement_gate(qubit)
43+
circuit += cirq.H(qubit)
44+
assert not cirq.has_unitary(circuit)
45+
46+
47+
def test_fail_fast_measure_large_memory():
48+
num_qubits = 100
49+
measurement_op = cirq.MeasurementGate(num_qubits, 'a').on(*cirq.LineQubit.range(num_qubits))
50+
assert not cirq.has_unitary(measurement_op)
51+
52+
2953
def test_via_unitary():
3054
class No1:
3155
def _unitary_(self):

0 commit comments

Comments
 (0)