Skip to content

Commit 409a412

Browse files
authored
Disallow empty measurement keys and fix tests using empty keys (#4060)
* Empty keys are illegal now * merge * fix duplicate post-inits
1 parent 9c23053 commit 409a412

File tree

9 files changed

+43
-40
lines changed

9 files changed

+43
-40
lines changed

cirq-core/cirq/ion/ion_device_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def test_validate_measurement_non_adjacent_qubits_ok():
106106
d = ion_device(3)
107107

108108
d.validate_operation(
109-
cirq.GateOperation(cirq.MeasurementGate(2), (cirq.LineQubit(0), cirq.LineQubit(1)))
109+
cirq.GateOperation(cirq.MeasurementGate(2, 'key'), (cirq.LineQubit(0), cirq.LineQubit(1)))
110110
)
111111

112112

cirq-core/cirq/neutral_atoms/neutral_atom_devices_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def test_validate_moment_errors():
179179
m = cirq.Moment(cirq.X.on_each(*(d.qubit_list()[1:])))
180180
with pytest.raises(ValueError, match="Bad number of simultaneous XY gates"):
181181
d.validate_moment(m)
182-
m = cirq.Moment([cirq.MeasurementGate(1).on(q00), cirq.Z.on(q01)])
182+
m = cirq.Moment([cirq.MeasurementGate(1, 'a').on(q00), cirq.Z.on(q01)])
183183
with pytest.raises(
184184
ValueError, match="Measurements can't be simultaneous with other operations"
185185
):

cirq-core/cirq/ops/measurement_gate_test.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -29,43 +29,43 @@ def test_eval_repr():
2929

3030
@pytest.mark.parametrize('num_qubits', [1, 2, 4])
3131
def test_measure_init(num_qubits):
32-
assert cirq.MeasurementGate(num_qubits).num_qubits() == num_qubits
32+
assert cirq.MeasurementGate(num_qubits, 'a').num_qubits() == num_qubits
3333
assert cirq.MeasurementGate(num_qubits, key='a').key == 'a'
3434
assert cirq.MeasurementGate(num_qubits, key='a').mkey == cirq.MeasurementKey('a')
3535
assert cirq.MeasurementGate(num_qubits, key=cirq.MeasurementKey('a')).key == 'a'
3636
assert cirq.MeasurementGate(num_qubits, key=cirq.MeasurementKey('a')) == cirq.MeasurementGate(
3737
num_qubits, key='a'
3838
)
39-
assert cirq.MeasurementGate(num_qubits, invert_mask=(True,)).invert_mask == (True,)
40-
assert cirq.qid_shape(cirq.MeasurementGate(num_qubits)) == (2,) * num_qubits
41-
assert cirq.qid_shape(cirq.MeasurementGate(3, qid_shape=(1, 2, 3))) == (1, 2, 3)
42-
assert cirq.qid_shape(cirq.MeasurementGate(qid_shape=(1, 2, 3))) == (1, 2, 3)
39+
assert cirq.MeasurementGate(num_qubits, 'a', invert_mask=(True,)).invert_mask == (True,)
40+
assert cirq.qid_shape(cirq.MeasurementGate(num_qubits, 'a')) == (2,) * num_qubits
41+
assert cirq.qid_shape(cirq.MeasurementGate(3, 'a', qid_shape=(1, 2, 3))) == (1, 2, 3)
42+
assert cirq.qid_shape(cirq.MeasurementGate(key='a', qid_shape=(1, 2, 3))) == (1, 2, 3)
4343
with pytest.raises(ValueError, match='len.* >'):
44-
cirq.MeasurementGate(5, invert_mask=(True,) * 6)
44+
cirq.MeasurementGate(5, 'a', invert_mask=(True,) * 6)
4545
with pytest.raises(ValueError, match='len.* !='):
46-
cirq.MeasurementGate(5, qid_shape=(1, 2))
46+
cirq.MeasurementGate(5, 'a', qid_shape=(1, 2))
47+
with pytest.raises(ValueError, match='cannot be empty'):
48+
cirq.MeasurementGate(2, qid_shape=(1, 2))
4749
with pytest.raises(ValueError, match='Specify either'):
4850
cirq.MeasurementGate()
4951

5052

5153
@pytest.mark.parametrize('num_qubits', [1, 2, 4])
5254
def test_has_stabilizer_effect(num_qubits):
53-
assert cirq.has_stabilizer_effect(cirq.MeasurementGate(num_qubits))
55+
assert cirq.has_stabilizer_effect(cirq.MeasurementGate(num_qubits, 'a'))
5456

5557

5658
def test_measurement_eq():
5759
eq = cirq.testing.EqualsTester()
5860
eq.make_equality_group(
59-
lambda: cirq.MeasurementGate(1, ''),
60-
lambda: cirq.MeasurementGate(1, '', invert_mask=()),
61-
lambda: cirq.MeasurementGate(1, '', qid_shape=(2,)),
61+
lambda: cirq.MeasurementGate(1, 'a'),
62+
lambda: cirq.MeasurementGate(1, 'a', invert_mask=()),
63+
lambda: cirq.MeasurementGate(1, 'a', qid_shape=(2,)),
6264
)
63-
eq.add_equality_group(cirq.MeasurementGate(1, 'a'))
6465
eq.add_equality_group(cirq.MeasurementGate(1, 'a', invert_mask=(True,)))
6566
eq.add_equality_group(cirq.MeasurementGate(1, 'a', invert_mask=(False,)))
6667
eq.add_equality_group(cirq.MeasurementGate(1, 'b'))
6768
eq.add_equality_group(cirq.MeasurementGate(2, 'a'))
68-
eq.add_equality_group(cirq.MeasurementGate(2, ''))
6969
eq.add_equality_group(
7070
cirq.MeasurementGate(3, 'a'), cirq.MeasurementGate(3, 'a', qid_shape=(2, 2, 2))
7171
)
@@ -154,15 +154,14 @@ def test_qudit_measure_quil():
154154

155155
def test_measurement_gate_diagram():
156156
# Shows key.
157-
assert cirq.circuit_diagram_info(cirq.MeasurementGate(1)) == cirq.CircuitDiagramInfo(("M('')",))
158157
assert cirq.circuit_diagram_info(
159158
cirq.MeasurementGate(1, key='test')
160159
) == cirq.CircuitDiagramInfo(("M('test')",))
161160

162161
# Uses known qubit count.
163162
assert (
164163
cirq.circuit_diagram_info(
165-
cirq.MeasurementGate(3),
164+
cirq.MeasurementGate(3, 'a'),
166165
cirq.CircuitDiagramInfoArgs(
167166
known_qubits=None,
168167
known_qubit_count=3,
@@ -171,13 +170,13 @@ def test_measurement_gate_diagram():
171170
qubit_map=None,
172171
),
173172
)
174-
== cirq.CircuitDiagramInfo(("M('')", 'M', 'M'))
173+
== cirq.CircuitDiagramInfo(("M('a')", 'M', 'M'))
175174
)
176175

177176
# Shows invert mask.
178177
assert cirq.circuit_diagram_info(
179-
cirq.MeasurementGate(2, invert_mask=(False, True))
180-
) == cirq.CircuitDiagramInfo(("M('')", "!M"))
178+
cirq.MeasurementGate(2, 'a', invert_mask=(False, True))
179+
) == cirq.CircuitDiagramInfo(("M('a')", "!M"))
181180

182181
# Omits key when it is the default.
183182
a = cirq.NamedQubit('a')
@@ -210,12 +209,12 @@ def test_measurement_gate_diagram():
210209

211210
def test_measurement_channel():
212211
np.testing.assert_allclose(
213-
cirq.kraus(cirq.MeasurementGate(1)),
212+
cirq.kraus(cirq.MeasurementGate(1, 'a')),
214213
(np.array([[1, 0], [0, 0]]), np.array([[0, 0], [0, 1]])),
215214
)
216215
# yapf: disable
217216
np.testing.assert_allclose(
218-
cirq.kraus(cirq.MeasurementGate(2)),
217+
cirq.kraus(cirq.MeasurementGate(2, 'a')),
219218
(np.array([[1, 0, 0, 0],
220219
[0, 0, 0, 0],
221220
[0, 0, 0, 0],
@@ -233,7 +232,7 @@ def test_measurement_channel():
233232
[0, 0, 0, 0],
234233
[0, 0, 0, 1]])))
235234
np.testing.assert_allclose(
236-
cirq.kraus(cirq.MeasurementGate(2, qid_shape=(2, 3))),
235+
cirq.kraus(cirq.MeasurementGate(2, 'a', qid_shape=(2, 3))),
237236
(np.diag([1, 0, 0, 0, 0, 0]),
238237
np.diag([0, 1, 0, 0, 0, 0]),
239238
np.diag([0, 0, 1, 0, 0, 0]),
@@ -248,21 +247,21 @@ def test_measurement_qubit_count_vs_mask_length():
248247
b = cirq.NamedQubit('b')
249248
c = cirq.NamedQubit('c')
250249

251-
_ = cirq.MeasurementGate(num_qubits=1, invert_mask=(True,)).on(a)
252-
_ = cirq.MeasurementGate(num_qubits=2, invert_mask=(True, False)).on(a, b)
253-
_ = cirq.MeasurementGate(num_qubits=3, invert_mask=(True, False, True)).on(a, b, c)
250+
_ = cirq.MeasurementGate(num_qubits=1, key='a', invert_mask=(True,)).on(a)
251+
_ = cirq.MeasurementGate(num_qubits=2, key='a', invert_mask=(True, False)).on(a, b)
252+
_ = cirq.MeasurementGate(num_qubits=3, key='a', invert_mask=(True, False, True)).on(a, b, c)
254253
with pytest.raises(ValueError):
255-
_ = cirq.MeasurementGate(num_qubits=1, invert_mask=(True, False)).on(a)
254+
_ = cirq.MeasurementGate(num_qubits=1, key='a', invert_mask=(True, False)).on(a)
256255
with pytest.raises(ValueError):
257-
_ = cirq.MeasurementGate(num_qubits=3, invert_mask=(True, False, True)).on(a, b)
256+
_ = cirq.MeasurementGate(num_qubits=3, key='a', invert_mask=(True, False, True)).on(a, b)
258257

259258

260259
def test_consistent_protocols():
261260
for n in range(1, 5):
262-
gate = cirq.MeasurementGate(num_qubits=n)
261+
gate = cirq.MeasurementGate(num_qubits=n, key='a')
263262
cirq.testing.assert_implements_consistent_protocols(gate)
264263

265-
gate = cirq.MeasurementGate(num_qubits=n, qid_shape=(3,) * n)
264+
gate = cirq.MeasurementGate(num_qubits=n, key='a', qid_shape=(3,) * n)
266265
cirq.testing.assert_implements_consistent_protocols(gate)
267266

268267

cirq-core/cirq/sim/sparse_simulator_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1226,7 +1226,7 @@ def test_separated_measurements():
12261226
cirq.H(a),
12271227
cirq.H(b),
12281228
cirq.CZ(a, b),
1229-
cirq.measure(a, key=''),
1229+
cirq.measure(a, key='a'),
12301230
cirq.CZ(a, b),
12311231
cirq.H(b),
12321232
cirq.measure(b, key='zero'),

cirq-core/cirq/value/measurement_key.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ class MeasurementKey:
4141
path: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
4242

4343
def __post_init__(self):
44+
if not self.name:
45+
raise ValueError("Measurement key name cannot be empty")
4446
if MEASUREMENT_KEY_SEPARATOR in self.name:
4547
raise ValueError(
4648
f'Invalid key name: {self.name}\n{MEASUREMENT_KEY_SEPARATOR} is not allowed in '

cirq-core/cirq/value/measurement_key_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
def test_empty_init():
2121
with pytest.raises(TypeError, match='required positional argument'):
2222
_ = cirq.MeasurementKey()
23-
mkey = cirq.MeasurementKey('')
24-
assert mkey.name == ''
23+
with pytest.raises(ValueError, match='cannot be empty'):
24+
_ = cirq.MeasurementKey('')
2525

2626

2727
def test_nested_key():

cirq-google/cirq_google/devices/xmon_device_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414

1515
import pytest
1616

17-
import cirq
1817
import cirq_google as cg
18+
import cirq
1919

2020

2121
def square_device(width: int, height: int, holes=()) -> cg.XmonDevice:
@@ -133,7 +133,9 @@ def test_validate_measurement_non_adjacent_qubits_ok():
133133
d = square_device(3, 3)
134134

135135
d.validate_operation(
136-
cirq.GateOperation(cirq.MeasurementGate(2), (cirq.GridQubit(0, 0), cirq.GridQubit(2, 0)))
136+
cirq.GateOperation(
137+
cirq.MeasurementGate(2, 'a'), (cirq.GridQubit(0, 0), cirq.GridQubit(2, 0))
138+
)
137139
)
138140

139141

cirq-ionq/cirq_ionq/ionq_devices_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@
3838
cirq.YY ** 0.5,
3939
cirq.ZZ ** 0.5,
4040
cirq.SWAP,
41-
cirq.MeasurementGate(num_qubits=1),
42-
cirq.MeasurementGate(num_qubits=2),
43-
cirq.MeasurementGate(num_qubits=10),
41+
cirq.MeasurementGate(num_qubits=1, key='a'),
42+
cirq.MeasurementGate(num_qubits=2, key='b'),
43+
cirq.MeasurementGate(num_qubits=10, key='c'),
4444
)
4545

4646

cirq-pasqal/cirq_pasqal/pasqal_device_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,11 @@ def test_decompose_error():
8888

8989
# MeasurementGate is not a GateOperation
9090
with pytest.raises(TypeError):
91-
d.decompose_operation(cirq.ops.MeasurementGate(num_qubits=2))
91+
d.decompose_operation(cirq.ops.MeasurementGate(num_qubits=2, key='a'))
9292
# It has to be made into one
9393
assert d.is_pasqal_device_op(
9494
cirq.ops.GateOperation(
95-
cirq.ops.MeasurementGate(2), [cirq.NamedQubit('q0'), cirq.NamedQubit('q1')]
95+
cirq.ops.MeasurementGate(2, 'b'), [cirq.NamedQubit('q0'), cirq.NamedQubit('q1')]
9696
)
9797
)
9898

0 commit comments

Comments
 (0)