Skip to content

Commit a14e2df

Browse files
committed
Addressed Michael's comments
1 parent ad6dd90 commit a14e2df

File tree

8 files changed

+157
-60
lines changed

8 files changed

+157
-60
lines changed

cirq-google/cirq_google/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@
5757
from cirq_google.devices import (
5858
Bristlecone,
5959
Foxtail,
60-
GoogleDevice,
6160
GoogleNoiseProperties,
61+
GridDevice,
6262
NoiseModelFromGoogleNoiseProperties,
6363
SerializableDevice,
6464
Sycamore,

cirq-google/cirq_google/devices/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from cirq_google.devices.known_devices import Bristlecone, Foxtail, Sycamore, Sycamore23
2121

22-
from cirq_google.devices.google_device import GoogleDevice
22+
from cirq_google.devices.grid_device import GridDevice
2323

2424
from cirq_google.devices.serializable_device import SerializableDevice
2525

cirq-google/cirq_google/devices/google_device.py renamed to cirq-google/cirq_google/devices/grid_device.py

+82-19
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,89 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Device object representing Google devices."""
15+
"""Device object representing Google devices with a grid qubit layout."""
1616

17-
from typing import Any, Set, Tuple, cast
17+
from typing import Any, List, Set, Tuple, cast
1818
import cirq
1919
from cirq_google.api import v2
2020

2121

22+
def _get_qubit_pairs(proto: v2.device_pb2.DeviceSpecification) -> List[Tuple[cirq.Qid, cirq.Qid]]:
23+
"""Constructs a list of qubit pairs based on targets of symmetric 2-qubit gates in the proto."""
24+
25+
# While the `GateSpecification` proto message contains qubit target references, they are
26+
# ignored here because the following assumptions make them unnecessary currently:
27+
# * All valid qubit pairs work for all two-qubit gates.
28+
# * All valid qubits work for all single-qubit gates.
29+
# * Measurement gate can always be applied to all subset of qubits.
30+
return [
31+
(_qid_from_str(target.ids[0]), _qid_from_str(target.ids[1]))
32+
for ts in proto.valid_targets
33+
for target in ts.targets
34+
if len(target.ids) == 2 and ts.target_ordering == v2.device_pb2.TargetSet.SYMMETRIC
35+
]
36+
37+
38+
def _validate_device_specification(proto: v2.device_pb2.DeviceSpecification) -> None:
39+
"""Validates the DeviceSpecification proto.
40+
41+
Args:
42+
proto: The DeviceSpecification proto to validate.
43+
44+
Raises:
45+
ValueError: If the DeviceSpecification is invalid.
46+
47+
"""
48+
49+
for target_set in proto.valid_targets:
50+
51+
# Check for unknown qubits in targets.
52+
for target in target_set.targets:
53+
for target_id in target.ids:
54+
if target_id not in proto.valid_qubits:
55+
raise ValueError(
56+
f"Invalid DeviceSpecification: valid_targets contain qubit '{target_id}'"
57+
" which is not in valid_qubits."
58+
)
59+
60+
# Symmetric and asymmetric targets should not have repeated qubits.
61+
if (
62+
target_set.target_ordering == v2.device_pb2.TargetSet.SYMMETRIC
63+
or target_set.target_ordering == v2.device_pb2.TargetSet.ASYMMETRIC
64+
):
65+
for target in target_set.targets:
66+
if len(target.ids) > len(set(target.ids)):
67+
raise ValueError(
68+
f"Invalid DeviceSpecification: the target set '{target_set.name}' is either"
69+
" SYMMETRIC or ASYMMETRIC but has a target which contains repeated qubits:"
70+
f" {target.ids}."
71+
)
72+
73+
# A SUBSET_PERMUTATION target should contain exactly one qubit.
74+
if target_set.target_ordering == v2.device_pb2.TargetSet.SUBSET_PERMUTATION:
75+
for target in target_set.targets:
76+
if len(target.ids) != 1:
77+
raise ValueError(
78+
f"Invalid DeviceSpecification: the target set '{target_set.name}' is of"
79+
" type SUBSET_PERMUTATION but contains a target which does not have exactly"
80+
f" 1 qubit: {target.ids}."
81+
)
82+
83+
2284
@cirq.value_equality
23-
class GoogleDevice(cirq.Device):
24-
"""Device object representing Google devices.
85+
class GridDevice(cirq.Device):
86+
"""Device object representing Google devices with a grid qubit layout.
2587
2688
For end users, instances of this class are typically accessed via
2789
`Engine.get_processor('processor_name').get_device()`.
2890
2991
This class is compliant with the core `cirq.Device` abstraction. In particular:
3092
* Device information is captured in the `metadata` property.
31-
* An instance of `GoogleDevice` can be used to validate circuits, moments, and operations.
93+
* An instance of `GridDevice` can be used to validate circuits, moments, and operations.
3294
3395
Example use cases:
3496
35-
* Get an instance of a Google device.
97+
* Get an instance of a Google grid device.
3698
>>> device = cirq_google.get_engine().get_processor('weber').get_device()
3799
38100
* Print the grid layout of the device.
@@ -78,15 +140,15 @@ class GoogleDevice(cirq.Device):
78140
"""
79141

80142
def __init__(self, metadata: cirq.GridDeviceMetadata):
81-
"""Creates a GoogleDevice object.
143+
"""Creates a GridDevice object.
82144
83145
This constructor typically should not be used directly. Use `from_proto()` instead.
84146
"""
85147
self._metadata = metadata
86148

87149
@classmethod
88-
def from_proto(cls, proto: v2.device_pb2.DeviceSpecification) -> 'GoogleDevice':
89-
"""Create a `GoogleDevice` from a DeviceSpecification proto.
150+
def from_proto(cls, proto: v2.device_pb2.DeviceSpecification) -> 'GridDevice':
151+
"""Create a `GridDevice` from a DeviceSpecification proto.
90152
91153
This class only supports `cirq.GridQubit`s and `cirq.NamedQubit`s. If a
92154
`DeviceSpecification.valid_qubits` string is in the form `<int>_<int>`, it is parsed as a
@@ -99,6 +161,8 @@ def from_proto(cls, proto: v2.device_pb2.DeviceSpecification) -> 'GoogleDevice':
99161
ValueError: If the given `DeviceSpecification` is invalid.
100162
"""
101163

164+
_validate_device_specification(proto)
165+
102166
# Create qubit set
103167
all_qubits = [_qid_from_str(q) for q in proto.valid_qubits]
104168

@@ -110,7 +174,6 @@ def from_proto(cls, proto: v2.device_pb2.DeviceSpecification) -> 'GoogleDevice':
110174
# * All valid qubits work for all single-qubit gates.
111175
# * Measurement gate can always be applied to all subset of qubits.
112176
#
113-
# TODO(#5169) Consider adding the reversed pair, depending on the issue's solution.
114177
qubit_pairs = [
115178
(_qid_from_str(target.ids[0]), _qid_from_str(target.ids[1]))
116179
for ts in proto.valid_targets
@@ -128,10 +191,10 @@ def from_proto(cls, proto: v2.device_pb2.DeviceSpecification) -> 'GoogleDevice':
128191
except ValueError as ve:
129192
raise ValueError("DeviceSpecification is invalid.") from ve
130193

131-
return GoogleDevice(metadata)
194+
return GridDevice(metadata)
132195

133196
@property
134-
def metadata(self):
197+
def metadata(self) -> cirq.GridDeviceMetadata:
135198
"""Get metadata information for the device."""
136199
return self._metadata
137200

@@ -157,8 +220,10 @@ def validate_operation(self, operation: cirq.Operation) -> None:
157220
if q not in self._metadata.qubit_set:
158221
raise ValueError(f'Qubit not on device: {q!r}')
159222

160-
# TODO(#5169) May need to check the reverse pair depending on the issue's solution.
161-
if len(operation.qubits) == 2 and tuple(operation.qubits) not in self._metadata.qubit_pairs:
223+
if (
224+
len(operation.qubits) == 2
225+
and frozenset(operation.qubits) not in self._metadata.qubit_pairs
226+
):
162227
raise ValueError(f'Qubit pair is not valid on device: {operation.qubits!r}')
163228

164229
def __str__(self) -> str:
@@ -177,7 +242,7 @@ def __str__(self) -> str:
177242

178243
# Find pairs that are connected by two-qubit gates.
179244
Pair = Tuple[cirq.GridQubit, cirq.GridQubit]
180-
pairs = sorted({cast(Pair, pair) for pair in self._metadata.qubit_pairs})
245+
pairs = sorted({cast(Pair, tuple(pair)) for pair in self._metadata.qubit_pairs})
181246

182247
# Draw lines between connected pairs. Limit to horizontal/vertical
183248
# lines since that is all the diagram drawer can handle.
@@ -199,12 +264,10 @@ def _repr_pretty_(self, p: Any, cycle: bool) -> None:
199264
p.text(repr(self) if cycle else str(self))
200265

201266
def __repr__(self) -> str:
202-
return f'cirq_google.GoogleDevice({repr(self._metadata)})'
267+
return f'cirq_google.GridDevice({repr(self._metadata)})'
203268

204269
def _json_dict_(self):
205-
return {
206-
'metadata': self._metadata,
207-
}
270+
return {'metadata': self._metadata}
208271

209272
@classmethod
210273
def _from_json_dict_(cls, metadata, **kwargs):

cirq-google/cirq_google/devices/google_device_test.py renamed to cirq-google/cirq_google/devices/grid_device_test.py

+51-28
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,20 @@ def _create_device_spec_with_horizontal_couplings():
3333
# x -- x
3434

3535
grid_qubits = [cirq.GridQubit(i, j) for i in range(GRID_HEIGHT) for j in range(2)]
36+
3637
spec = v2.device_pb2.DeviceSpecification()
3738
spec.valid_qubits.extend([v2.qubit_to_proto_id(q) for q in grid_qubits])
3839
grid_targets = spec.valid_targets.add()
3940
grid_targets.name = '2_qubit_targets'
4041
grid_targets.target_ordering = v2.device_pb2.TargetSet.SYMMETRIC
41-
for row in range(GRID_HEIGHT):
42+
for row in range(int(GRID_HEIGHT / 2)):
4243
new_target = grid_targets.targets.add()
4344
new_target.ids.extend([v2.qubit_to_proto_id(cirq.GridQubit(row, j)) for j in range(2)])
45+
for row in range(int(GRID_HEIGHT / 2), GRID_HEIGHT):
46+
# Flip the qubit pair order for the second half of qubits
47+
# to verify GridDevice properly handles pair symmetry.
48+
new_target = grid_targets.targets.add()
49+
new_target.ids.extend([v2.qubit_to_proto_id(cirq.GridQubit(row, 1 - j)) for j in range(2)])
4450
gate = spec.valid_gates.add()
4551
gate.syc.SetInParent()
4652
gate.gate_duration_picos = 12000
@@ -102,22 +108,36 @@ def _create_device_spec_with_invalid_qubit_in_qubit_pair() -> v2.device_pb2.Devi
102108
return spec
103109

104110

105-
def test_google_device_from_proto_and_validation():
111+
def _create_device_spec_with_invalid_subset_permutation_target() -> v2.device_pb2.DeviceSpecification:
112+
q_proto_ids = [v2.qubit_to_proto_id(cirq.GridQubit(0, i)) for i in range(2)]
113+
114+
spec = v2.device_pb2.DeviceSpecification()
115+
spec.valid_qubits.extend(q_proto_ids)
116+
targets = spec.valid_targets.add()
117+
targets.name = 'test_targets'
118+
targets.target_ordering = v2.device_pb2.TargetSet.SUBSET_PERMUTATION
119+
new_target = targets.targets.add()
120+
new_target.ids.extend(q_proto_ids) # should only have 1 qubit instead
121+
122+
return spec
123+
124+
125+
def test_grid_device_from_proto_and_validation():
106126
grid_qubits, spec = _create_device_spec_with_horizontal_couplings()
107127

108-
device = cirq_google.GoogleDevice.from_proto(spec)
128+
device = cirq_google.GridDevice.from_proto(spec)
109129

110130
assert len(device.metadata.qubit_set) == len(grid_qubits)
111131
assert device.metadata.qubit_set == frozenset(grid_qubits)
112132
assert all(
113-
(cirq.GridQubit(row, 0), cirq.GridQubit(row, 1)) in device.metadata.qubit_pairs
133+
frozenset((cirq.GridQubit(row, 0), cirq.GridQubit(row, 1))) in device.metadata.qubit_pairs
114134
for row in range(GRID_HEIGHT)
115135
)
116136

117137

118-
def test_google_device_validate_operations_positive():
138+
def test_grid_device_validate_operations_positive():
119139
grid_qubits, spec = _create_device_spec_with_horizontal_couplings()
120-
device = cirq_google.GoogleDevice.from_proto(spec)
140+
device = cirq_google.GridDevice.from_proto(spec)
121141

122142
for q in grid_qubits:
123143
device.validate_operation(cirq.X(q))
@@ -129,9 +149,9 @@ def test_google_device_validate_operations_positive():
129149
# TODO(#5050) verify validate_operations gateset support
130150

131151

132-
def test_google_device_validate_operations_negative():
152+
def test_grid_device_validate_operations_negative():
133153
grid_qubits, spec = _create_device_spec_with_horizontal_couplings()
134-
device = cirq_google.GoogleDevice.from_proto(spec)
154+
device = cirq_google.GridDevice.from_proto(spec)
135155

136156
q = cirq.GridQubit(10, 10)
137157
with pytest.raises(ValueError, match='Qubit not on device'):
@@ -145,31 +165,34 @@ def test_google_device_validate_operations_negative():
145165
# TODO(#5050) verify validate_operations gateset errors
146166

147167

148-
@pytest.mark.parametrize(
149-
'spec',
150-
[
151-
# TODO(#5050) implement once gateset support is implemented
152-
# _create_device_spec_with_missing_gate_durations(),
153-
_create_device_spec_with_qubit_pair_self_loops(),
154-
_create_device_spec_with_invalid_qubit_in_qubit_pair(),
155-
],
156-
)
157-
def test_google_device_invalid_device_spec(spec):
158-
with pytest.raises(ValueError, match='DeviceSpecification is invalid'):
159-
cirq_google.GoogleDevice.from_proto(spec)
168+
def test_grid_device_invalid_qubit_in_qubit_pair():
169+
with pytest.raises(ValueError, match='which is not in valid_qubits'):
170+
cirq_google.GridDevice.from_proto(_create_device_spec_with_invalid_qubit_in_qubit_pair())
171+
172+
173+
def test_grid_device_invalid_target_self_loops():
174+
with pytest.raises(ValueError, match='contains repeated qubits'):
175+
cirq_google.GridDevice.from_proto(_create_device_spec_with_qubit_pair_self_loops())
176+
177+
178+
def test_grid_device_invalid_subset_permutation_target():
179+
with pytest.raises(ValueError, match='does not have exactly 1 qubit'):
180+
cirq_google.GridDevice.from_proto(
181+
_create_device_spec_with_invalid_subset_permutation_target()
182+
)
160183

161184

162-
def test_google_device_repr_json():
185+
def test_grid_device_repr_json():
163186
_, spec = _create_device_spec_with_horizontal_couplings()
164-
device = cirq_google.GoogleDevice.from_proto(spec)
187+
device = cirq_google.GridDevice.from_proto(spec)
165188

166189
assert eval(repr(device)) == device
167190
assert cirq.read_json(json_text=cirq.to_json(device)) == device
168191

169192

170-
def test_google_device_str_grid_qubits():
193+
def test_grid_device_str_grid_qubits():
171194
_, spec = _create_device_spec_with_all_couplings()
172-
device = cirq_google.GoogleDevice.from_proto(spec)
195+
device = cirq_google.GridDevice.from_proto(spec)
173196

174197
assert (
175198
str(device)
@@ -191,17 +214,17 @@ def test_google_device_str_grid_qubits():
191214

192215

193216
@pytest.mark.parametrize('cycle,func', [(False, str), (True, repr)])
194-
def test_google_device_repr_pretty(cycle, func):
217+
def test_grid_device_repr_pretty(cycle, func):
195218
_, spec = _create_device_spec_with_all_couplings()
196-
device = cirq_google.GoogleDevice.from_proto(spec)
219+
device = cirq_google.GridDevice.from_proto(spec)
197220
printer = mock.Mock()
198221
device._repr_pretty_(printer, cycle)
199222
printer.text.assert_called_once_with(func(device))
200223

201224

202-
def test_serializable_device_str_named_qubits():
225+
def test_grid_device_str_named_qubits():
203226
q_proto_id = v2.qubit_to_proto_id(cirq.NamedQubit('q'))
204227
spec = v2.device_pb2.DeviceSpecification()
205228
spec.valid_qubits.extend([q_proto_id])
206-
device = cirq_google.GoogleDevice.from_proto(spec)
229+
device = cirq_google.GridDevice.from_proto(spec)
207230
assert device.__class__.__name__ in str(device)

cirq-google/cirq_google/json_resolver_cache.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def _class_resolver_dictionary() -> Dict[str, ObjectFactory]:
3333
'GoogleNoiseProperties': cirq_google.GoogleNoiseProperties,
3434
'SycamoreGate': cirq_google.SycamoreGate,
3535
'GateTabulation': cirq_google.GateTabulation,
36-
'GoogleDevice': cirq_google.GoogleDevice,
36+
'GridDevice': cirq_google.GridDevice,
3737
'PhysicalZTag': cirq_google.PhysicalZTag,
3838
'FSimGateFamily': cirq_google.FSimGateFamily,
3939
'FloquetPhasedFSimCalibrationOptions': cirq_google.FloquetPhasedFSimCalibrationOptions,

cirq-google/cirq_google/json_test_data/GoogleDevice.repr

-1
This file was deleted.

0 commit comments

Comments
 (0)