Skip to content

Commit 56b5db2

Browse files
authored
add back add/remove_qubit to density matrix sim state (#6259)
* wip: add back add/remove_qubit to density matrix sim state * factorize inplace while removing qubit in density matrix simulation state
1 parent ed26d2f commit 56b5db2

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

cirq-core/cirq/sim/density_matrix_simulation_state.py

+16
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,22 @@ def __init__(
285285
)
286286
super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data)
287287

288+
def add_qubits(self, qubits: Sequence['cirq.Qid']):
289+
ret = super().add_qubits(qubits)
290+
return (
291+
self.kronecker_product(type(self)(qubits=qubits), inplace=True)
292+
if ret is NotImplemented
293+
else ret
294+
)
295+
296+
def remove_qubits(self, qubits: Sequence['cirq.Qid']):
297+
ret = super().remove_qubits(qubits)
298+
if ret is not NotImplemented:
299+
return ret
300+
extracted, remainder = self.factor(qubits, inplace=True)
301+
remainder._state._density_matrix *= extracted._state._density_matrix.reshape(-1)[0]
302+
return remainder
303+
288304
def _act_on_fallback_(
289305
self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
290306
) -> bool:

cirq-core/cirq/sim/simulation_state_test.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,8 @@ def test_delegating_gate_channel(exp):
164164
control_circuit = cirq.Circuit(cirq.H(q))
165165
control_circuit.append(cirq.ZPowGate(exponent=exp).on(q))
166166

167-
with pytest.raises(TypeError, match="DensityMatrixSimulator doesn't support"):
168-
# TODO: This test should pass once we extend support to DensityMatrixSimulator.
169-
assert_test_circuit_for_dm_simulator(test_circuit, control_circuit)
167+
assert_test_circuit_for_sv_simulator(test_circuit, control_circuit)
168+
assert_test_circuit_for_dm_simulator(test_circuit, control_circuit)
170169

171170

172171
@pytest.mark.parametrize('num_ancilla', [1, 2, 3])

0 commit comments

Comments
 (0)