Skip to content

Commit e3fbd98

Browse files
authored
Optimize qid comparisons (#6386)
Review: @NoureldinYosri
1 parent 2f3c1e2 commit e3fbd98

File tree

3 files changed

+98
-23
lines changed

3 files changed

+98
-23
lines changed

cirq-core/cirq/devices/grid_qubit.py

+34-8
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,15 @@ class _BaseGridQid(ops.Qid):
3333
_row: int
3434
_col: int
3535
_dimension: int
36+
_comp_key: Optional[Tuple[int, int]] = None
3637
_hash: Optional[int] = None
3738

3839
def __hash__(self) -> int:
3940
if self._hash is None:
4041
self._hash = hash((self._row, self._col, self._dimension))
4142
return self._hash
4243

43-
def __eq__(self, other):
44+
def __eq__(self, other) -> bool:
4445
# Explicitly implemented for performance (vs delegating to Qid).
4546
if isinstance(other, _BaseGridQid):
4647
return self is other or (
@@ -50,7 +51,7 @@ def __eq__(self, other):
5051
)
5152
return NotImplemented
5253

53-
def __ne__(self, other):
54+
def __ne__(self, other) -> bool:
5455
# Explicitly implemented for performance (vs delegating to Qid).
5556
if isinstance(other, _BaseGridQid):
5657
return self is not other and (
@@ -60,8 +61,38 @@ def __ne__(self, other):
6061
)
6162
return NotImplemented
6263

64+
def __lt__(self, other) -> bool:
65+
# Explicitly implemented for performance (vs delegating to Qid).
66+
if isinstance(other, _BaseGridQid):
67+
k0, k1 = self._comparison_key(), other._comparison_key()
68+
return k0 < k1 or (k0 == k1 and self._dimension < other._dimension)
69+
return super().__lt__(other)
70+
71+
def __le__(self, other) -> bool:
72+
# Explicitly implemented for performance (vs delegating to Qid).
73+
if isinstance(other, _BaseGridQid):
74+
k0, k1 = self._comparison_key(), other._comparison_key()
75+
return k0 < k1 or (k0 == k1 and self._dimension <= other._dimension)
76+
return super().__le__(other)
77+
78+
def __ge__(self, other) -> bool:
79+
# Explicitly implemented for performance (vs delegating to Qid).
80+
if isinstance(other, _BaseGridQid):
81+
k0, k1 = self._comparison_key(), other._comparison_key()
82+
return k0 > k1 or (k0 == k1 and self._dimension >= other._dimension)
83+
return super().__ge__(other)
84+
85+
def __gt__(self, other) -> bool:
86+
# Explicitly implemented for performance (vs delegating to Qid).
87+
if isinstance(other, _BaseGridQid):
88+
k0, k1 = self._comparison_key(), other._comparison_key()
89+
return k0 > k1 or (k0 == k1 and self._dimension > other._dimension)
90+
return super().__gt__(other)
91+
6392
def _comparison_key(self):
64-
return self._row, self._col
93+
if self._comp_key is None:
94+
self._comp_key = self._row, self._col
95+
return self._comp_key
6596

6697
@property
6798
def row(self) -> int:
@@ -359,11 +390,6 @@ def __getnewargs__(self):
359390
def _with_row_col(self, row: int, col: int) -> 'GridQubit':
360391
return GridQubit(row, col)
361392

362-
def _cmp_tuple(self):
363-
cls = GridQid if type(self) is GridQubit else type(self)
364-
# Must be same as Qid._cmp_tuple but with cls in place of type(self).
365-
return (cls.__name__, repr(cls), self._comparison_key(), self.dimension)
366-
367393
@staticmethod
368394
def square(diameter: int, top: int = 0, left: int = 0) -> List['GridQubit']:
369395
"""Returns a square of GridQubits.

cirq-core/cirq/devices/line_qubit.py

+34-8
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,52 @@ def __hash__(self) -> int:
3737
self._hash = hash((self._x, self._dimension))
3838
return self._hash
3939

40-
def __eq__(self, other):
40+
def __eq__(self, other) -> bool:
4141
# Explicitly implemented for performance (vs delegating to Qid).
4242
if isinstance(other, _BaseLineQid):
4343
return self is other or (self._x == other._x and self._dimension == other._dimension)
4444
return NotImplemented
4545

46-
def __ne__(self, other):
46+
def __ne__(self, other) -> bool:
4747
# Explicitly implemented for performance (vs delegating to Qid).
4848
if isinstance(other, _BaseLineQid):
4949
return self is not other and (
5050
self._x != other._x or self._dimension != other._dimension
5151
)
5252
return NotImplemented
5353

54+
def __lt__(self, other) -> bool:
55+
# Explicitly implemented for performance (vs delegating to Qid).
56+
if isinstance(other, _BaseLineQid):
57+
return self._x < other._x or (
58+
self._x == other._x and self._dimension < other._dimension
59+
)
60+
return super().__lt__(other)
61+
62+
def __le__(self, other) -> bool:
63+
# Explicitly implemented for performance (vs delegating to Qid).
64+
if isinstance(other, _BaseLineQid):
65+
return self._x < other._x or (
66+
self._x == other._x and self._dimension <= other._dimension
67+
)
68+
return super().__le__(other)
69+
70+
def __ge__(self, other) -> bool:
71+
# Explicitly implemented for performance (vs delegating to Qid).
72+
if isinstance(other, _BaseLineQid):
73+
return self._x > other._x or (
74+
self._x == other._x and self._dimension >= other._dimension
75+
)
76+
return super().__ge__(other)
77+
78+
def __gt__(self, other) -> bool:
79+
# Explicitly implemented for performance (vs delegating to Qid).
80+
if isinstance(other, _BaseLineQid):
81+
return self._x > other._x or (
82+
self._x == other._x and self._dimension > other._dimension
83+
)
84+
return super().__gt__(other)
85+
5486
def _comparison_key(self):
5587
return self._x
5688

@@ -279,12 +311,6 @@ def __getnewargs__(self):
279311
def _with_x(self, x: int) -> 'LineQubit':
280312
return LineQubit(x)
281313

282-
def _cmp_tuple(self):
283-
cls = LineQid if type(self) is LineQubit else type(self)
284-
# Must be the same as Qid._cmp_tuple but with cls in place of
285-
# type(self).
286-
return (cls.__name__, repr(cls), self._comparison_key(), self._dimension)
287-
288314
@staticmethod
289315
def range(*range_args) -> List['LineQubit']:
290316
"""Returns a range of line qubits.

cirq-core/cirq/ops/named_qubit.py

+30-7
Original file line numberDiff line numberDiff line change
@@ -37,22 +37,50 @@ def __hash__(self) -> int:
3737
self._hash = hash((self._name, self._dimension))
3838
return self._hash
3939

40-
def __eq__(self, other):
40+
def __eq__(self, other) -> bool:
4141
# Explicitly implemented for performance (vs delegating to Qid).
4242
if isinstance(other, _BaseNamedQid):
4343
return self is other or (
4444
self._name == other._name and self._dimension == other._dimension
4545
)
4646
return NotImplemented
4747

48-
def __ne__(self, other):
48+
def __ne__(self, other) -> bool:
4949
# Explicitly implemented for performance (vs delegating to Qid).
5050
if isinstance(other, _BaseNamedQid):
5151
return self is not other and (
5252
self._name != other._name or self._dimension != other._dimension
5353
)
5454
return NotImplemented
5555

56+
def __lt__(self, other) -> bool:
57+
# Explicitly implemented for performance (vs delegating to Qid).
58+
if isinstance(other, _BaseNamedQid):
59+
k0, k1 = self._comparison_key(), other._comparison_key()
60+
return k0 < k1 or (k0 == k1 and self._dimension < other._dimension)
61+
return super().__lt__(other)
62+
63+
def __le__(self, other) -> bool:
64+
# Explicitly implemented for performance (vs delegating to Qid).
65+
if isinstance(other, _BaseNamedQid):
66+
k0, k1 = self._comparison_key(), other._comparison_key()
67+
return k0 < k1 or (k0 == k1 and self._dimension <= other._dimension)
68+
return super().__le__(other)
69+
70+
def __ge__(self, other) -> bool:
71+
# Explicitly implemented for performance (vs delegating to Qid).
72+
if isinstance(other, _BaseNamedQid):
73+
k0, k1 = self._comparison_key(), other._comparison_key()
74+
return k0 > k1 or (k0 == k1 and self._dimension >= other._dimension)
75+
return super().__ge__(other)
76+
77+
def __gt__(self, other) -> bool:
78+
# Explicitly implemented for performance (vs delegating to Qid).
79+
if isinstance(other, _BaseNamedQid):
80+
k0, k1 = self._comparison_key(), other._comparison_key()
81+
return k0 > k1 or (k0 == k1 and self._dimension > other._dimension)
82+
return super().__gt__(other)
83+
5684
def _comparison_key(self):
5785
if self._comp_key is None:
5886
self._comp_key = _pad_digits(self._name)
@@ -174,11 +202,6 @@ def __getnewargs__(self):
174202
"""Returns a tuple of args to pass to __new__ when unpickling."""
175203
return (self._name,)
176204

177-
def _cmp_tuple(self):
178-
cls = NamedQid if type(self) is NamedQubit else type(self)
179-
# Must be same as Qid._cmp_tuple but with cls in place of type(self).
180-
return (cls.__name__, repr(cls), self._comparison_key(), self._dimension)
181-
182205
def __str__(self) -> str:
183206
return self._name
184207

0 commit comments

Comments
 (0)