Skip to content

Commit 3f325af

Browse files
committed
Allow specifying initial state vector in DensityMatrixSimulator
This changes how ActOnDensityMatrixArgs is constructed to allow specifying the initial state as a state vector or state tensor, or as a density matrix or density tensor. Fixes #3958
1 parent 51a2b6b commit 3f325af

File tree

2 files changed

+68
-1
lines changed

2 files changed

+68
-1
lines changed

Diff for: cirq-core/cirq/sim/act_on_density_matrix_args.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
"""Objects and methods for acting efficiently on a density matrix."""
1515

16+
import math
1617
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Type, Union
1718

1819
import numpy as np
@@ -79,7 +80,26 @@ def create(
7980
).reshape(qid_shape * 2)
8081
else:
8182
if qid_shape is not None:
82-
density_matrix = initial_state.reshape(qid_shape * 2)
83+
qid_size = math.prod(qid_shape)
84+
shape = initial_state.shape
85+
if shape == qid_shape or shape == (qid_size,):
86+
if len(shape) != 1:
87+
initial_state = initial_state.reshape((qid_size,))
88+
elif shape == qid_shape * 2 or shape == (qid_size, qid_size):
89+
if len(shape) != 2:
90+
initial_state = initial_state.reshape((qid_size, qid_size))
91+
if dtype and initial_state.dtype != dtype:
92+
# Convert type because to_valid_density_matrix does not convert dtype.
93+
initial_state = initial_state.astype(dtype)
94+
else:
95+
raise ValueError(
96+
f'Invalid initial state. Expected state vector of shape {(qid_size,)} '
97+
f'or density matrix of shape {(qid_size, qid_size)}; '
98+
f'got {initial_state.shape}.'
99+
)
100+
density_matrix = qis.to_valid_density_matrix(
101+
initial_state, len(qid_shape), qid_shape=qid_shape, dtype=dtype
102+
).reshape(qid_shape * 2)
83103
else:
84104
density_matrix = initial_state
85105
if np.may_share_memory(density_matrix, initial_state):

Diff for: cirq-core/cirq/sim/act_on_density_matrix_args_test.py

+47
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,50 @@ def test_with_qubits():
105105
def test_qid_shape_error():
106106
with pytest.raises(ValueError, match="qid_shape must be provided"):
107107
cirq.sim.act_on_density_matrix_args._BufferedDensityMatrix.create(initial_state=0)
108+
109+
110+
def test_initial_state_vector():
111+
qubits = cirq.LineQubit.range(3)
112+
args = cirq.ActOnDensityMatrixArgs(
113+
qubits=qubits, initial_state=np.full((8,), 1 / np.sqrt(8)), dtype=np.complex64
114+
)
115+
assert args.target_tensor.shape == (2, 2, 2, 2, 2, 2)
116+
117+
args2 = cirq.ActOnDensityMatrixArgs(
118+
qubits=qubits, initial_state=np.full((2, 2, 2), 1 / np.sqrt(8)), dtype=np.complex64
119+
)
120+
assert args2.target_tensor.shape == (2, 2, 2, 2, 2, 2)
121+
122+
123+
def test_initial_state_matrix():
124+
qubits = cirq.LineQubit.range(3)
125+
args = cirq.ActOnDensityMatrixArgs(
126+
qubits=qubits, initial_state=np.full((8, 8), 1 / 8), dtype=np.complex64
127+
)
128+
assert args.target_tensor.shape == (2, 2, 2, 2, 2, 2)
129+
130+
args2 = cirq.ActOnDensityMatrixArgs(
131+
qubits=qubits, initial_state=np.full((2, 2, 2, 2, 2, 2), 1 / 8), dtype=np.complex64
132+
)
133+
assert args2.target_tensor.shape == (2, 2, 2, 2, 2, 2)
134+
135+
136+
def test_initial_state_bad_shape():
137+
qubits = cirq.LineQubit.range(3)
138+
with pytest.raises(ValueError, match="Invalid initial state."):
139+
cirq.ActOnDensityMatrixArgs(
140+
qubits=qubits, initial_state=np.full((4,), 1 / 2), dtype=np.complex64
141+
)
142+
with pytest.raises(ValueError, match="Invalid initial state."):
143+
cirq.ActOnDensityMatrixArgs(
144+
qubits=qubits, initial_state=np.full((2, 2), 1 / 2), dtype=np.complex64
145+
)
146+
147+
with pytest.raises(ValueError, match="Invalid initial state."):
148+
cirq.ActOnDensityMatrixArgs(
149+
qubits=qubits, initial_state=np.full((4, 4), 1 / 4), dtype=np.complex64
150+
)
151+
with pytest.raises(ValueError, match="Invalid initial state."):
152+
cirq.ActOnDensityMatrixArgs(
153+
qubits=qubits, initial_state=np.full((2, 2, 2, 2), 1 / 4), dtype=np.complex64
154+
)

0 commit comments

Comments
 (0)