Skip to content

Commit 9a38c2d

Browse files
authored
Projectors for TensorFlow quantum (#4331)
TensorFlow Quantum would like to be able to specify projectors in a memory efficient way. This PR attempts to provide such a facility.
1 parent 2467e39 commit 9a38c2d

File tree

7 files changed

+438
-0
lines changed

7 files changed

+438
-0
lines changed

cirq-core/cirq/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@
250250
PhasedXPowGate,
251251
PhasedXZGate,
252252
PhaseFlipChannel,
253+
ProjectorString,
253254
RandomGateChannel,
254255
qft,
255256
Qid,

cirq-core/cirq/json_resolver_cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def two_qubit_matrix_gate(matrix):
117117
'PhasedISwapPowGate': cirq.PhasedISwapPowGate,
118118
'PhasedXPowGate': cirq.PhasedXPowGate,
119119
'PhasedXZGate': cirq.PhasedXZGate,
120+
'ProjectorString': cirq.ProjectorString,
120121
'RandomGateChannel': cirq.RandomGateChannel,
121122
'QuantumFourierTransformGate': cirq.QuantumFourierTransformGate,
122123
'RepetitionsStoppingCriteria': cirq.work.RepetitionsStoppingCriteria,

cirq-core/cirq/ops/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,10 @@
131131
ParallelGateOperation,
132132
)
133133

134+
from cirq.ops.projector import (
135+
ProjectorString,
136+
)
137+
134138
from cirq.ops.controlled_operation import (
135139
ControlledOperation,
136140
)

cirq-core/cirq/ops/projector.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import itertools
2+
from typing import (
3+
Any,
4+
Dict,
5+
Iterable,
6+
List,
7+
Mapping,
8+
Optional,
9+
Union,
10+
)
11+
12+
import numpy as np
13+
from scipy.sparse import csr_matrix
14+
15+
from cirq import value
16+
from cirq.ops import raw_types
17+
18+
19+
def _check_qids_dimension(qids):
20+
"""A utility to check that we only have Qubits."""
21+
for qid in qids:
22+
if qid.dimension != 2:
23+
raise ValueError(f"Only qubits are supported, but {qid} has dimension {qid.dimension}")
24+
25+
26+
@value.value_equality
27+
class ProjectorString:
28+
def __init__(
29+
self,
30+
projector_dict: Dict[raw_types.Qid, int],
31+
coefficient: Union[int, float, complex] = 1,
32+
):
33+
"""Contructor for ProjectorString
34+
35+
Args:
36+
projector_dict: A python dictionary mapping from cirq.Qid to integers. A key value pair
37+
represents the desired computational basis state for that qubit.
38+
coefficient: Initial scalar coefficient. Defaults to 1.
39+
"""
40+
_check_qids_dimension(projector_dict.keys())
41+
self._projector_dict = projector_dict
42+
self._coefficient = complex(coefficient)
43+
44+
@property
45+
def projector_dict(self) -> Dict[raw_types.Qid, int]:
46+
return self._projector_dict
47+
48+
@property
49+
def coefficient(self) -> complex:
50+
return self._coefficient
51+
52+
def matrix(self, projector_qids: Optional[Iterable[raw_types.Qid]] = None) -> csr_matrix:
53+
"""Returns the matrix of self in computational basis of qubits.
54+
55+
Args:
56+
projector_qids: Ordered collection of qubits that determine the subspace
57+
in which the matrix representation of the ProjectorString is to
58+
be computed. Qbits absent from self.qubits are acted on by
59+
the identity. Defaults to the qubits of the projector_dict.
60+
61+
Returns:
62+
A sparse matrix that is the projection in the specified basis.
63+
"""
64+
projector_qids = self._projector_dict.keys() if projector_qids is None else projector_qids
65+
_check_qids_dimension(projector_qids)
66+
idx_to_keep = [
67+
[self._projector_dict[qid]] if qid in self._projector_dict else [0, 1]
68+
for qid in projector_qids
69+
]
70+
71+
total_d = np.prod([qid.dimension for qid in projector_qids])
72+
73+
ones_idx = []
74+
for idx in itertools.product(*idx_to_keep):
75+
d = total_d
76+
kron_idx = 0
77+
for i, qid in zip(idx, projector_qids):
78+
d //= qid.dimension
79+
kron_idx += i * d
80+
ones_idx.append(kron_idx)
81+
82+
return csr_matrix(
83+
([self._coefficient] * len(ones_idx), (ones_idx, ones_idx)), shape=(total_d, total_d)
84+
)
85+
86+
def _get_idx_to_keep(self, qid_map: Mapping[raw_types.Qid, int]):
87+
num_qubits = len(qid_map)
88+
idx_to_keep: List[Any] = [slice(0, 2)] * num_qubits
89+
for q in self.projector_dict.keys():
90+
idx_to_keep[qid_map[q]] = self.projector_dict[q]
91+
return tuple(idx_to_keep)
92+
93+
def expectation_from_state_vector(
94+
self,
95+
state_vector: np.ndarray,
96+
qid_map: Mapping[raw_types.Qid, int],
97+
) -> complex:
98+
"""Expectation of the projection from a state vector.
99+
100+
Computes the expectation value of this ProjectorString on the provided state vector.
101+
102+
Args:
103+
state_vector: An array representing a valid state vector.
104+
qubit_map: A map from all qubits used in this ProjectorString to the
105+
indices of the qubits that `state_vector` is defined over.
106+
Returns:
107+
The expectation value of the input state.
108+
"""
109+
_check_qids_dimension(qid_map.keys())
110+
num_qubits = len(qid_map)
111+
index = self._get_idx_to_keep(qid_map)
112+
return self._coefficient * np.sum(
113+
np.abs(np.reshape(state_vector, (2,) * num_qubits)[index]) ** 2
114+
)
115+
116+
def expectation_from_density_matrix(
117+
self,
118+
state: np.ndarray,
119+
qid_map: Mapping[raw_types.Qid, int],
120+
) -> complex:
121+
"""Expectation of the projection from a density matrix.
122+
123+
Computes the expectation value of this ProjectorString on the provided state.
124+
125+
Args:
126+
state: An array representing a valid density matrix.
127+
qubit_map: A map from all qubits used in this ProjectorString to the
128+
indices of the qubits that `state_vector` is defined over.
129+
Returns:
130+
The expectation value of the input state.
131+
"""
132+
_check_qids_dimension(qid_map.keys())
133+
num_qubits = len(qid_map)
134+
index = self._get_idx_to_keep(qid_map) * 2
135+
result = np.reshape(state, (2,) * (2 * num_qubits))[index]
136+
while any(result.shape):
137+
result = np.trace(result, axis1=0, axis2=len(result.shape) // 2)
138+
return self._coefficient * result
139+
140+
def __repr__(self) -> str:
141+
return (
142+
f"cirq.ProjectorString(projector_dict={self._projector_dict},"
143+
+ f"coefficient={self._coefficient})"
144+
)
145+
146+
def _json_dict_(self) -> Dict[str, Any]:
147+
return {
148+
'cirq_type': self.__class__.__name__,
149+
'projector_dict': list(self._projector_dict.items()),
150+
'coefficient': self._coefficient,
151+
}
152+
153+
@classmethod
154+
def _from_json_dict_(cls, projector_dict, coefficient, **kwargs):
155+
return cls(projector_dict=dict(projector_dict), coefficient=coefficient)
156+
157+
def _value_equality_values_(self) -> Any:
158+
projector_dict = sorted(self._projector_dict.items())
159+
return (tuple(projector_dict), self._coefficient)

0 commit comments

Comments
 (0)