Skip to content

Commit c28646f

Browse files
daxfohlCirqBot
andauthored
Sample independent qubit sets without merging state space (#4110)
* split * Allow config param split_entangled_states * default split to off * ensure consistent_act_on circuits have a qubit. * lint * lint * mps * lint * lint * run sparse by default * fix tests * fix tests * fix tests * most of sparse and dm * clifford * sim_base * sim_base * mps * turn off on experiments with rounding error * fix tests * fix tests * fix testsCreate base step result * clifford * mps * mps * mps * tableau * test simulator * test simulator * Update simulator_base.py * Drop mps/join * Fix clifford extract * lint * simplify index * Add qubits to base class * Fix clifford sampling * Fix _sim_state_values * fix tostring tests, format * remove split/join from ch-form * remove split/join from ch-form * push merged state to base layer * lint * mypy * mypy * mypy * Add default arg for zero qubit circuits * Have last repetition reuse original state repr * Remove cast * Split all pure initial states by default * Detangle on reset channels * docstrings * docstrings * docstrings * docstrings * fix merge * lint * Add unit test for integer states * format * Add tests for splitting and joining * remove unnecessary qubits param * Clean up default args * Fix failing test * Add ActOnArgsContainer * Add ActOnArgsContainer * Clean up tests * Clean up tests * Clean up tests * format * Fix tests and coverage * Add OperationTarget interface * Fix unit tests * mypy, lint, mocks, coverage * coverage * lint, tests * lint, tests * mypy * mypy, tests * remove test code * test * dead code * mocks * add log to container * fix logs * dead code * unit test * unit test * dead code * operationtarget samples * StepResultBase * Mock, format * EmptyActOnArgs * EmptyActOnArgs * simplify dummyargs * lint * Add [] to actonargs * rename _create_act_on_arg * coverage * coverage * Default sparse sim to split=false * format * Default sparse sim to split=false * Default density matrix sim to split=false * lint * lint * lint * lint * address review comments * lint * Defaults back to split=false * add error if setting state when split is enabled * Unit tests * coverage * coverage * coverage * docs * conflicts * conflicts * cover * Add qubits to bb84 * mergedsimstate private * q_set * default to split=True * Allow set_state * Allow set_state * format * fix merge * fix merge * maintain order in sampling for determinicity. * Pydoc fixes * revert bb48 num_qubits change * fix docstrings for set_state error * Remove duplicate sample declaration from ActOnArgs * Remove unnecessary split_untangled_states=True * Reduce atol of dm/sv test * Add test for sim_state propagation from step_result * Add test for sim_state propagation from step_result Co-authored-by: Cirq Bot <[email protected]>
1 parent 5ca4ab6 commit c28646f

27 files changed

+347
-278
lines changed

cirq-core/cirq/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@
403403
StateVectorStepResult,
404404
StateVectorTrialResult,
405405
StepResult,
406+
StepResultBase,
406407
)
407408

408409
from cirq.study import (

cirq-core/cirq/contrib/quimb/mps_simulator.py

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,9 @@ def _create_partial_act_on_args(
117117

118118
def _create_step_result(
119119
self,
120-
sim_state: 'MPSState',
121-
qubit_map: Dict['cirq.Qid', int],
120+
sim_state: 'cirq.OperationTarget[MPSState]',
122121
):
123-
return MPSSimulatorStepResult(
124-
measurements=sim_state.log_of_measurement_results, state=sim_state
125-
)
122+
return MPSSimulatorStepResult(sim_state)
126123

127124
def _create_simulator_trial_result(
128125
self,
@@ -169,22 +166,22 @@ def __str__(self) -> str:
169166
return f'measurements: {samples}\noutput state: {final}'
170167

171168

172-
class MPSSimulatorStepResult(simulator.StepResult['MPSState']):
169+
class MPSSimulatorStepResult(simulator_base.StepResultBase['MPSState', 'MPSState']):
173170
"""A `StepResult` that can perform measurements."""
174171

175-
def __init__(self, state, measurements):
172+
def __init__(
173+
self,
174+
sim_state: 'cirq.OperationTarget[MPSState]',
175+
):
176176
"""Results of a step of the simulator.
177177
Attributes:
178-
state: A MPSState
179-
measurements: A dictionary from measurement gate key to measurement
180-
results, ordered by the qubits that the measurement operates on.
181-
qubit_map: A map from the Qubits in the Circuit to the the index
182-
of this qubit for a canonical ordering. This canonical ordering
183-
is used to define the state vector (see the state_vector()
184-
method).
178+
sim_state: The qubit:ActOnArgs lookup for this step.
185179
"""
186-
self.measurements = measurements
187-
self.state = state.copy()
180+
super().__init__(sim_state)
181+
182+
@property
183+
def state(self):
184+
return self._merged_sim_state
188185

189186
def __str__(self) -> str:
190187
def bitstring(vals):
@@ -204,24 +201,6 @@ def bitstring(vals):
204201
def _simulator_state(self):
205202
return self.state
206203

207-
def sample(
208-
self,
209-
qubits: List[ops.Qid],
210-
repetitions: int = 1,
211-
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
212-
) -> np.ndarray:
213-
214-
measurements: List[int] = []
215-
216-
for _ in range(repetitions):
217-
measurements.append(
218-
self.state.perform_measurement(
219-
qubits, value.parse_random_state(seed), collapse_state_vector=False
220-
)
221-
)
222-
223-
return np.array(measurements, dtype=int)
224-
225204

226205
@value.value_equality
227206
class MPSState(ActOnArgs):
@@ -537,3 +516,21 @@ def perform_measurement(
537516
def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]:
538517
"""Measures the axes specified by the simulator."""
539518
return self.perform_measurement(qubits, self.prng)
519+
520+
def sample(
521+
self,
522+
qubits: Sequence[ops.Qid],
523+
repetitions: int = 1,
524+
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
525+
) -> np.ndarray:
526+
527+
measurements: List[List[int]] = []
528+
529+
for _ in range(repetitions):
530+
measurements.append(
531+
self.perform_measurement(
532+
qubits, value.parse_random_state(seed), collapse_state_vector=False
533+
)
534+
)
535+
536+
return np.array(measurements, dtype=int)

cirq-core/cirq/contrib/quimb/mps_simulator_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,11 +274,11 @@ def test_trial_result_str():
274274

275275
def test_empty_step_result():
276276
q0 = cirq.LineQubit(0)
277-
state = ccq.mps_simulator.MPSState(qubits=(q0,), prng=value.parse_random_state(0))
278-
step_result = ccq.mps_simulator.MPSSimulatorStepResult(state, measurements={'0': [1]})
277+
sim = ccq.mps_simulator.MPSSimulator()
278+
step_result = next(sim.simulate_moment_steps(cirq.Circuit(cirq.measure(q0))))
279279
assert (
280280
str(step_result)
281-
== """0=1
281+
== """0=0
282282
TensorNetwork([
283283
Tensor(shape=(2,), inds=('i_0',), tags=set()),
284284
])"""

cirq-core/cirq/experiments/grid_parallel_two_qubit_xeb_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def test_estimate_parallel_two_qubit_xeb_fidelity_on_grid_no_noise(tmpdir):
2020
two_qubit_gate = cirq.ISWAP ** 0.5
2121
cycles = [5, 10, 15]
2222
data_collection_id = collect_grid_parallel_two_qubit_xeb_data(
23-
sampler=cirq.Simulator(seed=34310),
23+
sampler=cirq.Simulator(seed=34310, split_untangled_states=False),
2424
qubits=qubits,
2525
two_qubit_gate=two_qubit_gate,
2626
num_circuits=2,
@@ -53,7 +53,9 @@ def test_estimate_parallel_two_qubit_xeb_fidelity_on_grid_depolarizing(tmpdir):
5353
cycles = [5, 10, 15]
5454
e = 0.01
5555
data_collection_id = collect_grid_parallel_two_qubit_xeb_data(
56-
sampler=cirq.DensityMatrixSimulator(noise=cirq.depolarize(e), seed=65008),
56+
sampler=cirq.DensityMatrixSimulator(
57+
noise=cirq.depolarize(e), seed=65008, split_untangled_states=False
58+
),
5759
qubits=qubits,
5860
two_qubit_gate=two_qubit_gate,
5961
num_circuits=2,

cirq-core/cirq/experiments/single_qubit_readout_calibration_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(self, p0: float, p1: float, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE'
3131
self.p0 = p0
3232
self.p1 = p1
3333
self.prng = cirq.value.parse_random_state(seed)
34-
self.simulator = cirq.Simulator(seed=self.prng)
34+
self.simulator = cirq.Simulator(seed=self.prng, split_untangled_states=False)
3535

3636
def run_sweep(
3737
self,

cirq-core/cirq/protocols/act_on_protocol_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ def copy(self):
3737
def _act_on_fallback_(self, action, qubits, allow_decompose):
3838
return self.fallback_result
3939

40+
def sample(self, qubits, repetitions=1, seed=None):
41+
pass
42+
4043

4144
op = cirq.X(cirq.LineQubit(0))
4245

cirq-core/cirq/sim/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
)
6565

6666
from cirq.sim.simulator_base import (
67+
StepResultBase,
6768
SimulatorBase,
6869
)
6970

cirq-core/cirq/sim/act_on_args.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,7 @@ def __init__(
6060
axes: The indices of axes corresponding to the qubits that the
6161
operation is supposed to act upon.
6262
log_of_measurement_results: A mutable object that measurements are
63-
being recorded into. Edit it easily by calling
64-
`ActOnStateVectorArgs.record_measurement_result`.
63+
being recorded into.
6564
"""
6665
if prng is None:
6766
prng = cast(np.random.RandomState, np.random)
@@ -72,7 +71,7 @@ def __init__(
7271
if log_of_measurement_results is None:
7372
log_of_measurement_results = {}
7473
self._qubits = tuple(qubits)
75-
self.qubit_map = {q: i for i, q in enumerate(self.qubits)}
74+
self.qubit_map = {q: i for i, q in enumerate(qubits)}
7675
self._axes = tuple(axes)
7776
self.prng = prng
7877
self._log_of_measurement_results = log_of_measurement_results
@@ -89,9 +88,9 @@ def measure(self, qubits: Sequence['cirq.Qid'], key: str, invert_mask: Sequence[
8988
"""
9089
bits = self._perform_measurement(qubits)
9190
corrected = [bit ^ (bit < 2 and mask) for bit, mask in zip(bits, invert_mask)]
92-
if key in self.log_of_measurement_results:
91+
if key in self._log_of_measurement_results:
9392
raise ValueError(f"Measurement already logged to key {key!r}")
94-
self.log_of_measurement_results[key] = corrected
93+
self._log_of_measurement_results[key] = corrected
9594

9695
def get_axes(self, qubits: Sequence['cirq.Qid']) -> List[int]:
9796
return [self.qubit_map[q] for q in qubits]

cirq-core/cirq/sim/act_on_args_container.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,14 @@
2020
Sequence,
2121
Optional,
2222
Iterator,
23-
Tuple,
2423
Any,
24+
Tuple,
25+
Set,
26+
List,
2527
)
2628

29+
import numpy as np
30+
2731
from cirq import ops
2832
from cirq.sim.operation_target import OperationTarget
2933
from cirq.sim.simulator import (
@@ -122,6 +126,26 @@ def qubits(self) -> Tuple['cirq.Qid', ...]:
122126
def log_of_measurement_results(self) -> Dict[str, Any]:
123127
return self._log_of_measurement_results
124128

129+
def sample(
130+
self,
131+
qubits: List[ops.Qid],
132+
repetitions: int = 1,
133+
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
134+
) -> np.ndarray:
135+
columns = []
136+
selected_order: List[ops.Qid] = []
137+
q_set = set(qubits)
138+
for v in dict.fromkeys(self.args.values()):
139+
qs = [q for q in v.qubits if q in q_set]
140+
if any(qs):
141+
column = v.sample(qs, repetitions, seed)
142+
columns.append(column)
143+
selected_order += qs
144+
stacked = np.column_stack(columns)
145+
qubit_map = {q: i for i, q in enumerate(selected_order)}
146+
index_order = [qubit_map[q] for q in qubits]
147+
return stacked[:, index_order]
148+
125149
def __getitem__(self, item: Optional['cirq.Qid']) -> TActOnArgs:
126150
return self.args[item]
127151

cirq-core/cirq/sim/act_on_args_container_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ def transpose_to_qubit_order(self, qubits: Sequence['cirq.Qid']) -> 'EmptyActOnA
6464
logs=self.log_of_measurement_results,
6565
)
6666

67+
def sample(self, qubits, repetitions=1, seed=None):
68+
pass
69+
6770

6871
q0, q1 = qs2 = cirq.LineQubit.range(2)
6972

cirq-core/cirq/sim/act_on_args_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ def __init__(self):
2525
def copy(self):
2626
pass
2727

28+
def sample(self, qubits, repetitions=1, seed=None):
29+
pass
30+
2831
def _perform_measurement(self, qubits):
2932
return [5, 3]
3033

cirq-core/cirq/sim/act_on_density_matrix_args.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,7 @@ def __init__(
8282
prng: The pseudo random number generator to use for probabilistic
8383
effects.
8484
log_of_measurement_results: A mutable object that measurements are
85-
being recorded into. Edit it easily by calling
86-
`ActOnStateVectorArgs.record_measurement_result`.
85+
being recorded into.
8786
axes: The indices of axes corresponding to the qubits that the
8887
operation is supposed to act upon.
8988
"""
@@ -197,6 +196,21 @@ def transpose_to_qubit_order(
197196
log_of_measurement_results=self.log_of_measurement_results,
198197
)
199198

199+
def sample(
200+
self,
201+
qubits: Sequence['cirq.Qid'],
202+
repetitions: int = 1,
203+
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
204+
) -> np.ndarray:
205+
indices = [self.qubit_map[q] for q in qubits]
206+
return sim.sample_density_matrix(
207+
self.target_tensor,
208+
indices,
209+
qid_shape=tuple(q.dimension for q in self.qubits),
210+
repetitions=repetitions,
211+
seed=seed,
212+
)
213+
200214

201215
def _strat_apply_channel_to_state(
202216
action: Any, args: ActOnDensityMatrixArgs, qubits: Sequence['cirq.Qid']

cirq-core/cirq/sim/act_on_state_vector_args.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,12 @@ def _rewrite_deprecated_args(args, kwargs):
4141
class ActOnStateVectorArgs(ActOnArgs):
4242
"""State and context for an operation acting on a state vector.
4343
44-
There are three common ways to act on this object:
44+
There are two common ways to act on this object:
4545
4646
1. Directly edit the `target_tensor` property, which is storing the state
4747
vector of the quantum system as a numpy array with one axis per qudit.
4848
2. Overwrite the `available_buffer` property with the new state vector, and
4949
then pass `available_buffer` into `swap_target_tensor_for`.
50-
3. Call `record_measurement_result(key, val)` to log a measurement result.
5150
"""
5251

5352
@deprecated_parameter(
@@ -84,8 +83,7 @@ def __init__(
8483
prng: The pseudo random number generator to use for probabilistic
8584
effects.
8685
log_of_measurement_results: A mutable object that measurements are
87-
being recorded into. Edit it easily by calling
88-
`ActOnStateVectorArgs.record_measurement_result`.
86+
being recorded into.
8987
axes: The indices of axes corresponding to the qubits that the
9088
operation is supposed to act upon.
9189
"""
@@ -255,6 +253,21 @@ def transpose_to_qubit_order(self, qubits: Sequence['cirq.Qid']) -> 'cirq.ActOnS
255253
)
256254
return new_args
257255

256+
def sample(
257+
self,
258+
qubits: Sequence['cirq.Qid'],
259+
repetitions: int = 1,
260+
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
261+
) -> np.ndarray:
262+
indices = [self.qubit_map[q] for q in qubits]
263+
return sim.sample_state_vector(
264+
self.target_tensor,
265+
indices,
266+
qid_shape=tuple(q.dimension for q in self.qubits),
267+
repetitions=repetitions,
268+
seed=seed,
269+
)
270+
258271

259272
def _strat_act_on_state_vector_from_apply_unitary(
260273
unitary_value: Any,

cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,9 @@ def _rewrite_deprecated_args(args, kwargs):
4444

4545
class ActOnCliffordTableauArgs(ActOnArgs):
4646
"""State and context for an operation acting on a clifford tableau.
47-
There are two common ways to act on this object:
48-
1. Directly edit the `tableau` property, which is storing the clifford
49-
tableau of the quantum system with one axis per qubit.
50-
2. Call `record_measurement_result(key, val)` to log a measurement result.
47+
48+
To act on this object, directly edit the `tableau` property, which is
49+
storing the density matrix of the quantum system with one axis per qubit.
5150
"""
5251

5352
@deprecated_parameter(
@@ -77,8 +76,7 @@ def __init__(
7776
prng: The pseudo random number generator to use for probabilistic
7877
effects.
7978
log_of_measurement_results: A mutable object that measurements are
80-
being recorded into. Edit it easily by calling
81-
`ActOnCliffordTableauArgs.record_measurement_result`.
79+
being recorded into.
8280
axes: The indices of axes corresponding to the qubits that the
8381
operation is supposed to act upon.
8482
"""
@@ -111,6 +109,15 @@ def copy(self) -> 'cirq.ActOnCliffordTableauArgs':
111109
log_of_measurement_results=self.log_of_measurement_results.copy(),
112110
)
113111

112+
def sample(
113+
self,
114+
qubits: Sequence['cirq.Qid'],
115+
repetitions: int = 1,
116+
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
117+
) -> np.ndarray:
118+
# Unnecessary for now but can be added later if there is a use case.
119+
raise NotImplementedError()
120+
114121

115122
def _strat_act_on_clifford_tableau_from_single_qubit_decompose(
116123
val: Any, args: 'cirq.ActOnCliffordTableauArgs', qubits: Sequence['cirq.Qid']

0 commit comments

Comments
 (0)