Skip to content

Setup for disabling state_vector copy #5324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions cirq-core/cirq/circuits/circuit_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def test_recursive_params():

# First example should behave like an X when simulated
result = cirq.Simulator().simulate(cirq.Circuit(circuitop), param_resolver=outer_params)
assert np.allclose(result.state_vector(), [0, 1])
assert np.allclose(result.state_vector(copy=False), [0, 1])


@pytest.mark.parametrize('add_measurements', [True, False])
Expand Down Expand Up @@ -343,9 +343,9 @@ def test_repeat_zero_times(add_measurements, use_repetition_ids, initial_reps):
subcircuit.freeze(), repetitions=initial_reps, use_repetition_ids=use_repetition_ids
)
result = cirq.Simulator().simulate(cirq.Circuit(op))
assert np.allclose(result.state_vector(), [0, 1] if initial_reps % 2 else [1, 0])
assert np.allclose(result.state_vector(copy=False), [0, 1] if initial_reps % 2 else [1, 0])
result = cirq.Simulator().simulate(cirq.Circuit(op**0))
assert np.allclose(result.state_vector(), [1, 0])
assert np.allclose(result.state_vector(copy=False), [1, 0])


def test_no_repetition_ids():
Expand Down Expand Up @@ -375,13 +375,13 @@ def test_parameterized_repeat():
assert cirq.parameter_names(op) == {'a'}
assert not cirq.has_unitary(op)
result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 0})
assert np.allclose(result.state_vector(), [1, 0])
assert np.allclose(result.state_vector(copy=False), [1, 0])
result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1})
assert np.allclose(result.state_vector(), [0, 1])
assert np.allclose(result.state_vector(copy=False), [0, 1])
result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 2})
assert np.allclose(result.state_vector(), [1, 0])
assert np.allclose(result.state_vector(copy=False), [1, 0])
result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': -1})
assert np.allclose(result.state_vector(), [0, 1])
assert np.allclose(result.state_vector(copy=False), [0, 1])
with pytest.raises(TypeError, match='Only integer or sympy repetitions are allowed'):
cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1.5})
with pytest.raises(ValueError, match='Circuit contains ops whose symbols were not specified'):
Expand All @@ -390,13 +390,13 @@ def test_parameterized_repeat():
assert cirq.parameter_names(op) == {'a'}
assert not cirq.has_unitary(op)
result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 0})
assert np.allclose(result.state_vector(), [1, 0])
assert np.allclose(result.state_vector(copy=False), [1, 0])
result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1})
assert np.allclose(result.state_vector(), [0, 1])
assert np.allclose(result.state_vector(copy=False), [0, 1])
result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 2})
assert np.allclose(result.state_vector(), [1, 0])
assert np.allclose(result.state_vector(copy=False), [1, 0])
result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': -1})
assert np.allclose(result.state_vector(), [0, 1])
assert np.allclose(result.state_vector(copy=False), [0, 1])
with pytest.raises(TypeError, match='Only integer or sympy repetitions are allowed'):
cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1.5})
with pytest.raises(ValueError, match='Circuit contains ops whose symbols were not specified'):
Expand All @@ -405,11 +405,11 @@ def test_parameterized_repeat():
assert cirq.parameter_names(op) == {'a', 'b'}
assert not cirq.has_unitary(op)
result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1, 'b': 1})
assert np.allclose(result.state_vector(), [0, 1])
assert np.allclose(result.state_vector(copy=False), [0, 1])
result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 2, 'b': 1})
assert np.allclose(result.state_vector(), [1, 0])
assert np.allclose(result.state_vector(copy=False), [1, 0])
result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1, 'b': 2})
assert np.allclose(result.state_vector(), [1, 0])
assert np.allclose(result.state_vector(copy=False), [1, 0])
with pytest.raises(TypeError, match='Only integer or sympy repetitions are allowed'):
cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1.5, 'b': 1})
with pytest.raises(ValueError, match='Circuit contains ops whose symbols were not specified'):
Expand All @@ -418,11 +418,11 @@ def test_parameterized_repeat():
assert cirq.parameter_names(op) == {'a', 'b'}
assert not cirq.has_unitary(op)
result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1, 'b': 1})
assert np.allclose(result.state_vector(), [1, 0])
assert np.allclose(result.state_vector(copy=False), [1, 0])
result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1.5, 'b': 1})
assert np.allclose(result.state_vector(), [0, 1])
assert np.allclose(result.state_vector(copy=False), [0, 1])
result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1, 'b': 1.5})
assert np.allclose(result.state_vector(), [0, 1])
assert np.allclose(result.state_vector(copy=False), [0, 1])
with pytest.raises(TypeError, match='Only integer or sympy repetitions are allowed'):
cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1.5, 'b': 1.5})
with pytest.raises(ValueError, match='Circuit contains ops whose symbols were not specified'):
Expand Down
6 changes: 4 additions & 2 deletions cirq-core/cirq/contrib/quantum_volume/quantum_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,14 @@ def compute_heavy_set(circuit: cirq.Circuit) -> List[int]:
# output is defined in terms of probabilities, where our wave function is in
# terms of amplitudes. We convert it by using the Born rule: squaring each
# amplitude and taking their absolute value
median = np.median(np.abs(results.state_vector() ** 2))
median = np.median(np.abs(results.state_vector(copy=False) ** 2))

# The output wave function is a vector from the result value (big-endian) to
# the probability of that bit-string. Return all of the bit-string
# values that have a probability greater than the median.
return [idx for idx, amp in enumerate(results.state_vector()) if np.abs(amp**2) > median]
return [
idx for idx, amp in enumerate(results.state_vector(copy=False)) if np.abs(amp**2) > median
]


@dataclass
Expand Down
3 changes: 2 additions & 1 deletion cirq-core/cirq/experiments/grid_parallel_two_qubit_xeb.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,8 @@ def _get_xeb_result(
while moment_index < 2 * depth:
step_result = next(step_results)
moment_index += 1
amplitudes = step_result.state_vector()
# copy=False is safe because state_vector_to_probabilities will copy anyways
amplitudes = step_result.state_vector(copy=False)
probabilities = value.state_vector_to_probabilities(amplitudes)
_, counts = np.unique(measurements, return_counts=True)
empirical_probs = counts / len(measurements)
Expand Down
3 changes: 2 additions & 1 deletion cirq-core/cirq/experiments/xeb_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def __call__(self, task: _Simulate2qXEBTask) -> List[Dict[str, Any]]:
if cycle_depth not in cycle_depths:
continue

psi = step_result.state_vector()
# copy=False is safe because state_vector_to_probabilities will copy anyways
psi = step_result.state_vector(copy=False)
pure_probs = value.state_vector_to_probabilities(psi)

records += [
Expand Down
6 changes: 5 additions & 1 deletion cirq-core/cirq/ops/boolean_hamiltonian_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,11 @@ def test_circuit(boolean_str):

circuit.append(hamiltonian_gate.on(*qubits))

phi = cirq.Simulator().simulate(circuit, qubit_order=qubits, initial_state=0).state_vector()
phi = (
cirq.Simulator()
.simulate(circuit, qubit_order=qubits, initial_state=0)
.state_vector(copy=False)
)
actual = np.arctan2(phi.real, phi.imag) - math.pi / 2.0 > 0.0

# Compare the two:
Expand Down
32 changes: 23 additions & 9 deletions cirq-core/cirq/ops/common_gates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,7 +1088,7 @@ def test_xpow_dim_3():

sim = cirq.Simulator()
circuit = cirq.Circuit([x(cirq.LineQid(0, 3)) ** 0.5] * 6)
svs = [step.state_vector() for step in sim.simulate_moment_steps(circuit)]
svs = [step.state_vector(copy=True) for step in sim.simulate_moment_steps(circuit)]
# fmt: off
expected = [
[0.67, 0.67, 0.33],
Expand Down Expand Up @@ -1116,7 +1116,7 @@ def test_xpow_dim_4():

sim = cirq.Simulator()
circuit = cirq.Circuit([x(cirq.LineQid(0, 4)) ** 0.5] * 8)
svs = [step.state_vector() for step in sim.simulate_moment_steps(circuit)]
svs = [step.state_vector(copy=True) for step in sim.simulate_moment_steps(circuit)]
# fmt: off
expected = [
[0.65, 0.65, 0.27, 0.27],
Expand Down Expand Up @@ -1147,11 +1147,15 @@ def test_zpow_dim_3():

sim = cirq.Simulator()
circuit = cirq.Circuit([z(cirq.LineQid(0, 3)) ** 0.5] * 6)
svs = [step.state_vector() for step in sim.simulate_moment_steps(circuit, initial_state=0)]
svs = [
step.state_vector(copy=True) for step in sim.simulate_moment_steps(circuit, initial_state=0)
]
expected = [[1, 0, 0]] * 6
assert np.allclose((svs), expected)

svs = [step.state_vector() for step in sim.simulate_moment_steps(circuit, initial_state=1)]
svs = [
step.state_vector(copy=True) for step in sim.simulate_moment_steps(circuit, initial_state=1)
]
# fmt: off
expected = [
[0, L**0.5, 0],
Expand All @@ -1164,7 +1168,9 @@ def test_zpow_dim_3():
# fmt: on
assert np.allclose((svs), expected)

svs = [step.state_vector() for step in sim.simulate_moment_steps(circuit, initial_state=2)]
svs = [
step.state_vector(copy=True) for step in sim.simulate_moment_steps(circuit, initial_state=2)
]
# fmt: off
expected = [
[0, 0, L],
Expand Down Expand Up @@ -1192,11 +1198,15 @@ def test_zpow_dim_4():

sim = cirq.Simulator()
circuit = cirq.Circuit([z(cirq.LineQid(0, 4)) ** 0.5] * 8)
svs = [step.state_vector() for step in sim.simulate_moment_steps(circuit, initial_state=0)]
svs = [
step.state_vector(copy=True) for step in sim.simulate_moment_steps(circuit, initial_state=0)
]
expected = [[1, 0, 0, 0]] * 8
assert np.allclose((svs), expected)

svs = [step.state_vector() for step in sim.simulate_moment_steps(circuit, initial_state=1)]
svs = [
step.state_vector(copy=True) for step in sim.simulate_moment_steps(circuit, initial_state=1)
]
# fmt: off
expected = [
[0, 1j**0.5, 0, 0],
Expand All @@ -1211,7 +1221,9 @@ def test_zpow_dim_4():
# fmt: on
assert np.allclose(svs, expected)

svs = [step.state_vector() for step in sim.simulate_moment_steps(circuit, initial_state=2)]
svs = [
step.state_vector(copy=True) for step in sim.simulate_moment_steps(circuit, initial_state=2)
]
# fmt: off
expected = [
[0, 0, 1j, 0],
Expand All @@ -1226,7 +1238,9 @@ def test_zpow_dim_4():
# fmt: on
assert np.allclose(svs, expected)

svs = [step.state_vector() for step in sim.simulate_moment_steps(circuit, initial_state=3)]
svs = [
step.state_vector(copy=True) for step in sim.simulate_moment_steps(circuit, initial_state=3)
]
# fmt: off
expected = [
[0, 0, 0, 1j**1.5],
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/sim/mux.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def final_state_vector(
param_resolver=param_resolver,
)

return result.state_vector()
return result.state_vector(copy=False)


def sample_sweep(
Expand Down
4 changes: 3 additions & 1 deletion cirq-core/cirq/sim/simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,9 @@ def _kraus_(self):
cirq.Circuit(Reset11To00().on(*cirq.LineQubit.range(2))), initial_state=k
)
np.testing.assert_allclose(
out.state_vector(), cirq.one_hot(index=k % 3, shape=4, dtype=np.complex64), atol=1e-8
out.state_vector(copy=False),
cirq.one_hot(index=k % 3, shape=4, dtype=np.complex64),
atol=1e-8,
)


Expand Down
10 changes: 8 additions & 2 deletions cirq-core/cirq/sim/sparse_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import numpy as np

from cirq import ops
from cirq import _compat, ops
from cirq._compat import deprecated_parameter
from cirq.sim import simulator, state_vector, state_vector_simulator, state_vector_simulation_state

Expand Down Expand Up @@ -246,7 +246,7 @@ def __init__(
self._dtype = dtype
self._state_vector: Optional[np.ndarray] = None

def state_vector(self, copy: bool = True):
def state_vector(self, copy: Optional[bool] = None):
"""Return the state vector at this point in the computation.

The state is returned in the computational basis with these basis
Expand Down Expand Up @@ -279,6 +279,12 @@ def state_vector(self, copy: bool = True):
parameters from the state vector and store then using False
can speed up simulation by eliminating a memory copy.
"""
if copy is None:
_compat._warn_or_error(
"Starting in v0.16, state_vector will not copy the state by default. "
"Explicitly set copy=True to copy the state."
)
copy = True
if self._state_vector is None:
self._state_vector = np.array([1])
state = self._merged_sim_state
Expand Down
39 changes: 25 additions & 14 deletions cirq-core/cirq/sim/sparse_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,9 +550,20 @@ def test_simulate_moment_steps(dtype: Type[np.number], split: bool):
simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
for i, step in enumerate(simulator.simulate_moment_steps(circuit)):
if i == 0:
np.testing.assert_almost_equal(step.state_vector(), np.array([0.5] * 4))
np.testing.assert_almost_equal(step.state_vector(copy=True), np.array([0.5] * 4))
else:
np.testing.assert_almost_equal(step.state_vector(), np.array([1, 0, 0, 0]))
np.testing.assert_almost_equal(step.state_vector(copy=True), np.array([1, 0, 0, 0]))


def test_simulate_moment_steps_implicit_copy_deprecated():
q0 = cirq.LineQubit(0)
simulator = cirq.Simulator()
steps = list(simulator.simulate_moment_steps(cirq.Circuit(cirq.X(q0))))

with cirq.testing.assert_deprecated(
"state_vector will not copy the state by default", deadline="v0.16"
):
_ = steps[0].state_vector()


@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
Expand All @@ -563,7 +574,7 @@ def test_simulate_moment_steps_empty_circuit(dtype: Type[np.number], split: bool
step = None
for step in simulator.simulate_moment_steps(circuit):
pass
assert np.allclose(step.state_vector(), np.array([1]))
assert np.allclose(step.state_vector(copy=True), np.array([1]))
assert not step.qubit_map


Expand Down Expand Up @@ -599,10 +610,10 @@ def test_simulate_moment_steps_intermediate_measurement(dtype: Type[np.number],
result = int(step.measurements['q(0)'][0])
expected = np.zeros(2)
expected[result] = 1
np.testing.assert_almost_equal(step.state_vector(), expected)
np.testing.assert_almost_equal(step.state_vector(copy=True), expected)
if i == 2:
expected = np.array([np.sqrt(0.5), np.sqrt(0.5) * (-1) ** result])
np.testing.assert_almost_equal(step.state_vector(), expected)
np.testing.assert_almost_equal(step.state_vector(copy=True), expected)


@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
Expand Down Expand Up @@ -710,8 +721,8 @@ def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs):

initial_state = np.array([np.sqrt(0.5), np.sqrt(0.5)], dtype=np.complex64)
result = simulator.simulate(circuit, initial_state=initial_state)
np.testing.assert_array_almost_equal(result.state_vector(), initial_state)
assert not initial_state is result.state_vector()
np.testing.assert_array_almost_equal(result.state_vector(copy=False), initial_state)
assert not initial_state is result.state_vector(copy=False)


def test_does_not_modify_initial_state():
Expand All @@ -735,7 +746,7 @@ def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs):
result = simulator.simulate(circuit, initial_state=initial_state)
np.testing.assert_array_almost_equal(np.array([1, 0], dtype=np.complex64), initial_state)
np.testing.assert_array_almost_equal(
result.state_vector(), np.array([0, 1], dtype=np.complex64)
result.state_vector(copy=False), np.array([0, 1], dtype=np.complex64)
)


Expand Down Expand Up @@ -787,7 +798,7 @@ def test_simulates_composite():
np.testing.assert_allclose(
c.final_state_vector(ignore_terminal_measurements=False, dtype=np.complex64), expected
)
np.testing.assert_allclose(cirq.Simulator().simulate(c).state_vector(), expected)
np.testing.assert_allclose(cirq.Simulator().simulate(c).state_vector(copy=False), expected)


def test_simulate_measurement_inversions():
Expand All @@ -804,15 +815,15 @@ def test_works_on_pauli_string_phasor():
a, b = cirq.LineQubit.range(2)
c = cirq.Circuit(np.exp(0.5j * np.pi * cirq.X(a) * cirq.X(b)))
sim = cirq.Simulator()
result = sim.simulate(c).state_vector()
result = sim.simulate(c).state_vector(copy=False)
np.testing.assert_allclose(result.reshape(4), np.array([0, 0, 0, 1j]), atol=1e-8)


def test_works_on_pauli_string():
a, b = cirq.LineQubit.range(2)
c = cirq.Circuit(cirq.X(a) * cirq.X(b))
sim = cirq.Simulator()
result = sim.simulate(c).state_vector()
result = sim.simulate(c).state_vector(copy=False)
np.testing.assert_allclose(result.reshape(4), np.array([0, 0, 0, 1]), atol=1e-8)


Expand Down Expand Up @@ -1322,9 +1333,9 @@ def test_final_state_vector_is_not_last_object():
initial_state = np.array([1, 0], dtype=np.complex64)
circuit = cirq.Circuit(cirq.wait(q))
result = sim.simulate(circuit, initial_state=initial_state)
assert result.state_vector() is not initial_state
assert not np.shares_memory(result.state_vector(), initial_state)
np.testing.assert_equal(result.state_vector(), initial_state)
assert result.state_vector(copy=False) is not initial_state
assert not np.shares_memory(result.state_vector(copy=False), initial_state)
np.testing.assert_equal(result.state_vector(copy=False), initial_state)


def test_deterministic_gate_noise():
Expand Down
Loading