Skip to content

Commit 656ce9d

Browse files
Replace pure python loops with numpy where possible in channels.py. (#5839)
Boosts speed.
1 parent 0e62198 commit 656ce9d

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

Diff for: cirq-core/cirq/qis/channels.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,9 @@ def kraus_to_choi(kraus_operators: Sequence[np.ndarray]) -> np.ndarray:
5050
Choi matrix of the channel specified by kraus_operators.
5151
"""
5252
d = np.prod(kraus_operators[0].shape, dtype=np.int64)
53-
c = np.zeros((d, d), dtype=np.complex128)
54-
for k in kraus_operators:
55-
v = np.reshape(k, d)
56-
c += np.outer(v, v.conj())
57-
return c
53+
choi_rank = len(kraus_operators)
54+
k = np.reshape(kraus_operators, (choi_rank, d))
55+
return np.einsum('bi,bj->ij', k, k.conj())
5856

5957

6058
def choi_to_kraus(choi: np.ndarray, atol: float = 1e-10) -> Sequence[np.ndarray]:
@@ -105,7 +103,8 @@ def choi_to_kraus(choi: np.ndarray, atol: float = 1e-10) -> Sequence[np.ndarray]
105103

106104
w = np.maximum(w, 0)
107105
u = np.sqrt(w) * v
108-
return [k.reshape(d, d) for k in u.T if np.linalg.norm(k) > atol]
106+
keep = np.linalg.norm(u.T, axis=-1) > atol
107+
return [k.reshape(d, d) for k, keep_i in zip(u.T, keep) if keep_i]
109108

110109

111110
def kraus_to_superoperator(kraus_operators: Sequence[np.ndarray]) -> np.ndarray:
@@ -140,10 +139,9 @@ def kraus_to_superoperator(kraus_operators: Sequence[np.ndarray]) -> np.ndarray:
140139
Superoperator matrix of the channel specified by kraus_operators.
141140
"""
142141
d_out, d_in = kraus_operators[0].shape
143-
m = np.zeros((d_out * d_out, d_in * d_in), dtype=np.complex128)
144-
for k in kraus_operators:
145-
m += np.kron(k, k.conj())
146-
return m
142+
ops_arr = np.asarray(kraus_operators)
143+
m = np.einsum('bij,bkl->ikjl', ops_arr, ops_arr.conj())
144+
return m.reshape((d_out * d_out, d_in * d_in))
147145

148146

149147
def superoperator_to_kraus(superoperator: np.ndarray, atol: float = 1e-10) -> Sequence[np.ndarray]:

0 commit comments

Comments
 (0)