Skip to content

Commit 0f11adc

Browse files
authored
Disallow empty measurement keys and fix tests using empty keys (quantumlib#4060)
* Empty keys are illegal now * merge * fix duplicate post-inits
1 parent ab0dc94 commit 0f11adc

6 files changed

+34
-33
lines changed

cirq/ion/ion_device_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def test_validate_measurement_non_adjacent_qubits_ok():
106106
d = ion_device(3)
107107

108108
d.validate_operation(
109-
cirq.GateOperation(cirq.MeasurementGate(2), (cirq.LineQubit(0), cirq.LineQubit(1)))
109+
cirq.GateOperation(cirq.MeasurementGate(2, 'key'), (cirq.LineQubit(0), cirq.LineQubit(1)))
110110
)
111111

112112

cirq/neutral_atoms/neutral_atom_devices_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def test_validate_moment_errors():
179179
m = cirq.Moment(cirq.X.on_each(*(d.qubit_list()[1:])))
180180
with pytest.raises(ValueError, match="Bad number of simultaneous XY gates"):
181181
d.validate_moment(m)
182-
m = cirq.Moment([cirq.MeasurementGate(1).on(q00), cirq.Z.on(q01)])
182+
m = cirq.Moment([cirq.MeasurementGate(1, 'a').on(q00), cirq.Z.on(q01)])
183183
with pytest.raises(
184184
ValueError, match="Measurements can't be simultaneous with other operations"
185185
):

cirq/ops/measurement_gate_test.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -29,43 +29,43 @@ def test_eval_repr():
2929

3030
@pytest.mark.parametrize('num_qubits', [1, 2, 4])
3131
def test_measure_init(num_qubits):
32-
assert cirq.MeasurementGate(num_qubits).num_qubits() == num_qubits
32+
assert cirq.MeasurementGate(num_qubits, 'a').num_qubits() == num_qubits
3333
assert cirq.MeasurementGate(num_qubits, key='a').key == 'a'
3434
assert cirq.MeasurementGate(num_qubits, key='a').mkey == cirq.MeasurementKey('a')
3535
assert cirq.MeasurementGate(num_qubits, key=cirq.MeasurementKey('a')).key == 'a'
3636
assert cirq.MeasurementGate(num_qubits, key=cirq.MeasurementKey('a')) == cirq.MeasurementGate(
3737
num_qubits, key='a'
3838
)
39-
assert cirq.MeasurementGate(num_qubits, invert_mask=(True,)).invert_mask == (True,)
40-
assert cirq.qid_shape(cirq.MeasurementGate(num_qubits)) == (2,) * num_qubits
41-
assert cirq.qid_shape(cirq.MeasurementGate(3, qid_shape=(1, 2, 3))) == (1, 2, 3)
42-
assert cirq.qid_shape(cirq.MeasurementGate(qid_shape=(1, 2, 3))) == (1, 2, 3)
39+
assert cirq.MeasurementGate(num_qubits, 'a', invert_mask=(True,)).invert_mask == (True,)
40+
assert cirq.qid_shape(cirq.MeasurementGate(num_qubits, 'a')) == (2,) * num_qubits
41+
assert cirq.qid_shape(cirq.MeasurementGate(3, 'a', qid_shape=(1, 2, 3))) == (1, 2, 3)
42+
assert cirq.qid_shape(cirq.MeasurementGate(key='a', qid_shape=(1, 2, 3))) == (1, 2, 3)
4343
with pytest.raises(ValueError, match='len.* >'):
44-
cirq.MeasurementGate(5, invert_mask=(True,) * 6)
44+
cirq.MeasurementGate(5, 'a', invert_mask=(True,) * 6)
4545
with pytest.raises(ValueError, match='len.* !='):
46-
cirq.MeasurementGate(5, qid_shape=(1, 2))
46+
cirq.MeasurementGate(5, 'a', qid_shape=(1, 2))
47+
with pytest.raises(ValueError, match='cannot be empty'):
48+
cirq.MeasurementGate(2, qid_shape=(1, 2))
4749
with pytest.raises(ValueError, match='Specify either'):
4850
cirq.MeasurementGate()
4951

5052

5153
@pytest.mark.parametrize('num_qubits', [1, 2, 4])
5254
def test_has_stabilizer_effect(num_qubits):
53-
assert cirq.has_stabilizer_effect(cirq.MeasurementGate(num_qubits))
55+
assert cirq.has_stabilizer_effect(cirq.MeasurementGate(num_qubits, 'a'))
5456

5557

5658
def test_measurement_eq():
5759
eq = cirq.testing.EqualsTester()
5860
eq.make_equality_group(
59-
lambda: cirq.MeasurementGate(1, ''),
60-
lambda: cirq.MeasurementGate(1, '', invert_mask=()),
61-
lambda: cirq.MeasurementGate(1, '', qid_shape=(2,)),
61+
lambda: cirq.MeasurementGate(1, 'a'),
62+
lambda: cirq.MeasurementGate(1, 'a', invert_mask=()),
63+
lambda: cirq.MeasurementGate(1, 'a', qid_shape=(2,)),
6264
)
63-
eq.add_equality_group(cirq.MeasurementGate(1, 'a'))
6465
eq.add_equality_group(cirq.MeasurementGate(1, 'a', invert_mask=(True,)))
6566
eq.add_equality_group(cirq.MeasurementGate(1, 'a', invert_mask=(False,)))
6667
eq.add_equality_group(cirq.MeasurementGate(1, 'b'))
6768
eq.add_equality_group(cirq.MeasurementGate(2, 'a'))
68-
eq.add_equality_group(cirq.MeasurementGate(2, ''))
6969
eq.add_equality_group(
7070
cirq.MeasurementGate(3, 'a'), cirq.MeasurementGate(3, 'a', qid_shape=(2, 2, 2))
7171
)
@@ -154,15 +154,14 @@ def test_qudit_measure_quil():
154154

155155
def test_measurement_gate_diagram():
156156
# Shows key.
157-
assert cirq.circuit_diagram_info(cirq.MeasurementGate(1)) == cirq.CircuitDiagramInfo(("M('')",))
158157
assert cirq.circuit_diagram_info(
159158
cirq.MeasurementGate(1, key='test')
160159
) == cirq.CircuitDiagramInfo(("M('test')",))
161160

162161
# Uses known qubit count.
163162
assert (
164163
cirq.circuit_diagram_info(
165-
cirq.MeasurementGate(3),
164+
cirq.MeasurementGate(3, 'a'),
166165
cirq.CircuitDiagramInfoArgs(
167166
known_qubits=None,
168167
known_qubit_count=3,
@@ -171,13 +170,13 @@ def test_measurement_gate_diagram():
171170
qubit_map=None,
172171
),
173172
)
174-
== cirq.CircuitDiagramInfo(("M('')", 'M', 'M'))
173+
== cirq.CircuitDiagramInfo(("M('a')", 'M', 'M'))
175174
)
176175

177176
# Shows invert mask.
178177
assert cirq.circuit_diagram_info(
179-
cirq.MeasurementGate(2, invert_mask=(False, True))
180-
) == cirq.CircuitDiagramInfo(("M('')", "!M"))
178+
cirq.MeasurementGate(2, 'a', invert_mask=(False, True))
179+
) == cirq.CircuitDiagramInfo(("M('a')", "!M"))
181180

182181
# Omits key when it is the default.
183182
a = cirq.NamedQubit('a')
@@ -210,12 +209,12 @@ def test_measurement_gate_diagram():
210209

211210
def test_measurement_channel():
212211
np.testing.assert_allclose(
213-
cirq.kraus(cirq.MeasurementGate(1)),
212+
cirq.kraus(cirq.MeasurementGate(1, 'a')),
214213
(np.array([[1, 0], [0, 0]]), np.array([[0, 0], [0, 1]])),
215214
)
216215
# yapf: disable
217216
np.testing.assert_allclose(
218-
cirq.kraus(cirq.MeasurementGate(2)),
217+
cirq.kraus(cirq.MeasurementGate(2, 'a')),
219218
(np.array([[1, 0, 0, 0],
220219
[0, 0, 0, 0],
221220
[0, 0, 0, 0],
@@ -233,7 +232,7 @@ def test_measurement_channel():
233232
[0, 0, 0, 0],
234233
[0, 0, 0, 1]])))
235234
np.testing.assert_allclose(
236-
cirq.kraus(cirq.MeasurementGate(2, qid_shape=(2, 3))),
235+
cirq.kraus(cirq.MeasurementGate(2, 'a', qid_shape=(2, 3))),
237236
(np.diag([1, 0, 0, 0, 0, 0]),
238237
np.diag([0, 1, 0, 0, 0, 0]),
239238
np.diag([0, 0, 1, 0, 0, 0]),
@@ -248,21 +247,21 @@ def test_measurement_qubit_count_vs_mask_length():
248247
b = cirq.NamedQubit('b')
249248
c = cirq.NamedQubit('c')
250249

251-
_ = cirq.MeasurementGate(num_qubits=1, invert_mask=(True,)).on(a)
252-
_ = cirq.MeasurementGate(num_qubits=2, invert_mask=(True, False)).on(a, b)
253-
_ = cirq.MeasurementGate(num_qubits=3, invert_mask=(True, False, True)).on(a, b, c)
250+
_ = cirq.MeasurementGate(num_qubits=1, key='a', invert_mask=(True,)).on(a)
251+
_ = cirq.MeasurementGate(num_qubits=2, key='a', invert_mask=(True, False)).on(a, b)
252+
_ = cirq.MeasurementGate(num_qubits=3, key='a', invert_mask=(True, False, True)).on(a, b, c)
254253
with pytest.raises(ValueError):
255-
_ = cirq.MeasurementGate(num_qubits=1, invert_mask=(True, False)).on(a)
254+
_ = cirq.MeasurementGate(num_qubits=1, key='a', invert_mask=(True, False)).on(a)
256255
with pytest.raises(ValueError):
257-
_ = cirq.MeasurementGate(num_qubits=3, invert_mask=(True, False, True)).on(a, b)
256+
_ = cirq.MeasurementGate(num_qubits=3, key='a', invert_mask=(True, False, True)).on(a, b)
258257

259258

260259
def test_consistent_protocols():
261260
for n in range(1, 5):
262-
gate = cirq.MeasurementGate(num_qubits=n)
261+
gate = cirq.MeasurementGate(num_qubits=n, key='a')
263262
cirq.testing.assert_implements_consistent_protocols(gate)
264263

265-
gate = cirq.MeasurementGate(num_qubits=n, qid_shape=(3,) * n)
264+
gate = cirq.MeasurementGate(num_qubits=n, key='a', qid_shape=(3,) * n)
266265
cirq.testing.assert_implements_consistent_protocols(gate)
267266

268267

cirq/sim/sparse_simulator_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1226,7 +1226,7 @@ def test_separated_measurements():
12261226
cirq.H(a),
12271227
cirq.H(b),
12281228
cirq.CZ(a, b),
1229-
cirq.measure(a, key=''),
1229+
cirq.measure(a, key='a'),
12301230
cirq.CZ(a, b),
12311231
cirq.H(b),
12321232
cirq.measure(b, key='zero'),

cirq/value/measurement_key.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ class MeasurementKey:
4141
path: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
4242

4343
def __post_init__(self):
44+
if not self.name:
45+
raise ValueError("Measurement key name cannot be empty")
4446
if MEASUREMENT_KEY_SEPARATOR in self.name:
4547
raise ValueError(
4648
f'Invalid key name: {self.name}\n{MEASUREMENT_KEY_SEPARATOR} is not allowed in '

cirq/value/measurement_key_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
def test_empty_init():
2121
with pytest.raises(TypeError, match='required positional argument'):
2222
_ = cirq.MeasurementKey()
23-
mkey = cirq.MeasurementKey('')
24-
assert mkey.name == ''
23+
with pytest.raises(ValueError, match='cannot be empty'):
24+
_ = cirq.MeasurementKey('')
2525

2626

2727
def test_nested_key():

0 commit comments

Comments
 (0)