Skip to content

Commit 267c2ee

Browse files
authored
Fix more mypy --next type errors (#5392)
1 parent d563992 commit 267c2ee

File tree

7 files changed

+34
-27
lines changed

7 files changed

+34
-27
lines changed

cirq-core/cirq/study/resolver.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,11 @@ def value_of(
157157
if isinstance(value, sympy.Pow) and len(value.args) == 2:
158158
base = self.value_of(value.args[0], recursive)
159159
exponent = self.value_of(value.args[1], recursive)
160+
# Casts because numpy can handle expressions (by delegating to __pow__), but does
161+
# not have signature that will support this.
160162
if isinstance(base, numbers.Number):
161-
return np.float_power(base, exponent)
162-
return np.power(base, exponent)
163+
return np.float_power(cast(complex, base), cast(complex, exponent))
164+
return np.power(cast(complex, base), cast(complex, exponent))
163165

164166
if not isinstance(value, sympy.Basic):
165167
# No known way to resolve this variable, return unchanged.

cirq-core/cirq/transformers/eject_phased_paulis.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def _try_get_known_phased_pauli(
321321
elif (
322322
isinstance(gate, ops.PhasedXZGate)
323323
and not protocols.is_parameterized(gate.z_exponent)
324-
and np.isclose(gate.z_exponent, 0)
324+
and np.isclose(float(gate.z_exponent), 0)
325325
):
326326
e = gate.x_exponent
327327
p = gate.axis_phase_exponent
@@ -336,9 +336,12 @@ def _try_get_known_z_half_turns(
336336
g = op.gate
337337
if (
338338
isinstance(g, ops.PhasedXZGate)
339-
and np.isclose(g.x_exponent, 0)
340-
and np.isclose(g.axis_phase_exponent, 0)
339+
and not protocols.is_parameterized(g.x_exponent)
340+
and not protocols.is_parameterized(g.axis_phase_exponent)
341+
and np.isclose(float(g.x_exponent), 0)
342+
and np.isclose(float(g.axis_phase_exponent), 0)
341343
):
344+
342345
h = g.z_exponent
343346
elif isinstance(g, ops.ZPowGate):
344347
h = g.exponent

cirq-core/cirq/transformers/heuristic_decompositions/gate_tabulation_math_utils.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,17 @@ def random_qubit_unitary(
5858
rng: Random number generator to be used in sampling. Default is
5959
numpy.random.
6060
"""
61-
rng = np.random if rng is None else rng
61+
real_rng: np.random.RandomState = np.random if rng is None else rng
6262

63-
theta = np.arcsin(np.sqrt(rng.rand(*shape)))
64-
phi_d = rng.rand(*shape) * np.pi * 2
65-
phi_o = rng.rand(*shape) * np.pi * 2
63+
theta = np.arcsin(np.sqrt(real_rng.rand(*shape)))
64+
phi_d = real_rng.rand(*shape) * np.pi * 2
65+
phi_o = real_rng.rand(*shape) * np.pi * 2
6666

6767
out = _single_qubit_unitary(theta, phi_d, phi_o)
6868

6969
if randomize_global_phase:
7070
out = np.moveaxis(out, (-2, -1), (0, 1))
71-
out *= np.exp(1j * np.pi * 2 * rng.rand(*shape))
71+
out *= np.exp(1j * np.pi * 2 * real_rng.rand(*shape))
7272
out = np.moveaxis(out, (0, 1), (-2, -1))
7373
return out
7474

cirq-core/cirq/transformers/heuristic_decompositions/two_qubit_gate_tabulation.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""Attempt to tabulate single qubit gates required to generate a target 2Q gate
1616
with a product A k A."""
1717
from functools import reduce
18-
from typing import Tuple, Sequence, List, NamedTuple
18+
from typing import List, NamedTuple, Sequence, Tuple
1919

2020
from dataclasses import dataclass
2121
import numpy as np
@@ -100,7 +100,7 @@ def compile_two_qubit_gate(self, unitary: np.ndarray) -> TwoQubitGateTabulationR
100100
unitary = np.asarray(unitary)
101101
kak_vec = cirq.kak_vector(unitary, check_preconditions=False)
102102
infidelities = kak_vector_infidelity(kak_vec, self.kak_vecs, ignore_equivalent_vectors=True)
103-
nearest_ind = infidelities.argmin()
103+
nearest_ind = int(infidelities.argmin())
104104

105105
success = infidelities[nearest_ind] < self.max_expected_infidelity
106106

@@ -483,13 +483,13 @@ def two_qubit_gate_product_tabulation(
483483
else:
484484
missed_points.append(missing_vec)
485485

486-
kak_vecs = np.array(kak_vecs)
486+
kak_vecs_arr = np.array(kak_vecs)
487487
summary += (
488488
f'\nFraction of Weyl chamber reached with 2 gates and 3 gates '
489489
f'(after patchup)'
490-
f': {(len(kak_vecs) - 1) / num_mesh_points :.3f}'
490+
f': {(len(kak_vecs_arr) - 1) / num_mesh_points :.3f}'
491491
)
492492

493493
return TwoQubitGateTabulation(
494-
base_gate, kak_vecs, sq_cycles, max_infidelity, summary, tuple(missed_points)
494+
base_gate, kak_vecs_arr, sq_cycles, max_infidelity, summary, tuple(missed_points)
495495
)

cirq-core/cirq/value/duration.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,8 @@ def __init__(
8888
else:
8989
raise TypeError(f'Not a `cirq.DURATION_LIKE`: {repr(value)}.')
9090

91-
self._picos: Union[float, int, sympy.Expr] = (
92-
picos + nanos * 1000 + micros * 1000_000 + millis * 1000_000_000
93-
)
94-
if isinstance(self._picos, np.number):
95-
self._picos = float(self._picos)
91+
val = picos + nanos * 1000 + micros * 1000_000 + millis * 1000_000_000
92+
self._picos: _NUMERIC_OUTPUT_TYPE = float(val) if isinstance(val, np.number) else val
9693

9794
def _is_parameterized_(self) -> bool:
9895
return protocols.is_parameterized(self._picos)

cirq-core/cirq/vis/histogram.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,16 @@ def integrated_histogram(
8686
if isinstance(data, Mapping):
8787
data = list(data.values())
8888

89-
data = [d for d in data if not np.isnan(d)]
90-
n = len(data)
89+
float_data = [float(d) for d in data if not np.isnan(float(d))]
90+
91+
n = len(float_data)
9192

9293
if not show_zero:
9394
bin_values = np.linspace(0, 1, n + 1)
94-
parameter_values = sorted(np.concatenate(([0], data)))
95+
parameter_values = sorted(np.concatenate(([0], float_data)))
9596
else:
9697
bin_values = np.linspace(0, 1, n)
97-
parameter_values = sorted(data)
98+
parameter_values = sorted(float_data)
9899
plot_options = {"where": 'post', "color": 'b', "linestyle": '-', "lw": 1.0, "ms": 0.0}
99100
plot_options.update(kwargs)
100101

@@ -127,15 +128,19 @@ def integrated_histogram(
127128

128129
if median_line:
129130
set_line(
130-
np.median(data),
131+
np.median(float_data),
131132
linestyle='--',
132133
color=plot_options['color'],
133134
alpha=0.5,
134135
label=median_label,
135136
)
136137
if mean_line:
137138
set_line(
138-
np.mean(data), linestyle='-.', color=plot_options['color'], alpha=0.5, label=mean_label
139+
np.mean(float_data),
140+
linestyle='-.',
141+
color=plot_options['color'],
142+
alpha=0.5,
143+
label=mean_label,
139144
)
140145
if show_plot:
141146
fig.show()

cirq-core/cirq/vis/state_histogram.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def plot_state_histogram(
9292
elif isinstance(data, collections.Counter):
9393
tick_label, values = zip(*sorted(data.items()))
9494
else:
95-
values = data
95+
values = np.array(data)
9696
if not tick_label:
9797
tick_label = np.arange(len(values))
9898
ax.bar(np.arange(len(values)), values, tick_label=tick_label)

0 commit comments

Comments
 (0)