From 19df7fd8e771488e931c486d41a1427fd3b55e96 Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Tue, 9 May 2023 20:30:48 +0000 Subject: [PATCH 1/9] GridDevice serialization refactor. --- .../cirq_google/devices/grid_device.py | 366 +++++++++++------- .../cirq_google/devices/grid_device_test.py | 130 ++++--- .../cirq_google/devices/known_devices.py | 6 +- 3 files changed, 302 insertions(+), 200 deletions(-) diff --git a/cirq-google/cirq_google/devices/grid_device.py b/cirq-google/cirq_google/devices/grid_device.py index 1560ebfa83e..b9831b8f74b 100644 --- a/cirq-google/cirq_google/devices/grid_device.py +++ b/cirq-google/cirq_google/devices/grid_device.py @@ -30,6 +30,7 @@ ) import re import warnings +from dataclasses import dataclass, field import cirq from cirq_google import ops @@ -40,14 +41,7 @@ # Gate family constants used in various parts of GridDevice logic. -_SYC_GATE_FAMILY = cirq.GateFamily(ops.SYC) -_SQRT_ISWAP_GATE_FAMILY = cirq.GateFamily(cirq.SQRT_ISWAP) -_SQRT_ISWAP_INV_GATE_FAMILY = cirq.GateFamily(cirq.SQRT_ISWAP_INV) -_CZ_GATE_FAMILY = cirq.GateFamily(cirq.CZ) _PHASED_XZ_GATE_FAMILY = cirq.GateFamily(cirq.PhasedXZGate) -_VIRTUAL_ZPOW_GATE_FAMILY = cirq.GateFamily(cirq.ZPowGate, tags_to_ignore=[ops.PhysicalZTag()]) -_PHYSICAL_ZPOW_GATE_FAMILY = cirq.GateFamily(cirq.ZPowGate, tags_to_accept=[ops.PhysicalZTag()]) -_COUPLER_PULSE_GATE_FAMILY = cirq.GateFamily(experimental_ops.CouplerPulse) _MEASUREMENT_GATE_FAMILY = cirq.GateFamily(cirq.MeasurementGate) _WAIT_GATE_FAMILY = cirq.GateFamily(cirq.WaitGate) @@ -74,6 +68,75 @@ _VARIADIC_GATE_FAMILIES = [_MEASUREMENT_GATE_FAMILY, _WAIT_GATE_FAMILY] +GateOrFamily = Union[Type[cirq.Gate], cirq.Gate, cirq.GateFamily] + + +@dataclass +class _GateRepresentations: + """Contains equivalent representations of a gate in both DeviceSpecification and GridDevice.""" + + """The name of gate type in `GateSpecification`.""" + gate_spec_name: str + """Gate representations to be included in generated gatesets and gate durations.""" + primary_forms: List[GateOrFamily] + """GateFamilies which match all other valid gate representations.""" + additional_forms: List[cirq.GateFamily] = field(default_factory=list) + """`primary_forms` (as GateFamilies) + `additional_forms`""" + all_forms: List[cirq.GateFamily] = field(init=False) + + def __post_init__(self): + self.all_forms = [ + gof if isinstance(gof, cirq.GateFamily) else cirq.GateFamily(gof) + for gof in self.primary_forms + ] + self.additional_forms + + +"""Valid gates for a GridDevice.""" +_GATES: List[_GateRepresentations] = [ + _GateRepresentations( + gate_spec_name='syc', + primary_forms=[_SYC_FSIM_GATE_FAMILY], + additional_forms=[cirq.GateFamily(ops.SYC)], + ), + _GateRepresentations( + gate_spec_name='sqrt_iswap', + primary_forms=[_SQRT_ISWAP_FSIM_GATE_FAMILY], + additional_forms=[cirq.GateFamily(cirq.SQRT_ISWAP)], + ), + _GateRepresentations( + gate_spec_name='sqrt_iswap_inv', + primary_forms=[_SQRT_ISWAP_INV_FSIM_GATE_FAMILY], + additional_forms=[cirq.GateFamily(cirq.SQRT_ISWAP_INV)], + ), + _GateRepresentations( + gate_spec_name='cz', + primary_forms=[_CZ_FSIM_GATE_FAMILY], + additional_forms=[cirq.GateFamily(cirq.CZ)], + ), + _GateRepresentations( + gate_spec_name='phased_xz', + primary_forms=[cirq.PhasedXZGate, cirq.XPowGate, cirq.YPowGate, cirq.PhasedXPowGate], + ), + _GateRepresentations( + gate_spec_name='virtual_zpow', + primary_forms=[cirq.GateFamily(cirq.ZPowGate, tags_to_ignore=[ops.PhysicalZTag()])], + ), + _GateRepresentations( + gate_spec_name='physical_zpow', + primary_forms=[cirq.GateFamily(cirq.ZPowGate, tags_to_accept=[ops.PhysicalZTag()])], + ), + _GateRepresentations( + gate_spec_name='coupler_pulse', primary_forms=[experimental_ops.CouplerPulse] + ), + _GateRepresentations(gate_spec_name='meas', primary_forms=[cirq.MeasurementGate]), + _GateRepresentations(gate_spec_name='wait', primary_forms=[cirq.WaitGate]), +] + + +def _in_or_equals(g: GateOrFamily, gate_family: cirq.GateFamily): + return g == gate_family or g in gate_family + + def _validate_device_specification(proto: v2.device_pb2.DeviceSpecification) -> None: """Raises a ValueError if the `DeviceSpecification` proto is invalid.""" @@ -93,7 +156,6 @@ def _validate_device_specification(proto: v2.device_pb2.DeviceSpecification) -> qubit_set.add(q_name) for target_set in proto.valid_targets: - # Check for unknown qubits in targets. for target in target_set.targets: for target_id in target.ids: @@ -120,41 +182,63 @@ def _validate_device_specification(proto: v2.device_pb2.DeviceSpecification) -> raise ValueError("Invalid DeviceSpecification: target_ordering cannot be ASYMMETRIC.") -def _build_gateset_and_gate_durations( +def _serialize_gateset_and_gate_durations( + out: v2.device_pb2.DeviceSpecification, + gateset: cirq.Gateset, + gate_durations: Mapping['cirq.GateFamily', 'cirq.Duration'], +) -> v2.device_pb2.DeviceSpecification: + """Serializes the given gateset and gate durations to DeviceSpecification.""" + + gate_specs = {} + for gate_family in gateset.gates: + gate_spec = v2.device_pb2.GateSpecification() + gate_rep = next( + (gr for gr in _GATES for gf in gr.all_forms if _in_or_equals(gate_family, gf)), None + ) + if gate_rep is None: + raise ValueError(f'Unrecognized gate: {gate_family}.') + + # Set gate + getattr(gate_spec, gate_rep.gate_spec_name).SetInParent() + + # Set gate duration + is_duration_set = False + for gf in gate_rep.all_forms: + if gf not in gate_durations: + continue + + gate_duration_picos = int(gate_durations[gf].total_picos()) + if is_duration_set and gate_duration_picos != gate_spec.gate_duration_picos: + raise ValueError( + f'Multiple gate families in the following list exist in the gate duration dict, and they are expected to have the same duration value: {gate_rep.all_forms}' + ) + is_duration_set = True + gate_spec.gate_duration_picos = gate_duration_picos + + # GateSpecification dedup. Multiple gates or GateFamilies in the gateset could map to the + # same GateSpecification. + gate_name = gate_spec.WhichOneof('gate') + gate_specs[gate_name] = gate_spec + + # Sort by gate name to keep valid_gates stable. + out.valid_gates.extend([v for _, v in sorted(gate_specs.items(), key=lambda item: item[0])]) + + return out + + +def _deserialize_gateset_and_gate_durations( proto: v2.device_pb2.DeviceSpecification, ) -> Tuple[cirq.Gateset, Mapping[cirq.GateFamily, cirq.Duration]]: - """Extracts gate set and gate duration information from the given DeviceSpecification proto.""" + """Deserializes gateset and gate duration from DeviceSpecification.""" - gates_list: List[Union[Type[cirq.Gate], cirq.Gate, cirq.GateFamily]] = [] + gates_list: List[GateOrFamily] = [] gate_durations: Dict[cirq.GateFamily, cirq.Duration] = {} - # TODO(#5050) Describe how to add/remove gates. - for gate_spec in proto.valid_gates: gate_name = gate_spec.WhichOneof('gate') - cirq_gates: List[Union[Type[cirq.Gate], cirq.Gate, cirq.GateFamily]] = [] - - if gate_name == 'syc': - cirq_gates = [_SYC_FSIM_GATE_FAMILY] - elif gate_name == 'sqrt_iswap': - cirq_gates = [_SQRT_ISWAP_FSIM_GATE_FAMILY] - elif gate_name == 'sqrt_iswap_inv': - cirq_gates = [_SQRT_ISWAP_INV_FSIM_GATE_FAMILY] - elif gate_name == 'cz': - cirq_gates = [_CZ_FSIM_GATE_FAMILY] - elif gate_name == 'phased_xz': - cirq_gates = [cirq.PhasedXZGate, cirq.XPowGate, cirq.YPowGate, cirq.PhasedXPowGate] - elif gate_name == 'virtual_zpow': - cirq_gates = [_VIRTUAL_ZPOW_GATE_FAMILY] - elif gate_name == 'physical_zpow': - cirq_gates = [_PHYSICAL_ZPOW_GATE_FAMILY] - elif gate_name == 'coupler_pulse': - cirq_gates = [experimental_ops.CouplerPulse] - elif gate_name == 'meas': - cirq_gates = [cirq.MeasurementGate] - elif gate_name == 'wait': - cirq_gates = [cirq.WaitGate] - else: + + gate_rep = next((gr for gr in _GATES if gr.gate_spec_name == gate_name), None) + if gate_rep is None: # coverage: ignore warnings.warn( f"The DeviceSpecification contains the gate '{gate_name}' which is not recognized" @@ -163,11 +247,8 @@ def _build_gateset_and_gate_durations( ) continue - gates_list.extend(cirq_gates) - - # TODO(#5050) Allow different gate representations of the same gate to be looked up in - # gate_durations. - for g in cirq_gates: + gates_list.extend(gate_rep.primary_forms) + for g in gate_rep.primary_forms: if not isinstance(g, cirq.GateFamily): g = cirq.GateFamily(g) gate_durations[g] = cirq.Duration(picos=gate_spec.gate_duration_picos) @@ -316,20 +397,21 @@ class GridDevice(cirq.Device): https://github.com/quantumlib/Cirq/blob/master/cirq-google/cirq_google/api/v2/device.proto ) is the main specification for device information surfaced by the Quantum Computing Service. - Thus, this class is should be instantiated using a `DeviceSpecification` proto via the + Thus, this class should typically be instantiated using a `DeviceSpecification` proto via the `from_proto()` class method. """ def __init__(self, metadata: cirq.GridDeviceMetadata): """Creates a GridDevice object. - This constructor typically should not be used directly. Use `from_proto()` instead. + This constructor should not be used directly outside the class implementation. Use + `from_proto()` or `from_device_information()` instead. """ self._metadata = metadata @classmethod def from_proto(cls, proto: v2.device_pb2.DeviceSpecification) -> 'GridDevice': - """Create a `GridDevice` from a `DeviceSpecification` proto. + """Deserializes the `DeviceSpecification` to a `GridDevice`. Args: proto: The `DeviceSpecification` proto describing a Google device. @@ -357,7 +439,7 @@ def from_proto(cls, proto: v2.device_pb2.DeviceSpecification) -> 'GridDevice': if len(target.ids) == 2 and ts.target_ordering == v2.device_pb2.TargetSet.SYMMETRIC ] - gateset, gate_durations = _build_gateset_and_gate_durations(proto) + gateset, gate_durations = _deserialize_gateset_and_gate_durations(proto) try: metadata = cirq.GridDeviceMetadata( @@ -373,6 +455,99 @@ def from_proto(cls, proto: v2.device_pb2.DeviceSpecification) -> 'GridDevice': return GridDevice(metadata) + def to_proto( + self, out: Optional[v2.device_pb2.DeviceSpecification] = None + ) -> v2.device_pb2.DeviceSpecification: + """Serializes the GridDevice to a DeviceSpecification. + + Args: + out: Optional DeviceSpecification to be populated. Fields are populated in-place. + + Returns: + The populated DeviceSpecification if out is specified, or the newly created + DeviceSpecification. + """ + qubits = self._metadata.qubit_set + pairs = [tuple(sorted(pair)) for pair in self._metadata.qubit_pairs] + gateset = self._metadata.gateset + gate_durations = self._metadata.gate_durations + + if out is None: + out = v2.device_pb2.DeviceSpecification() + + # If fields are already filled (i.e. as part of the old DeviceSpecification format), leave them + # as is. Fields populated in the new format do not conflict with how they were populated in the + # old format. + # TODO(#5050) remove empty checks below once deprecated fields in DeviceSpecification are + # removed. + + if len(out.valid_qubits) == 0: + known_devices.populate_qubits_in_device_proto(qubits, out) + if len(out.valid_targets) == 0: + known_devices.populate_qubit_pairs_in_device_proto(pairs, out) + _serialize_gateset_and_gate_durations( + out, gateset, {} if gate_durations is None else gate_durations + ) + _validate_device_specification(out) + + return out + + @classmethod + def from_device_information( + cls, + *, + qubit_pairs: Collection[Tuple[cirq.GridQubit, cirq.GridQubit]], + gateset: cirq.Gateset, + gate_durations: Optional[Mapping['cirq.GateFamily', 'cirq.Duration']] = None, + ) -> 'GridDevice': + """Constructs a GridDevice using the device information provided. + + This is a convenience method for constructing a GridDevice given partial gateset and + gate_duration information: for every distinct gate, only one representation needs to be in + gateset and gate_duration. The remaining representations will be automatically generated. + + For example, if the input gateset contains only `cirq.PhasedXZGate`, and the input + gate_durations is `{cirq.GateFamily(cirq.PhasedXZGate): cirq.Duration(picos=3)}`, + `GridDevice.metadata.gateset` will be + + ``` + cirq.Gateset(cirq.PhasedXZGate, cirq.XPowGate, cirq.YPowGate, cirq.PhasedXPowGate) + ``` + + and `GridDevice.metadata.gate_durations` will be + + ``` + { + cirq.GateFamily(cirq.PhasedXZGate): cirq.Duration(picos=3), + cirq.GateFamily(cirq.XPowGate): cirq.Duration(picos=3), + cirq.GateFamily(cirq.YPowGate): cirq.Duration(picos=3), + cirq.GateFamily(cirq.PhasedXPowGate): cirq.Duration(picos=3), + } + ``` + + This method reduces the complexity of constructing `GridDevice` on server side by requiring + only the bare essential device information. + + Args: + qubit_pairs: Collection of bidirectional qubit couplings available on the device. + gateset: The gate set supported by the device. + gate_durations: Optional mapping from gates supported by the device to their timing + estimates. Not every gate is required to have an associated duration. + out: If set, device information will be serialized into this DeviceSpecification. + + Raises: + ValueError: If a pair contains two identical qubits. + ValueError: If `gateset` contains invalid GridDevice gates. + ValueError: If `gate_durations` contains keys which are not in `gateset`. + ValueError: If multiple gate families in gate_durations can + represent a particular gate, but they have different durations. + """ + metadata = cirq.GridDeviceMetadata( + qubit_pairs=qubit_pairs, gateset=gateset, gate_durations=gate_durations + ) + incomplete_device = GridDevice(metadata) + return GridDevice.from_proto(incomplete_device.to_proto()) + @property def metadata(self) -> cirq.GridDeviceMetadata: """Get metadata information for the device.""" @@ -456,104 +631,3 @@ def _from_json_dict_(cls, metadata, **kwargs): def _value_equality_values_(self): return self._metadata - - -def _set_gate_in_gate_spec( - gate_spec: v2.device_pb2.GateSpecification, gate_family: cirq.GateFamily -) -> None: - if gate_family == _SYC_GATE_FAMILY or gate_family == _SYC_FSIM_GATE_FAMILY: - gate_spec.syc.SetInParent() - elif gate_family == _SQRT_ISWAP_GATE_FAMILY or gate_family == _SQRT_ISWAP_FSIM_GATE_FAMILY: - gate_spec.sqrt_iswap.SetInParent() - elif ( - gate_family == _SQRT_ISWAP_INV_GATE_FAMILY - or gate_family == _SQRT_ISWAP_INV_FSIM_GATE_FAMILY - ): - gate_spec.sqrt_iswap_inv.SetInParent() - elif gate_family == _CZ_GATE_FAMILY or gate_family == _CZ_FSIM_GATE_FAMILY: - gate_spec.cz.SetInParent() - elif gate_family == _PHASED_XZ_GATE_FAMILY: - gate_spec.phased_xz.SetInParent() - elif gate_family == _VIRTUAL_ZPOW_GATE_FAMILY: - gate_spec.virtual_zpow.SetInParent() - elif gate_family == _PHYSICAL_ZPOW_GATE_FAMILY: - gate_spec.physical_zpow.SetInParent() - elif gate_family == _COUPLER_PULSE_GATE_FAMILY: - gate_spec.coupler_pulse.SetInParent() - elif gate_family == _MEASUREMENT_GATE_FAMILY: - gate_spec.meas.SetInParent() - elif gate_family == _WAIT_GATE_FAMILY: - gate_spec.wait.SetInParent() - else: - raise ValueError(f'Unrecognized gate {gate_family}.') - - -def _create_device_specification_proto( - *, - qubits: Collection[cirq.GridQubit], - pairs: Collection[Tuple[cirq.GridQubit, cirq.GridQubit]], - gateset: cirq.Gateset, - gate_durations: Optional[Mapping['cirq.GateFamily', 'cirq.Duration']] = None, - out: Optional[v2.device_pb2.DeviceSpecification] = None, -) -> v2.device_pb2.DeviceSpecification: - """Serializes the given device information into a DeviceSpecification proto. - - EXPERIMENTAL: DeviceSpecification serialization API may change. - - This function does not serialize a `GridDevice`. Instead, it only takes a subset of device - information sufficient to populate the `DeviceSpecification` proto. This reduces the complexity - of constructing `DeviceSpecification` and `GridDevice` on server side by requiring only the bare - essential device information. - - Args: - qubits: Collection of qubits available on the device. - pairs: Collection of bidirectional qubit couplings available on the device. - gateset: The gate set supported by the device. - gate_durations: Optional mapping from gates supported by the device to their timing - estimates. Not every gate is required to have an associated duration. - out: If set, device information will be serialized into this DeviceSpecification. - - Raises: - ValueError: If a qubit in `pairs` is not part of `qubits`. - ValueError: If a pair contains two identical qubits. - ValueError: If `gate_durations` contains keys which are not in `gateset`. - ValueError: If `gateset` contains a gate which is not recognized by DeviceSpecification. - """ - - if gate_durations is not None: - extra_gate_families = (gate_durations.keys() | gateset.gates) - gateset.gates - if extra_gate_families: - raise ValueError( - 'Gate durations contain keys which are not part of the gateset:' - f' {extra_gate_families}' - ) - - if out is None: - out = v2.device_pb2.DeviceSpecification() - - # If fields are already filled (i.e. as part of the old DeviceSpecification format), leave them - # as is. Fields populated in the new format do not conflict with how they were populated in the - # old format. - # TODO(#5050) remove empty checks below once deprecated fields in DeviceSpecification are - # removed. - - if len(out.valid_qubits) == 0: - known_devices.populate_qubits_in_device_proto(qubits, out) - - if len(out.valid_targets) == 0: - known_devices.populate_qubit_pairs_in_device_proto(pairs, out) - - gate_specs = [] - for gate_family in gateset.gates: - gate_spec = v2.device_pb2.GateSpecification() - _set_gate_in_gate_spec(gate_spec, gate_family) - if gate_durations is not None and gate_family in gate_durations: - gate_spec.gate_duration_picos = int(gate_durations[gate_family].total_picos()) - gate_specs.append(gate_spec) - - # Sort by gate name to keep valid_gates stable. - out.valid_gates.extend(sorted(gate_specs, key=lambda s: s.WhichOneof('gate'))) - - _validate_device_specification(out) - - return out diff --git a/cirq-google/cirq_google/devices/grid_device_test.py b/cirq-google/cirq_google/devices/grid_device_test.py index 029c6088edc..00b6f3f82d6 100644 --- a/cirq-google/cirq_google/devices/grid_device_test.py +++ b/cirq-google/cirq_google/devices/grid_device_test.py @@ -17,7 +17,6 @@ import unittest.mock as mock import pytest -from google.protobuf import text_format import cirq import cirq_google @@ -440,8 +439,8 @@ def test_grid_device_repr_pretty(cycle, func): printer.text.assert_called_once_with(func(device)) -def test_to_proto(): - device_info, expected_spec = _create_device_spec_with_horizontal_couplings() +def test_device_from_device_information_equals_device_from_proto(): + device_info, spec = _create_device_spec_with_horizontal_couplings() # The set of gates in gate_durations are consistent with what's generated in # _create_device_spec_with_horizontal_couplings() @@ -465,69 +464,56 @@ def test_to_proto(): cirq.GateFamily(cirq.ops.wait_gate.WaitGate): base_duration * 9, } - spec = grid_device._create_device_specification_proto( - qubits=device_info.grid_qubits, - pairs=device_info.qubit_pairs, + device_from_information = cirq_google.GridDevice.from_device_information( + qubit_pairs=device_info.qubit_pairs, gateset=cirq.Gateset(*gate_durations.keys()), gate_durations=gate_durations, ) - assert text_format.MessageToString(spec) == text_format.MessageToString(expected_spec) + assert device_from_information == cirq_google.GridDevice.from_proto(spec) @pytest.mark.parametrize( - 'error_match, qubits, qubit_pairs, gateset, gate_durations', + 'error_match, qubit_pairs, gateset, gate_durations', [ ( - 'Gate durations contain keys which are not part of the gateset', - [cirq.GridQubit(0, 0)], - [], - cirq.Gateset(cirq.CZ), - {cirq.GateFamily(cirq.SQRT_ISWAP): 1_000}, + 'Self loop encountered in qubit', + [(cirq.GridQubit(0, 0), cirq.GridQubit(0, 0))], + cirq.Gateset(), + None, ), - ('not in the GridQubit form', [cirq.NamedQubit('q0_0')], [], cirq.Gateset(), None), ( - 'valid_targets contain .* which is not in valid_qubits', - [cirq.GridQubit(0, 0)], + 'Unrecognized gate', [(cirq.GridQubit(0, 0), cirq.GridQubit(0, 1))], - cirq.Gateset(), + cirq.Gateset(cirq.H), None, ), ( - 'has a target which contains repeated qubits', - [cirq.GridQubit(0, 0)], - [(cirq.GridQubit(0, 0), cirq.GridQubit(0, 0))], - cirq.Gateset(), - None, + 'Some gate_durations keys are not found in gateset', + [(cirq.GridQubit(0, 0), cirq.GridQubit(0, 1))], + cirq.Gateset(cirq.CZ), + {cirq.GateFamily(cirq.SQRT_ISWAP): cirq.Duration(picos=1_000)}, + ), + ( + 'Multiple gate families .* expected to have the same duration value', + [(cirq.GridQubit(0, 0), cirq.GridQubit(0, 1))], + cirq.Gateset(cirq.PhasedXZGate, cirq.XPowGate), + { + cirq.GateFamily(cirq.PhasedXZGate): cirq.Duration(picos=1_000), + cirq.GateFamily(cirq.XPowGate): cirq.Duration(picos=2_000), + }, ), - ('Unrecognized gate', [cirq.GridQubit(0, 0)], [], cirq.Gateset(cirq.H), None), ], ) -def test_to_proto_invalid_input(error_match, qubits, qubit_pairs, gateset, gate_durations): +def test_from_device_information_invalid_input(error_match, qubit_pairs, gateset, gate_durations): with pytest.raises(ValueError, match=error_match): - grid_device._create_device_specification_proto( - qubits=qubits, pairs=qubit_pairs, gateset=gateset, gate_durations=gate_durations + grid_device.GridDevice.from_device_information( + qubit_pairs=qubit_pairs, gateset=gateset, gate_durations=gate_durations ) -def test_to_proto_empty(): - spec = grid_device._create_device_specification_proto( - # Qubits are always expected to be set - qubits=[cirq.GridQubit(0, i) for i in range(5)], - pairs=[], - gateset=cirq.Gateset(), - gate_durations=None, - ) - device = cirq_google.GridDevice.from_proto(spec) - - assert len(device.metadata.qubit_set) == 5 - assert len(device.metadata.qubit_pairs) == 0 - assert device.metadata.gateset == cirq.Gateset() - assert device.metadata.gate_durations is None - - -def test_to_proto_fsim_gate_family(): - """Verifies that FSimGateFamilies are serialized correctly.""" +def test_from_device_information_fsim_gate_family(): + """Verifies that FSimGateFamilies are recognized correctly.""" gateset = cirq.Gateset( cirq_google.FSimGateFamily(gates_to_accept=[cirq_google.SYC]), @@ -536,11 +522,55 @@ def test_to_proto_fsim_gate_family(): cirq_google.FSimGateFamily(gates_to_accept=[cirq.CZ]), ) - spec = grid_device._create_device_specification_proto( - qubits=[cirq.GridQubit(0, 0)], pairs=(), gateset=gateset + device = grid_device.GridDevice.from_device_information( + qubit_pairs=[(cirq.GridQubit(0, 0), cirq.GridQubit(0, 1))], gateset=gateset ) - assert any(gate_spec.HasField('syc') for gate_spec in spec.valid_gates) - assert any(gate_spec.HasField('sqrt_iswap') for gate_spec in spec.valid_gates) - assert any(gate_spec.HasField('sqrt_iswap_inv') for gate_spec in spec.valid_gates) - assert any(gate_spec.HasField('cz') for gate_spec in spec.valid_gates) + assert gateset.gates.issubset(device.metadata.gateset.gates) + + +def test_from_device_information_empty(): + device = grid_device.GridDevice.from_device_information( + qubit_pairs=[], gateset=cirq.Gateset(), gate_durations=None + ) + + assert len(device.metadata.qubit_set) == 0 + assert len(device.metadata.qubit_pairs) == 0 + assert device.metadata.gateset == cirq.Gateset() + assert device.metadata.gate_durations is None + + +def test_to_proto(): + device_info, expected_spec = _create_device_spec_with_horizontal_couplings() + + # The set of gates in gate_durations are consistent with what's generated in + # _create_device_spec_with_horizontal_couplings() + base_duration = cirq.Duration(picos=1_000) + gate_durations = { + cirq.GateFamily(cirq_google.SYC): base_duration * 0, + cirq.GateFamily(cirq.SQRT_ISWAP): base_duration * 1, + cirq.GateFamily(cirq.SQRT_ISWAP_INV): base_duration * 2, + cirq.GateFamily(cirq.CZ): base_duration * 3, + cirq.GateFamily(cirq.ops.phased_x_z_gate.PhasedXZGate): base_duration * 4, + cirq.GateFamily( + cirq.ops.common_gates.ZPowGate, tags_to_ignore=[cirq_google.PhysicalZTag()] + ): base_duration + * 5, + cirq.GateFamily( + cirq.ops.common_gates.ZPowGate, tags_to_accept=[cirq_google.PhysicalZTag()] + ): base_duration + * 6, + cirq.GateFamily(cirq_google.experimental.ops.coupler_pulse.CouplerPulse): base_duration * 7, + cirq.GateFamily(cirq.ops.measurement_gate.MeasurementGate): base_duration * 8, + cirq.GateFamily(cirq.ops.wait_gate.WaitGate): base_duration * 9, + } + + spec = cirq_google.GridDevice.from_device_information( + qubit_pairs=device_info.qubit_pairs, + gateset=cirq.Gateset(*gate_durations.keys()), + gate_durations=gate_durations, + ).to_proto() + + assert cirq_google.GridDevice.from_proto(spec) == cirq_google.GridDevice.from_proto( + expected_spec + ) diff --git a/cirq-google/cirq_google/devices/known_devices.py b/cirq-google/cirq_google/devices/known_devices.py index c4f1b58b875..0de9c626b3b 100644 --- a/cirq-google/cirq_google/devices/known_devices.py +++ b/cirq-google/cirq_google/devices/known_devices.py @@ -57,7 +57,6 @@ def _create_grid_device_from_diagram( ascii_grid: str, gateset: cirq.Gateset, gate_durations: Optional[Dict['cirq.GateFamily', 'cirq.Duration']] = None, - out: Optional[device_pb2.DeviceSpecification] = None, ) -> grid_device.GridDevice: """Parse ASCIIart device layout into a GridDevice instance. @@ -80,10 +79,9 @@ def _create_grid_device_from_diagram( if neighbor > qubit and neighbor in qubit_set: pairs.append((qubit, cast(cirq.GridQubit, neighbor))) - device_specification = grid_device._create_device_specification_proto( - qubits=qubits, pairs=pairs, gateset=gateset, gate_durations=gate_durations, out=out + return grid_device.GridDevice.from_device_information( + qubit_pairs=pairs, gateset=gateset, gate_durations=gate_durations ) - return grid_device.GridDevice.from_proto(device_specification) def populate_qubits_in_device_proto( From 34c3fe87c2966b529a3ba362878a34cd5a375948 Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Fri, 12 May 2023 21:41:21 +0000 Subject: [PATCH 2/9] Fixed type and lint errors --- cirq-google/cirq_google/devices/grid_device.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/cirq-google/cirq_google/devices/grid_device.py b/cirq-google/cirq_google/devices/grid_device.py index b9831b8f74b..009f45c34d0 100644 --- a/cirq-google/cirq_google/devices/grid_device.py +++ b/cirq-google/cirq_google/devices/grid_device.py @@ -134,7 +134,7 @@ def __post_init__(self): def _in_or_equals(g: GateOrFamily, gate_family: cirq.GateFamily): - return g == gate_family or g in gate_family + return (isinstance(g, cirq.GateFamily) and g == gate_family) or g in gate_family def _validate_device_specification(proto: v2.device_pb2.DeviceSpecification) -> None: @@ -210,7 +210,8 @@ def _serialize_gateset_and_gate_durations( gate_duration_picos = int(gate_durations[gf].total_picos()) if is_duration_set and gate_duration_picos != gate_spec.gate_duration_picos: raise ValueError( - f'Multiple gate families in the following list exist in the gate duration dict, and they are expected to have the same duration value: {gate_rep.all_forms}' + 'Multiple gate families in the following list exist in the gate duration dict,' + f' and they are expected to have the same duration value: {gate_rep.all_forms}' ) is_duration_set = True gate_spec.gate_duration_picos = gate_duration_picos @@ -468,16 +469,18 @@ def to_proto( DeviceSpecification. """ qubits = self._metadata.qubit_set - pairs = [tuple(sorted(pair)) for pair in self._metadata.qubit_pairs] + pairs: List[Tuple[cirq.GridQubit, cirq.GridQubit]] = [ + tuple(sorted(pair)) for pair in self._metadata.qubit_pairs + ] gateset = self._metadata.gateset gate_durations = self._metadata.gate_durations if out is None: out = v2.device_pb2.DeviceSpecification() - # If fields are already filled (i.e. as part of the old DeviceSpecification format), leave them - # as is. Fields populated in the new format do not conflict with how they were populated in the - # old format. + # If fields are already filled (i.e. as part of the old DeviceSpecification format), leave + # them as is. Fields populated in the new format do not conflict with how they were + # populated in the old format. # TODO(#5050) remove empty checks below once deprecated fields in DeviceSpecification are # removed. From 4837660ed7d9aa2d415bd632c381abd5926d1e85 Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Sat, 13 May 2023 00:03:47 +0000 Subject: [PATCH 3/9] Fixed more mypy errors --- cirq-google/cirq_google/devices/grid_device.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/cirq-google/cirq_google/devices/grid_device.py b/cirq-google/cirq_google/devices/grid_device.py index 009f45c34d0..6c69083d26f 100644 --- a/cirq-google/cirq_google/devices/grid_device.py +++ b/cirq-google/cirq_google/devices/grid_device.py @@ -134,7 +134,12 @@ def __post_init__(self): def _in_or_equals(g: GateOrFamily, gate_family: cirq.GateFamily): - return (isinstance(g, cirq.GateFamily) and g == gate_family) or g in gate_family + if isinstance(g, cirq.GateFamily): + return g == gate_family + elif isinstance(g, cirq.Gate): + return g in gate_family + else: # Gate type + return cirq.GateFamily(g) == gate_family def _validate_device_specification(proto: v2.device_pb2.DeviceSpecification) -> None: @@ -469,8 +474,10 @@ def to_proto( DeviceSpecification. """ qubits = self._metadata.qubit_set - pairs: List[Tuple[cirq.GridQubit, cirq.GridQubit]] = [ - tuple(sorted(pair)) for pair in self._metadata.qubit_pairs + unordered_pairs = [tuple(pair_set) for pair_set in self._metadata.qubit_pairs] + pairs = [ + (pair[0], pair[1]) if pair[0] <= pair[1] else (pair[1], pair[0]) + for pair in unordered_pairs ] gateset = self._metadata.gateset gate_durations = self._metadata.gate_durations From d5561f2aa95eaa16cd112810eb187ed1d97c2150 Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Wed, 17 May 2023 01:02:40 +0000 Subject: [PATCH 4/9] Addressed nits and small fixes --- .../cirq_google/devices/grid_device.py | 62 ++++++++++--------- 1 file changed, 33 insertions(+), 29 deletions(-) diff --git a/cirq-google/cirq_google/devices/grid_device.py b/cirq-google/cirq_google/devices/grid_device.py index 6c69083d26f..a8a17f1d099 100644 --- a/cirq-google/cirq_google/devices/grid_device.py +++ b/cirq-google/cirq_google/devices/grid_device.py @@ -73,15 +73,18 @@ @dataclass class _GateRepresentations: - """Contains equivalent representations of a gate in both DeviceSpecification and GridDevice.""" + """Contains equivalent representations of a gate in both DeviceSpecification and GridDevice. + + Attributes: + gate_spec_name: The name of gate type in `GateSpecification`. + primary forms: Gate representations to be included in generated gatesets and gate durations. + additional_forms: GateFamilies which match all other valid gate representations. + all_forms: `primary_forms` (as GateFamilies) + `additional_forms` + """ - """The name of gate type in `GateSpecification`.""" gate_spec_name: str - """Gate representations to be included in generated gatesets and gate durations.""" primary_forms: List[GateOrFamily] - """GateFamilies which match all other valid gate representations.""" additional_forms: List[cirq.GateFamily] = field(default_factory=list) - """`primary_forms` (as GateFamilies) + `additional_forms`""" all_forms: List[cirq.GateFamily] = field(init=False) def __post_init__(self): @@ -190,11 +193,11 @@ def _validate_device_specification(proto: v2.device_pb2.DeviceSpecification) -> def _serialize_gateset_and_gate_durations( out: v2.device_pb2.DeviceSpecification, gateset: cirq.Gateset, - gate_durations: Mapping['cirq.GateFamily', 'cirq.Duration'], + gate_durations: Mapping[cirq.GateFamily, cirq.Duration], ) -> v2.device_pb2.DeviceSpecification: """Serializes the given gateset and gate durations to DeviceSpecification.""" - gate_specs = {} + gate_specs: Dict[str, v2.device_pb2.GateSpecification] = {} for gate_family in gateset.gates: gate_spec = v2.device_pb2.GateSpecification() gate_rep = next( @@ -202,32 +205,31 @@ def _serialize_gateset_and_gate_durations( ) if gate_rep is None: raise ValueError(f'Unrecognized gate: {gate_family}.') + gate_name = gate_rep.gate_spec_name # Set gate - getattr(gate_spec, gate_rep.gate_spec_name).SetInParent() + getattr(gate_spec, gate_name).SetInParent() # Set gate duration - is_duration_set = False - for gf in gate_rep.all_forms: - if gf not in gate_durations: - continue - - gate_duration_picos = int(gate_durations[gf].total_picos()) - if is_duration_set and gate_duration_picos != gate_spec.gate_duration_picos: - raise ValueError( - 'Multiple gate families in the following list exist in the gate duration dict,' - f' and they are expected to have the same duration value: {gate_rep.all_forms}' - ) - is_duration_set = True - gate_spec.gate_duration_picos = gate_duration_picos + gate_durations_picos = { + int(gate_durations[gf].total_picos()) + for gf in gate_rep.all_forms + if gf in gate_durations + } + if len(gate_durations_picos) > 1: + raise ValueError( + 'Multiple gate families in the following list exist in the gate duration dict,' + f' and they are expected to have the same duration value: {gate_rep.all_forms}' + ) + elif len(gate_durations_picos) == 1: + gate_spec.gate_duration_picos = gate_durations_picos.pop() # GateSpecification dedup. Multiple gates or GateFamilies in the gateset could map to the # same GateSpecification. - gate_name = gate_spec.WhichOneof('gate') gate_specs[gate_name] = gate_spec # Sort by gate name to keep valid_gates stable. - out.valid_gates.extend([v for _, v in sorted(gate_specs.items(), key=lambda item: item[0])]) + out.valid_gates.extend(v for _, v in sorted(gate_specs.items())) return out @@ -475,10 +477,7 @@ def to_proto( """ qubits = self._metadata.qubit_set unordered_pairs = [tuple(pair_set) for pair_set in self._metadata.qubit_pairs] - pairs = [ - (pair[0], pair[1]) if pair[0] <= pair[1] else (pair[1], pair[0]) - for pair in unordered_pairs - ] + pairs = sorted((q0, q1) if q0 <= q1 else (q1, q0) for q0, q1 in unordered_pairs) gateset = self._metadata.gateset gate_durations = self._metadata.gate_durations @@ -491,9 +490,9 @@ def to_proto( # TODO(#5050) remove empty checks below once deprecated fields in DeviceSpecification are # removed. - if len(out.valid_qubits) == 0: + if not out.valid_qubits: known_devices.populate_qubits_in_device_proto(qubits, out) - if len(out.valid_targets) == 0: + if not out.valid_targets: known_devices.populate_qubit_pairs_in_device_proto(pairs, out) _serialize_gateset_and_gate_durations( out, gateset, {} if gate_durations is None else gate_durations @@ -556,6 +555,11 @@ def from_device_information( qubit_pairs=qubit_pairs, gateset=gateset, gate_durations=gate_durations ) incomplete_device = GridDevice(metadata) + # incomplete_device may have incomplete gateset and gate durations information, as described + # in the docstring. + # To generate the full gateset and gate durations, we rely on the device deserialization + # logic by first serializing then deserializing the fake device, to ensure that the + # resulting device is consistent with one that is deserialized from DeviceSpecification. return GridDevice.from_proto(incomplete_device.to_proto()) @property From 6ca5cd6d6a20b8e2856f5b85d949d965270f284b Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Wed, 17 May 2023 19:09:57 +0000 Subject: [PATCH 5/9] Fix coverage checks; made from_device_information experimental --- .../cirq_google/devices/grid_device.py | 4 +++- .../cirq_google/devices/grid_device_test.py | 23 +++++++++++++++---- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/cirq-google/cirq_google/devices/grid_device.py b/cirq-google/cirq_google/devices/grid_device.py index a8a17f1d099..6212431067b 100644 --- a/cirq-google/cirq_google/devices/grid_device.py +++ b/cirq-google/cirq_google/devices/grid_device.py @@ -502,7 +502,7 @@ def to_proto( return out @classmethod - def from_device_information( + def _from_device_information( cls, *, qubit_pairs: Collection[Tuple[cirq.GridQubit, cirq.GridQubit]], @@ -511,6 +511,8 @@ def from_device_information( ) -> 'GridDevice': """Constructs a GridDevice using the device information provided. + EXPERIMENTAL: this method may have changes which are not backward compatible in the future. + This is a convenience method for constructing a GridDevice given partial gateset and gate_duration information: for every distinct gate, only one representation needs to be in gateset and gate_duration. The remaining representations will be automatically generated. diff --git a/cirq-google/cirq_google/devices/grid_device_test.py b/cirq-google/cirq_google/devices/grid_device_test.py index 00b6f3f82d6..ebb900ac813 100644 --- a/cirq-google/cirq_google/devices/grid_device_test.py +++ b/cirq-google/cirq_google/devices/grid_device_test.py @@ -442,8 +442,25 @@ def test_grid_device_repr_pretty(cycle, func): def test_device_from_device_information_equals_device_from_proto(): device_info, spec = _create_device_spec_with_horizontal_couplings() - # The set of gates in gate_durations are consistent with what's generated in + # The set of gates in gateset and gate durations are consistent with what's generated in # _create_device_spec_with_horizontal_couplings() + gateset = cirq.Gateset( + cirq_google.SYC, + cirq.SQRT_ISWAP, + cirq.SQRT_ISWAP_INV, + cirq.CZ, + cirq.ops.phased_x_z_gate.PhasedXZGate, + cirq.GateFamily( + cirq.ops.common_gates.ZPowGate, tags_to_ignore=[cirq_google.PhysicalZTag()] + ), + cirq.GateFamily( + cirq.ops.common_gates.ZPowGate, tags_to_accept=[cirq_google.PhysicalZTag()] + ), + cirq_google.experimental.ops.coupler_pulse.CouplerPulse, + cirq.ops.measurement_gate.MeasurementGate, + cirq.ops.wait_gate.WaitGate, + ) + base_duration = cirq.Duration(picos=1_000) gate_durations = { cirq.GateFamily(cirq_google.SYC): base_duration * 0, @@ -465,9 +482,7 @@ def test_device_from_device_information_equals_device_from_proto(): } device_from_information = cirq_google.GridDevice.from_device_information( - qubit_pairs=device_info.qubit_pairs, - gateset=cirq.Gateset(*gate_durations.keys()), - gate_durations=gate_durations, + qubit_pairs=device_info.qubit_pairs, gateset=gateset, gate_durations=gate_durations ) assert device_from_information == cirq_google.GridDevice.from_proto(spec) From b84e74c075fd55b16999ed0131da0024d16aafc7 Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Wed, 17 May 2023 19:24:14 +0000 Subject: [PATCH 6/9] Renamed _GateRepresentation fields to clarify purpose --- .../cirq_google/devices/grid_device.py | 77 +++++++++++-------- 1 file changed, 45 insertions(+), 32 deletions(-) diff --git a/cirq-google/cirq_google/devices/grid_device.py b/cirq-google/cirq_google/devices/grid_device.py index 6212431067b..19b5eba17a4 100644 --- a/cirq-google/cirq_google/devices/grid_device.py +++ b/cirq-google/cirq_google/devices/grid_device.py @@ -30,7 +30,7 @@ ) import re import warnings -from dataclasses import dataclass, field +from dataclasses import dataclass import cirq from cirq_google import ops @@ -77,62 +77,74 @@ class _GateRepresentations: Attributes: gate_spec_name: The name of gate type in `GateSpecification`. - primary forms: Gate representations to be included in generated gatesets and gate durations. - additional_forms: GateFamilies which match all other valid gate representations. - all_forms: `primary_forms` (as GateFamilies) + `additional_forms` + deserialized_forms: Gate representations to be included when the corresponding + `GateSpecification` gate type is deserialized into gatesets and gate durations. + serializable_forms: GateFamilies used to check whether a given gate can be serialized to the + gate type in this _GateRepresentation. """ gate_spec_name: str - primary_forms: List[GateOrFamily] - additional_forms: List[cirq.GateFamily] = field(default_factory=list) - all_forms: List[cirq.GateFamily] = field(init=False) - - def __post_init__(self): - self.all_forms = [ - gof if isinstance(gof, cirq.GateFamily) else cirq.GateFamily(gof) - for gof in self.primary_forms - ] + self.additional_forms + deserialized_forms: List[GateOrFamily] + serializable_forms: List[cirq.GateFamily] """Valid gates for a GridDevice.""" _GATES: List[_GateRepresentations] = [ _GateRepresentations( gate_spec_name='syc', - primary_forms=[_SYC_FSIM_GATE_FAMILY], - additional_forms=[cirq.GateFamily(ops.SYC)], + deserialized_forms=[_SYC_FSIM_GATE_FAMILY], + serializable_forms=[_SYC_FSIM_GATE_FAMILY, cirq.GateFamily(ops.SYC)], ), _GateRepresentations( gate_spec_name='sqrt_iswap', - primary_forms=[_SQRT_ISWAP_FSIM_GATE_FAMILY], - additional_forms=[cirq.GateFamily(cirq.SQRT_ISWAP)], + deserialized_forms=[_SQRT_ISWAP_FSIM_GATE_FAMILY], + serializable_forms=[_SQRT_ISWAP_FSIM_GATE_FAMILY, cirq.GateFamily(cirq.SQRT_ISWAP)], ), _GateRepresentations( gate_spec_name='sqrt_iswap_inv', - primary_forms=[_SQRT_ISWAP_INV_FSIM_GATE_FAMILY], - additional_forms=[cirq.GateFamily(cirq.SQRT_ISWAP_INV)], + deserialized_forms=[_SQRT_ISWAP_INV_FSIM_GATE_FAMILY], + serializable_forms=[_SQRT_ISWAP_INV_FSIM_GATE_FAMILY, cirq.GateFamily(cirq.SQRT_ISWAP_INV)], ), _GateRepresentations( gate_spec_name='cz', - primary_forms=[_CZ_FSIM_GATE_FAMILY], - additional_forms=[cirq.GateFamily(cirq.CZ)], + deserialized_forms=[_CZ_FSIM_GATE_FAMILY], + serializable_forms=[_CZ_FSIM_GATE_FAMILY, cirq.GateFamily(cirq.CZ)], ), _GateRepresentations( gate_spec_name='phased_xz', - primary_forms=[cirq.PhasedXZGate, cirq.XPowGate, cirq.YPowGate, cirq.PhasedXPowGate], + deserialized_forms=[cirq.PhasedXZGate, cirq.XPowGate, cirq.YPowGate, cirq.PhasedXPowGate], + serializable_forms=[ + cirq.GateFamily(cirq.PhasedXZGate), + cirq.GateFamily(cirq.XPowGate), + cirq.GateFamily(cirq.YPowGate), + cirq.GateFamily(cirq.PhasedXPowGate), + ], ), _GateRepresentations( gate_spec_name='virtual_zpow', - primary_forms=[cirq.GateFamily(cirq.ZPowGate, tags_to_ignore=[ops.PhysicalZTag()])], + deserialized_forms=[cirq.GateFamily(cirq.ZPowGate, tags_to_ignore=[ops.PhysicalZTag()])], + serializable_forms=[cirq.GateFamily(cirq.ZPowGate, tags_to_ignore=[ops.PhysicalZTag()])], ), _GateRepresentations( gate_spec_name='physical_zpow', - primary_forms=[cirq.GateFamily(cirq.ZPowGate, tags_to_accept=[ops.PhysicalZTag()])], + deserialized_forms=[cirq.GateFamily(cirq.ZPowGate, tags_to_accept=[ops.PhysicalZTag()])], + serializable_forms=[cirq.GateFamily(cirq.ZPowGate, tags_to_accept=[ops.PhysicalZTag()])], + ), + _GateRepresentations( + gate_spec_name='coupler_pulse', + deserialized_forms=[experimental_ops.CouplerPulse], + serializable_forms=[cirq.GateFamily(experimental_ops.CouplerPulse)], + ), + _GateRepresentations( + gate_spec_name='meas', + deserialized_forms=[cirq.MeasurementGate], + serializable_forms=[cirq.GateFamily(cirq.MeasurementGate)], ), _GateRepresentations( - gate_spec_name='coupler_pulse', primary_forms=[experimental_ops.CouplerPulse] + gate_spec_name='wait', + deserialized_forms=[cirq.WaitGate], + serializable_forms=[cirq.GateFamily(cirq.WaitGate)], ), - _GateRepresentations(gate_spec_name='meas', primary_forms=[cirq.MeasurementGate]), - _GateRepresentations(gate_spec_name='wait', primary_forms=[cirq.WaitGate]), ] @@ -201,7 +213,8 @@ def _serialize_gateset_and_gate_durations( for gate_family in gateset.gates: gate_spec = v2.device_pb2.GateSpecification() gate_rep = next( - (gr for gr in _GATES for gf in gr.all_forms if _in_or_equals(gate_family, gf)), None + (gr for gr in _GATES for gf in gr.serializable_forms if _in_or_equals(gate_family, gf)), + None, ) if gate_rep is None: raise ValueError(f'Unrecognized gate: {gate_family}.') @@ -213,13 +226,13 @@ def _serialize_gateset_and_gate_durations( # Set gate duration gate_durations_picos = { int(gate_durations[gf].total_picos()) - for gf in gate_rep.all_forms + for gf in gate_rep.serializable_forms if gf in gate_durations } if len(gate_durations_picos) > 1: raise ValueError( 'Multiple gate families in the following list exist in the gate duration dict,' - f' and they are expected to have the same duration value: {gate_rep.all_forms}' + f' and they are expected to have the same duration value: {gate_rep.serializable_forms}' ) elif len(gate_durations_picos) == 1: gate_spec.gate_duration_picos = gate_durations_picos.pop() @@ -255,8 +268,8 @@ def _deserialize_gateset_and_gate_durations( ) continue - gates_list.extend(gate_rep.primary_forms) - for g in gate_rep.primary_forms: + gates_list.extend(gate_rep.deserialized_forms) + for g in gate_rep.deserialized_forms: if not isinstance(g, cirq.GateFamily): g = cirq.GateFamily(g) gate_durations[g] = cirq.Duration(picos=gate_spec.gate_duration_picos) From adb6b61232b13ba86fdb3e72fc1f6262df84e664 Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Wed, 17 May 2023 21:39:09 +0000 Subject: [PATCH 7/9] Updated _from_device_information() callsites --- cirq-google/cirq_google/devices/grid_device.py | 2 +- cirq-google/cirq_google/devices/grid_device_test.py | 10 +++++----- cirq-google/cirq_google/devices/known_devices.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/cirq-google/cirq_google/devices/grid_device.py b/cirq-google/cirq_google/devices/grid_device.py index 19b5eba17a4..cb58008b675 100644 --- a/cirq-google/cirq_google/devices/grid_device.py +++ b/cirq-google/cirq_google/devices/grid_device.py @@ -426,7 +426,7 @@ def __init__(self, metadata: cirq.GridDeviceMetadata): """Creates a GridDevice object. This constructor should not be used directly outside the class implementation. Use - `from_proto()` or `from_device_information()` instead. + `from_proto()` instead. """ self._metadata = metadata diff --git a/cirq-google/cirq_google/devices/grid_device_test.py b/cirq-google/cirq_google/devices/grid_device_test.py index ebb900ac813..a7307ef2dbc 100644 --- a/cirq-google/cirq_google/devices/grid_device_test.py +++ b/cirq-google/cirq_google/devices/grid_device_test.py @@ -481,7 +481,7 @@ def test_device_from_device_information_equals_device_from_proto(): cirq.GateFamily(cirq.ops.wait_gate.WaitGate): base_duration * 9, } - device_from_information = cirq_google.GridDevice.from_device_information( + device_from_information = cirq_google.GridDevice._from_device_information( qubit_pairs=device_info.qubit_pairs, gateset=gateset, gate_durations=gate_durations ) @@ -522,7 +522,7 @@ def test_device_from_device_information_equals_device_from_proto(): ) def test_from_device_information_invalid_input(error_match, qubit_pairs, gateset, gate_durations): with pytest.raises(ValueError, match=error_match): - grid_device.GridDevice.from_device_information( + grid_device.GridDevice._from_device_information( qubit_pairs=qubit_pairs, gateset=gateset, gate_durations=gate_durations ) @@ -537,7 +537,7 @@ def test_from_device_information_fsim_gate_family(): cirq_google.FSimGateFamily(gates_to_accept=[cirq.CZ]), ) - device = grid_device.GridDevice.from_device_information( + device = grid_device.GridDevice._from_device_information( qubit_pairs=[(cirq.GridQubit(0, 0), cirq.GridQubit(0, 1))], gateset=gateset ) @@ -545,7 +545,7 @@ def test_from_device_information_fsim_gate_family(): def test_from_device_information_empty(): - device = grid_device.GridDevice.from_device_information( + device = grid_device.GridDevice._from_device_information( qubit_pairs=[], gateset=cirq.Gateset(), gate_durations=None ) @@ -580,7 +580,7 @@ def test_to_proto(): cirq.GateFamily(cirq.ops.wait_gate.WaitGate): base_duration * 9, } - spec = cirq_google.GridDevice.from_device_information( + spec = cirq_google.GridDevice._from_device_information( qubit_pairs=device_info.qubit_pairs, gateset=cirq.Gateset(*gate_durations.keys()), gate_durations=gate_durations, diff --git a/cirq-google/cirq_google/devices/known_devices.py b/cirq-google/cirq_google/devices/known_devices.py index 0de9c626b3b..9ffe4d01336 100644 --- a/cirq-google/cirq_google/devices/known_devices.py +++ b/cirq-google/cirq_google/devices/known_devices.py @@ -79,7 +79,7 @@ def _create_grid_device_from_diagram( if neighbor > qubit and neighbor in qubit_set: pairs.append((qubit, cast(cirq.GridQubit, neighbor))) - return grid_device.GridDevice.from_device_information( + return grid_device.GridDevice._from_device_information( qubit_pairs=pairs, gateset=gateset, gate_durations=gate_durations ) From ca35e7dd8839414e372322bd666c4797915e00ec Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Wed, 17 May 2023 21:54:04 +0000 Subject: [PATCH 8/9] Fix lint error --- cirq-google/cirq_google/devices/grid_device.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cirq-google/cirq_google/devices/grid_device.py b/cirq-google/cirq_google/devices/grid_device.py index cb58008b675..ec2f75030a5 100644 --- a/cirq-google/cirq_google/devices/grid_device.py +++ b/cirq-google/cirq_google/devices/grid_device.py @@ -231,8 +231,8 @@ def _serialize_gateset_and_gate_durations( } if len(gate_durations_picos) > 1: raise ValueError( - 'Multiple gate families in the following list exist in the gate duration dict,' - f' and they are expected to have the same duration value: {gate_rep.serializable_forms}' + 'Multiple gate families in the following list exist in the gate duration dict, and ' + f'they are expected to have the same duration value: {gate_rep.serializable_forms}' ) elif len(gate_durations_picos) == 1: gate_spec.gate_duration_picos = gate_durations_picos.pop() From 6a46d7c1aec8b1761ece73fb52fcdb4b3132e5d4 Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Wed, 17 May 2023 23:43:54 +0000 Subject: [PATCH 9/9] Remove unnecessary _in_or_equals() check. Gateset.gates is a collection of GateFamilies, so most branches in the check don't apply. --- cirq-google/cirq_google/devices/grid_device.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/cirq-google/cirq_google/devices/grid_device.py b/cirq-google/cirq_google/devices/grid_device.py index ec2f75030a5..73f65484f33 100644 --- a/cirq-google/cirq_google/devices/grid_device.py +++ b/cirq-google/cirq_google/devices/grid_device.py @@ -148,15 +148,6 @@ class _GateRepresentations: ] -def _in_or_equals(g: GateOrFamily, gate_family: cirq.GateFamily): - if isinstance(g, cirq.GateFamily): - return g == gate_family - elif isinstance(g, cirq.Gate): - return g in gate_family - else: # Gate type - return cirq.GateFamily(g) == gate_family - - def _validate_device_specification(proto: v2.device_pb2.DeviceSpecification) -> None: """Raises a ValueError if the `DeviceSpecification` proto is invalid.""" @@ -213,8 +204,7 @@ def _serialize_gateset_and_gate_durations( for gate_family in gateset.gates: gate_spec = v2.device_pb2.GateSpecification() gate_rep = next( - (gr for gr in _GATES for gf in gr.serializable_forms if _in_or_equals(gate_family, gf)), - None, + (gr for gr in _GATES for gf in gr.serializable_forms if gf == gate_family), None ) if gate_rep is None: raise ValueError(f'Unrecognized gate: {gate_family}.')