Skip to content

Commit 6d79f4d

Browse files
Add is_cptp predicate (quantumlib#4365)
As requested in quantumlib#4194. Can be used for quantumlib#2271. This predicate is meant to be invoked when constructing a channel to verify that the provided Kraus operators actually describe a valid quantum channel. Recommendations for cleaner `is_cptp` behavior or additional test cases are welcome.
1 parent 9ee5533 commit 6d79f4d

File tree

4 files changed

+64
-0
lines changed

4 files changed

+64
-0
lines changed

cirq/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@
130130
dot,
131131
expand_matrix_in_orthogonal_basis,
132132
hilbert_schmidt_inner_product,
133+
is_cptp,
133134
is_diagonal,
134135
is_hermitian,
135136
is_normal,

cirq/linalg/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161

6262
from cirq.linalg.predicates import (
6363
allclose_up_to_global_phase,
64+
is_cptp,
6465
is_diagonal,
6566
is_hermitian,
6667
is_normal,

cirq/linalg/predicates.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,21 @@ def is_normal(matrix: np.ndarray, *, rtol: float = 1e-5, atol: float = 1e-8) ->
149149
return matrix_commutes(matrix, matrix.T.conj(), rtol=rtol, atol=atol)
150150

151151

152+
def is_cptp(*, kraus_ops: Sequence[np.ndarray], rtol: float = 1e-5, atol: float = 1e-8):
153+
"""Determines if a channel is completely positive trace preserving (CPTP).
154+
155+
A channel composed of Kraus operators K[0:n] is a CPTP map if the sum of
156+
the products `adjoint(K[i]) * K[i])` is equal to 1.
157+
158+
Args:
159+
kraus_ops: The Kraus operators of the channel to check.
160+
rtol: The relative tolerance on equality.
161+
atol: The absolute tolerance on equality.
162+
"""
163+
sum_ndarray = cast(np.ndarray, sum(matrix.T.conj() @ matrix for matrix in kraus_ops))
164+
return np.allclose(sum_ndarray, np.eye(*sum_ndarray.shape), rtol=rtol, atol=atol)
165+
166+
152167
def matrix_commutes(
153168
m1: np.ndarray, m2: np.ndarray, *, rtol: float = 1e-5, atol: float = 1e-8
154169
) -> bool:

cirq/linalg/predicates_test.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,53 @@ def test_is_normal_tolerance():
292292
assert not cirq.is_normal(np.array([[0, 0.5, 0], [0, 0, 0.6], [0, 0, 0]]), atol=atol)
293293

294294

295+
def test_is_cptp():
296+
rt2 = np.sqrt(0.5)
297+
# Amplitude damping with gamma=0.5.
298+
assert cirq.is_cptp(kraus_ops=[np.array([[1, 0], [0, rt2]]), np.array([[0, rt2], [0, 0]])])
299+
# Depolarizing channel with p=0.75.
300+
assert cirq.is_cptp(
301+
kraus_ops=[
302+
np.array([[1, 0], [0, 1]]) * 0.5,
303+
np.array([[0, 1], [1, 0]]) * 0.5,
304+
np.array([[0, -1j], [1j, 0]]) * 0.5,
305+
np.array([[1, 0], [0, -1]]) * 0.5,
306+
]
307+
)
308+
309+
assert not cirq.is_cptp(kraus_ops=[np.array([[1, 0], [0, 1]]), np.array([[0, 1], [0, 0]])])
310+
assert not cirq.is_cptp(
311+
kraus_ops=[
312+
np.array([[1, 0], [0, 1]]),
313+
np.array([[0, 1], [1, 0]]),
314+
np.array([[0, -1j], [1j, 0]]),
315+
np.array([[1, 0], [0, -1]]),
316+
]
317+
)
318+
319+
# Makes 4 2x2 kraus ops.
320+
one_qubit_u = cirq.testing.random_unitary(8)
321+
one_qubit_kraus = np.reshape(one_qubit_u[:, :2], (-1, 2, 2))
322+
assert cirq.is_cptp(kraus_ops=one_qubit_kraus)
323+
324+
# Makes 16 4x4 kraus ops.
325+
two_qubit_u = cirq.testing.random_unitary(64)
326+
two_qubit_kraus = np.reshape(two_qubit_u[:, :4], (-1, 4, 4))
327+
assert cirq.is_cptp(kraus_ops=two_qubit_kraus)
328+
329+
330+
def test_is_cptp_tolerance():
331+
rt2_ish = np.sqrt(0.5) - 0.01
332+
atol = 0.25
333+
# Moderately-incorrect amplitude damping with gamma=0.5.
334+
assert cirq.is_cptp(
335+
kraus_ops=[np.array([[1, 0], [0, rt2_ish]]), np.array([[0, rt2_ish], [0, 0]])], atol=atol
336+
)
337+
assert not cirq.is_cptp(
338+
kraus_ops=[np.array([[1, 0], [0, rt2_ish]]), np.array([[0, rt2_ish], [0, 0]])], atol=1e-8
339+
)
340+
341+
295342
def test_commutes():
296343
assert matrix_commutes(np.empty((0, 0)), np.empty((0, 0)))
297344
assert not matrix_commutes(np.empty((1, 0)), np.empty((0, 1)))

0 commit comments

Comments
 (0)