Skip to content

Commit 7ab1b73

Browse files
authored
Remove unnecessary state copy (quantumlib#5469)
Removes the copy when simulating params. quantumlib#3494 (comment) for discussion. @95-martin-orion
1 parent eb9f1a7 commit 7ab1b73

File tree

2 files changed

+37
-10
lines changed

2 files changed

+37
-10
lines changed

cirq/sim/simulator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -595,10 +595,11 @@ def simulate_sweep_iter(
595595
possible parameter resolver.
596596
"""
597597
qubit_order = ops.QubitOrder.as_qubit_order(qubit_order)
598-
for param_resolver in study.to_resolvers(params):
598+
resolvers = list(study.to_resolvers(params))
599+
for i, param_resolver in enumerate(resolvers):
599600
state = (
600601
initial_state.copy()
601-
if isinstance(initial_state, SimulationStateBase)
602+
if isinstance(initial_state, SimulationStateBase) and i < len(resolvers) - 1
602603
else initial_state
603604
)
604605
all_step_results = self.simulate_moment_steps(

cirq/sim/simulator_base_test.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@
2222

2323

2424
class CountingState(cirq.qis.QuantumStateRepresentation):
25-
def __init__(self, data, gate_count=0, measurement_count=0):
25+
def __init__(self, data, gate_count=0, measurement_count=0, copy_count=0):
2626
self.data = data
2727
self.gate_count = gate_count
2828
self.measurement_count = measurement_count
29+
self.copy_count = copy_count
2930

3031
def measure(
3132
self, axes: Sequence[int], seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None
@@ -38,21 +39,22 @@ def kron(self: 'CountingState', other: 'CountingState') -> 'CountingState':
3839
self.data,
3940
self.gate_count + other.gate_count,
4041
self.measurement_count + other.measurement_count,
42+
self.copy_count + other.copy_count,
4143
)
4244

4345
def factor(
4446
self: 'CountingState', axes: Sequence[int], *, validate=True, atol=1e-07
4547
) -> Tuple['CountingState', 'CountingState']:
46-
return CountingState(self.data, self.gate_count, self.measurement_count), CountingState(
47-
self.data
48-
)
48+
return CountingState(
49+
self.data, self.gate_count, self.measurement_count, self.copy_count
50+
), CountingState(self.data)
4951

5052
def reindex(self: 'CountingState', axes: Sequence[int]) -> 'CountingState':
51-
return self.copy()
53+
return CountingState(self.data, self.gate_count, self.measurement_count, self.copy_count)
5254

5355
def copy(self, deep_copy_buffers: bool = True) -> 'CountingState':
5456
return CountingState(
55-
data=self.data, gate_count=self.gate_count, measurement_count=self.measurement_count
57+
self.data, self.gate_count, self.measurement_count, self.copy_count + 1
5658
)
5759

5860

@@ -79,6 +81,10 @@ def gate_count(self):
7981
def measurement_count(self):
8082
return self._state.measurement_count
8183

84+
@property
85+
def copy_count(self):
86+
return self._state.copy_count
87+
8288

8389
class SplittableCountingSimulationState(CountingSimulationState):
8490
@property
@@ -171,19 +177,22 @@ def test_simulate_empty_circuit():
171177
r = sim.simulate(cirq.Circuit())
172178
assert r._final_simulator_state.gate_count == 0
173179
assert r._final_simulator_state.measurement_count == 0
180+
assert r._final_simulator_state.copy_count == 0
174181

175182

176183
def test_simulate_one_gate_circuit():
177184
sim = CountingSimulator()
178185
r = sim.simulate(cirq.Circuit(cirq.X(q0)))
179186
assert r._final_simulator_state.gate_count == 1
187+
assert r._final_simulator_state.copy_count == 0
180188

181189

182190
def test_simulate_one_measurement_circuit():
183191
sim = CountingSimulator()
184192
r = sim.simulate(cirq.Circuit(cirq.measure(q0)))
185193
assert r._final_simulator_state.gate_count == 0
186194
assert r._final_simulator_state.measurement_count == 1
195+
assert r._final_simulator_state.copy_count == 0
187196

188197

189198
def test_empty_circuit_simulation_has_moment():
@@ -196,13 +205,26 @@ def test_noise_applied():
196205
sim = CountingSimulator(noise=cirq.X)
197206
r = sim.simulate(cirq.Circuit(cirq.X(q0)))
198207
assert r._final_simulator_state.gate_count == 2
208+
assert r._final_simulator_state.copy_count == 0
199209

200210

201211
def test_noise_applied_measurement_gate():
202212
sim = CountingSimulator(noise=cirq.X)
203213
r = sim.simulate(cirq.Circuit(cirq.measure(q0)))
204214
assert r._final_simulator_state.gate_count == 1
205215
assert r._final_simulator_state.measurement_count == 1
216+
assert r._final_simulator_state.copy_count == 0
217+
218+
219+
def test_parameterized_copies_all_but_last():
220+
sim = CountingSimulator()
221+
n = 4
222+
rs = sim.simulate_sweep(cirq.Circuit(cirq.X(q0) ** 'a'), [{'a': i} for i in range(n)])
223+
for i in range(n):
224+
r = rs[i]
225+
assert r._final_simulator_state.gate_count == 1
226+
assert r._final_simulator_state.measurement_count == 0
227+
assert r._final_simulator_state.copy_count == 0 if i == n - 1 else 1
206228

207229

208230
def test_cannot_act():
@@ -382,14 +404,18 @@ def _has_unitary_(self):
382404
op1 = TestOp(has_unitary=True)
383405
op2 = TestOp(has_unitary=True)
384406
circuit = cirq.Circuit(op1, cirq.XPowGate(exponent=sympy.Symbol('a'))(q), op2)
385-
simulator.simulate_sweep(program=circuit, params=params)
407+
rs = simulator.simulate_sweep(program=circuit, params=params)
408+
assert rs[0]._final_simulator_state.copy_count == 1
409+
assert rs[1]._final_simulator_state.copy_count == 0
386410
assert op1.count == 1
387411
assert op2.count == 2
388412

389413
op1 = TestOp(has_unitary=False)
390414
op2 = TestOp(has_unitary=False)
391415
circuit = cirq.Circuit(op1, cirq.XPowGate(exponent=sympy.Symbol('a'))(q), op2)
392-
simulator.simulate_sweep(program=circuit, params=params)
416+
rs = simulator.simulate_sweep(program=circuit, params=params)
417+
assert rs[0]._final_simulator_state.copy_count == 1
418+
assert rs[1]._final_simulator_state.copy_count == 0
393419
assert op1.count == 2
394420
assert op2.count == 2
395421

0 commit comments

Comments
 (0)