Skip to content

Commit 259dd91

Browse files
authored
Fix most numpy type errors in cirq/ops (#3997)
Using `check/mypy --next | grep cirq/ops` this fixes almost all of the issues raised with simple modifications. There is a looming problem with places where we use `numbers.Complex`, which is a nice generalization for Union[int, float, complex] but does not play nice with numpy type information. Towards #3767
1 parent aaa1385 commit 259dd91

12 files changed

+37
-35
lines changed

cirq/ops/common_channels.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,10 @@ def _json_dict_(self) -> Dict[str, Any]:
194194
def _approx_eq_(self, other: Any, atol: float) -> bool:
195195
return (
196196
self.num_qubits == other.num_qubits
197-
and np.isclose(self.p_i, other.p_i, atol=atol)
198-
and np.isclose(self.p_x, other.p_x, atol=atol)
199-
and np.isclose(self.p_y, other.p_y, atol=atol)
200-
and np.isclose(self.p_z, other.p_z, atol=atol)
197+
and np.isclose(self.p_i, other.p_i, atol=atol).item()
198+
and np.isclose(self.p_x, other.p_x, atol=atol).item()
199+
and np.isclose(self.p_y, other.p_y, atol=atol).item()
200+
and np.isclose(self.p_z, other.p_z, atol=atol).item()
201201
)
202202

203203

@@ -356,7 +356,7 @@ def _json_dict_(self) -> Dict[str, Any]:
356356
return protocols.obj_to_dict_helper(self, ['p', 'n_qubits'])
357357

358358
def _approx_eq_(self, other: Any, atol: float) -> bool:
359-
return np.isclose(self.p, other.p, atol=atol) and self.n_qubits == other.n_qubits
359+
return np.isclose(self.p, other.p, atol=atol).item() and self.n_qubits == other.n_qubits
360360

361361

362362
def depolarize(p: float, n_qubits: int = 1) -> DepolarizingChannel:
@@ -498,8 +498,9 @@ def _json_dict_(self) -> Dict[str, Any]:
498498
return protocols.obj_to_dict_helper(self, ['p', 'gamma'])
499499

500500
def _approx_eq_(self, other: Any, atol: float) -> bool:
501-
return np.isclose(self.gamma, other.gamma, atol=atol) and np.isclose(
502-
self.p, other.p, atol=atol
501+
return (
502+
np.isclose(self.gamma, other.gamma, atol=atol).item()
503+
and np.isclose(self.p, other.p, atol=atol).item()
503504
)
504505

505506

@@ -628,7 +629,7 @@ def _json_dict_(self) -> Dict[str, Any]:
628629
return protocols.obj_to_dict_helper(self, ['gamma'])
629630

630631
def _approx_eq_(self, other: Any, atol: float) -> bool:
631-
return np.isclose(self.gamma, other.gamma, atol=atol)
632+
return np.isclose(self.gamma, other.gamma, atol=atol).item()
632633

633634

634635
def amplitude_damp(gamma: float) -> AmplitudeDampingChannel:
@@ -863,7 +864,7 @@ def _json_dict_(self) -> Dict[str, Any]:
863864
return protocols.obj_to_dict_helper(self, ['gamma'])
864865

865866
def _approx_eq_(self, other: Any, atol: float) -> bool:
866-
return np.isclose(self._gamma, other._gamma, atol=atol)
867+
return np.isclose(self._gamma, other._gamma, atol=atol).item()
867868

868869

869870
def phase_damp(gamma: float) -> PhaseDampingChannel:
@@ -973,7 +974,7 @@ def _json_dict_(self) -> Dict[str, Any]:
973974
return protocols.obj_to_dict_helper(self, ['p'])
974975

975976
def _approx_eq_(self, other: Any, atol: float) -> bool:
976-
return np.isclose(self.p, other.p, atol=atol)
977+
return np.isclose(self.p, other.p, atol=atol).item()
977978

978979

979980
def _phase_flip_Z() -> common_gates.ZPowGate:
@@ -1129,7 +1130,7 @@ def _json_dict_(self) -> Dict[str, Any]:
11291130
return protocols.obj_to_dict_helper(self, ['p'])
11301131

11311132
def _approx_eq_(self, other: Any, atol: float) -> bool:
1132-
return np.isclose(self._p, other._p, atol=atol)
1133+
return np.isclose(self._p, other._p, atol=atol).item()
11331134

11341135

11351136
def _bit_flip(p: float) -> BitFlipChannel:

cirq/ops/controlled_operation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def _extend_matrix(self, sub_matrix: np.ndarray) -> np.ndarray:
137137
for control_vals in itertools.product(*self.control_values):
138138
active = (*(v for v in control_vals), *(slice(None),) * sub_n) * 2
139139
tensor[active] = sub_tensor
140-
return tensor.reshape((np.prod(qid_shape, dtype=int),) * 2)
140+
return tensor.reshape((np.prod(qid_shape, dtype=int).item(),) * 2)
141141

142142
def _unitary_(self) -> Union[np.ndarray, NotImplementedType]:
143143
sub_matrix = protocols.unitary(self.sub_operation, None)

cirq/ops/dense_pauli_string.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def frozen(self) -> 'DensePauliString':
374374
def copy(
375375
self,
376376
coefficient: Optional[complex] = None,
377-
pauli_mask: Union[None, Iterable[int], np.ndarray] = None,
377+
pauli_mask: Union[None, str, Iterable[int], np.ndarray] = None,
378378
) -> 'DensePauliString':
379379
if pauli_mask is None and (coefficient is None or coefficient == self.coefficient):
380380
return self
@@ -446,7 +446,7 @@ def __imul__(self, other):
446446
def copy(
447447
self,
448448
coefficient: Optional[complex] = None,
449-
pauli_mask: Union[None, Iterable[int], np.ndarray] = None,
449+
pauli_mask: Union[None, str, Iterable[int], np.ndarray] = None,
450450
) -> 'MutableDensePauliString':
451451
return MutableDensePauliString(
452452
coefficient=self.coefficient if coefficient is None else coefficient,
@@ -549,5 +549,5 @@ def _vectorized_pauli_mul_phase(
549549
t -= 1
550550

551551
# Result is i raised to the sum of the per-term phase exponents.
552-
s = int(np.sum(t, dtype=np.uint8) & 3)
552+
s = int(np.sum(t, dtype=np.uint8).item() & 3)
553553
return 1j ** s

cirq/ops/diagonal_gate.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
passed as a list.
1919
"""
2020

21-
from typing import AbstractSet, Any, Tuple, Iterator, List, Sequence, TYPE_CHECKING, Union
21+
from typing import AbstractSet, Any, Iterator, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
22+
2223
import numpy as np
2324
import sympy
2425

@@ -30,7 +31,7 @@
3031
import cirq
3132

3233

33-
def _fast_walsh_hadamard_transform(a: Tuple[Any, ...]) -> np.array:
34+
def _fast_walsh_hadamard_transform(a: Tuple[Any, ...]) -> np.ndarray:
3435
"""Fast Walsh–Hadamard Transform of an array."""
3536
h = 1
3637
a_ = np.array(a)
@@ -99,7 +100,7 @@ def _resolve_parameters_(
99100
def _has_unitary_(self) -> bool:
100101
return not self._is_parameterized_()
101102

102-
def _unitary_(self) -> np.ndarray:
103+
def _unitary_(self) -> Optional[np.ndarray]:
103104
if self._is_parameterized_():
104105
return None
105106
return np.diag([np.exp(1j * angle) for angle in self._diag_angles_radians])

cirq/ops/eigen_gate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def _diagram_exponent(
188188

189189
# Canonicalize the rounded exponent into (-period/2, period/2].
190190
if args.precision is not None:
191-
result = np.around(result, args.precision)
191+
result = np.around(result, args.precision).item()
192192
h = diagram_period / 2
193193
if not (-h < result <= h):
194194
result = h - result

cirq/ops/identity.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def _has_unitary_(self) -> bool:
7373
return True
7474

7575
def _unitary_(self) -> np.ndarray:
76-
return np.identity(np.prod(self._qid_shape, dtype=int))
76+
return np.identity(np.prod(self._qid_shape, dtype=int).item())
7777

7878
def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs') -> Optional[np.ndarray]:
7979
return args.target_tensor

cirq/ops/pauli_interaction_gate.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -112,14 +112,16 @@ def _decompose_(self, qubits: Sequence['cirq.Qid']) -> 'cirq.OP_TREE':
112112
def _circuit_diagram_info_(
113113
self, args: 'cirq.CircuitDiagramInfoArgs'
114114
) -> 'cirq.CircuitDiagramInfo':
115-
labels = cast(
116-
Dict[pauli_gates.Pauli, np.ndarray],
117-
{pauli_gates.X: 'X', pauli_gates.Y: 'Y', pauli_gates.Z: '@'},
118-
)
115+
labels: Dict['cirq.Pauli', str] = {
116+
pauli_gates.X: 'X',
117+
pauli_gates.Y: 'Y',
118+
pauli_gates.Z: '@',
119+
}
119120
l0 = labels[self.pauli0]
120121
l1 = labels[self.pauli1]
121122
# Add brackets around letter if inverted
122-
l0, l1 = (f'(-{l})' if inv else l for l, inv in ((l0, self.invert0), (l1, self.invert1)))
123+
l0 = f'(-{l0})' if self.invert0 else l0
124+
l1 = f'(-{l1})' if self.invert1 else l1
123125
return protocols.CircuitDiagramInfo(
124126
wire_symbols=(l0, l1), exponent=self._diagram_exponent(args)
125127
)

cirq/ops/pauli_sum_exponential.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def matrix(self) -> np.ndarray:
104104
"""
105105
if protocols.is_parameterized(self._exponent):
106106
raise ValueError("Exponent should not parameterized.")
107-
ret = 1
107+
ret = np.ones(1)
108108
for pauli_string_exp in self:
109109
ret = np.kron(ret, protocols.unitary(pauli_string_exp))
110110
return ret

cirq/ops/phased_x_z_gate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def _value_equality_values_(self):
108108
)
109109

110110
@staticmethod
111-
def from_matrix(mat: np.array) -> 'cirq.PhasedXZGate':
111+
def from_matrix(mat: np.ndarray) -> 'cirq.PhasedXZGate':
112112
pre_phase, rotation, post_phase = linalg.deconstruct_single_qubit_matrix_into_angles(mat)
113113
pre_phase /= np.pi
114114
post_phase /= np.pi

cirq/ops/raw_types.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ def _commutes_(
552552

553553
m12 = protocols.unitary_protocol.unitary(circuit12, default=None)
554554
m21 = protocols.unitary_protocol.unitary(circuit21, default=None)
555-
if m12 is None:
555+
if m12 is None or m21 is None:
556556
return NotImplemented
557557

558558
return np.allclose(m12, m21, atol=atol)

cirq/ops/two_qubit_diagonal_gate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def _resolve_parameters_(
6464
def _has_unitary_(self) -> bool:
6565
return not self._is_parameterized_()
6666

67-
def _unitary_(self) -> np.ndarray:
67+
def _unitary_(self) -> Optional[np.ndarray]:
6868
if self._is_parameterized_():
6969
return None
7070
return np.diag([np.exp(1j * angle) for angle in self._diag_angles_radians])

cirq/qis/states.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,7 @@ def validate_normalized_state_vector(
682682
state_vector: np.ndarray,
683683
*, # Force keyword arguments
684684
qid_shape: Tuple[int, ...],
685-
dtype: Optional[Type[np.number]] = None,
685+
dtype: Optional[np.dtype] = None,
686686
atol: float = 1e-7,
687687
) -> None:
688688
"""Checks that the given state vector is valid.
@@ -754,7 +754,7 @@ def to_valid_density_matrix(
754754
num_qubits: Optional[int] = None,
755755
*, # Force keyword arguments
756756
qid_shape: Optional[Tuple[int, ...]] = None,
757-
dtype: Optional[Type[np.number]] = None,
757+
dtype: Optional[np.dtype] = None,
758758
atol: float = 1e-7,
759759
) -> np.ndarray:
760760
"""Verifies the density_matrix_rep is valid and converts it to ndarray form.
@@ -888,9 +888,7 @@ def one_hot(
888888
return result
889889

890890

891-
def eye_tensor(
892-
half_shape: Tuple[int, ...], *, dtype: Type[np.number] # Force keyword args
893-
) -> np.array:
891+
def eye_tensor(half_shape: Tuple[int, ...], *, dtype: np.dtype) -> np.ndarray:
894892
"""Returns an identity matrix reshaped into a tensor.
895893
896894
Args:
@@ -902,6 +900,6 @@ def eye_tensor(
902900
Returns:
903901
The created numpy array with shape `half_shape + half_shape`.
904902
"""
905-
identity = np.eye(np.prod(half_shape, dtype=int), dtype=dtype)
903+
identity = np.eye(np.prod(half_shape, dtype=int).item(), dtype=dtype)
906904
identity.shape = half_shape * 2
907905
return identity

0 commit comments

Comments
 (0)