diff --git a/cirq-core/cirq/devices/grid_device_metadata.py b/cirq-core/cirq/devices/grid_device_metadata.py index bb18951f0c2..f0740526cad 100644 --- a/cirq-core/cirq/devices/grid_device_metadata.py +++ b/cirq-core/cirq/devices/grid_device_metadata.py @@ -13,7 +13,7 @@ # limitations under the License. """Metadata subtype for 2D Homogenous devices.""" -from typing import TYPE_CHECKING, Optional, FrozenSet, Iterable, Tuple, Dict +from typing import TYPE_CHECKING, cast, Optional, FrozenSet, Iterable, Tuple, Dict import networkx as nx from cirq import value @@ -29,21 +29,21 @@ class GridDeviceMetadata(device.DeviceMetadata): def __init__( self, - qubit_pairs: Iterable[Tuple['cirq.Qid', 'cirq.Qid']], + qubit_pairs: Iterable[Tuple['cirq.GridQubit', 'cirq.GridQubit']], gateset: 'cirq.Gateset', gate_durations: Optional[Dict['cirq.GateFamily', 'cirq.Duration']] = None, - all_qubits: Optional[Iterable['cirq.Qid']] = None, + all_qubits: Optional[Iterable['cirq.GridQubit']] = None, compilation_target_gatesets: Iterable['cirq.CompilationTargetGateset'] = (), ): """Create a GridDeviceMetadata object. - Create a GridDevice which has a well defined set of couplable + Create a grid device which has a well defined set of couplable qubit pairs that have the same two qubit gates available in both coupling directions. Gate times (if provided) are expected to be uniform across all qubits on the device. Args: - qubit_pairs: Iterable of pairs of `cirq.Qid`s representing + qubit_pairs: Iterable of pairs of `cirq.GridQubit`s representing bi-directional couplings. gateset: `cirq.Gateset` indicating gates supported everywhere on the device. @@ -114,7 +114,16 @@ def __init__( self._gate_durations = gate_durations @property - def qubit_pairs(self) -> FrozenSet[FrozenSet['cirq.Qid']]: + def qubit_set(self) -> FrozenSet['cirq.GridQubit']: + """Returns the set of grid qubits on the device. + + Returns: + Frozenset of qubits on device. + """ + return cast(FrozenSet['cirq.GridQubit'], super().qubit_set) + + @property + def qubit_pairs(self) -> FrozenSet[FrozenSet['cirq.GridQubit']]: """Returns the set of all couple-able qubits on the device. Each element in the outer frozenset is a 2-element frozenset representing a bidirectional @@ -123,7 +132,7 @@ def qubit_pairs(self) -> FrozenSet[FrozenSet['cirq.Qid']]: return self._qubit_pairs @property - def isolated_qubits(self) -> FrozenSet['cirq.Qid']: + def isolated_qubits(self) -> FrozenSet['cirq.GridQubit']: """Returns the set of all isolated qubits on the device (if appliable).""" return self._isolated_qubits diff --git a/cirq-google/cirq_google/devices/serializable_device.py b/cirq-google/cirq_google/devices/serializable_device.py index 0e420d0cf44..e30ef162dc7 100644 --- a/cirq-google/cirq_google/devices/serializable_device.py +++ b/cirq-google/cirq_google/devices/serializable_device.py @@ -101,15 +101,22 @@ def __init__( self.qubits = qubits self.gate_definitions = gate_definitions has_subcircuit_support: bool = cirq.FrozenCircuit in gate_definitions + self._metadata = cirq.GridDeviceMetadata( - qubit_pairs=[ - (pair[0], pair[1]) - for gate_defs in gate_definitions.values() - for gate_def in gate_defs - if gate_def.number_of_qubits == 2 - for pair in gate_def.target_set - if len(pair) == 2 and pair[0] < pair[1] - ], + qubit_pairs=cast( + List[Tuple[cirq.GridQubit, cirq.GridQubit]], + [ + (pair[0], pair[1]) + for gate_defs in gate_definitions.values() + for gate_def in gate_defs + if gate_def.number_of_qubits == 2 + for pair in gate_def.target_set + if len(pair) == 2 + and pair[0] < pair[1] + and isinstance(pair[0], cirq.GridQubit) + and isinstance(pair[1], cirq.GridQubit) + ], + ), gateset=cirq.Gateset( *(g for g in gate_definitions.keys() if issubclass(g, cirq.Gate)), cirq.GlobalPhaseGate, diff --git a/cirq-google/cirq_google/line/placement/anneal.py b/cirq-google/cirq_google/line/placement/anneal.py index 75b5ceefa6c..78591b5f5d8 100644 --- a/cirq-google/cirq_google/line/placement/anneal.py +++ b/cirq-google/cirq_google/line/placement/anneal.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import cast, Callable, List, Optional, Tuple, Set, Any, TYPE_CHECKING +from typing import Callable, List, Optional, Tuple, Set, Any, TYPE_CHECKING import numpy as np @@ -37,7 +37,7 @@ def __init__(self, device: 'cirq_google.GridDevice', seed=None) -> None: device: Chip description. seed: Optional seed value for random number generator. """ - self._c = cast(Set[cirq.GridQubit], device.metadata.qubit_set) + self._c = device.metadata.qubit_set self._c_adj = chip_as_adjacency_list(device) self._rand = np.random.RandomState(seed) diff --git a/cirq-google/cirq_google/line/placement/chip.py b/cirq-google/cirq_google/line/placement/chip.py index 8bf18d911d6..90d8755ff21 100644 --- a/cirq-google/cirq_google/line/placement/chip.py +++ b/cirq-google/cirq_google/line/placement/chip.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import cast, Dict, List, Set, Tuple, TYPE_CHECKING +from typing import Dict, List, Tuple, TYPE_CHECKING import cirq @@ -86,7 +86,7 @@ def chip_as_adjacency_list( Map from nodes to list of qubits which represent all the neighbours of given qubit. """ - c_set = cast(Set[cirq.GridQubit], device.metadata.qubit_set) + c_set = device.metadata.qubit_set c_adj: Dict[cirq.GridQubit, List[cirq.GridQubit]] = {} for n in c_set: c_adj[n] = [] diff --git a/cirq-google/cirq_google/line/placement/greedy.py b/cirq-google/cirq_google/line/placement/greedy.py index cbd96df5c2c..af248cab73d 100644 --- a/cirq-google/cirq_google/line/placement/greedy.py +++ b/cirq-google/cirq_google/line/placement/greedy.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import cast, Dict, List, Optional, Set, TYPE_CHECKING +from typing import Dict, List, Optional, Set, TYPE_CHECKING import abc import collections @@ -304,7 +304,7 @@ def place_line(self, device: 'cirq_google.GridDevice', length: int) -> GridQubit if not device.metadata.qubit_set: return GridQubitLineTuple() - start: GridQubit = cast(GridQubit, min(device.metadata.qubit_set)) + start: GridQubit = min(device.metadata.qubit_set) sequences: List[LineSequence] = [] greedy_search: Dict[str, List[GreedySequenceSearch]] = { 'minimal_connectivity': [_PickFewestNeighbors(device, start)],