Skip to content

Commit 9c710fb

Browse files
authored
Fix most numpy type errors in cirq/linalg (#4000)
Using `check/mypy --next | grep cirq/linalg` this fixes all the problems. #3767
1 parent 845836a commit 9c710fb

8 files changed

+59
-43
lines changed

cirq/linalg/combinators.py

+13-10
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,15 @@
1515
"""Utility methods for combining matrices."""
1616

1717
import functools
18-
from typing import Union, Type
18+
from typing import Union, TYPE_CHECKING
1919

2020
import numpy as np
2121

2222
from cirq._doc import document
2323

24+
if TYPE_CHECKING:
25+
from numpy.typing import DTypeLike, ArrayLike
26+
2427

2528
def kron(*factors: Union[np.ndarray, complex, float], shape_len: int = 2) -> np.ndarray:
2629
"""Computes the kronecker product of a sequence of values.
@@ -104,7 +107,7 @@ def kron_with_controls(*factors: Union[np.ndarray, complex, float]) -> np.ndarra
104107
return product
105108

106109

107-
def dot(*values: Union[float, complex, np.ndarray]) -> Union[float, complex, np.ndarray]:
110+
def dot(*values: 'ArrayLike') -> np.ndarray:
108111
"""Computes the dot/matrix product of a sequence of values.
109112
110113
Performs the computation in serial order without regard to the matrix
@@ -117,20 +120,20 @@ def dot(*values: Union[float, complex, np.ndarray]) -> Union[float, complex, np.
117120
Returns:
118121
The resulting value or matrix.
119122
"""
123+
if len(values) == 0:
124+
raise ValueError("cirq.dot must be called with arguments")
125+
126+
if len(values) == 1:
127+
# note: it's important that we copy input arrays.
128+
return np.array(values[0])
120129

121-
if len(values) <= 1:
122-
if len(values) == 0:
123-
raise ValueError("cirq.dot must be called with arguments")
124-
if isinstance(values[0], np.ndarray):
125-
return np.array(values[0])
126-
return values[0]
127-
result = values[0]
130+
result = np.asarray(values[0])
128131
for value in values[1:]:
129132
result = np.dot(result, value)
130133
return result
131134

132135

133-
def _merge_dtypes(dtype1: Type[np.number], dtype2: Type[np.number]) -> Type[np.number]:
136+
def _merge_dtypes(dtype1: 'DTypeLike', dtype2: 'DTypeLike') -> np.dtype:
134137
return (np.zeros(0, dtype1) + np.zeros(0, dtype2)).dtype
135138

136139

cirq/linalg/decompositions.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def _group_similar(items: List[T], comparer: Callable[[T, T], bool]) -> List[Lis
121121

122122
def unitary_eig(
123123
matrix: np.ndarray, check_preconditions: bool = True, atol: float = 1e-8
124-
) -> Tuple[np.array, np.ndarray]:
124+
) -> Tuple[np.ndarray, np.ndarray]:
125125
"""Gives the guaranteed unitary eigendecomposition of a normal matrix.
126126
127127
All hermitian and unitary matrices are normal matrices. This method was
@@ -337,7 +337,7 @@ def _unitary_(self) -> np.ndarray:
337337

338338
def __str__(self) -> str:
339339
axis_terms = '+'.join(
340-
'{:.3g}*{}'.format(e, a) if e < 0.9999 else a
340+
f'{e:.3g}*{a}' if e < 0.9999 else a
341341
for e, a in zip(self.axis, ['X', 'Y', 'Z'])
342342
if abs(e) >= 1e-8
343343
).replace('+-', '-')
@@ -648,11 +648,13 @@ def coord_transform(
648648

649649
# parse input and extract KAK vector
650650
if not isinstance(interactions, np.ndarray):
651-
interactions = [
651+
interactions_extracted: List[np.ndarray] = [
652652
a if isinstance(a, np.ndarray) else protocols.unitary(a) for a in interactions
653653
]
654+
else:
655+
interactions_extracted = [interactions]
654656

655-
points = kak_vector(interactions) * 4 / np.pi
657+
points = kak_vector(interactions_extracted) * 4 / np.pi
656658

657659
ax.scatter(*coord_transform(points), **kwargs)
658660
ax.set_xlim(0, +1)

cirq/linalg/decompositions_test.py

+3
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,9 @@ def test_scatter_plot_normalized_kak_interaction_coefficients():
568568
)
569569
assert ax2 is ax
570570

571+
ax3 = cirq.scatter_plot_normalized_kak_interaction_coefficients(data[1], ax=ax)
572+
assert ax3 is ax
573+
571574

572575
def _vector_kron(first: np.ndarray, second: np.ndarray) -> np.ndarray:
573576
"""Vectorized implementation of kron for square matrices."""

cirq/linalg/diagonalize.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,9 @@ def bidiagonalize_real_matrix_pair_with_symmetric_products(
183183
raise ValueError('mat1 must be real.')
184184
if np.any(np.imag(mat2) != 0):
185185
raise ValueError('mat2 must be real.')
186-
if not predicates.is_hermitian(mat1.dot(mat2.T), rtol=rtol, atol=atol):
186+
if not predicates.is_hermitian(np.dot(mat1, mat2.T), rtol=rtol, atol=atol):
187187
raise ValueError('mat1 @ mat2.T must be symmetric.')
188-
if not predicates.is_hermitian(mat1.T.dot(mat2), rtol=rtol, atol=atol):
188+
if not predicates.is_hermitian(np.dot(mat1.T, mat2), rtol=rtol, atol=atol):
189189
raise ValueError('mat1.T @ mat2 must be symmetric.')
190190

191191
# Use SVD to bi-diagonalize the first matrix.
@@ -200,7 +200,7 @@ def bidiagonalize_real_matrix_pair_with_symmetric_products(
200200
base_diag = base_diag[:rank, :rank]
201201

202202
# Try diagonalizing the second matrix with the same factors as the first.
203-
semi_corrected = base_left.T.dot(np.real(mat2)).dot(base_right.T)
203+
semi_corrected = combinators.dot(base_left.T, np.real(mat2), base_right.T)
204204

205205
# Fix up the part of the second matrix's diagonalization that's matched
206206
# against non-zero diagonal entries in the first matrix's diagonalization
@@ -218,15 +218,15 @@ def bidiagonalize_real_matrix_pair_with_symmetric_products(
218218
# Merge the fixup factors into the initial diagonalization.
219219
left_adjust = combinators.block_diag(overlap_adjust, extra_left_adjust)
220220
right_adjust = combinators.block_diag(overlap_adjust.T, extra_right_adjust)
221-
left = left_adjust.T.dot(base_left.T)
222-
right = base_right.T.dot(right_adjust.T)
221+
left = np.dot(left_adjust.T, base_left.T)
222+
right = np.dot(base_right.T, right_adjust.T)
223223

224224
return left, right
225225

226226

227227
def bidiagonalize_unitary_with_special_orthogonals(
228228
mat: np.ndarray, *, rtol: float = 1e-5, atol: float = 1e-8, check_preconditions: bool = True
229-
) -> Tuple[np.ndarray, np.array, np.ndarray]:
229+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
230230
"""Finds orthogonal matrices L, R such that L @ matrix @ R is diagonal.
231231
232232
Args:

cirq/linalg/operator_spaces.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
"""Utilities for manipulating linear operators as elements of vector space."""
16-
1716
from typing import Dict, Tuple
1817

1918
import numpy as np
@@ -32,7 +31,7 @@
3231

3332
def kron_bases(*bases: Dict[str, np.ndarray], repeat: int = 1) -> Dict[str, np.ndarray]:
3433
"""Creates tensor product of bases."""
35-
product_basis = {'': 1}
34+
product_basis = {'': np.ones(1)}
3635
for basis in bases * repeat:
3736
product_basis = {
3837
name1 + name2: np.kron(matrix1, matrix2)
@@ -98,14 +97,14 @@ def pow_pauli_combination(
9897
if exponent == 0:
9998
return 1, 0, 0, 0
10099

101-
v = np.sqrt(ax * ax + ay * ay + az * az)
102-
s = np.power(ai + v, exponent)
103-
t = np.power(ai - v, exponent)
100+
v = np.sqrt(ax * ax + ay * ay + az * az).item()
101+
s = (ai + v) ** exponent
102+
t = (ai - v) ** exponent
104103

105104
ci = (s + t) / 2
106105
if s == t:
107106
# v is near zero, only one term in binomial expansion survives
108-
cxyz = exponent * np.power(ai, exponent - 1)
107+
cxyz = exponent * ai ** (exponent - 1)
109108
else:
110109
# v is non-zero, account for all terms of binomial expansion
111110
cxyz = (s - t) / 2

cirq/linalg/predicates.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from cirq import value
2121

2222

23-
def is_diagonal(matrix: np.ndarray, *, atol: float = 1e-8) -> bool:
23+
def is_diagonal(matrix: np.ndarray, *, atol: float = 1e-8) -> np.bool_:
2424
"""Determines if a matrix is a approximately diagonal.
2525
2626
A matrix is diagonal if i!=j implies m[i,j]==0.
@@ -72,7 +72,7 @@ def is_orthogonal(matrix: np.ndarray, *, rtol: float = 1e-5, atol: float = 1e-8)
7272
"""
7373
return (
7474
matrix.shape[0] == matrix.shape[1]
75-
and np.all(np.imag(matrix) == 0)
75+
and np.all(np.imag(matrix) == 0).item()
7676
and np.allclose(matrix.dot(matrix.T), np.eye(matrix.shape[0]), rtol=rtol, atol=atol)
7777
)
7878

cirq/linalg/tolerance.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@
1414

1515
"""Utility for testing approximate equality of matrices and scalars within
1616
tolerances."""
17-
from typing import Union, Iterable
17+
from typing import Union, Iterable, TYPE_CHECKING
1818

1919
import numpy as np
2020

21+
if TYPE_CHECKING:
22+
from numpy.typing import ArrayLike
2123

22-
def all_near_zero(
23-
a: Union[float, complex, Iterable[float], np.ndarray], *, atol: float = 1e-8
24-
) -> bool:
24+
25+
def all_near_zero(a: 'ArrayLike', *, atol: float = 1e-8) -> np.bool_:
2526
"""Checks if the tensor's elements are all near zero.
2627
2728
Args:
@@ -33,7 +34,7 @@ def all_near_zero(
3334

3435
def all_near_zero_mod(
3536
a: Union[float, complex, Iterable[float], np.ndarray], period: float, *, atol: float = 1e-8
36-
) -> bool:
37+
) -> np.bool_:
3738
"""Checks if the tensor's elements are all near multiples of the period.
3839
3940
Args:

cirq/linalg/transformations.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
"""Utility methods for transforming matrices or vectors."""
1616

17-
from typing import Tuple, Optional, Sequence, List, Union, TypeVar
17+
from typing import Tuple, Optional, Sequence, List, Union
1818

1919
import numpy as np
2020

@@ -26,9 +26,7 @@
2626
# of type np.ndarray to ensure the method has the correct type signature in that
2727
# case. It is checked for using `is`, so it won't have a false positive if the
2828
# user provides a different np.array([]) value.
29-
RaiseValueErrorIfNotProvided = np.array([]) # type: np.ndarray
30-
31-
TDefault = TypeVar('TDefault')
29+
RaiseValueErrorIfNotProvided: np.ndarray = np.array([])
3230

3331

3432
def reflection_matrix_pow(reflection_matrix: np.ndarray, exponent: float):
@@ -326,6 +324,10 @@ def partial_trace(tensor: np.ndarray, keep_indices: List[int]) -> np.ndarray:
326324
return np.einsum(tensor, left_indices + right_indices)
327325

328326

327+
class EntangledStateError(ValueError):
328+
"""Raised when a product state is expected, but an entangled state is provided."""
329+
330+
329331
def partial_trace_of_state_vector_as_mixture(
330332
state_vector: np.ndarray, keep_indices: List[int], *, atol: Union[int, float] = 1e-8
331333
) -> Tuple[Tuple[float, np.ndarray], ...]:
@@ -357,9 +359,13 @@ def partial_trace_of_state_vector_as_mixture(
357359
"""
358360

359361
# Attempt to do efficient state factoring.
360-
state = sub_state_vector(state_vector, keep_indices, default=None, atol=atol)
361-
if state is not None:
362+
try:
363+
state = sub_state_vector(
364+
state_vector, keep_indices, default=RaiseValueErrorIfNotProvided, atol=atol
365+
)
362366
return ((1.0, state),)
367+
except EntangledStateError:
368+
pass
363369

364370
# Fall back to a (non-unique) mixture representation.
365371
keep_dims = 1 << len(keep_indices)
@@ -382,7 +388,7 @@ def sub_state_vector(
382388
state_vector: np.ndarray,
383389
keep_indices: List[int],
384390
*,
385-
default: TDefault = RaiseValueErrorIfNotProvided,
391+
default: np.ndarray = RaiseValueErrorIfNotProvided,
386392
atol: Union[int, float] = 1e-8,
387393
) -> np.ndarray:
388394
r"""Attempts to factor a state vector into two parts and return one of them.
@@ -424,8 +430,10 @@ def sub_state_vector(
424430
425431
Raises:
426432
ValueError: if the `state_vector` is not of the correct shape or the
427-
indices are not a valid subset of the input `state_vector`'s indices, or
428-
the result of factoring is not a pure state.
433+
indices are not a valid subset of the input `state_vector`'s indices
434+
EntangledStateError: If the result of factoring is not a pure state and
435+
`default` is not provided.
436+
429437
"""
430438

431439
if not np.log2(state_vector.size).is_integer():
@@ -471,7 +479,7 @@ def sub_state_vector(
471479
if default is not RaiseValueErrorIfNotProvided:
472480
return default
473481

474-
raise ValueError(
482+
raise EntangledStateError(
475483
"Input state vector could not be factored into pure state over "
476484
"indices {}".format(keep_indices)
477485
)

0 commit comments

Comments
 (0)