|
13 | 13 | # limitations under the License.
|
14 | 14 | """Protocol for determining commutativity."""
|
15 | 15 |
|
16 |
| -from typing import Any, TypeVar, Union |
| 16 | +from typing import Any, overload, TypeVar, Union |
17 | 17 |
|
18 | 18 | import numpy as np
|
19 | 19 | from typing_extensions import Protocol
|
|
26 | 26 | # whether or not the caller provided a 'default' argument.
|
27 | 27 | # It is checked for using `is`, so it won't have a false positive if the user
|
28 | 28 | # provides a different np.array([]) value.
|
29 |
| -RaiseTypeErrorIfNotProvided = np.array([]) |
| 29 | +RaiseTypeErrorIfNotProvided = object() |
30 | 30 |
|
31 | 31 | TDefault = TypeVar('TDefault')
|
32 | 32 |
|
@@ -73,13 +73,21 @@ def _commutes_(self, other: Any, *, atol: float) -> Union[None, bool, NotImpleme
|
73 | 73 | """
|
74 | 74 |
|
75 | 75 |
|
| 76 | +@overload |
| 77 | +def commutes(v1: Any, v2: Any, *, atol: Union[int, float] = 1e-8) -> bool: |
| 78 | + ... |
| 79 | + |
| 80 | + |
| 81 | +@overload |
76 | 82 | def commutes(
|
77 |
| - v1: Any, |
78 |
| - v2: Any, |
79 |
| - *, |
80 |
| - atol: Union[int, float] = 1e-8, |
81 |
| - default: Union[bool, TDefault] = RaiseTypeErrorIfNotProvided, |
| 83 | + v1: Any, v2: Any, *, atol: Union[int, float] = 1e-8, default: TDefault |
82 | 84 | ) -> Union[bool, TDefault]:
|
| 85 | + ... |
| 86 | + |
| 87 | + |
| 88 | +def commutes( |
| 89 | + v1: Any, v2: Any, *, atol: Union[int, float] = 1e-8, default: Any = RaiseTypeErrorIfNotProvided |
| 90 | +) -> Any: |
83 | 91 | """Determines whether two values commute.
|
84 | 92 |
|
85 | 93 | This is determined by any one of the following techniques:
|
|
0 commit comments