12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
"""Objects and methods for acting efficiently on a state tensor."""
15
- import abc
16
15
import copy
17
16
import inspect
17
+ import warnings
18
18
from typing import (
19
19
Any ,
20
20
cast ,
28
28
TYPE_CHECKING ,
29
29
Tuple ,
30
30
)
31
- import warnings
32
31
33
32
import numpy as np
34
33
@@ -59,6 +58,7 @@ def __init__(
59
58
log_of_measurement_results : Optional [Dict [str , List [int ]]] = None ,
60
59
ignore_measurement_results : bool = False ,
61
60
classical_data : Optional ['cirq.ClassicalDataStore' ] = None ,
61
+ state : Optional ['cirq.QuantumStateRepresentation' ] = None ,
62
62
):
63
63
"""Inits ActOnArgs.
64
64
@@ -76,6 +76,7 @@ def __init__(
76
76
simulators that can represent mixed states.
77
77
classical_data: The shared classical data container for this
78
78
simulation.
79
+ state: The underlying quantum state of the simulation.
79
80
"""
80
81
if prng is None :
81
82
prng = cast (np .random .RandomState , np .random )
@@ -90,6 +91,7 @@ def __init__(
90
91
}
91
92
)
92
93
self ._ignore_measurement_results = ignore_measurement_results
94
+ self ._state = state
93
95
94
96
@property
95
97
def prng (self ) -> np .random .RandomState :
@@ -148,10 +150,21 @@ def measure(self, qubits: Sequence['cirq.Qid'], key: str, invert_mask: Sequence[
148
150
def get_axes (self , qubits : Sequence ['cirq.Qid' ]) -> List [int ]:
149
151
return [self .qubit_map [q ] for q in qubits ]
150
152
151
- @abc .abstractmethod
152
153
def _perform_measurement (self , qubits : Sequence ['cirq.Qid' ]) -> List [int ]:
153
- """Child classes that perform measurements should implement this with
154
- the implementation."""
154
+ """Delegates the call to measure the density matrix."""
155
+ if self ._state is not None :
156
+ return self ._state .measure (self .get_axes (qubits ), self .prng )
157
+ raise NotImplementedError ()
158
+
159
+ def sample (
160
+ self ,
161
+ qubits : Sequence ['cirq.Qid' ],
162
+ repetitions : int = 1 ,
163
+ seed : 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None ,
164
+ ) -> np .ndarray :
165
+ if self ._state is not None :
166
+ return self ._state .sample (self .get_axes (qubits ), repetitions , seed )
167
+ raise NotImplementedError ()
155
168
156
169
def copy (self : TSelf , deep_copy_buffers : bool = True ) -> TSelf :
157
170
"""Creates a copy of the object.
@@ -165,6 +178,10 @@ def copy(self: TSelf, deep_copy_buffers: bool = True) -> TSelf:
165
178
A copied instance.
166
179
"""
167
180
args = copy .copy (self )
181
+ args ._classical_data = self ._classical_data .copy ()
182
+ if self ._state is not None :
183
+ args ._state = self ._state .copy (deep_copy_buffers = deep_copy_buffers )
184
+ return args
168
185
if 'deep_copy_buffers' in inspect .signature (self ._on_copy ).parameters :
169
186
self ._on_copy (args , deep_copy_buffers )
170
187
else :
@@ -176,7 +193,6 @@ def copy(self: TSelf, deep_copy_buffers: bool = True) -> TSelf:
176
193
DeprecationWarning ,
177
194
)
178
195
self ._on_copy (args )
179
- args ._classical_data = self ._classical_data .copy ()
180
196
return args
181
197
182
198
def _on_copy (self : TSelf , args : TSelf , deep_copy_buffers : bool = True ):
@@ -190,7 +206,10 @@ def create_merged_state(self: TSelf) -> TSelf:
190
206
def kronecker_product (self : TSelf , other : TSelf , * , inplace = False ) -> TSelf :
191
207
"""Joins two state spaces together."""
192
208
args = self if inplace else copy .copy (self )
193
- self ._on_kronecker_product (other , args )
209
+ if self ._state is not None and other ._state is not None :
210
+ args ._state = self ._state .kron (other ._state )
211
+ else :
212
+ self ._on_kronecker_product (other , args )
194
213
args ._set_qubits (self .qubits + other .qubits )
195
214
return args
196
215
@@ -225,15 +244,20 @@ def factor(
225
244
"""Splits two state spaces after a measurement or reset."""
226
245
extracted = copy .copy (self )
227
246
remainder = self if inplace else copy .copy (self )
228
- self ._on_factor (qubits , extracted , remainder , validate , atol )
247
+ if self ._state is not None :
248
+ e , r = self ._state .factor (self .get_axes (qubits ), validate = validate , atol = atol )
249
+ extracted ._state = e
250
+ remainder ._state = r
251
+ else :
252
+ self ._on_factor (qubits , extracted , remainder , validate , atol )
229
253
extracted ._set_qubits (qubits )
230
254
remainder ._set_qubits ([q for q in self .qubits if q not in qubits ])
231
255
return extracted , remainder
232
256
233
257
@property
234
258
def allows_factoring (self ):
235
259
"""Subclasses that allow factorization should override this."""
236
- return False
260
+ return self . _state . supports_factor if self . _state is not None else False
237
261
238
262
def _on_factor (
239
263
self : TSelf ,
@@ -265,7 +289,10 @@ def transpose_to_qubit_order(
265
289
if len (self .qubits ) != len (qubits ) or set (qubits ) != set (self .qubits ):
266
290
raise ValueError (f'Qubits do not match. Existing: { self .qubits } , provided: { qubits } ' )
267
291
args = self if inplace else copy .copy (self )
268
- self ._on_transpose_to_qubit_order (qubits , args )
292
+ if self ._state is not None :
293
+ args ._state = self ._state .reindex (self .get_axes (qubits ))
294
+ else :
295
+ self ._on_transpose_to_qubit_order (qubits , args )
269
296
args ._set_qubits (qubits )
270
297
return args
271
298
@@ -356,7 +383,7 @@ def __iter__(self) -> Iterator[Optional['cirq.Qid']]:
356
383
357
384
@property
358
385
def can_represent_mixed_states (self ) -> bool :
359
- return False
386
+ return self . _state . can_represent_mixed_states if self . _state is not None else False
360
387
361
388
362
389
def strat_act_on_from_apply_decompose (
0 commit comments