Skip to content

Commit 0ffe54f

Browse files
pavoljuhasmaffoo
authored andcommitted
Fix typing of the protocols.commutes() function (quantumlib#5651)
Return bool when the `default` argument is not specified. Otherwise allow return type to be the same as the `default` argument type. Co-authored-by: Matthew Neeley <[email protected]>
1 parent e54d259 commit 0ffe54f

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

cirq-core/cirq/protocols/commutes_protocol.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
"""Protocol for determining commutativity."""
1515

16-
from typing import Any, TypeVar, Union
16+
from typing import Any, overload, TypeVar, Union
1717

1818
import numpy as np
1919
from typing_extensions import Protocol
@@ -26,7 +26,7 @@
2626
# whether or not the caller provided a 'default' argument.
2727
# It is checked for using `is`, so it won't have a false positive if the user
2828
# provides a different np.array([]) value.
29-
RaiseTypeErrorIfNotProvided = np.array([])
29+
RaiseTypeErrorIfNotProvided = object()
3030

3131
TDefault = TypeVar('TDefault')
3232

@@ -73,13 +73,21 @@ def _commutes_(self, other: Any, *, atol: float) -> Union[None, bool, NotImpleme
7373
"""
7474

7575

76+
@overload
77+
def commutes(v1: Any, v2: Any, *, atol: Union[int, float] = 1e-8) -> bool:
78+
...
79+
80+
81+
@overload
7682
def commutes(
77-
v1: Any,
78-
v2: Any,
79-
*,
80-
atol: Union[int, float] = 1e-8,
81-
default: Union[bool, TDefault] = RaiseTypeErrorIfNotProvided,
83+
v1: Any, v2: Any, *, atol: Union[int, float] = 1e-8, default: TDefault
8284
) -> Union[bool, TDefault]:
85+
...
86+
87+
88+
def commutes(
89+
v1: Any, v2: Any, *, atol: Union[int, float] = 1e-8, default: Any = RaiseTypeErrorIfNotProvided
90+
) -> Any:
8391
"""Determines whether two values commute.
8492
8593
This is determined by any one of the following techniques:

0 commit comments

Comments
 (0)