Skip to content

Commit 26dbabc

Browse files
authored
Speed up hashing for GridQubit, LineQubit, and NamedQubit (#6350)
Review: @dstrain115
1 parent 3c81961 commit 26dbabc

File tree

4 files changed

+172
-99
lines changed

4 files changed

+172
-99
lines changed

cirq-core/cirq/devices/grid_qubit.py

+49-38
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import numpy as np
2121

22-
from cirq import _compat, ops, protocols
22+
from cirq import ops, protocols
2323

2424
if TYPE_CHECKING:
2525
import cirq
@@ -29,9 +29,43 @@
2929
class _BaseGridQid(ops.Qid):
3030
"""The Base class for `GridQid` and `GridQubit`."""
3131

32-
def __init__(self, row: int, col: int):
33-
self._row = row
34-
self._col = col
32+
_row: int
33+
_col: int
34+
_dimension: int
35+
_hash: Optional[int] = None
36+
37+
def __getstate__(self):
38+
# Don't save hash when pickling; see #3777.
39+
state = self.__dict__
40+
if "_hash" in state:
41+
state = state.copy()
42+
del state["_hash"]
43+
return state
44+
45+
def __hash__(self) -> int:
46+
if self._hash is None:
47+
self._hash = hash((self._row, self._col, self._dimension))
48+
return self._hash
49+
50+
def __eq__(self, other):
51+
# Explicitly implemented for performance (vs delegating to Qid).
52+
if isinstance(other, _BaseGridQid):
53+
return (
54+
self._row == other._row
55+
and self._col == other._col
56+
and self._dimension == other._dimension
57+
)
58+
return NotImplemented
59+
60+
def __ne__(self, other):
61+
# Explicitly implemented for performance (vs delegating to Qid).
62+
if isinstance(other, _BaseGridQid):
63+
return (
64+
self._row != other._row
65+
or self._col != other._col
66+
or self._dimension != other._dimension
67+
)
68+
return NotImplemented
3569

3670
def _comparison_key(self):
3771
return self._row, self._col
@@ -44,6 +78,10 @@ def row(self) -> int:
4478
def col(self) -> int:
4579
return self._col
4680

81+
@property
82+
def dimension(self) -> int:
83+
return self._dimension
84+
4785
def with_dimension(self, dimension: int) -> 'GridQid':
4886
return GridQid(self._row, self._col, dimension=dimension)
4987

@@ -149,13 +187,10 @@ def __init__(self, row: int, col: int, *, dimension: int) -> None:
149187
dimension: The dimension of the qid's Hilbert space, i.e.
150188
the number of quantum levels.
151189
"""
152-
super().__init__(row, col)
153-
self._dimension = dimension
154190
self.validate_dimension(dimension)
155-
156-
@property
157-
def dimension(self):
158-
return self._dimension
191+
self._row = row
192+
self._col = col
193+
self._dimension = dimension
159194

160195
def _with_row_col(self, row: int, col: int) -> 'GridQid':
161196
return GridQid(row, col, dimension=self.dimension)
@@ -288,35 +323,11 @@ class GridQubit(_BaseGridQid):
288323
cirq.GridQubit(5, 4)
289324
"""
290325

291-
def __getstate__(self):
292-
# Don't save hash when pickling; see #3777.
293-
state = self.__dict__
294-
hash_key = _compat._method_cache_name(self.__hash__)
295-
if hash_key in state:
296-
state = state.copy()
297-
del state[hash_key]
298-
return state
299-
300-
@_compat.cached_method
301-
def __hash__(self) -> int:
302-
# Explicitly cached for performance (vs delegating to Qid).
303-
return super().__hash__()
326+
_dimension = 2
304327

305-
def __eq__(self, other):
306-
# Explicitly implemented for performance (vs delegating to Qid).
307-
if isinstance(other, GridQubit):
308-
return self._row == other._row and self._col == other._col
309-
return NotImplemented
310-
311-
def __ne__(self, other):
312-
# Explicitly implemented for performance (vs delegating to Qid).
313-
if isinstance(other, GridQubit):
314-
return self._row != other._row or self._col != other._col
315-
return NotImplemented
316-
317-
@property
318-
def dimension(self) -> int:
319-
return 2
328+
def __init__(self, row: int, col: int) -> None:
329+
self._row = row
330+
self._col = col
320331

321332
def _with_row_col(self, row: int, col: int):
322333
return GridQubit(row, col)

cirq-core/cirq/devices/grid_qubit_test.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import pytest
2020

2121
import cirq
22-
from cirq import _compat
2322

2423

2524
def test_init():
@@ -45,8 +44,7 @@ def test_pickled_hash():
4544
q = cirq.GridQubit(3, 4)
4645
q_bad = cirq.GridQubit(3, 4)
4746
_ = hash(q_bad) # compute hash to ensure it is cached.
48-
hash_key = _compat._method_cache_name(cirq.GridQubit.__hash__)
49-
setattr(q_bad, hash_key, getattr(q_bad, hash_key) + 1)
47+
q_bad._hash = q_bad._hash + 1
5048
assert q_bad == q
5149
assert hash(q_bad) != hash(q)
5250
data = pickle.dumps(q_bad)

cirq-core/cirq/devices/line_qubit.py

+67-40
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,48 @@
2727
class _BaseLineQid(ops.Qid):
2828
"""The base class for `LineQid` and `LineQubit`."""
2929

30-
def __init__(self, x: int) -> None:
31-
"""Initializes a line qubit at the given x coordinate."""
32-
self._x = x
30+
_x: int
31+
_dimension: int
32+
_hash: Optional[int] = None
33+
34+
def __getstate__(self):
35+
# Don't save hash when pickling; see #3777.
36+
state = self.__dict__
37+
if "_hash" in state:
38+
state = state.copy()
39+
del state["_hash"]
40+
return state
41+
42+
def __hash__(self) -> int:
43+
if self._hash is None:
44+
self._hash = hash((self._x, self._dimension))
45+
return self._hash
46+
47+
def __eq__(self, other):
48+
# Explicitly implemented for performance (vs delegating to Qid).
49+
if isinstance(other, _BaseLineQid):
50+
return self._x == other._x and self._dimension == other._dimension
51+
return NotImplemented
52+
53+
def __ne__(self, other):
54+
# Explicitly implemented for performance (vs delegating to Qid).
55+
if isinstance(other, _BaseLineQid):
56+
return self._x != other._x or self._dimension != other._dimension
57+
return NotImplemented
3358

3459
def _comparison_key(self):
35-
return self.x
60+
return self._x
3661

3762
@property
3863
def x(self) -> int:
3964
return self._x
4065

66+
@property
67+
def dimension(self) -> int:
68+
return self._dimension
69+
4170
def with_dimension(self, dimension: int) -> 'LineQid':
42-
return LineQid(self.x, dimension)
71+
return LineQid(self._x, dimension)
4372

4473
def is_adjacent(self, other: 'cirq.Qid') -> bool:
4574
"""Determines if two qubits are adjacent line qubits.
@@ -49,49 +78,45 @@ def is_adjacent(self, other: 'cirq.Qid') -> bool:
4978
5079
Returns: True iff other and self are adjacent.
5180
"""
52-
return isinstance(other, _BaseLineQid) and abs(self.x - other.x) == 1
81+
return isinstance(other, _BaseLineQid) and abs(self._x - other._x) == 1
5382

5483
def neighbors(self, qids: Optional[Iterable[ops.Qid]] = None) -> Set['_BaseLineQid']:
5584
"""Returns qubits that are potential neighbors to this LineQubit
5685
5786
Args:
5887
qids: optional Iterable of qubits to constrain neighbors to.
5988
"""
60-
neighbors = set()
61-
for q in [self - 1, self + 1]:
62-
if qids is None or q in qids:
63-
neighbors.add(q)
64-
return neighbors
89+
return {q for q in [self - 1, self + 1] if qids is None or q in qids}
6590

6691
@abc.abstractmethod
6792
def _with_x(self, x: int) -> Self:
6893
"""Returns a qubit with the same type but a different value of `x`."""
6994

7095
def __add__(self, other: Union[int, Self]) -> Self:
7196
if isinstance(other, _BaseLineQid):
72-
if self.dimension != other.dimension:
97+
if self._dimension != other._dimension:
7398
raise TypeError(
7499
"Can only add LineQids with identical dimension. "
75-
f"Got {self.dimension} and {other.dimension}"
100+
f"Got {self._dimension} and {other._dimension}"
76101
)
77-
return self._with_x(x=self.x + other.x)
102+
return self._with_x(x=self._x + other._x)
78103
if not isinstance(other, int):
79104
raise TypeError(f"Can only add ints and {type(self).__name__}. Instead was {other}")
80-
return self._with_x(self.x + other)
105+
return self._with_x(self._x + other)
81106

82107
def __sub__(self, other: Union[int, Self]) -> Self:
83108
if isinstance(other, _BaseLineQid):
84-
if self.dimension != other.dimension:
109+
if self._dimension != other._dimension:
85110
raise TypeError(
86111
"Can only subtract LineQids with identical dimension. "
87-
f"Got {self.dimension} and {other.dimension}"
112+
f"Got {self._dimension} and {other._dimension}"
88113
)
89-
return self._with_x(x=self.x - other.x)
114+
return self._with_x(x=self._x - other._x)
90115
if not isinstance(other, int):
91116
raise TypeError(
92117
f"Can only subtract ints and {type(self).__name__}. Instead was {other}"
93118
)
94-
return self._with_x(self.x - other)
119+
return self._with_x(self._x - other)
95120

96121
def __radd__(self, other: int) -> Self:
97122
return self + other
@@ -100,16 +125,16 @@ def __rsub__(self, other: int) -> Self:
100125
return -self + other
101126

102127
def __neg__(self) -> Self:
103-
return self._with_x(-self.x)
128+
return self._with_x(-self._x)
104129

105130
def __complex__(self) -> complex:
106-
return complex(self.x)
131+
return complex(self._x)
107132

108133
def __float__(self) -> float:
109-
return float(self.x)
134+
return float(self._x)
110135

111136
def __int__(self) -> int:
112-
return int(self.x)
137+
return int(self._x)
113138

114139

115140
class LineQid(_BaseLineQid):
@@ -137,16 +162,12 @@ def __init__(self, x: int, dimension: int) -> None:
137162
dimension: The dimension of the qid's Hilbert space, i.e.
138163
the number of quantum levels.
139164
"""
140-
super().__init__(x)
141-
self._dimension = dimension
142165
self.validate_dimension(dimension)
143-
144-
@property
145-
def dimension(self):
146-
return self._dimension
166+
self._x = x
167+
self._dimension = dimension
147168

148169
def _with_x(self, x: int) -> 'LineQid':
149-
return LineQid(x, dimension=self.dimension)
170+
return LineQid(x, dimension=self._dimension)
150171

151172
@staticmethod
152173
def range(*range_args, dimension: int) -> List['LineQid']:
@@ -192,15 +213,15 @@ def for_gate(val: Any, start: int = 0, step: int = 1) -> List['LineQid']:
192213
return LineQid.for_qid_shape(qid_shape(val), start=start, step=step)
193214

194215
def __repr__(self) -> str:
195-
return f"cirq.LineQid({self.x}, dimension={self.dimension})"
216+
return f"cirq.LineQid({self._x}, dimension={self._dimension})"
196217

197218
def __str__(self) -> str:
198-
return f"q({self.x}) (d={self.dimension})"
219+
return f"q({self._x}) (d={self._dimension})"
199220

200221
def _circuit_diagram_info_(
201222
self, args: 'cirq.CircuitDiagramInfoArgs'
202223
) -> 'cirq.CircuitDiagramInfo':
203-
return protocols.CircuitDiagramInfo(wire_symbols=(f"{self.x} (d={self.dimension})",))
224+
return protocols.CircuitDiagramInfo(wire_symbols=(f"{self._x} (d={self._dimension})",))
204225

205226
def _json_dict_(self) -> Dict[str, Any]:
206227
return protocols.obj_to_dict_helper(self, ['x', 'dimension'])
@@ -223,9 +244,15 @@ class LineQubit(_BaseLineQid):
223244
224245
"""
225246

226-
@property
227-
def dimension(self) -> int:
228-
return 2
247+
_dimension = 2
248+
249+
def __init__(self, x: int) -> None:
250+
"""Initializes a line qubit at the given x coordinate.
251+
252+
Args:
253+
x: The x coordinate.
254+
"""
255+
self._x = x
229256

230257
def _with_x(self, x: int) -> 'LineQubit':
231258
return LineQubit(x)
@@ -234,7 +261,7 @@ def _cmp_tuple(self):
234261
cls = LineQid if type(self) is LineQubit else type(self)
235262
# Must be the same as Qid._cmp_tuple but with cls in place of
236263
# type(self).
237-
return (cls.__name__, repr(cls), self._comparison_key(), self.dimension)
264+
return (cls.__name__, repr(cls), self._comparison_key(), self._dimension)
238265

239266
@staticmethod
240267
def range(*range_args) -> List['LineQubit']:
@@ -249,15 +276,15 @@ def range(*range_args) -> List['LineQubit']:
249276
return [LineQubit(i) for i in range(*range_args)]
250277

251278
def __repr__(self) -> str:
252-
return f"cirq.LineQubit({self.x})"
279+
return f"cirq.LineQubit({self._x})"
253280

254281
def __str__(self) -> str:
255-
return f"q({self.x})"
282+
return f"q({self._x})"
256283

257284
def _circuit_diagram_info_(
258285
self, args: 'cirq.CircuitDiagramInfoArgs'
259286
) -> 'cirq.CircuitDiagramInfo':
260-
return protocols.CircuitDiagramInfo(wire_symbols=(f"{self.x}",))
287+
return protocols.CircuitDiagramInfo(wire_symbols=(f"{self._x}",))
261288

262289
def _json_dict_(self) -> Dict[str, Any]:
263290
return protocols.obj_to_dict_helper(self, ['x'])

0 commit comments

Comments
 (0)