Skip to content

Commit 956f963

Browse files
maffootonybruguier
authored andcommitted
Allow specifying initial state vector in DensityMatrixSimulator (quantumlib#5223)
This changes how ActOnDensityMatrixArgs is constructed to allow specifying the initial state as a state vector or state tensor, or as a density matrix or density tensor. Some of this could perhaps be moved into `cirq.to_valid_density_matrix` if people think that is a better place. Currently `to_valid_density_matrix` only handles 1D state vectors or 2D density matrices, not 2x2x..2 tensors in either case, but if we have the qid_shape we can tell handle these unambiguously. Fixes quantumlib#3958
1 parent 78caf84 commit 956f963

File tree

4 files changed

+85
-5
lines changed

4 files changed

+85
-5
lines changed

cirq-core/cirq/qis/states.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -968,9 +968,13 @@ def to_valid_density_matrix(
968968
ValueError if the density_matrix_rep is not valid.
969969
"""
970970
qid_shape = _qid_shape_from_args(num_qubits, qid_shape)
971-
if isinstance(density_matrix_rep, np.ndarray) and density_matrix_rep.ndim == 2:
972-
validate_density_matrix(density_matrix_rep, qid_shape=qid_shape, dtype=dtype, atol=atol)
973-
return density_matrix_rep
971+
if isinstance(density_matrix_rep, np.ndarray):
972+
N = np.prod(qid_shape, dtype=np.int64)
973+
if len(qid_shape) > 1 and density_matrix_rep.shape == qid_shape * 2:
974+
density_matrix_rep = density_matrix_rep.reshape((N, N))
975+
if density_matrix_rep.shape == (N, N):
976+
validate_density_matrix(density_matrix_rep, qid_shape=qid_shape, dtype=dtype, atol=atol)
977+
return density_matrix_rep
974978

975979
state_vector = to_valid_state_vector(
976980
density_matrix_rep, len(qid_shape), qid_shape=qid_shape, dtype=dtype

cirq-core/cirq/qis/states_test.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -607,14 +607,29 @@ def test_to_valid_density_matrix_from_density_matrix():
607607
assert_valid_density_matrix(np.diag([0.2, 0.8, 0, 0]), qid_shape=(4,))
608608

609609

610+
def test_to_valid_density_matrix_from_density_matrix_tensor():
611+
np.testing.assert_almost_equal(
612+
cirq.to_valid_density_matrix(
613+
cirq.one_hot(shape=(2, 2, 2, 2, 2, 2), dtype=np.complex64), num_qubits=3
614+
),
615+
cirq.one_hot(shape=(8, 8), dtype=np.complex64),
616+
)
617+
np.testing.assert_almost_equal(
618+
cirq.to_valid_density_matrix(
619+
cirq.one_hot(shape=(2, 3, 4, 2, 3, 4), dtype=np.complex64), qid_shape=(2, 3, 4)
620+
),
621+
cirq.one_hot(shape=(24, 24), dtype=np.complex64),
622+
)
623+
624+
610625
def test_to_valid_density_matrix_not_square():
611626
with pytest.raises(ValueError, match='shape'):
612627
cirq.to_valid_density_matrix(np.array([[1], [0]]), num_qubits=1)
613628

614629

615630
def test_to_valid_density_matrix_size_mismatch_num_qubits():
616631
with pytest.raises(ValueError, match='shape'):
617-
cirq.to_valid_density_matrix(np.array([[1, 0], [0, 0]]), num_qubits=2)
632+
cirq.to_valid_density_matrix(np.array([[[1, 0], [0, 0]], [[0, 0], [0, 0]]]), num_qubits=2)
618633
with pytest.raises(ValueError, match='shape'):
619634
cirq.to_valid_density_matrix(np.eye(4) / 4.0, num_qubits=1)
620635

@@ -690,6 +705,16 @@ def test_to_valid_density_matrix_from_state_vector():
690705
)
691706

692707

708+
def test_to_valid_density_matrix_from_state_vector_tensor():
709+
np.testing.assert_almost_equal(
710+
cirq.to_valid_density_matrix(
711+
density_matrix_rep=np.array(np.full((2, 2), 0.5), dtype=np.complex64),
712+
num_qubits=2,
713+
),
714+
0.25 * np.ones((4, 4)),
715+
)
716+
717+
693718
def test_to_valid_density_matrix_from_state_invalid_state():
694719
with pytest.raises(ValueError, match="Invalid quantum state"):
695720
cirq.to_valid_density_matrix(np.array([1, 0, 0]), num_qubits=2)

cirq-core/cirq/sim/act_on_density_matrix_args.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,11 @@ def create(
7979
).reshape(qid_shape * 2)
8080
else:
8181
if qid_shape is not None:
82-
density_matrix = initial_state.reshape(qid_shape * 2)
82+
if dtype and initial_state.dtype != dtype:
83+
initial_state = initial_state.astype(dtype)
84+
density_matrix = qis.to_valid_density_matrix(
85+
initial_state, len(qid_shape), qid_shape=qid_shape, dtype=dtype
86+
).reshape(qid_shape * 2)
8387
else:
8488
density_matrix = initial_state
8589
if np.may_share_memory(density_matrix, initial_state):

cirq-core/cirq/sim/act_on_density_matrix_args_test.py

+47
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,50 @@ def test_with_qubits():
9898
def test_qid_shape_error():
9999
with pytest.raises(ValueError, match="qid_shape must be provided"):
100100
cirq.sim.act_on_density_matrix_args._BufferedDensityMatrix.create(initial_state=0)
101+
102+
103+
def test_initial_state_vector():
104+
qubits = cirq.LineQubit.range(3)
105+
args = cirq.ActOnDensityMatrixArgs(
106+
qubits=qubits, initial_state=np.full((8,), 1 / np.sqrt(8)), dtype=np.complex64
107+
)
108+
assert args.target_tensor.shape == (2, 2, 2, 2, 2, 2)
109+
110+
args2 = cirq.ActOnDensityMatrixArgs(
111+
qubits=qubits, initial_state=np.full((2, 2, 2), 1 / np.sqrt(8)), dtype=np.complex64
112+
)
113+
assert args2.target_tensor.shape == (2, 2, 2, 2, 2, 2)
114+
115+
116+
def test_initial_state_matrix():
117+
qubits = cirq.LineQubit.range(3)
118+
args = cirq.ActOnDensityMatrixArgs(
119+
qubits=qubits, initial_state=np.full((8, 8), 1 / 8), dtype=np.complex64
120+
)
121+
assert args.target_tensor.shape == (2, 2, 2, 2, 2, 2)
122+
123+
args2 = cirq.ActOnDensityMatrixArgs(
124+
qubits=qubits, initial_state=np.full((2, 2, 2, 2, 2, 2), 1 / 8), dtype=np.complex64
125+
)
126+
assert args2.target_tensor.shape == (2, 2, 2, 2, 2, 2)
127+
128+
129+
def test_initial_state_bad_shape():
130+
qubits = cirq.LineQubit.range(3)
131+
with pytest.raises(ValueError, match="Invalid quantum state"):
132+
cirq.ActOnDensityMatrixArgs(
133+
qubits=qubits, initial_state=np.full((4,), 1 / 2), dtype=np.complex64
134+
)
135+
with pytest.raises(ValueError, match="Invalid quantum state"):
136+
cirq.ActOnDensityMatrixArgs(
137+
qubits=qubits, initial_state=np.full((2, 2), 1 / 2), dtype=np.complex64
138+
)
139+
140+
with pytest.raises(ValueError, match="Invalid quantum state"):
141+
cirq.ActOnDensityMatrixArgs(
142+
qubits=qubits, initial_state=np.full((4, 4), 1 / 4), dtype=np.complex64
143+
)
144+
with pytest.raises(ValueError, match="Invalid quantum state"):
145+
cirq.ActOnDensityMatrixArgs(
146+
qubits=qubits, initial_state=np.full((2, 2, 2, 2), 1 / 4), dtype=np.complex64
147+
)

0 commit comments

Comments
 (0)