Skip to content

Commit eddb2d9

Browse files
authored
Add caching to value_equality_values decorator for auto generated methods. (#6275)
* Add caching to value_equality_values decorator for auto generated methods. * Fix pylint and formatting errors * Address nits, fix bugs and make PauliSum unhashable
1 parent 0e80fa5 commit eddb2d9

File tree

3 files changed

+15
-4
lines changed

3 files changed

+15
-4
lines changed

cirq-core/cirq/ops/dense_pauli_string.py

+3
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,9 @@ def copy(
570570
def __str__(self) -> str:
571571
return super().__str__() + ' (mutable)'
572572

573+
def _value_equality_values_(self):
574+
return self.coefficient, tuple(PAULI_CHARS[p] for p in self.pauli_mask)
575+
573576
@classmethod
574577
def inline_gaussian_elimination(cls, rows: 'List[MutableDensePauliString]') -> None:
575578
if not rows:

cirq-core/cirq/ops/linear_combinations.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def _pauli_string_from_unit(unit: UnitPauliStringT, coefficient: Union[int, floa
357357
return PauliString(qubit_pauli_map=dict(unit), coefficient=coefficient)
358358

359359

360-
@value.value_equality(approximate=True)
360+
@value.value_equality(approximate=True, unhashable=True)
361361
class PauliSum:
362362
"""Represents operator defined by linear combination of PauliStrings.
363363

cirq-core/cirq/value/value_equality_attr.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from typing_extensions import Protocol
1919

20-
from cirq import protocols
20+
from cirq import protocols, _compat
2121

2222

2323
class _SupportsValueEquality(Protocol):
@@ -221,13 +221,21 @@ class return the existing class' type.
221221
)
222222
else:
223223
setattr(cls, '_value_equality_values_cls_', lambda self: cls)
224-
setattr(cls, '__hash__', None if unhashable else _value_equality_hash)
224+
cached_values_getter = values_getter if unhashable else _compat.cached_method(values_getter)
225+
setattr(cls, '_value_equality_values_', cached_values_getter)
226+
setattr(cls, '__hash__', None if unhashable else _compat.cached_method(_value_equality_hash))
225227
setattr(cls, '__eq__', _value_equality_eq)
226228
setattr(cls, '__ne__', _value_equality_ne)
227229

228230
if approximate:
229231
if not hasattr(cls, '_value_equality_approximate_values_'):
230-
setattr(cls, '_value_equality_approximate_values_', values_getter)
232+
setattr(cls, '_value_equality_approximate_values_', cached_values_getter)
233+
else:
234+
approx_values_getter = getattr(cls, '_value_equality_approximate_values_')
235+
cached_approx_values_getter = (
236+
approx_values_getter if unhashable else _compat.cached_method(approx_values_getter)
237+
)
238+
setattr(cls, '_value_equality_approximate_values_', cached_approx_values_getter)
231239
setattr(cls, '_approx_eq_', _value_equality_approx_eq)
232240

233241
return cls

0 commit comments

Comments
 (0)