Skip to content

Commit 89a5323

Browse files
authored
More numpy types (#5683)
12 errors left Part of #3767
1 parent 5dd3a94 commit 89a5323

14 files changed

+38
-33
lines changed

cirq-core/cirq/circuits/circuit.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
Any,
3131
Callable,
3232
Mapping,
33+
MutableSequence,
3334
cast,
3435
Dict,
3536
FrozenSet,
@@ -1462,7 +1463,7 @@ def concat_ragged(
14621463

14631464
# Allocate a buffer large enough to append and prepend all the circuits.
14641465
pad_len = sum(len(c) for c in circuits) - n_acc
1465-
buffer = np.zeros(shape=pad_len * 2 + n_acc, dtype=object)
1466+
buffer: MutableSequence['cirq.Moment'] = [cirq.Moment()] * (pad_len * 2 + n_acc)
14661467

14671468
# Put the initial circuit in the center of the buffer.
14681469
offset = pad_len
@@ -1601,7 +1602,11 @@ def _overlap_collision_time(
16011602

16021603

16031604
def _concat_ragged_helper(
1604-
c1_offset: int, n1: int, buf: np.ndarray, c2: Sequence['cirq.Moment'], align: 'cirq.Alignment'
1605+
c1_offset: int,
1606+
n1: int,
1607+
buf: MutableSequence['cirq.Moment'],
1608+
c2: Sequence['cirq.Moment'],
1609+
align: 'cirq.Alignment',
16051610
) -> Tuple[int, int]:
16061611
n2 = len(c2)
16071612
shift = _overlap_collision_time(buf[c1_offset : c1_offset + n1], c2, align)
@@ -2369,7 +2374,7 @@ def _resolve_parameters_(
23692374
return Circuit(resolved_moments)
23702375

23712376
@property
2372-
def moments(self):
2377+
def moments(self) -> Sequence['cirq.Moment']:
23732378
return self._moments
23742379

23752380
def with_noise(self, noise: 'cirq.NOISE_MODEL_LIKE') -> 'cirq.Circuit':

cirq-core/cirq/circuits/circuit_operation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ def repeat(
589589
# As CircuitOperation is immutable, this can safely return the original.
590590
return self
591591

592-
expected_repetition_id_length = abs(repetitions)
592+
expected_repetition_id_length: int = np.abs(repetitions)
593593

594594
if repetition_ids is None:
595595
if self.use_repetition_ids:

cirq-core/cirq/contrib/quimb/state_vector.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def circuit_to_tensors(
7676

7777
for moment in circuit.moments:
7878
for op in moment.operations:
79-
assert op.gate._has_unitary_()
79+
assert cirq.has_unitary(op.gate)
8080
start_inds = [f'i{qubit_frontier[q]}_q{q}' for q in op.qubits]
8181
for q in op.qubits:
8282
qubit_frontier[q] += 1

cirq-core/cirq/ops/pauli_string.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -744,7 +744,8 @@ def _expectation_from_density_matrix_no_validation(
744744

745745
while any(result.shape):
746746
result = np.trace(result, axis1=0, axis2=len(result.shape) // 2)
747-
return result * self.coefficient
747+
748+
return float(result * self.coefficient)
748749

749750
def zip_items(
750751
self, other: 'cirq.PauliString[TKey]'

cirq-core/cirq/protocols/circuit_diagram_info_protocol.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def format_radians(self, radians: Union[sympy.Basic, int, float]) -> str:
274274
return '0'
275275
if radians == -np.pi:
276276
return '-' + unit
277-
if self.precision is not None:
277+
if self.precision is not None and not isinstance(radians, sympy.Basic):
278278
quantity = self.format_real(radians / np.pi)
279279
return quantity + unit
280280
return repr(radians)

cirq-core/cirq/protocols/trace_distance_bound.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, TypeVar, Optional, Sequence
15+
from typing import Any, TypeVar, Optional, Sequence, Union
1616

1717
import numpy as np
1818
from typing_extensions import Protocol
@@ -109,7 +109,7 @@ def _strat_distance_from_unitary(val: Any) -> Optional[float]:
109109
return trace_distance_from_angle_list(np.angle(np.linalg.eigvals(u)))
110110

111111

112-
def trace_distance_from_angle_list(angle_list: Sequence[float]) -> float:
112+
def trace_distance_from_angle_list(angle_list: Union[Sequence[float], np.ndarray]) -> float:
113113
"""Given a list of arguments of the eigenvalues of a unitary matrix,
114114
calculates the trace distance bound of the unitary effect.
115115

cirq-core/cirq/qis/clifford_tableau.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from cirq import protocols
2020
from cirq._compat import proper_repr
2121
from cirq.qis import quantum_state_representation
22-
from cirq.value import big_endian_int_to_digits, linear_dict
22+
from cirq.value import big_endian_int_to_digits, linear_dict, random_state
2323

2424
if TYPE_CHECKING:
2525
import cirq
@@ -509,7 +509,7 @@ def destabilizers(self) -> List['cirq.DensePauliString']:
509509
generators above generate the full Pauli group on n qubits."""
510510
return [self._row_to_dense_pauli(i) for i in range(self.n)]
511511

512-
def _measure(self, q, prng: np.random.RandomState = np.random) -> int:
512+
def _measure(self, q, prng: np.random.RandomState) -> int:
513513
"""Performs a projective measurement on the q'th qubit.
514514
515515
Returns: the result (0 or 1) of the measurement.
@@ -651,6 +651,6 @@ def apply_global_phase(self, coefficient: linear_dict.Scalar):
651651
pass
652652

653653
def measure(
654-
self, axes: Sequence[int], seed: Optional['cirq.RANDOM_STATE_OR_SEED_LIKE'] = None
654+
self, axes: Sequence[int], seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None
655655
) -> List[int]:
656-
return [self._measure(axis, seed) for axis in axes]
656+
return [self._measure(axis, random_state.parse_random_state(seed)) for axis in axes]

cirq-core/cirq/sim/clifford/stabilizer_state_ch_form.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import cirq
1919
from cirq import protocols, qis, value
20-
from cirq.value import big_endian_int_to_digits
20+
from cirq.value import big_endian_int_to_digits, random_state
2121

2222

2323
@value.value_equality
@@ -388,7 +388,7 @@ def apply_global_phase(self, coefficient: value.Scalar):
388388
def measure(
389389
self, axes: Sequence[int], seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None
390390
) -> List[int]:
391-
return [self._measure(axis, seed) for axis in axes]
391+
return [self._measure(axis, random_state.parse_random_state(seed)) for axis in axes]
392392

393393

394394
def _phase(exponent, global_shift):

cirq-core/cirq/sim/density_matrix_utils.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ def sample_density_matrix(
6868
qid_shape = (2,) * num_qubits
6969
else:
7070
_validate_density_matrix_qid_shape(density_matrix, qid_shape)
71-
num_qubits = len(qid_shape)
7271
meas_shape = _indices_shape(qid_shape, indices)
7372

7473
if repetitions == 0 or len(indices) == 0:
@@ -139,16 +138,16 @@ def measure_density_matrix(
139138
qid_shape = (2,) * num_qubits
140139
else:
141140
_validate_density_matrix_qid_shape(density_matrix, qid_shape)
142-
num_qubits = len(qid_shape)
143141
meas_shape = _indices_shape(qid_shape, indices)
144142

145-
arrout: np.ndarray = (
146-
np.copy(density_matrix)
147-
if out is None
148-
else density_matrix
149-
if out is density_matrix
150-
else (np.copyto(dst=out, src=density_matrix), out)[-1]
151-
)
143+
arrout: np.ndarray
144+
if out is None:
145+
arrout = np.copy(density_matrix)
146+
elif out is density_matrix:
147+
arrout = density_matrix
148+
else:
149+
np.copyto(dst=out, src=density_matrix)
150+
arrout = out
152151

153152
if len(indices) == 0:
154153
return ([], arrout)

cirq-core/cirq/sim/state_vector_simulation_state.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Objects and methods for acting efficiently on a state vector."""
15-
1615
from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, TYPE_CHECKING, Union
1716

1817
import numpy as np
@@ -225,13 +224,12 @@ def prepare_into_buffer(k: int):
225224
e.reshape(shape * 2).astype(self._state_vector.dtype) for e in kraus_operators
226225
]
227226
p = prng.random()
228-
weight = None
229-
fallback_weight = 0
227+
fallback_weight = 0.0
230228
fallback_weight_index = 0
231-
index = None
229+
232230
for index in range(len(kraus_tensors)):
233231
prepare_into_buffer(index)
234-
weight = np.linalg.norm(self._buffer) ** 2
232+
weight = float(np.linalg.norm(self._buffer) ** 2)
235233

236234
if weight > fallback_weight:
237235
fallback_weight_index = index

cirq-core/cirq/transformers/heuristic_decompositions/gate_tabulation_math_utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Union, Sequence, Optional
44

55
import numpy as np
6+
from cirq.value import random_state
67

78
_RealArraylike = Union[np.ndarray, float]
89

@@ -58,7 +59,7 @@ def random_qubit_unitary(
5859
rng: Random number generator to be used in sampling. Default is
5960
numpy.random.
6061
"""
61-
real_rng: np.random.RandomState = np.random if rng is None else rng
62+
real_rng = random_state.parse_random_state(rng)
6263

6364
theta = np.arcsin(np.sqrt(real_rng.rand(*shape)))
6465
phi_d = real_rng.rand(*shape) * np.pi * 2

cirq-core/cirq/value/type_alias.py

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from typing import Union
16+
1617
import sympy
1718

1819
from cirq._doc import document

cirq-core/cirq/vis/state_histogram.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ def plot_state_histogram(
9393
tick_label, values = zip(*sorted(data.items()))
9494
else:
9595
values = np.array(data)
96-
if not tick_label:
97-
tick_label = np.arange(len(values))
96+
if tick_label is None:
97+
tick_label = [str(i) for i in range(len(values))]
9898
ax.bar(np.arange(len(values)), values, tick_label=tick_label)
9999
ax.set_xlabel(xlabel)
100100
ax.set_ylabel(ylabel)

cirq-google/cirq_google/engine/virtual_engine_factory.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def create_default_noisy_quantum_virtual_machine(
393393
try: # coverage: ignore
394394
import qsimcirq # type: ignore
395395

396-
simulator_class = qsimcirq.Simulator # coverage: ignore
396+
simulator_class = qsimcirq.QSimSimulator # coverage: ignore
397397
except ImportError:
398398
simulator_class = cirq.Simulator # coverage: ignore
399399

0 commit comments

Comments
 (0)