Skip to content

Commit a2fd902

Browse files
authored
Fix the final mypy --next errors (#5760)
Will push changes to github actions after this is in. * `np.einsum` is missing a type signature, resulting in some type ignores (filed issue at numpy numpy/numpy#21978) * In practice one can use numpy types like np.double for parameters, but currently we don't support that. Filed #5758 to consider this. * Numpy `tolist` does not return a type signature of `List`. wat? Part of #3767
1 parent 9e61e0e commit a2fd902

File tree

11 files changed

+41
-23
lines changed

11 files changed

+41
-23
lines changed

cirq-core/cirq/circuits/circuit_operation.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,14 @@
2020
import math
2121
from typing import (
2222
Callable,
23-
Mapping,
24-
Sequence,
23+
cast,
2524
Dict,
2625
FrozenSet,
2726
Iterator,
2827
List,
28+
Mapping,
2929
Optional,
30+
Sequence,
3031
Tuple,
3132
TYPE_CHECKING,
3233
Union,
@@ -78,7 +79,7 @@ class CircuitOperation(ops.Operation):
7879
def __init__(
7980
self,
8081
circuit: 'cirq.FrozenCircuit',
81-
repetitions: int = 1,
82+
repetitions: INT_TYPE = 1,
8283
qubit_map: Optional[Dict['cirq.Qid', 'cirq.Qid']] = None,
8384
measurement_key_map: Optional[Dict[str, str]] = None,
8485
param_resolver: Optional[study.ParamResolverOrSimilarType] = None,
@@ -790,4 +791,8 @@ def _resolve_parameters_(
790791
self, resolver: 'cirq.ParamResolver', recursive: bool
791792
) -> 'cirq.CircuitOperation':
792793
resolved = self.with_params(resolver.param_dict, recursive)
793-
return resolved.replace(repetitions=resolver.value_of(self.repetitions, recursive))
794+
# repetitions can resolve to a float, but this is ok since constructor converts to
795+
# nearby int.
796+
return resolved.replace(
797+
repetitions=resolver.value_of(cast('cirq.TParamVal', self.repetitions), recursive)
798+
)

cirq-core/cirq/contrib/qasm_import/_parser.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import functools
1515
import operator
16-
from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Sequence, Union
16+
from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Union, TYPE_CHECKING
1717

1818
import numpy as np
1919
import sympy
@@ -25,6 +25,10 @@
2525
from cirq.contrib.qasm_import.exception import QasmException
2626

2727

28+
if TYPE_CHECKING:
29+
import cirq
30+
31+
2832
class Qasm:
2933
"""Qasm stores the final result of the Qasm parsing."""
3034

@@ -115,7 +119,10 @@ def on(
115119
# used as arguments, we generate reg_size GateOperations via iterating
116120
# through each qubit of the registers 0 to n-1 and use the same one
117121
# qubit from the "single-qubit registers" for each operation.
118-
op_qubits = cast(Sequence[Sequence[ops.Qid]], functools.reduce(np.broadcast, args))
122+
op_qubits = functools.reduce(
123+
cast(Callable[[List['cirq.Qid'], List['cirq.Qid']], List['cirq.Qid']], np.broadcast),
124+
args,
125+
)
119126
for qubits in op_qubits:
120127
if isinstance(qubits, ops.Qid):
121128
yield final_gate.on(qubits)

cirq-core/cirq/experiments/readout_confusion_matrix.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,11 @@ def _get_vars(self, qubit_pattern: Sequence['cirq.Qid']) -> List[int]:
175175
def _confusion_matrix(self, qubits: Sequence['cirq.Qid']) -> np.ndarray:
176176
ein_input = []
177177
for qs, cm in zip(self.measure_qubits, self.confusion_matrices):
178-
ein_input += [cm.reshape((2, 2) * len(qs)), self._get_vars(qs)]
178+
ein_input.extend([cm.reshape((2, 2) * len(qs)), self._get_vars(qs)])
179179
ein_out = self._get_vars(qubits)
180-
ret = np.einsum(*ein_input, ein_out).reshape((2 ** len(qubits),) * 2)
180+
181+
# TODO(#5757): remote type ignore when numpy has proper override signature.
182+
ret = np.einsum(*ein_input, ein_out).reshape((2 ** len(qubits),) * 2) # type: ignore
181183
return ret / ret.sum(axis=1)
182184

183185
def confusion_matrix(self, qubits: Optional[Sequence['cirq.Qid']] = None) -> np.ndarray:

cirq-core/cirq/linalg/decompositions.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
Iterable,
2424
List,
2525
Optional,
26-
Sequence,
2726
Set,
2827
Tuple,
2928
TYPE_CHECKING,
@@ -637,7 +636,7 @@ def scatter_plot_normalized_kak_interaction_coefficients(
637636
ax = fig.add_subplot(1, 1, 1, projection='3d')
638637

639638
def coord_transform(
640-
pts: Sequence[Tuple[float, float, float]]
639+
pts: Union[List[Tuple[int, int, int]], np.ndarray]
641640
) -> Tuple[Iterable[float], Iterable[float], Iterable[float]]:
642641
if len(pts) == 0:
643642
return [], [], []

cirq-core/cirq/linalg/predicates.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,10 @@ def is_cptp(*, kraus_ops: Sequence[np.ndarray], rtol: float = 1e-5, atol: float
161161
atol: The absolute tolerance on equality.
162162
"""
163163
sum_ndarray = cast(np.ndarray, sum(matrix.T.conj() @ matrix for matrix in kraus_ops))
164-
return np.allclose(sum_ndarray, np.eye(*sum_ndarray.shape), rtol=rtol, atol=atol)
164+
# Explicitly pull out shapes and don't use tuple to avoid confusing numpy type overrides.
165+
return np.allclose(
166+
sum_ndarray, np.eye(sum_ndarray.shape[0], sum_ndarray.shape[1]), rtol=rtol, atol=atol
167+
)
165168

166169

167170
def matrix_commutes(

cirq-core/cirq/linalg/transformations.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def targeted_left_multiply(
151151

152152
all_indices = set(input_indices + data_indices + tuple(output_indices))
153153

154+
# TODO(#5757): remote type ignore when numpy has proper override signature.
154155
return np.einsum(
155156
left_matrix,
156157
input_indices,
@@ -164,7 +165,7 @@ def targeted_left_multiply(
164165
# And this is workaround for *another* bug!
165166
# Supposed to be able to just say 'old=old'.
166167
**({'out': out} if out is not None else {}),
167-
)
168+
) # type: ignore
168169

169170

170171
def targeted_conjugate_about(
@@ -333,7 +334,8 @@ def partial_trace(tensor: np.ndarray, keep_indices: Sequence[int]) -> np.ndarray
333334
keep_map = dict(zip(keep_indices, sorted(keep_indices)))
334335
left_indices = [keep_map[i] if i in keep_set else i for i in range(ndim)]
335336
right_indices = [ndim + i if i in keep_set else i for i in left_indices]
336-
return np.einsum(tensor, left_indices + right_indices)
337+
# TODO(#5757): remote type ignore when numpy has proper override signature.
338+
return np.einsum(tensor, left_indices + right_indices) # type: ignore
337339

338340

339341
class EntangledStateError(ValueError):

cirq-core/cirq/qis/states.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -677,13 +677,14 @@ def density_matrix_from_state_vector(
677677
sum_inds = np.array(range(n_qubits))
678678
sum_inds[indices] += n_qubits
679679

680+
# TODO(#5757): remote type ignore when numpy has proper override signature.
680681
rho = np.einsum(
681682
state_vector,
682683
list(range(n_qubits)),
683684
np.conj(state_vector),
684-
sum_inds.tolist(),
685-
indices + sum_inds[indices].tolist(),
686-
)
685+
cast(List, sum_inds.tolist()),
686+
indices + cast(List, sum_inds[indices].tolist()),
687+
) # type: ignore
687688
new_shape = np.prod([shape[i] for i in indices], dtype=np.int64)
688689

689690
return rho.reshape((new_shape, new_shape))

cirq-core/cirq/sim/state_vector.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -352,5 +352,4 @@ def _probs(state: np.ndarray, indices: Sequence[int], qid_shape: Tuple[int, ...]
352352
probs = np.sum(probs, axis=tuple(range(1, len(probs.shape))))
353353

354354
# To deal with rounding issues, ensure that the probabilities sum to 1.
355-
probs /= np.sum(probs)
356-
return probs
355+
return probs / np.sum(probs)

cirq-core/cirq/study/resolver.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
"""Resolves ParameterValues to assigned values."""
1616
import numbers
17-
from typing import Any, Dict, Iterator, Mapping, Optional, TYPE_CHECKING, Union, cast
17+
from typing import Any, cast, Dict, Iterator, Mapping, Optional, TYPE_CHECKING, Union
1818

1919
import numpy as np
2020
import sympy

cirq-google/cirq_google/serialization/circuit_serializer_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def circuit_proto(json: Dict, qubits: List[str]):
5353

5454
X_PROTO = op_proto({'xpowgate': {'exponent': {'float_value': 1.0}}, 'qubit_constant_index': [0]})
5555

56-
56+
# TODO(#5758): Add support for numpy types to `TParamVal`.
5757
OPERATIONS = [
5858
(cirq.X(Q0), X_PROTO),
5959
(
@@ -69,11 +69,11 @@ def circuit_proto(json: Dict, qubits: List[str]):
6969
op_proto({'xpowgate': {'exponent': {'float_value': 0.125}}, 'qubit_constant_index': [0]}),
7070
),
7171
(
72-
cirq.XPowGate(exponent=np.double(0.125))(Q1),
72+
cirq.XPowGate(exponent=np.double(0.125))(Q1), # type: ignore
7373
op_proto({'xpowgate': {'exponent': {'float_value': 0.125}}, 'qubit_constant_index': [0]}),
7474
),
7575
(
76-
cirq.XPowGate(exponent=np.short(1))(Q1),
76+
cirq.XPowGate(exponent=np.short(1))(Q1), # type: ignore
7777
op_proto({'xpowgate': {'exponent': {'float_value': 1.0}}, 'qubit_constant_index': [0]}),
7878
),
7979
(

examples/direct_fidelity_estimation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ def direct_fidelity_estimation(
419419
measured_pauli_traces = pauli_traces
420420
else:
421421
# Otherwise, randomly sample as per probability.
422-
measured_pauli_traces = np.random.choice(pauli_traces, size=len(pauli_traces), p=p)
422+
measured_pauli_traces = np.random.choice(pauli_traces, size=len(pauli_traces), p=p).tolist()
423423

424424
trial_results: List[Result] = []
425425
for pauli_trace in measured_pauli_traces:

0 commit comments

Comments
 (0)