Skip to content

Remove unnecessary state copy #5469

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 8, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions cirq-core/cirq/sim/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,10 +595,11 @@ def simulate_sweep_iter(
possible parameter resolver.
"""
qubit_order = ops.QubitOrder.as_qubit_order(qubit_order)
for param_resolver in study.to_resolvers(params):
resolvers = list(study.to_resolvers(params))
for i, param_resolver in enumerate(resolvers):
state = (
initial_state.copy()
if isinstance(initial_state, SimulationStateBase)
if isinstance(initial_state, SimulationStateBase) and i < len(resolvers) - 1
else initial_state
)
all_step_results = self.simulate_moment_steps(
Expand Down
42 changes: 34 additions & 8 deletions cirq-core/cirq/sim/simulator_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@


class CountingState(cirq.qis.QuantumStateRepresentation):
def __init__(self, data, gate_count=0, measurement_count=0):
def __init__(self, data, gate_count=0, measurement_count=0, copy_count=0):
self.data = data
self.gate_count = gate_count
self.measurement_count = measurement_count
self.copy_count = copy_count

def measure(
self, axes: Sequence[int], seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None
Expand All @@ -38,21 +39,22 @@ def kron(self: 'CountingState', other: 'CountingState') -> 'CountingState':
self.data,
self.gate_count + other.gate_count,
self.measurement_count + other.measurement_count,
self.copy_count + other.copy_count,
)

def factor(
self: 'CountingState', axes: Sequence[int], *, validate=True, atol=1e-07
) -> Tuple['CountingState', 'CountingState']:
return CountingState(self.data, self.gate_count, self.measurement_count), CountingState(
self.data
)
return CountingState(
self.data, self.gate_count, self.measurement_count, self.copy_count
), CountingState(self.data)

def reindex(self: 'CountingState', axes: Sequence[int]) -> 'CountingState':
return self.copy()
return CountingState(self.data, self.gate_count, self.measurement_count, self.copy_count)

def copy(self, deep_copy_buffers: bool = True) -> 'CountingState':
return CountingState(
data=self.data, gate_count=self.gate_count, measurement_count=self.measurement_count
self.data, self.gate_count, self.measurement_count, self.copy_count + 1
)


Expand All @@ -79,6 +81,10 @@ def gate_count(self):
def measurement_count(self):
return self._state.measurement_count

@property
def copy_count(self):
return self._state.copy_count


class SplittableCountingSimulationState(CountingSimulationState):
@property
Expand Down Expand Up @@ -171,19 +177,22 @@ def test_simulate_empty_circuit():
r = sim.simulate(cirq.Circuit())
assert r._final_simulator_state.gate_count == 0
assert r._final_simulator_state.measurement_count == 0
assert r._final_simulator_state.copy_count == 0


def test_simulate_one_gate_circuit():
sim = CountingSimulator()
r = sim.simulate(cirq.Circuit(cirq.X(q0)))
assert r._final_simulator_state.gate_count == 1
assert r._final_simulator_state.copy_count == 0


def test_simulate_one_measurement_circuit():
sim = CountingSimulator()
r = sim.simulate(cirq.Circuit(cirq.measure(q0)))
assert r._final_simulator_state.gate_count == 0
assert r._final_simulator_state.measurement_count == 1
assert r._final_simulator_state.copy_count == 0


def test_empty_circuit_simulation_has_moment():
Expand All @@ -196,13 +205,26 @@ def test_noise_applied():
sim = CountingSimulator(noise=cirq.X)
r = sim.simulate(cirq.Circuit(cirq.X(q0)))
assert r._final_simulator_state.gate_count == 2
assert r._final_simulator_state.copy_count == 0


def test_noise_applied_measurement_gate():
sim = CountingSimulator(noise=cirq.X)
r = sim.simulate(cirq.Circuit(cirq.measure(q0)))
assert r._final_simulator_state.gate_count == 1
assert r._final_simulator_state.measurement_count == 1
assert r._final_simulator_state.copy_count == 0


def test_parameterized_copies_all_but_last():
sim = CountingSimulator()
n = 4
rs = sim.simulate_sweep(cirq.Circuit(cirq.X(q0) ** 'a'), [{'a': i} for i in range(n)])
for i in range(n):
r = rs[i]
assert r._final_simulator_state.gate_count == 1
assert r._final_simulator_state.measurement_count == 0
assert r._final_simulator_state.copy_count == 0 if i == n - 1 else 1


def test_cannot_act():
Expand Down Expand Up @@ -382,14 +404,18 @@ def _has_unitary_(self):
op1 = TestOp(has_unitary=True)
op2 = TestOp(has_unitary=True)
circuit = cirq.Circuit(op1, cirq.XPowGate(exponent=sympy.Symbol('a'))(q), op2)
simulator.simulate_sweep(program=circuit, params=params)
rs = simulator.simulate_sweep(program=circuit, params=params)
assert rs[0]._final_simulator_state.copy_count == 1
assert rs[1]._final_simulator_state.copy_count == 0
assert op1.count == 1
assert op2.count == 2

op1 = TestOp(has_unitary=False)
op2 = TestOp(has_unitary=False)
circuit = cirq.Circuit(op1, cirq.XPowGate(exponent=sympy.Symbol('a'))(q), op2)
simulator.simulate_sweep(program=circuit, params=params)
rs = simulator.simulate_sweep(program=circuit, params=params)
assert rs[0]._final_simulator_state.copy_count == 1
assert rs[1]._final_simulator_state.copy_count == 0
assert op1.count == 2
assert op2.count == 2

Expand Down