Skip to content

Commit ab0dc94

Browse files
Lazily create TrialResult.final_simulator_state (quantumlib#4317)
* Lazily create TrialResult.final_simulator_state * Change TrialResult only to accept StepResult in constructor * Backwards compatibility for TrialResult initializer * lint * Platform independent test * Update density_matrix_simulator_test.py Co-authored-by: Orion Martin <[email protected]>
1 parent 6d79f4d commit ab0dc94

11 files changed

+249
-106
lines changed

cirq/contrib/quimb/mps_simulator.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -125,23 +125,21 @@ def _create_simulator_trial_result(
125125
self,
126126
params: study.ParamResolver,
127127
measurements: Dict[str, np.ndarray],
128-
final_simulator_state: 'MPSState',
128+
final_step_result: 'MPSSimulatorStepResult',
129129
) -> 'MPSTrialResult':
130130
"""Creates a single trial results with the measurements.
131131
132132
Args:
133-
circuit: The circuit to simulate.
134-
param_resolver: A ParamResolver for determining values of
135-
Symbols.
133+
params: A ParamResolver for determining values of Symbols.
136134
measurements: A dictionary from measurement key (e.g. qubit) to the
137135
actual measurement array.
138-
final_simulator_state: The final state of the simulator.
136+
final_step_result: The final step result of the simulation.
139137
140138
Returns:
141139
A single result.
142140
"""
143141
return MPSTrialResult(
144-
params=params, measurements=measurements, final_simulator_state=final_simulator_state
142+
params=params, measurements=measurements, final_step_result=final_step_result
145143
)
146144

147145

@@ -152,13 +150,15 @@ def __init__(
152150
self,
153151
params: study.ParamResolver,
154152
measurements: Dict[str, np.ndarray],
155-
final_simulator_state: 'MPSState',
153+
final_step_result: 'MPSSimulatorStepResult',
156154
) -> None:
157155
super().__init__(
158-
params=params, measurements=measurements, final_simulator_state=final_simulator_state
156+
params=params, measurements=measurements, final_step_result=final_step_result
159157
)
160158

161-
self.final_state = final_simulator_state
159+
@property
160+
def final_state(self):
161+
return self._final_simulator_state
162162

163163
def __str__(self) -> str:
164164
samples = super().__str__()

cirq/contrib/quimb/mps_simulator_test.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import itertools
22
import math
3+
from unittest import mock
34

45
import numpy as np
56
import pytest
67
import sympy
78

89
import cirq
9-
from cirq import value
1010
import cirq.contrib.quimb as ccq
1111
import cirq.experiments.google_v2_supremacy_circuit as supremacy_v2
12+
from cirq import value
1213

1314

1415
def assert_same_output_as_dense(circuit, qubit_order, initial_state=0, grouping=None):
@@ -252,7 +253,8 @@ def test_measurement_str():
252253

253254
def test_trial_result_str():
254255
q0 = cirq.LineQubit(0)
255-
final_simulator_state = ccq.mps_simulator.MPSState(
256+
final_step_result = mock.Mock(cirq.StepResult)
257+
final_step_result._simulator_state.return_value = ccq.mps_simulator.MPSState(
256258
qubits=(q0,),
257259
prng=value.parse_random_state(0),
258260
simulation_options=ccq.mps_simulator.MPSOptions(),
@@ -262,7 +264,7 @@ def test_trial_result_str():
262264
ccq.mps_simulator.MPSTrialResult(
263265
params=cirq.ParamResolver({}),
264266
measurements={'m': np.array([[1]])},
265-
final_simulator_state=final_simulator_state,
267+
final_step_result=final_step_result,
266268
)
267269
)
268270
== """measurements: m=1

cirq/sim/clifford/clifford_simulator.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,11 @@ def _create_simulator_trial_result(
105105
self,
106106
params: study.ParamResolver,
107107
measurements: Dict[str, np.ndarray],
108-
final_simulator_state,
108+
final_step_result: 'CliffordSimulatorStepResult',
109109
):
110110

111111
return CliffordTrialResult(
112-
params=params, measurements=measurements, final_simulator_state=final_simulator_state
112+
params=params, measurements=measurements, final_step_result=final_step_result
113113
)
114114

115115

@@ -118,13 +118,15 @@ def __init__(
118118
self,
119119
params: study.ParamResolver,
120120
measurements: Dict[str, np.ndarray],
121-
final_simulator_state: 'CliffordState',
121+
final_step_result: 'CliffordSimulatorStepResult',
122122
) -> None:
123123
super().__init__(
124-
params=params, measurements=measurements, final_simulator_state=final_simulator_state
124+
params=params, measurements=measurements, final_step_result=final_step_result
125125
)
126126

127-
self.final_state = final_simulator_state
127+
@property
128+
def final_state(self):
129+
return self._final_simulator_state
128130

129131
def __str__(self) -> str:
130132
samples = super().__str__()

cirq/sim/clifford/clifford_simulator_test.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import itertools
2+
from unittest import mock
3+
24
import numpy as np
35
import pytest
46
import sympy
@@ -208,13 +210,14 @@ def test_clifford_state_initial_state():
208210

209211
def test_clifford_trial_result_repr():
210212
q0 = cirq.LineQubit(0)
211-
final_simulator_state = cirq.CliffordState(qubit_map={q0: 0})
213+
final_step_result = mock.Mock(cirq.CliffordSimulatorStepResult)
214+
final_step_result._simulator_state.return_value = cirq.CliffordState(qubit_map={q0: 0})
212215
assert (
213216
repr(
214217
cirq.CliffordTrialResult(
215218
params=cirq.ParamResolver({}),
216219
measurements={'m': np.array([[1]])},
217-
final_simulator_state=final_simulator_state,
220+
final_step_result=final_step_result,
218221
)
219222
)
220223
== "cirq.SimulationTrialResult(params=cirq.ParamResolver({}), "
@@ -225,13 +228,14 @@ def test_clifford_trial_result_repr():
225228

226229
def test_clifford_trial_result_str():
227230
q0 = cirq.LineQubit(0)
228-
final_simulator_state = cirq.CliffordState(qubit_map={q0: 0})
231+
final_step_result = mock.Mock(cirq.CliffordSimulatorStepResult)
232+
final_step_result._simulator_state.return_value = cirq.CliffordState(qubit_map={q0: 0})
229233
assert (
230234
str(
231235
cirq.CliffordTrialResult(
232236
params=cirq.ParamResolver({}),
233237
measurements={'m': np.array([[1]])},
234-
final_simulator_state=final_simulator_state,
238+
final_step_result=final_step_result,
235239
)
236240
)
237241
== "measurements: m=1\n"

cirq/sim/density_matrix_simulator.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -222,10 +222,10 @@ def _create_simulator_trial_result(
222222
self,
223223
params: study.ParamResolver,
224224
measurements: Dict[str, np.ndarray],
225-
final_simulator_state: 'DensityMatrixSimulatorState',
225+
final_step_result: 'DensityMatrixStepResult',
226226
) -> 'DensityMatrixTrialResult':
227227
return DensityMatrixTrialResult(
228-
params=params, measurements=measurements, final_simulator_state=final_simulator_state
228+
params=params, measurements=measurements, final_step_result=final_step_result
229229
)
230230

231231
# TODO(#4209): Deduplicate with identical code in sparse_simulator.
@@ -423,22 +423,27 @@ class DensityMatrixTrialResult(simulator.SimulationTrialResult):
423423
measurement gate.)
424424
final_simulator_state: The final simulator state of the system after the
425425
trial finishes.
426-
final_density_matrix: The final density matrix of the system.
427426
"""
428427

429428
def __init__(
430429
self,
431430
params: study.ParamResolver,
432431
measurements: Dict[str, np.ndarray],
433-
final_simulator_state: DensityMatrixSimulatorState,
432+
final_step_result: DensityMatrixStepResult,
434433
) -> None:
435434
super().__init__(
436-
params=params, measurements=measurements, final_simulator_state=final_simulator_state
437-
)
438-
size = np.prod(protocols.qid_shape(self), dtype=np.int64)
439-
self.final_density_matrix = np.reshape(
440-
final_simulator_state.density_matrix.copy(), (size, size)
435+
params=params, measurements=measurements, final_step_result=final_step_result
441436
)
437+
self._final_density_matrix: Optional[np.ndarray] = None
438+
439+
@property
440+
def final_density_matrix(self):
441+
if self._final_density_matrix is None:
442+
size = np.prod(protocols.qid_shape(self), dtype=np.int64)
443+
self._final_density_matrix = np.reshape(
444+
self._final_simulator_state.density_matrix.copy(), (size, size)
445+
)
446+
return self._final_density_matrix
442447

443448
def _value_equality_values_(self) -> Any:
444449
measurements = {k: v.tolist() for k, v in sorted(self.measurements.items())}

cirq/sim/density_matrix_simulator_test.py

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Type
15-
from unittest import mock
1614
import itertools
1715
import random
16+
from typing import Type
17+
from unittest import mock
18+
1819
import numpy as np
1920
import pytest
2021
import sympy
@@ -998,61 +999,65 @@ def test_density_matrix_simulator_state_repr():
998999

9991000
def test_density_matrix_trial_result_eq():
10001001
q0 = cirq.LineQubit(0)
1001-
final_simulator_state = cirq.DensityMatrixSimulatorState(
1002+
final_step_result = mock.Mock(cirq.StepResult)
1003+
final_step_result._simulator_state.return_value = cirq.DensityMatrixSimulatorState(
10021004
density_matrix=np.ones((2, 2)) * 0.5, qubit_map={q0: 0}
10031005
)
10041006
eq = cirq.testing.EqualsTester()
10051007
eq.add_equality_group(
10061008
cirq.DensityMatrixTrialResult(
10071009
params=cirq.ParamResolver({}),
10081010
measurements={},
1009-
final_simulator_state=final_simulator_state,
1011+
final_step_result=final_step_result,
10101012
),
10111013
cirq.DensityMatrixTrialResult(
10121014
params=cirq.ParamResolver({}),
10131015
measurements={},
1014-
final_simulator_state=final_simulator_state,
1016+
final_step_result=final_step_result,
10151017
),
10161018
)
10171019
eq.add_equality_group(
10181020
cirq.DensityMatrixTrialResult(
10191021
params=cirq.ParamResolver({'s': 1}),
10201022
measurements={},
1021-
final_simulator_state=final_simulator_state,
1023+
final_step_result=final_step_result,
10221024
)
10231025
)
10241026
eq.add_equality_group(
10251027
cirq.DensityMatrixTrialResult(
10261028
params=cirq.ParamResolver({'s': 1}),
10271029
measurements={'m': np.array([[1]])},
1028-
final_simulator_state=final_simulator_state,
1030+
final_step_result=final_step_result,
10291031
)
10301032
)
10311033

10321034

10331035
def test_density_matrix_trial_result_qid_shape():
10341036
q0, q1 = cirq.LineQubit.range(2)
1037+
final_step_result = mock.Mock(cirq.StepResult)
1038+
final_step_result._simulator_state.return_value = cirq.DensityMatrixSimulatorState(
1039+
density_matrix=np.ones((4, 4)) / 4, qubit_map={q0: 0, q1: 1}
1040+
)
10351041
assert (
10361042
cirq.qid_shape(
10371043
cirq.DensityMatrixTrialResult(
10381044
params=cirq.ParamResolver({}),
10391045
measurements={},
1040-
final_simulator_state=cirq.DensityMatrixSimulatorState(
1041-
density_matrix=np.ones((4, 4)) / 4, qubit_map={q0: 0, q1: 1}
1042-
),
1046+
final_step_result=final_step_result,
10431047
),
10441048
)
10451049
== (2, 2)
10461050
)
10471051
q0, q1 = cirq.LineQid.for_qid_shape((3, 4))
1052+
final_step_result._simulator_state.return_value = cirq.DensityMatrixSimulatorState(
1053+
density_matrix=np.ones((12, 12)) / 12, qubit_map={q0: 0, q1: 1}
1054+
)
10481055
assert (
10491056
cirq.qid_shape(
10501057
cirq.DensityMatrixTrialResult(
10511058
params=cirq.ParamResolver({}),
10521059
measurements={},
1053-
final_simulator_state=cirq.DensityMatrixSimulatorState(
1054-
density_matrix=np.ones((12, 12)) / 12, qubit_map={q0: 0, q1: 1}
1055-
),
1060+
final_step_result=final_step_result,
10561061
),
10571062
)
10581063
== (3, 4)
@@ -1061,15 +1066,16 @@ def test_density_matrix_trial_result_qid_shape():
10611066

10621067
def test_density_matrix_trial_result_repr():
10631068
q0 = cirq.LineQubit(0)
1064-
final_simulator_state = cirq.DensityMatrixSimulatorState(
1069+
final_step_result = mock.Mock(cirq.StepResult)
1070+
final_step_result._simulator_state.return_value = cirq.DensityMatrixSimulatorState(
10651071
density_matrix=np.ones((2, 2)) * 0.5, qubit_map={q0: 0}
10661072
)
10671073
assert (
10681074
repr(
10691075
cirq.DensityMatrixTrialResult(
10701076
params=cirq.ParamResolver({'s': 1}),
10711077
measurements={'m': np.array([[1]])},
1072-
final_simulator_state=final_simulator_state,
1078+
final_step_result=final_step_result,
10731079
)
10741080
)
10751081
== "cirq.DensityMatrixTrialResult("
@@ -1166,11 +1172,12 @@ def test_works_on_pauli_string():
11661172

11671173
def test_density_matrix_trial_result_str():
11681174
q0 = cirq.LineQubit(0)
1169-
final_simulator_state = cirq.DensityMatrixSimulatorState(
1175+
final_step_result = mock.Mock(cirq.StepResult)
1176+
final_step_result._simulator_state.return_value = cirq.DensityMatrixSimulatorState(
11701177
density_matrix=np.ones((2, 2)) * 0.5, qubit_map={q0: 0}
11711178
)
11721179
result = cirq.DensityMatrixTrialResult(
1173-
params=cirq.ParamResolver({}), measurements={}, final_simulator_state=final_simulator_state
1180+
params=cirq.ParamResolver({}), measurements={}, final_step_result=final_step_result
11741181
)
11751182

11761183
# numpy varies whitespace in its representation for different versions
@@ -1527,3 +1534,26 @@ def test_density_matrices_same_with_or_without_split_untangled_states():
15271534
sim = cirq.DensityMatrixSimulator(split_untangled_states=True)
15281535
result2 = sim.simulate(circuit).final_density_matrix
15291536
assert np.allclose(result1, result2)
1537+
1538+
1539+
def test_large_untangled_okay():
1540+
circuit = cirq.Circuit()
1541+
for i in range(59):
1542+
for _ in range(9):
1543+
circuit.append(cirq.X(cirq.LineQubit(i)))
1544+
circuit.append(cirq.measure(cirq.LineQubit(i)))
1545+
1546+
# Validate this can't be allocated with entangled state
1547+
with pytest.raises(MemoryError, match='Unable to allocate'):
1548+
_ = cirq.DensityMatrixSimulator(split_untangled_states=False).simulate(circuit)
1549+
1550+
# Validate a simulation run
1551+
result = cirq.DensityMatrixSimulator(split_untangled_states=True).simulate(circuit)
1552+
assert set(result._final_step_result._qubits) == set(cirq.LineQubit.range(59))
1553+
# _ = result.final_density_matrix hangs (as expected)
1554+
1555+
# Validate a trial run and sampling
1556+
result = cirq.DensityMatrixSimulator(split_untangled_states=True).run(circuit, repetitions=1000)
1557+
assert len(result.measurements) == 59
1558+
assert len(result.measurements['0']) == 1000
1559+
assert (result.measurements['0'] == np.full(1000, 1)).all()

0 commit comments

Comments
 (0)