Skip to content

Commit 0ef302f

Browse files
authored
Changed simulators fallback to decompose_once and removed ancilla support from DensityMatrixSimulator (#6127)
* Fix bugs in strat_act_on_from_apply_decompose and improve support for qubit allocation within decompose * Revert unrelated mypy change * Fix mypy types and remove context argument from strat_act_on_from_apply_decompose * Fix mypy error * Update docstrings
1 parent ebc52d5 commit 0ef302f

4 files changed

+85
-156
lines changed

cirq-core/cirq/sim/density_matrix_simulation_state.py

-16
Original file line numberDiff line numberDiff line change
@@ -285,22 +285,6 @@ def __init__(
285285
)
286286
super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data)
287287

288-
def add_qubits(self, qubits: Sequence['cirq.Qid']):
289-
ret = super().add_qubits(qubits)
290-
return (
291-
self.kronecker_product(type(self)(qubits=qubits), inplace=True)
292-
if ret is NotImplemented
293-
else ret
294-
)
295-
296-
def remove_qubits(self, qubits: Sequence['cirq.Qid']):
297-
ret = super().remove_qubits(qubits)
298-
if ret is not NotImplemented:
299-
return ret
300-
extracted, remainder = self.factor(qubits)
301-
remainder._state._density_matrix *= extracted._state._density_matrix.reshape(-1)[0]
302-
return remainder
303-
304288
def _act_on_fallback_(
305289
self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
306290
) -> bool:

cirq-core/cirq/sim/density_matrix_simulation_state_test.py

-12
Original file line numberDiff line numberDiff line change
@@ -123,15 +123,3 @@ def test_initial_state_bad_shape():
123123
cirq.DensityMatrixSimulationState(
124124
qubits=qubits, initial_state=np.full((2, 2, 2, 2), 1 / 4), dtype=np.complex64
125125
)
126-
127-
128-
def test_remove_qubits():
129-
"""Test the remove_qubits method."""
130-
q1 = cirq.LineQubit(0)
131-
q2 = cirq.LineQubit(1)
132-
state = cirq.DensityMatrixSimulationState(qubits=[q1, q2])
133-
134-
new_state = state.remove_qubits([q1])
135-
136-
assert len(new_state.qubits) == 1
137-
assert q1 not in new_state.qubits

cirq-core/cirq/sim/simulation_state.py

+33-34
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
List,
2424
Optional,
2525
Sequence,
26+
Set,
2627
TypeVar,
2728
TYPE_CHECKING,
2829
Tuple,
@@ -31,8 +32,8 @@
3132

3233
import numpy as np
3334

34-
from cirq import protocols, value
35-
from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits
35+
from cirq import ops, protocols, value
36+
3637
from cirq.sim.simulation_state_base import SimulationStateBase
3738

3839
TState = TypeVar('TState', bound='cirq.QuantumStateRepresentation')
@@ -166,35 +167,35 @@ def create_merged_state(self) -> Self:
166167
"""Creates a final merged state."""
167168
return self
168169

169-
def add_qubits(self: Self, qubits: Sequence['cirq.Qid']):
170-
"""Add qubits to a new state space and take the kron product.
171-
172-
Note that only Density Matrix and State Vector simulators
173-
override this function.
170+
def add_qubits(self: Self, qubits: Sequence['cirq.Qid']) -> Self:
171+
"""Add `qubits` in the `|0>` state to a new state space and take the kron product.
174172
175173
Args:
176174
qubits: Sequence of qubits to be added.
177175
178176
Returns:
179177
NotImplemented: If the subclass does not implement this method.
180-
181-
Raises:
182-
ValueError: If a qubit being added is already tracked.
178+
Self: A `cirq.SimulationState` with qubits added or `self` if there are no qubits to
179+
add.
183180
"""
184-
if any(q in self.qubits for q in qubits):
185-
raise ValueError(f"Qubit to add {qubits} should not already be tracked.")
181+
if not qubits:
182+
return self
186183
return NotImplemented
187184

188185
def remove_qubits(self: Self, qubits: Sequence['cirq.Qid']) -> Self:
189-
"""Remove qubits from the state space.
186+
"""Remove `qubits` from the state space.
187+
188+
The qubits to be removed should be untangled from rest of the system and in the |0> state.
190189
191190
Args:
192-
qubits: Sequence of qubits to be added.
191+
qubits: Sequence of qubits to be removed.
193192
194193
Returns:
195-
A new Simulation State with qubits removed. Or
196-
`self` if there are no qubits to remove."""
197-
if qubits is None or not qubits:
194+
NotImplemented: If the subclass does not implement this method.
195+
Self: A `cirq.SimulationState` with qubits removed or `self` if there are no qubits to
196+
remove.
197+
"""
198+
if not qubits:
198199
return self
199200
return NotImplemented
200201

@@ -325,25 +326,23 @@ def can_represent_mixed_states(self) -> bool:
325326
def strat_act_on_from_apply_decompose(
326327
val: Any, args: 'cirq.SimulationState', qubits: Sequence['cirq.Qid']
327328
) -> bool:
328-
operations, qubits1, _ = _try_decompose_into_operations_and_qubits(val)
329-
if operations is None:
329+
if isinstance(val, ops.Gate):
330+
decomposed = protocols.decompose_once_with_qubits(val, qubits, flatten=False, default=None)
331+
else:
332+
decomposed = protocols.decompose_once(val, flatten=False, default=None)
333+
if decomposed is None:
330334
return NotImplemented
331-
assert len(qubits1) == len(qubits)
332-
all_qubits = frozenset([q for op in operations for q in op.qubits])
333-
qubit_map = dict(zip(all_qubits, all_qubits))
334-
qubit_map.update(dict(zip(qubits1, qubits)))
335-
new_ancilla = tuple(q for q in sorted(all_qubits.difference(qubits)) if q not in args.qubits)
336-
args = args.add_qubits(new_ancilla)
337-
if args is NotImplemented:
338-
return NotImplemented
339-
for operation in operations:
340-
operation = operation.with_qubits(*[qubit_map[q] for q in operation.qubits])
335+
all_ancilla: Set['cirq.Qid'] = set()
336+
for operation in ops.flatten_to_ops(decomposed):
337+
curr_ancilla = tuple(q for q in operation.qubits if q not in args.qubits)
338+
args = args.add_qubits(curr_ancilla)
339+
if args is NotImplemented:
340+
return NotImplemented
341+
all_ancilla.update(curr_ancilla)
341342
protocols.act_on(operation, args)
342-
args = args.remove_qubits(new_ancilla)
343-
if args is NotImplemented: # coverage: ignore
344-
raise TypeError( # coverage: ignore
345-
f"{type(args)} implements `add_qubits` but not `remove_qubits`." # coverage: ignore
346-
) # coverage: ignore
343+
args = args.remove_qubits(tuple(all_ancilla))
344+
if args is NotImplemented:
345+
raise TypeError(f"{type(args)} implements add_qubits but not remove_qubits.")
347346
return True
348347

349348

cirq-core/cirq/sim/simulation_state_test.py

+52-94
Original file line numberDiff line numberDiff line change
@@ -42,61 +42,26 @@ def _act_on_fallback_(
4242
) -> bool:
4343
return True
4444

45-
46-
class AncillaZ(cirq.Gate):
47-
def __init__(self, exponent=1):
48-
self._exponent = exponent
49-
50-
def num_qubits(self) -> int:
51-
return 1
52-
53-
def _decompose_(self, qubits):
54-
ancilla = cirq.NamedQubit('Ancilla')
55-
yield cirq.CX(qubits[0], ancilla)
56-
yield cirq.Z(ancilla) ** self._exponent
57-
yield cirq.CX(qubits[0], ancilla)
58-
59-
60-
class AncillaH(cirq.Gate):
61-
def __init__(self, exponent=1):
62-
self._exponent = exponent
63-
64-
def num_qubits(self) -> int:
65-
return 1
66-
67-
def _decompose_(self, qubits):
68-
ancilla = cirq.NamedQubit('Ancilla')
69-
yield cirq.H(ancilla) ** self._exponent
70-
yield cirq.CX(ancilla, qubits[0])
71-
yield cirq.H(ancilla) ** self._exponent
72-
73-
74-
class AncillaY(cirq.Gate):
75-
def __init__(self, exponent=1):
76-
self._exponent = exponent
77-
78-
def num_qubits(self) -> int:
79-
return 1
80-
81-
def _decompose_(self, qubits):
82-
ancilla = cirq.NamedQubit('Ancilla')
83-
yield cirq.Y(ancilla) ** self._exponent
84-
yield cirq.CX(ancilla, qubits[0])
85-
yield cirq.Y(ancilla) ** self._exponent
45+
def add_qubits(self, qubits):
46+
ret = super().add_qubits(qubits)
47+
return self if NotImplemented else ret
8648

8749

8850
class DelegatingAncillaZ(cirq.Gate):
89-
def __init__(self, exponent=1):
51+
def __init__(self, exponent=1, measure_ancilla: bool = False):
9052
self._exponent = exponent
53+
self._measure_ancilla = measure_ancilla
9154

9255
def num_qubits(self) -> int:
9356
return 1
9457

9558
def _decompose_(self, qubits):
9659
a = cirq.NamedQubit('a')
9760
yield cirq.CX(qubits[0], a)
98-
yield AncillaZ(self._exponent).on(a)
61+
yield PhaseUsingCleanAncilla(self._exponent).on(a)
9962
yield cirq.CX(qubits[0], a)
63+
if self._measure_ancilla:
64+
yield cirq.measure(a)
10065

10166

10267
class Composite(cirq.Gate):
@@ -115,12 +80,23 @@ def test_measurements():
11580

11681
def test_decompose():
11782
args = DummySimulationState()
118-
assert (
119-
simulation_state.strat_act_on_from_apply_decompose(Composite(), args, [cirq.LineQubit(0)])
120-
is NotImplemented
83+
assert simulation_state.strat_act_on_from_apply_decompose(
84+
Composite(), args, [cirq.LineQubit(0)]
12185
)
12286

12387

88+
def test_decompose_for_gate_allocating_qubits_raises():
89+
class Composite(cirq.testing.SingleQubitGate):
90+
def _decompose_(self, qubits):
91+
anc = cirq.NamedQubit("anc")
92+
yield cirq.CNOT(*qubits, anc)
93+
94+
args = DummySimulationState()
95+
96+
with pytest.raises(TypeError, match="add_qubits but not remove_qubits"):
97+
simulation_state.strat_act_on_from_apply_decompose(Composite(), args, [cirq.LineQubit(0)])
98+
99+
124100
def test_mapping():
125101
args = DummySimulationState()
126102
assert list(iter(args)) == cirq.LineQubit.range(2)
@@ -162,53 +138,35 @@ def test_field_getters():
162138
assert args.qubit_map == {q: i for i, q in enumerate(cirq.LineQubit.range(2))}
163139

164140

165-
@pytest.mark.parametrize('exp', [-3, -2, -1, 0, 1, 2, 3])
166-
def test_ancilla_z(exp):
167-
q = cirq.LineQubit(0)
168-
test_circuit = cirq.Circuit(AncillaZ(exp).on(q))
169-
170-
control_circuit = cirq.Circuit(cirq.ZPowGate(exponent=exp).on(q))
171-
172-
assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)
173-
174-
175-
@pytest.mark.parametrize('exp', [-3, -2, -1, 0, 1, 2, 3])
176-
def test_ancilla_y(exp):
141+
@pytest.mark.parametrize('exp', np.linspace(0, 2 * np.pi, 10))
142+
def test_delegating_gate_unitary(exp):
177143
q = cirq.LineQubit(0)
178-
test_circuit = cirq.Circuit(AncillaY(exp).on(q))
179-
180-
control_circuit = cirq.Circuit(cirq.Y(q))
181-
control_circuit.append(cirq.Y(q))
182-
control_circuit.append(cirq.XPowGate(exponent=exp).on(q))
183-
184-
assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)
185144

186-
187-
@pytest.mark.parametrize('exp', [-3, -2, -1, 0, 1, 2, 3])
188-
def test_borrowable_qubit(exp):
189-
q = cirq.LineQubit(0)
190145
test_circuit = cirq.Circuit()
191146
test_circuit.append(cirq.H(q))
192-
test_circuit.append(cirq.X(q))
193-
test_circuit.append(AncillaH(exp).on(q))
147+
test_circuit.append(DelegatingAncillaZ(exp).on(q))
194148

195149
control_circuit = cirq.Circuit(cirq.H(q))
150+
control_circuit.append(cirq.ZPowGate(exponent=exp).on(q))
196151

197-
assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)
152+
assert_test_circuit_for_dm_simulator(test_circuit, control_circuit)
153+
assert_test_circuit_for_sv_simulator(test_circuit, control_circuit)
198154

199155

200-
@pytest.mark.parametrize('exp', [-3, -2, -1, 0, 1, 2, 3])
201-
def test_delegating_gate_qubit(exp):
156+
@pytest.mark.parametrize('exp', np.linspace(0, 2 * np.pi, 10))
157+
def test_delegating_gate_channel(exp):
202158
q = cirq.LineQubit(0)
203159

204160
test_circuit = cirq.Circuit()
205161
test_circuit.append(cirq.H(q))
206-
test_circuit.append(DelegatingAncillaZ(exp).on(q))
162+
test_circuit.append(DelegatingAncillaZ(exp, True).on(q))
207163

208164
control_circuit = cirq.Circuit(cirq.H(q))
209165
control_circuit.append(cirq.ZPowGate(exponent=exp).on(q))
210166

211-
assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)
167+
with pytest.raises(TypeError, match="DensityMatrixSimulator doesn't support"):
168+
# TODO: This test should pass once we extend support to DensityMatrixSimulator.
169+
assert_test_circuit_for_dm_simulator(test_circuit, control_circuit)
212170

213171

214172
@pytest.mark.parametrize('num_ancilla', [1, 2, 3])
@@ -221,7 +179,8 @@ def test_phase_using_dirty_ancilla(num_ancilla: int):
221179
u.on(q, *anc), PhaseUsingDirtyAncilla(ancilla_bitsize=num_ancilla).on(q)
222180
)
223181
control_circuit = cirq.Circuit(u.on(q, *anc), cirq.Z(q))
224-
assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)
182+
assert_test_circuit_for_dm_simulator(test_circuit, control_circuit)
183+
assert_test_circuit_for_sv_simulator(test_circuit, control_circuit)
225184

226185

227186
@pytest.mark.parametrize('num_ancilla', [1, 2, 3])
@@ -233,25 +192,24 @@ def test_phase_using_clean_ancilla(num_ancilla: int, theta: float):
233192
u.on(q), PhaseUsingCleanAncilla(theta=theta, ancilla_bitsize=num_ancilla).on(q)
234193
)
235194
control_circuit = cirq.Circuit(u.on(q), cirq.ZPowGate(exponent=theta).on(q))
236-
assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)
237-
238-
239-
def test_add_qubits_raise_value_error(num_ancilla=1):
240-
q = cirq.LineQubit(0)
241-
args = cirq.StateVectorSimulationState(qubits=[q])
242-
243-
with pytest.raises(ValueError, match='should not already be tracked.'):
244-
args.add_qubits([q])
195+
assert_test_circuit_for_dm_simulator(test_circuit, control_circuit)
196+
assert_test_circuit_for_sv_simulator(test_circuit, control_circuit)
245197

246198

247-
def test_remove_qubits_not_implemented(num_ancilla=1):
248-
args = DummySimulationState()
249-
250-
assert args.remove_qubits([cirq.LineQubit(0)]) is NotImplemented
199+
def assert_test_circuit_for_dm_simulator(test_circuit, control_circuit) -> None:
200+
# Density Matrix Simulator: For unitary gates, this fallbacks to `cirq.apply_channel`
201+
# which recursively calls to `cirq.apply_unitary(decompose=True)`.
202+
for split_untangled_states in [True, False]:
203+
sim = cirq.DensityMatrixSimulator(split_untangled_states=split_untangled_states)
204+
control_sim = sim.simulate(control_circuit).final_density_matrix
205+
test_sim = sim.simulate(test_circuit).final_density_matrix
206+
assert np.allclose(test_sim, control_sim)
251207

252208

253-
def assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit) -> None:
254-
for test_simulator in ['cirq.final_state_vector', 'cirq.final_density_matrix']:
255-
test_sim = eval(test_simulator)(test_circuit)
256-
control_sim = eval(test_simulator)(control_circuit)
209+
def assert_test_circuit_for_sv_simulator(test_circuit, control_circuit) -> None:
210+
# State Vector Simulator.
211+
for split_untangled_states in [True, False]:
212+
sim = cirq.Simulator(split_untangled_states=split_untangled_states)
213+
control_sim = sim.simulate(control_circuit).final_state_vector
214+
test_sim = sim.simulate(test_circuit).final_state_vector
257215
assert np.allclose(test_sim, control_sim)

0 commit comments

Comments
 (0)