Skip to content

Commit 3d645b7

Browse files
Enable serialization of non-GridQubits. (#2966)
* Enable serialization of non-GridQubits. * Lint fixes for serializer. * LineQubit proto IDs must be digits-only. * Add coverage. * named_qubit_* -> qubit_* * Fix lint error.
1 parent 6bdcb1c commit 3d645b7

File tree

6 files changed

+106
-10
lines changed

6 files changed

+106
-10
lines changed

cirq/google/api/v2/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
grid_qubit_from_proto_id,
2626
line_qubit_from_proto_id,
2727
named_qubit_from_proto_id,
28+
qubit_from_proto_id,
2829
qubit_to_proto_id,
2930
)
3031

cirq/google/api/v2/program.py

+36
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,42 @@ def qubit_to_proto_id(q: 'cirq.Qid') -> str:
4141
type(q)))
4242

4343

44+
def qubit_from_proto_id(proto_id: str) -> 'cirq.Qid':
45+
"""Return a `cirq.Qid` for a proto id.
46+
47+
Proto IDs of the form {int}_{int} are parsed as GridQubits.
48+
49+
Proto IDs of the form {int} are parsed as LineQubits.
50+
51+
All other proto IDs are parsed as NamedQubits. Note that this will happily
52+
accept any string; for circuits which explicitly use Grid or LineQubits,
53+
prefer one of the specialized methods below.
54+
55+
Args:
56+
proto_id: The id to convert.
57+
58+
Returns:
59+
A `cirq.Qid` corresponding to the proto id.
60+
"""
61+
num_coords = len(proto_id.split('_'))
62+
if num_coords == 2:
63+
try:
64+
grid_q = grid_qubit_from_proto_id(proto_id)
65+
return grid_q
66+
except ValueError:
67+
pass # Not a grid qubit.
68+
elif num_coords == 1:
69+
try:
70+
line_q = line_qubit_from_proto_id(proto_id)
71+
return line_q
72+
except ValueError:
73+
pass # Not a line qubit.
74+
75+
# named_qubit_from_proto has no failure condition.
76+
named_q = named_qubit_from_proto_id(proto_id)
77+
return named_q
78+
79+
4480
def grid_qubit_from_proto_id(proto_id: str) -> 'cirq.GridQubit':
4581
"""Parse a proto id to a `cirq.GridQubit`.
4682

cirq/google/api/v2/program_test.py

+14
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,17 @@ def test_line_qubit_from_proto_id_invalid():
7272

7373
def test_named_qubit_from_proto_id():
7474
assert v2.named_qubit_from_proto_id('a') == cirq.NamedQubit('a')
75+
76+
77+
def test_generic_qubit_from_proto_id():
78+
assert v2.qubit_from_proto_id('1_2') == cirq.GridQubit(1, 2)
79+
assert v2.qubit_from_proto_id('1') == cirq.LineQubit(1)
80+
assert v2.qubit_from_proto_id('a') == cirq.NamedQubit('a')
81+
82+
# Despite the fact that int(1_2_3) = 123, only pure numbers are parsed into
83+
# LineQubits.
84+
assert v2.qubit_from_proto_id('1_2_3') == cirq.NamedQubit('1_2_3')
85+
86+
# All non-int-parseable names are converted to NamedQubits.
87+
assert v2.qubit_from_proto_id('a') == cirq.NamedQubit('a')
88+
assert v2.qubit_from_proto_id('1_b') == cirq.NamedQubit('1_b')

cirq/google/op_deserializer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def from_proto(self,
9494
*,
9595
arg_function_language: str = '') -> 'cirq.Operation':
9696
"""Turns a cirq.google.api.v2.Operation proto into a GateOperation."""
97-
qubits = [v2.grid_qubit_from_proto_id(q.id) for q in proto.qubits]
97+
qubits = [v2.qubit_from_proto_id(q.id) for q in proto.qubits]
9898
args = self._args_from_proto(
9999
proto, arg_function_language=arg_function_language)
100100
if self.num_qubits_param is not None:

cirq/google/op_serializer.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
# limitations under the License.
1414

1515
from dataclasses import dataclass
16-
from typing import (Callable, cast, List, Optional, Type, TypeVar, Union,
16+
from typing import (Callable, List, Optional, Type, TypeVar, Union,
1717
TYPE_CHECKING)
1818

1919
import numpy as np
2020

21-
from cirq import devices, ops
21+
from cirq import ops
2222
from cirq.google.api import v2
2323
from cirq.google import arg_func_langs
2424
from cirq.google.arg_func_langs import _arg_to_proto
@@ -104,8 +104,6 @@ def to_proto(
104104
) -> Optional[v2.program_pb2.Operation]:
105105
"""Returns the cirq.google.api.v2.Operation message as a proto dict."""
106106

107-
if not all(isinstance(qubit, devices.GridQubit) for qubit in op.qubits):
108-
raise ValueError('All qubits must be GridQubits')
109107
gate = op.gate
110108
if not isinstance(gate, self.gate_type):
111109
raise ValueError(
@@ -120,8 +118,7 @@ def to_proto(
120118

121119
msg.gate.id = self.serialized_gate_id
122120
for qubit in op.qubits:
123-
msg.qubits.add().id = v2.qubit_to_proto_id(
124-
cast(devices.GridQubit, qubit))
121+
msg.qubits.add().id = v2.qubit_to_proto_id(qubit)
125122
for arg in self.args:
126123
value = self._value_from_gate(op, arg)
127124
if value is not None:

cirq/google/op_serializer_test.py

+51-3
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def test_to_proto_unsupported_type():
258258
serializer.to_proto(GateWithProperty(b's')(q))
259259

260260

261-
def test_to_proto_unsupported_qubit_type():
261+
def test_to_proto_named_qubit_supported():
262262
serializer = cg.GateOpSerializer(gate_type=GateWithProperty,
263263
serialized_gate_id='my_gate',
264264
args=[
@@ -268,8 +268,56 @@ def test_to_proto_unsupported_qubit_type():
268268
op_getter='val')
269269
])
270270
q = cirq.NamedQubit('a')
271-
with pytest.raises(ValueError, match='GridQubit'):
272-
serializer.to_proto(GateWithProperty(1.0)(q))
271+
arg_value = 1.0
272+
result = serializer.to_proto(GateWithProperty(arg_value)(q))
273+
274+
expected = op_proto({
275+
'gate': {
276+
'id': 'my_gate'
277+
},
278+
'args': {
279+
'my_val': {
280+
'arg_value': {
281+
'float_value': arg_value
282+
}
283+
}
284+
},
285+
'qubits': [{
286+
'id': 'a'
287+
}]
288+
})
289+
assert result == expected
290+
291+
292+
def test_to_proto_line_qubit_supported():
293+
serializer = cg.GateOpSerializer(gate_type=GateWithProperty,
294+
serialized_gate_id='my_gate',
295+
args=[
296+
cg.SerializingArg(
297+
serialized_name='my_val',
298+
serialized_type=float,
299+
op_getter='val')
300+
])
301+
q = cirq.LineQubit('10')
302+
arg_value = 1.0
303+
result = serializer.to_proto(GateWithProperty(arg_value)(q))
304+
305+
expected = op_proto({
306+
'gate': {
307+
'id': 'my_gate'
308+
},
309+
'args': {
310+
'my_val': {
311+
'arg_value': {
312+
'float_value': arg_value
313+
}
314+
}
315+
},
316+
'qubits': [{
317+
'id': '10'
318+
}]
319+
})
320+
assert result == expected
273321

274322

275323
def test_to_proto_required_but_not_present():

0 commit comments

Comments
 (0)