@@ -105,3 +105,50 @@ def test_with_qubits():
105
105
def test_qid_shape_error ():
106
106
with pytest .raises (ValueError , match = "qid_shape must be provided" ):
107
107
cirq .sim .act_on_density_matrix_args ._BufferedDensityMatrix .create (initial_state = 0 )
108
+
109
+
110
+ def test_initial_state_vector ():
111
+ qubits = cirq .LineQubit .range (3 )
112
+ args = cirq .ActOnDensityMatrixArgs (
113
+ qubits = qubits , initial_state = np .full ((8 ,), 1 / np .sqrt (8 )), dtype = np .complex64
114
+ )
115
+ assert args .target_tensor .shape == (2 , 2 , 2 , 2 , 2 , 2 )
116
+
117
+ args2 = cirq .ActOnDensityMatrixArgs (
118
+ qubits = qubits , initial_state = np .full ((2 , 2 , 2 ), 1 / np .sqrt (8 )), dtype = np .complex64
119
+ )
120
+ assert args2 .target_tensor .shape == (2 , 2 , 2 , 2 , 2 , 2 )
121
+
122
+
123
+ def test_initial_state_matrix ():
124
+ qubits = cirq .LineQubit .range (3 )
125
+ args = cirq .ActOnDensityMatrixArgs (
126
+ qubits = qubits , initial_state = np .full ((8 , 8 ), 1 / 8 ), dtype = np .complex64
127
+ )
128
+ assert args .target_tensor .shape == (2 , 2 , 2 , 2 , 2 , 2 )
129
+
130
+ args2 = cirq .ActOnDensityMatrixArgs (
131
+ qubits = qubits , initial_state = np .full ((2 , 2 , 2 , 2 , 2 , 2 ), 1 / 8 ), dtype = np .complex64
132
+ )
133
+ assert args2 .target_tensor .shape == (2 , 2 , 2 , 2 , 2 , 2 )
134
+
135
+
136
+ def test_initial_state_bad_shape ():
137
+ qubits = cirq .LineQubit .range (3 )
138
+ with pytest .raises (ValueError , match = "Invalid initial state." ):
139
+ cirq .ActOnDensityMatrixArgs (
140
+ qubits = qubits , initial_state = np .full ((4 ,), 1 / 2 ), dtype = np .complex64
141
+ )
142
+ with pytest .raises (ValueError , match = "Invalid initial state." ):
143
+ cirq .ActOnDensityMatrixArgs (
144
+ qubits = qubits , initial_state = np .full ((2 , 2 ), 1 / 2 ), dtype = np .complex64
145
+ )
146
+
147
+ with pytest .raises (ValueError , match = "Invalid initial state." ):
148
+ cirq .ActOnDensityMatrixArgs (
149
+ qubits = qubits , initial_state = np .full ((4 , 4 ), 1 / 4 ), dtype = np .complex64
150
+ )
151
+ with pytest .raises (ValueError , match = "Invalid initial state." ):
152
+ cirq .ActOnDensityMatrixArgs (
153
+ qubits = qubits , initial_state = np .full ((2 , 2 , 2 , 2 ), 1 / 4 ), dtype = np .complex64
154
+ )
0 commit comments