Skip to content

Commit f246032

Browse files
Consistent final_state_vector (quantumlib#5267)
Fixes quantumlib#3693. This pushes `Circuit.final_state_vector` and `mux.final_state_vector` towards a shared API of: ```python def final_state_vector( program: 'cirq.CIRCUIT_LIKE', *, initial_state: 'cirq.STATE_VECTOR_LIKE' = 0, param_resolver: 'cirq.ParamResolverOrSimilarType' = None, qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT, ignore_terminal_measurements: bool = False, dtype: Type[np.number] = np.complex64, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, ) -> 'np.ndarray': ``` Additionally, calls to `Circuit.final_state_vector` will now pass through the simulator instead of using `apply_unitary`. Notable changes in `Circuit`: - `qubits_that_should_be_present` is deprecated - Use of positional args is deprecated - The default of `ignore_terminal_measurements` will change to `False` in v0.16 - The default of `dtype` has directly changed to `np.complex64` Notable changes in `mux`: - Addition of `ignore_terminal_measurements` parameter Other changes: - Added `drop_terminal_measurements` transformer
1 parent d630a83 commit f246032

12 files changed

+339
-125
lines changed

cirq/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,7 @@
366366
dephase_measurements,
367367
drop_empty_moments,
368368
drop_negligible_operations,
369+
drop_terminal_measurements,
369370
eject_phased_paulis,
370371
eject_z,
371372
expand_composite,

cirq/circuits/circuit.py

+67-60
Original file line numberDiff line numberDiff line change
@@ -1041,89 +1041,96 @@ def _superoperator_(self) -> np.ndarray:
10411041
circuit_superoperator = moment_superoperator @ circuit_superoperator
10421042
return circuit_superoperator
10431043

1044+
@_compat.deprecated_parameter(
1045+
deadline='v0.16',
1046+
fix='Inject identity operators to include untouched qubits.',
1047+
parameter_desc='qubits_that_should_be_present',
1048+
match=lambda args, kwargs: 'qubits_that_should_be_present' in kwargs,
1049+
)
1050+
@_compat.deprecated_parameter(
1051+
deadline='v0.16',
1052+
fix='Only use keyword arguments.',
1053+
parameter_desc='positional args',
1054+
match=lambda args, kwargs: len(args) > 1,
1055+
)
10441056
def final_state_vector(
10451057
self,
1058+
# TODO(v0.16): Force kwargs and match order found in:
1059+
# cirq-core/cirq/sim/mux.py:final_state_vector
10461060
initial_state: 'cirq.STATE_VECTOR_LIKE' = 0,
10471061
qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT,
10481062
qubits_that_should_be_present: Iterable['cirq.Qid'] = (),
1049-
ignore_terminal_measurements: bool = True,
1050-
dtype: Type[np.number] = np.complex128,
1063+
ignore_terminal_measurements: Optional[bool] = None,
1064+
dtype: Optional[Type[np.number]] = None,
1065+
param_resolver: 'cirq.ParamResolverOrSimilarType' = None,
1066+
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
10511067
) -> np.ndarray:
1052-
"""Left-multiplies a state vector by the circuit's unitary effect.
1053-
1054-
A circuit's "unitary effect" is the unitary matrix produced by
1055-
multiplying together all of its gates' unitary matrices. A circuit
1056-
with non-unitary gates (such as measurement or parameterized gates) does
1057-
not have a well-defined unitary effect, and the method will fail if such
1058-
operations are present.
1068+
"""Returns the state vector resulting from acting operations on a state.
10591069
1060-
For convenience, terminal measurements are automatically ignored
1061-
instead of causing a failure. Set the `ignore_terminal_measurements`
1062-
argument to False to disable this behavior.
1063-
1064-
This method is equivalent to left-multiplying the input state by
1065-
`cirq.unitary(circuit)` but it's computed in a more efficient
1066-
way.
1070+
This is equivalent to calling cirq.final_state_vector with the same
1071+
arguments and this circuit as the "program".
10671072
10681073
Args:
1069-
initial_state: The input state for the circuit. This can be a list
1070-
of qudit values, a big endian int encoding the qudit values,
1071-
a vector of amplitudes, or a tensor of amplitudes.
1072-
1073-
When this is an int, it refers to a computational
1074-
basis state (e.g. 5 means initialize to ``|5⟩ = |...000101⟩``).
1075-
1076-
If this is a vector of amplitudes (a flat numpy array of the
1077-
correct length for the system) or a tensor of amplitudes (a
1078-
numpy array whose shape equals this circuit's `qid_shape`), it
1079-
directly specifies the initial state's amplitudes. The vector
1080-
type must be convertible to the given `dtype` argument.
1081-
qubit_order: Determines how qubits are ordered when passing matrices
1082-
into np.kron.
1074+
initial_state: If an int, the state is set to the computational
1075+
basis state corresponding to this state. Otherwise if this
1076+
is a np.ndarray it is the full initial state. In this case it
1077+
must be the correct size, be normalized (an L2 norm of 1), and
1078+
be safely castable to an appropriate dtype for the simulator.
1079+
qubit_order: Determines the canonical ordering of the qubits. This
1080+
is often used in specifying the initial state, i.e. the
1081+
ordering of the computational basis states.
10831082
qubits_that_should_be_present: Qubits that may or may not appear
10841083
in operations within the circuit, but that should be included
10851084
regardless when generating the matrix.
10861085
ignore_terminal_measurements: When set, measurements at the end of
10871086
the circuit are ignored instead of causing the method to
10881087
fail.
1089-
dtype: The numpy dtype for the returned unitary. Defaults to
1090-
np.complex128. Specifying np.complex64 will run faster at the
1091-
cost of precision. `dtype` must be a complex np.dtype, unless
1092-
all operations in the circuit have unitary matrices with
1093-
exclusively real coefficients (e.g. an H + TOFFOLI circuit).
1088+
dtype: The `numpy.dtype` used by the simulation. Typically one of
1089+
`numpy.complex64` or `numpy.complex128`.
1090+
param_resolver: Parameters to run with the program.
1091+
seed: The random seed to use for this simulator.
10941092
10951093
Returns:
1096-
A (possibly gigantic) numpy array storing the superposition that
1097-
came out of the circuit for the given input state.
1094+
The state vector resulting from applying the given unitary
1095+
operations to the desired initial state. Specifically, a numpy
1096+
array containing the amplitudes in np.kron order, where the
1097+
order of arguments to kron is determined by the qubit order
1098+
argument (which defaults to just sorting the qubits that are
1099+
present into an ascending order).
10981100
10991101
Raises:
1100-
ValueError: The circuit contains measurement gates that are not
1101-
ignored.
1102-
TypeError: The circuit contains gates that don't have a known
1103-
unitary matrix, e.g. gates parameterized by a Symbol.
1102+
ValueError: If the program doesn't have a well defined final state
1103+
because it has non-unitary gates.
11041104
"""
1105+
if ignore_terminal_measurements is None:
1106+
if self.has_measurements():
1107+
_compat._warn_or_error(
1108+
'`ignore_terminal_measurements` will default to False in v0.16. '
1109+
'To drop terminal measurements, please explicitly include '
1110+
'`ignore_terminal_measurements=True` when calling this method.'
1111+
)
1112+
ignore_terminal_measurements = True
1113+
1114+
if dtype is None:
1115+
_compat._warn_or_error(
1116+
'`dtype` will default to np.complex64 in v0.16. '
1117+
'To use the previous default, please explicitly include '
1118+
'`dtype=np.complex128` when calling this method.'
1119+
)
1120+
dtype = np.complex128
11051121

1106-
if not ignore_terminal_measurements and any(
1107-
protocols.is_measurement(op) for op in self.all_operations()
1108-
):
1109-
raise ValueError('Circuit contains a measurement.')
1110-
1111-
if not self.are_all_measurements_terminal():
1112-
raise ValueError('Circuit contains a non-terminal measurement.')
1113-
1114-
qs = ops.QubitOrder.as_qubit_order(qubit_order).order_for(
1115-
self.all_qubits().union(qubits_that_should_be_present)
1116-
)
1117-
1118-
# Force qubits to have dimension at least 2 for backwards compatibility.
1119-
qid_shape = self.qid_shape(qubit_order=qs)
1120-
state_len = np.prod(qid_shape, dtype=np.int64)
1122+
from cirq.sim.mux import final_state_vector
11211123

1122-
state = qis.to_valid_state_vector(initial_state, qid_shape=qid_shape, dtype=dtype).reshape(
1123-
qid_shape
1124+
program = Circuit(cirq.I(q) for q in qubits_that_should_be_present) + self
1125+
return final_state_vector(
1126+
program,
1127+
initial_state=initial_state,
1128+
param_resolver=param_resolver,
1129+
qubit_order=qubit_order,
1130+
ignore_terminal_measurements=ignore_terminal_measurements,
1131+
dtype=dtype,
1132+
seed=seed,
11241133
)
1125-
result = _apply_unitary_circuit(self, state, qs, dtype)
1126-
return result.reshape((state_len,))
11271134

11281135
def to_text_diagram(
11291136
self,

0 commit comments

Comments
 (0)