Skip to content

Commit 0482b47

Browse files
authored
Make _commutes_ consistent (#5217)
- Requires atol to be a named parameter. - Also changes atol to be uniformly float around the codebase. (not sure why it would be int, are people using an atol=1?) - Technically a breaking change, but it's unlikely people are using this widely as most commutes do not even use atol. Fixes: #3695
1 parent 15d2f2e commit 0482b47

File tree

10 files changed

+22
-18
lines changed

10 files changed

+22
-18
lines changed

Diff for: cirq-core/cirq/circuits/moment.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -574,9 +574,7 @@ def cleanup_key(key: Any) -> Any:
574574

575575
return diagram.render()
576576

577-
def _commutes_(
578-
self, other: Any, *, atol: Union[int, float] = 1e-8
579-
) -> Union[bool, NotImplementedType]:
577+
def _commutes_(self, other: Any, *, atol: float = 1e-8) -> Union[bool, NotImplementedType]:
580578
"""Determines whether Moment commutes with the Operation.
581579
582580
Args:

Diff for: cirq-core/cirq/contrib/acquaintance/permutation.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,7 @@ def __repr__(self) -> str:
164164
def _value_equality_values_(self) -> Any:
165165
return (self.swap_gate,)
166166

167-
def _commutes_(
168-
self, other: Any, atol: Union[int, float] = 1e-8
169-
) -> Union[bool, NotImplementedType]:
167+
def _commutes_(self, other: Any, *, atol: float = 1e-8) -> Union[bool, NotImplementedType]:
170168
if (
171169
isinstance(other, ops.Gate)
172170
and isinstance(other, ops.InterchangeableQubitsGate)

Diff for: cirq-core/cirq/ops/clifford_gate.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ def __pow__(self, exponent) -> 'SingleQubitCliffordGate':
383383

384384
return SingleQubitCliffordGate.from_clifford_tableau(self.clifford_tableau.inverse())
385385

386-
def _commutes_(self, other: Any, atol: float) -> Union[bool, NotImplementedType]:
386+
def _commutes_(self, other: Any, *, atol: float = 1e-8) -> Union[bool, NotImplementedType]:
387387
if isinstance(other, SingleQubitCliffordGate):
388388
return self.commutes_with_single_qubit_gate(other)
389389
if isinstance(other, Pauli):
@@ -838,7 +838,9 @@ def __pow__(self, exponent) -> 'CliffordGate':
838838
def __repr__(self) -> str:
839839
return f"Clifford Gate with Tableau:\n {self.clifford_tableau._str_full_()}"
840840

841-
def _commutes_(self, other: Any, atol: float) -> Union[bool, NotImplementedType, None]:
841+
def _commutes_(
842+
self, other: Any, *, atol: float = 1e-8
843+
) -> Union[bool, NotImplementedType, None]:
842844
# Note even if we assume two gates define the tabluea based on the same qubit order,
843845
# the following approach cannot judge it:
844846
# self.clifford_tableau.then(other.clifford_tableau) == other.clifford_tableau.then(

Diff for: cirq-core/cirq/ops/common_gates.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,7 @@ def __repr__(self) -> str:
662662
)
663663

664664
def _commutes_on_qids_(
665-
self, qids: 'Sequence[cirq.Qid]', other: Any, atol: float
665+
self, qids: 'Sequence[cirq.Qid]', other: Any, *, atol: float = 1e-8
666666
) -> Union[bool, NotImplementedType, None]:
667667
from cirq.ops.parity_gates import ZZPowGate
668668

Diff for: cirq-core/cirq/ops/dense_pauli_string.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,9 @@ def __repr__(self) -> str:
358358
f'coefficient={proper_repr(self.coefficient)})'
359359
)
360360

361-
def _commutes_(self, other: Any, atol: float) -> Union[bool, NotImplementedType, None]:
361+
def _commutes_(
362+
self, other: Any, *, atol: float = 1e-8
363+
) -> Union[bool, NotImplementedType, None]:
362364
if isinstance(other, BaseDensePauliString):
363365
n = min(len(self.pauli_mask), len(other.pauli_mask))
364366
phase = _vectorized_pauli_mul_phase(self.pauli_mask[:n], other.pauli_mask[:n])

Diff for: cirq-core/cirq/ops/gate_operation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def _unitary_(self) -> Union[np.ndarray, NotImplementedType]:
199199
return NotImplemented
200200

201201
def _commutes_(
202-
self, other: Any, atol: Union[int, float] = 1e-8
202+
self, other: Any, *, atol: float = 1e-8
203203
) -> Union[bool, NotImplementedType, None]:
204204
commutes = self.gate._commutes_on_qids_(self.qubits, other, atol=atol)
205205
if commutes is not NotImplemented:

Diff for: cirq-core/cirq/ops/pauli_gates.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ def __init__(self, index: int, name: str) -> None:
5353
def num_qubits(self):
5454
return 1
5555

56-
def _commutes_(self, other: Any, atol: float) -> Union[bool, NotImplementedType, None]:
56+
def _commutes_(
57+
self, other: Any, *, atol: float = 1e-8
58+
) -> Union[bool, NotImplementedType, None]:
5759
if not isinstance(other, Pauli):
5860
return NotImplemented
5961
return self is other

Diff for: cirq-core/cirq/ops/pauli_string.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ def zip_paulis(
678678
return (paulis for qubit, paulis in self.zip_items(other))
679679

680680
def _commutes_(
681-
self, other: Any, *, atol: Union[int, float] = 1e-8
681+
self, other: Any, *, atol: float = 1e-8
682682
) -> Union[bool, NotImplementedType, None]:
683683
if not isinstance(other, PauliString):
684684
return NotImplemented

Diff for: cirq-core/cirq/ops/raw_types.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -399,11 +399,13 @@ def _qid_shape_(self) -> Tuple[int, ...]:
399399
"""
400400

401401
def _commutes_on_qids_(
402-
self, qids: 'Sequence[cirq.Qid]', other: Any, atol: float
402+
self, qids: 'Sequence[cirq.Qid]', other: Any, *, atol: float = 1e-8
403403
) -> Union[bool, NotImplementedType, None]:
404404
return NotImplemented
405405

406-
def _commutes_(self, other: Any, atol: float) -> Union[None, NotImplementedType, bool]:
406+
def _commutes_(
407+
self, other: Any, *, atol: float = 1e-8
408+
) -> Union[None, NotImplementedType, bool]:
407409
if not isinstance(other, Gate):
408410
return NotImplemented
409411
if protocols.qid_shape(self) != protocols.qid_shape(other):
@@ -567,7 +569,7 @@ def validate_args(self, qubits: Sequence['cirq.Qid']):
567569
_validate_qid_shape(self, qubits)
568570

569571
def _commutes_(
570-
self, other: Any, *, atol: Union[int, float] = 1e-8
572+
self, other: Any, *, atol: float = 1e-8
571573
) -> Union[bool, NotImplementedType, None]:
572574
"""Determine if this Operation commutes with the object"""
573575
if not isinstance(other, Operation):
@@ -771,7 +773,7 @@ def _unitary_(self) -> Union[np.ndarray, NotImplementedType]:
771773
return protocols.unitary(self.sub_operation, NotImplemented)
772774

773775
def _commutes_(
774-
self, other: Any, *, atol: Union[int, float] = 1e-8
776+
self, other: Any, *, atol: float = 1e-8
775777
) -> Union[bool, NotImplementedType, None]:
776778
return protocols.commutes(self.sub_operation, other, atol=atol)
777779

Diff for: cirq-core/cirq/protocols/commutes_protocol.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class SupportsCommutes(Protocol):
3535
"""An object that can determine commutation relationships vs others."""
3636

3737
@doc_private
38-
def _commutes_(self, other: Any, atol: float) -> Union[None, bool, NotImplementedType]:
38+
def _commutes_(self, other: Any, *, atol: float) -> Union[None, bool, NotImplementedType]:
3939
r"""Determines if this object commutes with the other object.
4040
4141
Can return None to indicate the commutation relationship is

0 commit comments

Comments
 (0)