Skip to content

Commit 32a21a3

Browse files
authored
Base class for quantum states (#5065)
Creates a base class for all the quantum state classes created in #4979, and uses the inheritance to push the implementation of `ActOn<State>Args.kron`, `factor`, etc into the base class. Closes #4827 Resolves #3841 (comment) that's been bugging me for a year.
1 parent 7fe9671 commit 32a21a3

13 files changed

+235
-244
lines changed

cirq-core/cirq/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@
433433
operation_to_superoperator,
434434
QUANTUM_STATE_LIKE,
435435
QuantumState,
436+
QuantumStateRepresentation,
436437
quantum_state,
437438
STATE_VECTOR_LIKE,
438439
StabilizerState,

cirq-core/cirq/contrib/quimb/mps_simulator.py

+14-26
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import numpy as np
2525
import quimb.tensor as qtn
2626

27-
from cirq import devices, protocols, value
27+
from cirq import devices, protocols, qis, value
2828
from cirq._compat import deprecated
2929
from cirq.sim import simulator_base
3030
from cirq.sim.act_on_args import ActOnArgs
@@ -220,7 +220,7 @@ def _simulator_state(self):
220220

221221

222222
@value.value_equality
223-
class _MPSHandler:
223+
class _MPSHandler(qis.QuantumStateRepresentation):
224224
"""Quantum state of the MPS simulation."""
225225

226226
def __init__(
@@ -604,21 +604,24 @@ def __init__(
604604
Raises:
605605
ValueError: If the grouping does not cover the qubits.
606606
"""
607+
qubit_map = {q: i for i, q in enumerate(qubits)}
608+
final_grouping = qubit_map if grouping is None else grouping
609+
if final_grouping.keys() != qubit_map.keys():
610+
raise ValueError('Grouping must cover exactly the qubits.')
611+
state = _MPSHandler.create(
612+
initial_state=initial_state,
613+
qid_shape=tuple(q.dimension for q in qubits),
614+
simulation_options=simulation_options,
615+
grouping={qubit_map[k]: v for k, v in final_grouping.items()},
616+
)
607617
super().__init__(
618+
state=state,
608619
prng=prng,
609620
qubits=qubits,
610621
log_of_measurement_results=log_of_measurement_results,
611622
classical_data=classical_data,
612623
)
613-
final_grouping = self.qubit_map if grouping is None else grouping
614-
if final_grouping.keys() != self.qubit_map.keys():
615-
raise ValueError('Grouping must cover exactly the qubits.')
616-
self._state = _MPSHandler.create(
617-
initial_state=initial_state,
618-
qid_shape=tuple(q.dimension for q in qubits),
619-
simulation_options=simulation_options,
620-
grouping={self.qubit_map[k]: v for k, v in final_grouping.items()},
621-
)
624+
self._state: _MPSHandler = state
622625

623626
def i_str(self, i: int) -> str:
624627
# Returns the index name for the i'th qid.
@@ -636,9 +639,6 @@ def __str__(self) -> str:
636639
def _value_equality_values_(self) -> Any:
637640
return self.qubits, self._state
638641

639-
def _on_copy(self, target: 'MPSState', deep_copy_buffers: bool = True):
640-
target._state = self._state.copy(deep_copy_buffers)
641-
642642
def state_vector(self) -> np.ndarray:
643643
"""Returns the full state vector.
644644
@@ -709,15 +709,3 @@ def perform_measurement(
709709
tolerance specified in simulation options.
710710
"""
711711
return self._state._measure(self.get_axes(qubits), prng, collapse_state_vector)
712-
713-
def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]:
714-
"""Measures the axes specified by the simulator."""
715-
return self._state.measure(self.get_axes(qubits), self.prng)
716-
717-
def sample(
718-
self,
719-
qubits: Sequence['cirq.Qid'],
720-
repetitions: int = 1,
721-
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
722-
) -> np.ndarray:
723-
return self._state.sample(self.get_axes(qubits), repetitions, seed)

cirq-core/cirq/qis/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
superoperator_to_kraus,
2626
)
2727

28-
from cirq.qis.clifford_tableau import CliffordTableau, StabilizerState
28+
from cirq.qis.clifford_tableau import CliffordTableau, QuantumStateRepresentation, StabilizerState
2929

3030
from cirq.qis.measures import (
3131
entanglement_fidelity,

cirq-core/cirq/qis/clifford_tableau.py

+89-4
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,97 @@
1313
# limitations under the License.
1414

1515
import abc
16-
from typing import Any, Dict, List, TYPE_CHECKING
16+
from typing import Any, Dict, List, Sequence, Tuple, TYPE_CHECKING, TypeVar
1717
import numpy as np
1818

19-
from cirq import protocols
19+
from cirq import protocols, value
2020
from cirq.value import big_endian_int_to_digits, linear_dict
2121

2222
if TYPE_CHECKING:
2323
import cirq
2424

25+
TSelf = TypeVar('TSelf', bound='QuantumStateRepresentation')
2526

26-
class StabilizerState(metaclass=abc.ABCMeta):
27+
28+
class QuantumStateRepresentation(metaclass=abc.ABCMeta):
29+
@abc.abstractmethod
30+
def copy(self: TSelf, deep_copy_buffers: bool = True) -> TSelf:
31+
"""Creates a copy of the object.
32+
Args:
33+
deep_copy_buffers: If True, buffers will also be deep-copied.
34+
Otherwise the copy will share a reference to the original object's
35+
buffers.
36+
Returns:
37+
A copied instance.
38+
"""
39+
40+
@abc.abstractmethod
41+
def measure(
42+
self, axes: Sequence[int], seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None
43+
) -> List[int]:
44+
"""Measures the state.
45+
46+
Args:
47+
axes: The axes to measure.
48+
seed: The random number seed to use.
49+
Returns:
50+
The measurements in order.
51+
"""
52+
53+
def sample(
54+
self,
55+
axes: Sequence[int],
56+
repetitions: int = 1,
57+
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
58+
) -> np.ndarray:
59+
"""Samples the state. Subclasses can override with more performant method.
60+
61+
Args:
62+
axes: The axes to sample.
63+
repetitions: The number of samples to make.
64+
seed: The random number seed to use.
65+
Returns:
66+
The samples in order.
67+
"""
68+
prng = value.parse_random_state(seed)
69+
measurements = []
70+
for _ in range(repetitions):
71+
state = self.copy()
72+
measurements.append(state.measure(axes, prng))
73+
return np.array(measurements, dtype=bool)
74+
75+
def kron(self: TSelf, other: TSelf) -> TSelf:
76+
"""Joins two state spaces together."""
77+
raise NotImplementedError()
78+
79+
def factor(
80+
self: TSelf, axes: Sequence[int], *, validate=True, atol=1e-07
81+
) -> Tuple[TSelf, TSelf]:
82+
"""Splits two state spaces after a measurement or reset."""
83+
raise NotImplementedError()
84+
85+
def reindex(self: TSelf, axes: Sequence[int]) -> TSelf:
86+
"""Physically reindexes the state by the new basis.
87+
Args:
88+
axes: The desired axis order.
89+
Returns:
90+
The state with qubit order transposed and underlying representation
91+
updated.
92+
"""
93+
raise NotImplementedError()
94+
95+
@property
96+
def supports_factor(self) -> bool:
97+
"""Subclasses that allow factorization should override this."""
98+
return False
99+
100+
@property
101+
def can_represent_mixed_states(self) -> bool:
102+
"""Subclasses that can represent mixed states should override this."""
103+
return False
104+
105+
106+
class StabilizerState(QuantumStateRepresentation, metaclass=abc.ABCMeta):
27107
"""Interface for quantum stabilizer state representations.
28108
29109
This interface is used for CliffordTableau and StabilizerChForm quantum
@@ -222,7 +302,7 @@ def __eq__(self, other):
222302
def __copy__(self) -> 'CliffordTableau':
223303
return self.copy()
224304

225-
def copy(self) -> 'CliffordTableau':
305+
def copy(self, deep_copy_buffers: bool = True) -> 'CliffordTableau':
226306
state = CliffordTableau(self.n)
227307
state.rs = self.rs.copy()
228308
state.xs = self.xs.copy()
@@ -578,3 +658,8 @@ def apply_cx(
578658

579659
def apply_global_phase(self, coefficient: linear_dict.Scalar):
580660
pass
661+
662+
def measure(
663+
self, axes: Sequence[int], seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None
664+
) -> List[int]:
665+
return [self._measure(axis, seed) for axis in axes]

cirq-core/cirq/sim/act_on_args.py

+38-11
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Objects and methods for acting efficiently on a state tensor."""
15-
import abc
1615
import copy
1716
import inspect
17+
import warnings
1818
from typing import (
1919
Any,
2020
cast,
@@ -28,7 +28,6 @@
2828
TYPE_CHECKING,
2929
Tuple,
3030
)
31-
import warnings
3231

3332
import numpy as np
3433

@@ -59,6 +58,7 @@ def __init__(
5958
log_of_measurement_results: Optional[Dict[str, List[int]]] = None,
6059
ignore_measurement_results: bool = False,
6160
classical_data: Optional['cirq.ClassicalDataStore'] = None,
61+
state: Optional['cirq.QuantumStateRepresentation'] = None,
6262
):
6363
"""Inits ActOnArgs.
6464
@@ -76,6 +76,7 @@ def __init__(
7676
simulators that can represent mixed states.
7777
classical_data: The shared classical data container for this
7878
simulation.
79+
state: The underlying quantum state of the simulation.
7980
"""
8081
if prng is None:
8182
prng = cast(np.random.RandomState, np.random)
@@ -90,6 +91,7 @@ def __init__(
9091
}
9192
)
9293
self._ignore_measurement_results = ignore_measurement_results
94+
self._state = state
9395

9496
@property
9597
def prng(self) -> np.random.RandomState:
@@ -148,10 +150,21 @@ def measure(self, qubits: Sequence['cirq.Qid'], key: str, invert_mask: Sequence[
148150
def get_axes(self, qubits: Sequence['cirq.Qid']) -> List[int]:
149151
return [self.qubit_map[q] for q in qubits]
150152

151-
@abc.abstractmethod
152153
def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]:
153-
"""Child classes that perform measurements should implement this with
154-
the implementation."""
154+
"""Delegates the call to measure the density matrix."""
155+
if self._state is not None:
156+
return self._state.measure(self.get_axes(qubits), self.prng)
157+
raise NotImplementedError()
158+
159+
def sample(
160+
self,
161+
qubits: Sequence['cirq.Qid'],
162+
repetitions: int = 1,
163+
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
164+
) -> np.ndarray:
165+
if self._state is not None:
166+
return self._state.sample(self.get_axes(qubits), repetitions, seed)
167+
raise NotImplementedError()
155168

156169
def copy(self: TSelf, deep_copy_buffers: bool = True) -> TSelf:
157170
"""Creates a copy of the object.
@@ -165,6 +178,10 @@ def copy(self: TSelf, deep_copy_buffers: bool = True) -> TSelf:
165178
A copied instance.
166179
"""
167180
args = copy.copy(self)
181+
args._classical_data = self._classical_data.copy()
182+
if self._state is not None:
183+
args._state = self._state.copy(deep_copy_buffers=deep_copy_buffers)
184+
return args
168185
if 'deep_copy_buffers' in inspect.signature(self._on_copy).parameters:
169186
self._on_copy(args, deep_copy_buffers)
170187
else:
@@ -176,7 +193,6 @@ def copy(self: TSelf, deep_copy_buffers: bool = True) -> TSelf:
176193
DeprecationWarning,
177194
)
178195
self._on_copy(args)
179-
args._classical_data = self._classical_data.copy()
180196
return args
181197

182198
def _on_copy(self: TSelf, args: TSelf, deep_copy_buffers: bool = True):
@@ -190,7 +206,10 @@ def create_merged_state(self: TSelf) -> TSelf:
190206
def kronecker_product(self: TSelf, other: TSelf, *, inplace=False) -> TSelf:
191207
"""Joins two state spaces together."""
192208
args = self if inplace else copy.copy(self)
193-
self._on_kronecker_product(other, args)
209+
if self._state is not None and other._state is not None:
210+
args._state = self._state.kron(other._state)
211+
else:
212+
self._on_kronecker_product(other, args)
194213
args._set_qubits(self.qubits + other.qubits)
195214
return args
196215

@@ -225,15 +244,20 @@ def factor(
225244
"""Splits two state spaces after a measurement or reset."""
226245
extracted = copy.copy(self)
227246
remainder = self if inplace else copy.copy(self)
228-
self._on_factor(qubits, extracted, remainder, validate, atol)
247+
if self._state is not None:
248+
e, r = self._state.factor(self.get_axes(qubits), validate=validate, atol=atol)
249+
extracted._state = e
250+
remainder._state = r
251+
else:
252+
self._on_factor(qubits, extracted, remainder, validate, atol)
229253
extracted._set_qubits(qubits)
230254
remainder._set_qubits([q for q in self.qubits if q not in qubits])
231255
return extracted, remainder
232256

233257
@property
234258
def allows_factoring(self):
235259
"""Subclasses that allow factorization should override this."""
236-
return False
260+
return self._state.supports_factor if self._state is not None else False
237261

238262
def _on_factor(
239263
self: TSelf,
@@ -265,7 +289,10 @@ def transpose_to_qubit_order(
265289
if len(self.qubits) != len(qubits) or set(qubits) != set(self.qubits):
266290
raise ValueError(f'Qubits do not match. Existing: {self.qubits}, provided: {qubits}')
267291
args = self if inplace else copy.copy(self)
268-
self._on_transpose_to_qubit_order(qubits, args)
292+
if self._state is not None:
293+
args._state = self._state.reindex(self.get_axes(qubits))
294+
else:
295+
self._on_transpose_to_qubit_order(qubits, args)
269296
args._set_qubits(qubits)
270297
return args
271298

@@ -356,7 +383,7 @@ def __iter__(self) -> Iterator[Optional['cirq.Qid']]:
356383

357384
@property
358385
def can_represent_mixed_states(self) -> bool:
359-
return False
386+
return self._state.can_represent_mixed_states if self._state is not None else False
360387

361388

362389
def strat_act_on_from_apply_decompose(

cirq-core/cirq/sim/act_on_args_container_test.py

-15
Original file line numberDiff line numberDiff line change
@@ -41,25 +41,10 @@ def _act_on_fallback_(
4141
) -> bool:
4242
return True
4343

44-
def _on_copy(self, args):
45-
pass
46-
47-
def _on_kronecker_product(self, other, target):
48-
pass
49-
50-
def _on_transpose_to_qubit_order(self, qubits, target):
51-
pass
52-
53-
def _on_factor(self, qubits, extracted, remainder, validate=True, atol=1e-07):
54-
pass
55-
5644
@property
5745
def allows_factoring(self):
5846
return True
5947

60-
def sample(self, qubits, repetitions=1, seed=None):
61-
pass
62-
6348

6449
q0, q1, q2 = qs3 = cirq.LineQubit.range(3)
6550
qs2 = cirq.LineQubit.range(2)

0 commit comments

Comments
 (0)